mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-14 17:47:57 -05:00
Compare commits
63 Commits
dev
...
feat/backf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06b07604b4 | ||
|
|
9f0c8c06c5 | ||
|
|
3ba374286c | ||
|
|
f4da46cb57 | ||
|
|
10e385612e | ||
|
|
0db134fdd9 | ||
|
|
461bf25bc1 | ||
|
|
f45ef091e2 | ||
|
|
83f46d373d | ||
|
|
07153d5536 | ||
|
|
f3c747027b | ||
|
|
764e1026e5 | ||
|
|
0890ce00b5 | ||
|
|
7f952900ae | ||
|
|
dc5da41703 | ||
|
|
1f3a9d0922 | ||
|
|
c5c1d8d605 | ||
|
|
9ae54e2975 | ||
|
|
8063bb4503 | ||
|
|
2b28023266 | ||
|
|
1b8d8e3772 | ||
|
|
34eb6bdca1 | ||
|
|
44610bb778 | ||
|
|
9afa8a739b | ||
|
|
a76fa0f0a9 | ||
|
|
b0b556e24e | ||
|
|
60ba50431d | ||
|
|
4b8332a14f | ||
|
|
7097cedc1d | ||
|
|
5a60618c2d | ||
|
|
547c6f93d4 | ||
|
|
6dbd45eaf0 | ||
|
|
ca398f3cc5 | ||
|
|
16a14ca09e | ||
|
|
704b8a9207 | ||
|
|
1a5abcc36a | ||
|
|
419b966db1 | ||
|
|
9b8d917d99 | ||
|
|
6432d35db2 | ||
|
|
7d46a5c1dc | ||
|
|
a63370bc30 | ||
|
|
6a86f2e3ea | ||
|
|
679c7806f2 | ||
|
|
5c7391fcd7 | ||
|
|
faf9ad9b57 | ||
|
|
f5899acac0 | ||
|
|
72783dcc02 | ||
|
|
af13badf8f | ||
|
|
b491610ebf | ||
|
|
0b022073eb | ||
|
|
01eef83809 | ||
|
|
4644c09b9e | ||
|
|
374860ff2c | ||
|
|
e7e09ef4e1 | ||
|
|
5e691661a8 | ||
|
|
b0e8c17419 | ||
|
|
5a7c1e39dd | ||
|
|
53b03e746a | ||
|
|
5aaf07fbaf | ||
|
|
0d2996e501 | ||
|
|
9e37a66bca | ||
|
|
429a074848 | ||
|
|
7f1245dc42 |
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -176,7 +176,7 @@ jobs:
|
||||
}
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
run: poetry run prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
1
autogpt_platform/backend/.gitignore
vendored
1
autogpt_platform/backend/.gitignore
vendored
@@ -18,3 +18,4 @@ load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
migrations/*/rollback*.sql
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
@@ -17,6 +18,17 @@ setup_test_data = setup_test_data
|
||||
setup_firecrawl_test_data = setup_firecrawl_test_data
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_run_agent(setup_test_data):
|
||||
"""Test that the run_agent tool successfully executes an approved agent"""
|
||||
|
||||
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
Content Type Handlers for Unified Embeddings
|
||||
|
||||
Pluggable system for different content sources (store agents, blocks, docs).
|
||||
Each handler knows how to fetch and process its content type for embedding.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentItem:
|
||||
"""Represents a piece of content to be embedded."""
|
||||
|
||||
content_id: str # Unique identifier (DB ID or file path)
|
||||
content_type: ContentType
|
||||
searchable_text: str # Combined text for embedding
|
||||
metadata: dict[str, Any] # Content-specific metadata
|
||||
user_id: str | None = None # For user-scoped content
|
||||
|
||||
|
||||
class ContentHandler(ABC):
|
||||
"""Base handler for fetching and processing content for embeddings."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def content_type(self) -> ContentType:
|
||||
"""The ContentType this handler manages."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""
|
||||
Fetch items that don't have embeddings yet.
|
||||
|
||||
Args:
|
||||
batch_size: Maximum number of items to return
|
||||
|
||||
Returns:
|
||||
List of ContentItem objects ready for embedding
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
|
||||
Returns:
|
||||
Dict with keys: total, with_embeddings, without_embeddings
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StoreAgentHandler(ContentHandler):
|
||||
"""Handler for marketplace store agent listings."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.STORE_AGENT
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch approved store listings without embeddings."""
|
||||
from backend.api.features.store.embeddings import build_searchable_text
|
||||
|
||||
missing = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
|
||||
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND uce."contentId" IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return [
|
||||
ContentItem(
|
||||
content_id=row["id"],
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text=build_searchable_text(
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
),
|
||||
metadata={
|
||||
"name": row["name"],
|
||||
"categories": row["categories"] or [],
|
||||
},
|
||||
user_id=None, # Store agents are public
|
||||
)
|
||||
for row in missing
|
||||
]
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about store agent embedding coverage."""
|
||||
# Count approved versions
|
||||
approved_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
class BlockHandler(ContentHandler):
|
||||
"""Handler for block definitions (Python classes)."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.BLOCK
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch blocks without embeddings."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# Get all available blocks
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Check which ones have embeddings
|
||||
if not all_blocks:
|
||||
return []
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
|
||||
# Query for existing embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_blocks = [
|
||||
(block_id, block_cls)
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if block_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
items = []
|
||||
for block_id, block_cls in missing_blocks[:batch_size]:
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
# Build searchable text from block metadata
|
||||
parts = []
|
||||
if hasattr(block_instance, "name") and block_instance.name:
|
||||
parts.append(block_instance.name)
|
||||
if (
|
||||
hasattr(block_instance, "description")
|
||||
and block_instance.description
|
||||
):
|
||||
parts.append(block_instance.description)
|
||||
if hasattr(block_instance, "categories") and block_instance.categories:
|
||||
# Convert BlockCategory enum to strings
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
)
|
||||
|
||||
# Add input/output schema info
|
||||
if hasattr(block_instance, "input_schema"):
|
||||
schema = block_instance.input_schema
|
||||
if hasattr(schema, "model_json_schema"):
|
||||
schema_dict = schema.model_json_schema()
|
||||
if "properties" in schema_dict:
|
||||
for prop_name, prop_info in schema_dict[
|
||||
"properties"
|
||||
].items():
|
||||
if "description" in prop_info:
|
||||
parts.append(
|
||||
f"{prop_name}: {prop_info['description']}"
|
||||
)
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=block_id,
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": getattr(block_instance, "name", ""),
|
||||
"categories": getattr(block_instance, "categories", []),
|
||||
},
|
||||
user_id=None, # Blocks are public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process block {block_id}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about block embedding coverage."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
all_blocks = get_blocks()
|
||||
total_blocks = len(all_blocks)
|
||||
|
||||
if total_blocks == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_blocks,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_blocks - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
class DocumentationHandler(ContentHandler):
|
||||
"""Handler for documentation files (.md/.mdx)."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.DOCUMENTATION
|
||||
|
||||
def _get_docs_root(self) -> Path:
|
||||
"""Get the documentation root directory."""
|
||||
# Assuming docs are in /docs relative to project root
|
||||
backend_root = Path(__file__).parent.parent.parent.parent
|
||||
docs_root = backend_root.parent.parent / "docs"
|
||||
return docs_root
|
||||
|
||||
def _extract_title_and_content(self, file_path: Path) -> tuple[str, str]:
|
||||
"""Extract title and content from markdown file."""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
# Try to extract title from first # heading
|
||||
lines = content.split("\n")
|
||||
title = ""
|
||||
body_lines = []
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("# ") and not title:
|
||||
title = line[2:].strip()
|
||||
else:
|
||||
body_lines.append(line)
|
||||
|
||||
# If no title found, use filename
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
|
||||
body = "\n".join(body_lines)
|
||||
|
||||
return title, body
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {file_path}: {e}")
|
||||
return file_path.stem, ""
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch documentation files without embeddings."""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
logger.warning(f"Documentation root not found: {docs_root}")
|
||||
return []
|
||||
|
||||
# Find all .md and .mdx files
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
|
||||
# Get relative paths for content IDs
|
||||
doc_paths = [str(doc.relative_to(docs_root)) for doc in all_docs]
|
||||
|
||||
if not doc_paths:
|
||||
return []
|
||||
|
||||
# Check which ones have embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(doc_paths))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*doc_paths,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_docs = [
|
||||
(doc_path, doc_file)
|
||||
for doc_path, doc_file in zip(doc_paths, all_docs)
|
||||
if doc_path not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
items = []
|
||||
for doc_path, doc_file in missing_docs[:batch_size]:
|
||||
try:
|
||||
title, content = self._extract_title_and_content(doc_file)
|
||||
|
||||
# Build searchable text
|
||||
searchable_text = f"{title} {content}"
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=doc_path,
|
||||
content_type=ContentType.DOCUMENTATION,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"title": title,
|
||||
"path": doc_path,
|
||||
},
|
||||
user_id=None, # Documentation is public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process doc {doc_path}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about documentation embedding coverage."""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
# Count all .md and .mdx files
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
total_docs = len(all_docs)
|
||||
|
||||
if total_docs == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
doc_paths = [str(doc.relative_to(docs_root)) for doc in all_docs]
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(doc_paths))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*doc_paths,
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_docs,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_docs - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
# Content handler registry
|
||||
CONTENT_HANDLERS: dict[ContentType, ContentHandler] = {
|
||||
ContentType.STORE_AGENT: StoreAgentHandler(),
|
||||
ContentType.BLOCK: BlockHandler(),
|
||||
ContentType.DOCUMENTATION: DocumentationHandler(),
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Integration tests for content handlers using real DB.
|
||||
|
||||
Run with: poetry run pytest backend/api/features/store/content_handlers_integration_test.py -xvs
|
||||
|
||||
These tests use the real database but mock OpenAI calls.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_all_content_types,
|
||||
ensure_content_embedding,
|
||||
get_embedding_stats,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_real_db():
|
||||
"""Test StoreAgentHandler with real database queries."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list (may be empty if all have embeddings)
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None
|
||||
assert item.content_type.value == "STORE_AGENT"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_real_db():
|
||||
"""Test BlockHandler with real database queries."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0 # Should have at least some blocks
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be block UUID
|
||||
assert item.content_type.value == "BLOCK"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_real_fs():
|
||||
"""Test DocumentationHandler with real filesystem."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Get stats from real filesystem
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be relative path
|
||||
assert item.content_type.value == "DOCUMENTATION"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats_all_types():
|
||||
"""Test get_embedding_stats aggregates all content types."""
|
||||
stats = await get_embedding_stats()
|
||||
|
||||
# Should have structure with by_type and totals
|
||||
assert "by_type" in stats
|
||||
assert "totals" in stats
|
||||
|
||||
# Check each content type is present
|
||||
by_type = stats["by_type"]
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "BLOCK" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Check totals are aggregated
|
||||
totals = stats["totals"]
|
||||
assert totals["total"] >= 0
|
||||
assert totals["with_embeddings"] >= 0
|
||||
assert totals["without_embeddings"] >= 0
|
||||
assert "coverage_percent" in totals
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_ensure_content_embedding_blocks(mock_generate):
|
||||
"""Test creating embeddings for blocks (mocked OpenAI)."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * 1536
|
||||
|
||||
# Get one block without embedding
|
||||
handler = BlockHandler()
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
if not items:
|
||||
pytest.skip("No blocks without embeddings")
|
||||
|
||||
item = items[0]
|
||||
|
||||
# Try to create embedding (OpenAI mocked)
|
||||
result = await ensure_content_embedding(
|
||||
content_type=item.content_type,
|
||||
content_id=item.content_id,
|
||||
searchable_text=item.searchable_text,
|
||||
metadata=item.metadata,
|
||||
user_id=item.user_id,
|
||||
)
|
||||
|
||||
# Should succeed with mocked OpenAI
|
||||
assert result is True
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_backfill_all_content_types_dry_run(mock_generate):
|
||||
"""Test backfill_all_content_types processes all handlers in order."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * 1536
|
||||
|
||||
# Run backfill with batch_size=1 to process max 1 per type
|
||||
result = await backfill_all_content_types(batch_size=1)
|
||||
|
||||
# Should have results for all content types
|
||||
assert "by_type" in result
|
||||
assert "totals" in result
|
||||
|
||||
by_type = result["by_type"]
|
||||
assert "BLOCK" in by_type
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Each type should have correct structure
|
||||
for content_type, type_result in by_type.items():
|
||||
assert "processed" in type_result
|
||||
assert "success" in type_result
|
||||
assert "failed" in type_result
|
||||
|
||||
# Totals should aggregate
|
||||
totals = result["totals"]
|
||||
assert totals["processed"] >= 0
|
||||
assert totals["success"] >= 0
|
||||
assert totals["failed"] >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handler_registry():
|
||||
"""Test all handlers are registered in correct order."""
|
||||
from prisma.enums import ContentType
|
||||
|
||||
# All three types should be registered
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
# Check handler types
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
E2E tests for content handlers (blocks, store agents, documentation).
|
||||
|
||||
Tests the full flow: discovering content → generating embeddings → storing.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_missing_items(mocker):
|
||||
"""Test StoreAgentHandler fetches approved agents without embeddings."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock database query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "agent-1",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"subHeading": "Test heading",
|
||||
"categories": ["AI", "Testing"],
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_missing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "agent-1"
|
||||
assert items[0].content_type == ContentType.STORE_AGENT
|
||||
assert "Test Agent" in items[0].searchable_text
|
||||
assert "A test agent" in items[0].searchable_text
|
||||
assert items[0].metadata["name"] == "Test Agent"
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_stats(mocker):
|
||||
"""Test StoreAgentHandler returns correct stats."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock approved count query
|
||||
mock_approved = [{"count": 50}]
|
||||
# Mock embedded count query
|
||||
mock_embedded = [{"count": 30}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
side_effect=[mock_approved, mock_embedded],
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 50
|
||||
assert stats["with_embeddings"] == 30
|
||||
assert stats["without_embeddings"] == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items(mocker):
|
||||
"""Test BlockHandler discovers blocks without embeddings."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks to return test blocks
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Calculator Block"
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.input_schema.model_json_schema.return_value = {
|
||||
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
||||
}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
mock_existing = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_existing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "block-uuid-1"
|
||||
assert items[0].content_type == ContentType.BLOCK
|
||||
assert "Calculator Block" in items[0].searchable_text
|
||||
assert "Performs calculations" in items[0].searchable_text
|
||||
assert "MATH" in items[0].searchable_text
|
||||
assert "expression: Math expression" in items[0].searchable_text
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats(mocker):
|
||||
"""Test BlockHandler returns correct stats."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks
|
||||
mock_blocks = {
|
||||
"block-1": MagicMock(),
|
||||
"block-2": MagicMock(),
|
||||
"block-3": MagicMock(),
|
||||
}
|
||||
|
||||
# Mock embedded count query (2 blocks have embeddings)
|
||||
mock_embedded = [{"count": 2}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 2
|
||||
assert stats["without_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
"""Test DocumentationHandler discovers docs without embeddings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory with test files
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
|
||||
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
|
||||
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
|
||||
|
||||
# Mock _get_docs_root to return temp dir
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
# Check guide.md
|
||||
guide_item = next(
|
||||
(item for item in items if item.content_id == "guide.md"), None
|
||||
)
|
||||
assert guide_item is not None
|
||||
assert guide_item.content_type == ContentType.DOCUMENTATION
|
||||
assert "Getting Started" in guide_item.searchable_text
|
||||
assert "This is a guide" in guide_item.searchable_text
|
||||
assert guide_item.metadata["title"] == "Getting Started"
|
||||
assert guide_item.user_id is None
|
||||
|
||||
# Check api.mdx
|
||||
api_item = next(
|
||||
(item for item in items if item.content_id == "api.mdx"), None
|
||||
)
|
||||
assert api_item is not None
|
||||
assert "API Reference" in api_item.searchable_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_stats(tmp_path, mocker):
|
||||
"""Test DocumentationHandler returns correct stats."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
(docs_root / "doc1.md").write_text("# Doc 1")
|
||||
(docs_root / "doc2.md").write_text("# Doc 2")
|
||||
(docs_root / "doc3.mdx").write_text("# Doc 3")
|
||||
|
||||
# Mock embedded count query (1 doc has embedding)
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 1
|
||||
assert stats["without_embeddings"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_title_extraction(tmp_path):
|
||||
"""Test DocumentationHandler extracts title from markdown heading."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test with heading
|
||||
doc_with_heading = tmp_path / "with_heading.md"
|
||||
doc_with_heading.write_text("# My Title\n\nContent here")
|
||||
title, content = handler._extract_title_and_content(doc_with_heading)
|
||||
assert title == "My Title"
|
||||
assert "# My Title" not in content
|
||||
assert "Content here" in content
|
||||
|
||||
# Test without heading
|
||||
doc_without_heading = tmp_path / "no-heading.md"
|
||||
doc_without_heading.write_text("Just content, no heading")
|
||||
title, content = handler._extract_title_and_content(doc_without_heading)
|
||||
assert title == "No Heading" # Uses filename
|
||||
assert "Just content" in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handlers_registry():
|
||||
"""Test all content types are registered."""
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_missing_attributes():
|
||||
"""Test BlockHandler gracefully handles blocks with missing attributes."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with minimal attributes
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
# No description, categories, or schema
|
||||
del mock_block_instance.description
|
||||
del mock_block_instance.categories
|
||||
del mock_block_instance.input_schema
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock one good block and one bad block
|
||||
good_block = MagicMock()
|
||||
good_instance = MagicMock()
|
||||
good_instance.name = "Good Block"
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
bad_block.side_effect = Exception("Instantiation failed")
|
||||
|
||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
# Should only get the good block
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Mock _get_docs_root to return non-existent path
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
@@ -1,8 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import fastapi
|
||||
import prisma.enums
|
||||
@@ -10,7 +9,7 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
from backend.data.db import query_raw_with_schema, transaction
|
||||
from backend.data.db import transaction
|
||||
from backend.data.graph import (
|
||||
GraphMeta,
|
||||
GraphModel,
|
||||
@@ -30,6 +29,8 @@ from backend.util.settings import Settings
|
||||
|
||||
from . import exceptions as store_exceptions
|
||||
from . import model as store_model
|
||||
from .embeddings import ensure_embedding
|
||||
from .hybrid_search import hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -50,128 +51,77 @@ async def get_store_agents(
|
||||
page_size: int = 20,
|
||||
) -> store_model.StoreAgentsResponse:
|
||||
"""
|
||||
Get PUBLIC store agents from the StoreAgent view
|
||||
Get PUBLIC store agents from the StoreAgent view.
|
||||
|
||||
Search behavior:
|
||||
- With search_query: Uses hybrid search (semantic + lexical)
|
||||
- Fallback: If embeddings unavailable, gracefully degrades to lexical-only
|
||||
- Rationale: User-facing endpoint prioritizes availability over accuracy
|
||||
|
||||
Note: Admin operations (approval) use fail-fast to prevent inconsistent state.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
|
||||
search_used_hybrid = False
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
agents: list[dict[str, Any]] = []
|
||||
total = 0
|
||||
total_pages = 0
|
||||
|
||||
try:
|
||||
# If search_query is provided, use full-text search
|
||||
# If search_query is provided, use hybrid search (embeddings + tsvector)
|
||||
if search_query:
|
||||
offset = (page - 1) * page_size
|
||||
# Try hybrid search combining semantic and lexical signals
|
||||
# Falls back to lexical-only if OpenAI unavailable (user-facing, high SLA)
|
||||
try:
|
||||
agents, total = await hybrid_search(
|
||||
query=search_query,
|
||||
featured=featured,
|
||||
creators=creators,
|
||||
category=category,
|
||||
sorted_by="relevance", # Use hybrid scoring for relevance
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
search_used_hybrid = True
|
||||
except Exception as e:
|
||||
# Log error but fall back to lexical search for better UX
|
||||
logger.error(
|
||||
f"Hybrid search failed (likely OpenAI unavailable), "
|
||||
f"falling back to lexical search: {e}"
|
||||
)
|
||||
# search_used_hybrid remains False, will use fallback path below
|
||||
|
||||
# Whitelist allowed order_by columns
|
||||
ALLOWED_ORDER_BY = {
|
||||
"rating": "rating DESC, rank DESC",
|
||||
"runs": "runs DESC, rank DESC",
|
||||
"name": "agent_name ASC, rank ASC",
|
||||
"updated_at": "updated_at DESC, rank DESC",
|
||||
}
|
||||
# Convert hybrid search results (dict format) if hybrid succeeded
|
||||
if search_used_hybrid:
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = store_model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error parsing Store agent from hybrid search results: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate and get order clause
|
||||
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
||||
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
||||
else:
|
||||
order_by_clause = "updated_at DESC, rank DESC"
|
||||
|
||||
# Build WHERE conditions and parameters list
|
||||
where_parts: list[str] = []
|
||||
params: list[typing.Any] = [search_query] # $1 - search term
|
||||
param_index = 2 # Start at $2 for next parameter
|
||||
|
||||
# Always filter for available agents
|
||||
where_parts.append("is_available = true")
|
||||
|
||||
if featured:
|
||||
where_parts.append("featured = true")
|
||||
|
||||
if creators and creators:
|
||||
# Use ANY with array parameter
|
||||
where_parts.append(f"creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category and category:
|
||||
where_parts.append(f"${param_index} = ANY(categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
limit_param = f"${param_index}"
|
||||
offset_param = f"${param_index + 1}"
|
||||
|
||||
# Execute full-text search query with parameterized values
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
ts_rank_cd(search, query) AS rank
|
||||
FROM {{schema_prefix}}"StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
ORDER BY {order_by_clause}
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
# Count query for pagination - only uses search term parameter
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
"""
|
||||
|
||||
# Execute both queries with parameters
|
||||
agents = await query_raw_with_schema(sql_query, *params)
|
||||
|
||||
# For count, use params without pagination (last 2 params)
|
||||
count_params = params[:-2]
|
||||
count_result = await query_raw_with_schema(count_query, *count_params)
|
||||
|
||||
total = count_result[0]["count"] if count_result else 0
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert raw results to StoreAgent models
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = store_model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Store agent from search results: {e}")
|
||||
continue
|
||||
|
||||
else:
|
||||
# Non-search query path (original logic)
|
||||
if not search_used_hybrid:
|
||||
# Fallback path - use basic search or no search
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
@@ -180,6 +130,14 @@ async def get_store_agents(
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
# Add basic text search if search_query provided but hybrid failed
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
@@ -188,7 +146,7 @@ async def get_store_agents(
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
@@ -199,7 +157,7 @@ async def get_store_agents(
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
for agent in db_agents:
|
||||
try:
|
||||
# Create the StoreAgent object safely
|
||||
store_agent = store_model.StoreAgent(
|
||||
@@ -1577,7 +1535,7 @@ async def review_store_submission(
|
||||
)
|
||||
|
||||
# Update the AgentGraph with store listing data
|
||||
await prisma.models.AgentGraph.prisma().update(
|
||||
await prisma.models.AgentGraph.prisma(tx).update(
|
||||
where={
|
||||
"graphVersionId": {
|
||||
"id": store_listing_version.agentGraphId,
|
||||
@@ -1592,6 +1550,23 @@ async def review_store_submission(
|
||||
},
|
||||
)
|
||||
|
||||
# Generate embedding for approved listing (blocking - admin operation)
|
||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||
embedding_success = await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=store_listing_version.name,
|
||||
description=store_listing_version.description,
|
||||
sub_heading=store_listing_version.subHeading,
|
||||
categories=store_listing_version.categories or [],
|
||||
tx=tx,
|
||||
)
|
||||
if not embedding_success:
|
||||
raise ValueError(
|
||||
f"Failed to generate embedding for listing {store_listing_version_id}. "
|
||||
"This is likely due to OpenAI API being unavailable. "
|
||||
"Please try again later or contact support if the issue persists."
|
||||
)
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
data={
|
||||
|
||||
@@ -0,0 +1,628 @@
|
||||
"""
|
||||
Unified Content Embeddings Service
|
||||
|
||||
Handles generation and storage of OpenAI embeddings for all content types
|
||||
(store listings, blocks, documentation, library agents) to enable semantic/hybrid search.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.api.features.store.content_handlers import CONTENT_HANDLERS
|
||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.json import dumps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
# OpenAI embedding token limit (8,191 with 1 token buffer for safety)
|
||||
EMBEDDING_MAX_TOKENS = 8191
|
||||
|
||||
|
||||
def build_searchable_text(
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Build searchable text from listing version fields.
|
||||
|
||||
Combines relevant fields into a single string for embedding.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Name is important - include it
|
||||
if name:
|
||||
parts.append(name)
|
||||
|
||||
# Sub-heading provides context
|
||||
if sub_heading:
|
||||
parts.append(sub_heading)
|
||||
|
||||
# Description is the main content
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
# Categories help with semantic matching
|
||||
if categories:
|
||||
parts.append(" ".join(categories))
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
async def generate_embedding(text: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for text using OpenAI API.
|
||||
|
||||
Returns None if embedding generation fails.
|
||||
Fail-fast: no retries to maintain consistency with approval flow.
|
||||
"""
|
||||
try:
|
||||
client = get_openai_client()
|
||||
if not client:
|
||||
logger.error("openai_internal_api_key not set, cannot generate embedding")
|
||||
return None
|
||||
|
||||
# Truncate text to token limit using tiktoken
|
||||
# Character-based truncation is insufficient because token ratios vary by content type
|
||||
enc = encoding_for_model(EMBEDDING_MODEL)
|
||||
tokens = enc.encode(text)
|
||||
if len(tokens) > EMBEDDING_MAX_TOKENS:
|
||||
tokens = tokens[:EMBEDDING_MAX_TOKENS]
|
||||
truncated_text = enc.decode(tokens)
|
||||
logger.info(
|
||||
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
|
||||
)
|
||||
else:
|
||||
truncated_text = text
|
||||
|
||||
start_time = time.time()
|
||||
response = await client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=truncated_text,
|
||||
)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
embedding = response.data[0].embedding
|
||||
logger.info(
|
||||
f"Generated embedding: {len(embedding)} dims, "
|
||||
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
|
||||
)
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embedding: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def store_embedding(
|
||||
version_id: str,
|
||||
embedding: list[float],
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the database.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
DEPRECATED: Use ensure_embedding() instead (includes searchable_text).
|
||||
"""
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text="", # Empty for backward compat; ensure_embedding() populates this
|
||||
metadata=None,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
|
||||
async def store_content_embedding(
|
||||
content_type: ContentType,
|
||||
content_id: str,
|
||||
embedding: list[float],
|
||||
searchable_text: str,
|
||||
metadata: dict | None = None,
|
||||
user_id: str | None = None,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the unified content embeddings table.
|
||||
|
||||
New function for unified content embedding storage.
|
||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||
"""
|
||||
try:
|
||||
client = tx if tx else prisma.get_client()
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = embedding_to_vector_string(embedding)
|
||||
metadata_json = dumps(metadata or {})
|
||||
|
||||
# Upsert the embedding
|
||||
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
||||
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
||||
ON CONFLICT ("contentType", "contentId", "userId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $4::vector,
|
||||
"searchableText" = $5,
|
||||
"metadata" = $6::jsonb,
|
||||
"updatedAt" = NOW()
|
||||
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
|
||||
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
embedding_str,
|
||||
searchable_text,
|
||||
metadata_json,
|
||||
client=client,
|
||||
set_public_search_path=True,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for a listing version.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Returns dict with storeListingVersionId, embedding, timestamps or None if not found.
|
||||
"""
|
||||
result = await get_content_embedding(
|
||||
ContentType.STORE_AGENT, version_id, user_id=None
|
||||
)
|
||||
if result:
|
||||
# Transform to old format for backward compatibility
|
||||
return {
|
||||
"storeListingVersionId": result["contentId"],
|
||||
"embedding": result["embedding"],
|
||||
"createdAt": result["createdAt"],
|
||||
"updatedAt": result["updatedAt"],
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def get_content_embedding(
|
||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for any content type.
|
||||
|
||||
New function for unified content embedding retrieval.
|
||||
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
||||
"""
|
||||
try:
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
"contentType",
|
||||
"contentId",
|
||||
"userId",
|
||||
"embedding"::text as "embedding",
|
||||
"searchableText",
|
||||
"metadata",
|
||||
"createdAt",
|
||||
"updatedAt"
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
set_public_search_path=True,
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def ensure_embedding(
|
||||
version_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for the listing version.
|
||||
|
||||
Creates embedding if missing. Use force=True to regenerate.
|
||||
Backward-compatible wrapper for store listings.
|
||||
|
||||
Args:
|
||||
version_id: The StoreListingVersion ID
|
||||
name: Agent name
|
||||
description: Agent description
|
||||
sub_heading: Agent sub-heading
|
||||
categories: Agent categories
|
||||
force: Force regeneration even if embedding exists
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_embedding(version_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(f"Embedding for version {version_id} already exists")
|
||||
return True
|
||||
|
||||
# Build searchable text for embedding
|
||||
searchable_text = build_searchable_text(
|
||||
name, description, sub_heading, categories
|
||||
)
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
||||
return False
|
||||
|
||||
# Store the embedding with metadata using new function
|
||||
metadata = {
|
||||
"name": name,
|
||||
"subHeading": sub_heading,
|
||||
"categories": categories,
|
||||
}
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_embedding(version_id: str) -> bool:
|
||||
"""
|
||||
Delete embedding for a listing version.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
"""
|
||||
return await delete_content_embedding(ContentType.STORE_AGENT, version_id)
|
||||
|
||||
|
||||
async def delete_content_embedding(
|
||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Delete embedding for any content type.
|
||||
|
||||
New function for unified content embedding deletion.
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
|
||||
Args:
|
||||
content_type: The type of content (STORE_AGENT, LIBRARY_AGENT, etc.)
|
||||
content_id: The unique identifier for the content
|
||||
user_id: Optional user ID. For public content (STORE_AGENT, BLOCK), pass None.
|
||||
For user-scoped content (LIBRARY_AGENT), pass the user's ID to avoid
|
||||
deleting embeddings belonging to other users.
|
||||
|
||||
Returns:
|
||||
True if deletion succeeded, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND "contentId" = $2
|
||||
AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
client=client,
|
||||
)
|
||||
|
||||
user_str = f" (user: {user_id})" if user_id else ""
|
||||
logger.info(f"Deleted embedding for {content_type}:{content_id}{user_str}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding_stats() -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about embedding coverage for all content types.
|
||||
|
||||
Returns stats per content type and overall totals.
|
||||
"""
|
||||
try:
|
||||
stats_by_type = {}
|
||||
total_items = 0
|
||||
total_with_embeddings = 0
|
||||
total_without_embeddings = 0
|
||||
|
||||
# Aggregate stats from all handlers
|
||||
for content_type, handler in CONTENT_HANDLERS.items():
|
||||
try:
|
||||
stats = await handler.get_stats()
|
||||
stats_by_type[content_type.value] = {
|
||||
"total": stats["total"],
|
||||
"with_embeddings": stats["with_embeddings"],
|
||||
"without_embeddings": stats["without_embeddings"],
|
||||
"coverage_percent": (
|
||||
round(stats["with_embeddings"] / stats["total"] * 100, 1)
|
||||
if stats["total"] > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
total_items += stats["total"]
|
||||
total_with_embeddings += stats["with_embeddings"]
|
||||
total_without_embeddings += stats["without_embeddings"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get stats for {content_type.value}: {e}")
|
||||
stats_by_type[content_type.value] = {
|
||||
"total": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
return {
|
||||
"by_type": stats_by_type,
|
||||
"totals": {
|
||||
"total": total_items,
|
||||
"with_embeddings": total_with_embeddings,
|
||||
"without_embeddings": total_without_embeddings,
|
||||
"coverage_percent": (
|
||||
round(total_with_embeddings / total_items * 100, 1)
|
||||
if total_items > 0
|
||||
else 0
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding stats: {e}")
|
||||
return {
|
||||
"by_type": {},
|
||||
"totals": {
|
||||
"total": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
},
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for approved listings that don't have them.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing usage.
|
||||
This now delegates to backfill_all_content_types() to process all content types.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate per content type
|
||||
|
||||
Returns:
|
||||
Dict with success/failure counts aggregated across all content types
|
||||
"""
|
||||
# Delegate to the new generic backfill system
|
||||
result = await backfill_all_content_types(batch_size)
|
||||
|
||||
# Return in the old format for backward compatibility
|
||||
return result["totals"]
|
||||
|
||||
|
||||
async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for all content types using registered handlers.
|
||||
|
||||
Processes content types in order: BLOCK → STORE_AGENT → DOCUMENTATION.
|
||||
This ensures foundational content (blocks) are searchable first.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate per content type
|
||||
|
||||
Returns:
|
||||
Dict with stats per content type and overall totals
|
||||
"""
|
||||
results_by_type = {}
|
||||
total_processed = 0
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
|
||||
# Process content types in explicit order
|
||||
processing_order = [
|
||||
ContentType.BLOCK,
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
for content_type in processing_order:
|
||||
handler = CONTENT_HANDLERS.get(content_type)
|
||||
if not handler:
|
||||
logger.warning(f"No handler registered for {content_type.value}")
|
||||
continue
|
||||
try:
|
||||
logger.info(f"Processing {content_type.value} content type...")
|
||||
|
||||
# Get missing items from handler
|
||||
missing_items = await handler.get_missing_items(batch_size)
|
||||
|
||||
if not missing_items:
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"message": "No missing embeddings",
|
||||
}
|
||||
continue
|
||||
|
||||
# Process embeddings concurrently for better performance
|
||||
embedding_tasks = [
|
||||
ensure_content_embedding(
|
||||
content_type=item.content_type,
|
||||
content_id=item.content_id,
|
||||
searchable_text=item.searchable_text,
|
||||
metadata=item.metadata,
|
||||
user_id=item.user_id,
|
||||
)
|
||||
for item in missing_items
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
|
||||
|
||||
success = sum(1 for result in results if result is True)
|
||||
failed = len(results) - success
|
||||
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": len(missing_items),
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"message": f"Backfilled {success} embeddings, {failed} failed",
|
||||
}
|
||||
|
||||
total_processed += len(missing_items)
|
||||
total_success += success
|
||||
total_failed += failed
|
||||
|
||||
logger.info(
|
||||
f"{content_type.value}: processed {len(missing_items)}, "
|
||||
f"success {success}, failed {failed}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process {content_type.value}: {e}")
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
return {
|
||||
"by_type": results_by_type,
|
||||
"totals": {
|
||||
"processed": total_processed,
|
||||
"success": total_success,
|
||||
"failed": total_failed,
|
||||
"message": f"Overall: {total_success} succeeded, {total_failed} failed",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def embed_query(query: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for a search query.
|
||||
|
||||
Same as generate_embedding but with clearer intent.
|
||||
"""
|
||||
return await generate_embedding(query)
|
||||
|
||||
|
||||
def embedding_to_vector_string(embedding: list[float]) -> str:
|
||||
"""Convert embedding list to PostgreSQL vector string format."""
|
||||
return "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
|
||||
|
||||
async def ensure_content_embedding(
|
||||
content_type: ContentType,
|
||||
content_id: str,
|
||||
searchable_text: str,
|
||||
metadata: dict | None = None,
|
||||
user_id: str | None = None,
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for any content type.
|
||||
|
||||
Generic function for creating embeddings for store agents, blocks, docs, etc.
|
||||
|
||||
Args:
|
||||
content_type: ContentType enum value (STORE_AGENT, BLOCK, etc.)
|
||||
content_id: Unique identifier for the content
|
||||
searchable_text: Combined text for embedding generation
|
||||
metadata: Optional metadata to store with embedding
|
||||
force: Force regeneration even if embedding exists
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_content_embedding(content_type, content_id, user_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(
|
||||
f"Embedding for {content_type}:{content_id} already exists"
|
||||
)
|
||||
return True
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for {content_type}:{content_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Store the embedding
|
||||
return await store_content_embedding(
|
||||
content_type=content_type,
|
||||
content_id=content_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata or {},
|
||||
user_id=user_id,
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
Integration tests for embeddings with schema handling.
|
||||
|
||||
These tests verify that embeddings operations work correctly across different database schemas.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
|
||||
# Schema prefix tests removed - functionality moved to db.raw_with_schema() helper
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_store_content_embedding_with_schema():
|
||||
"""Test storing embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1] * 1536,
|
||||
searchable_text="test text",
|
||||
metadata={"test": "data"},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.execute_raw.called
|
||||
|
||||
# Get the SQL query that was executed
|
||||
call_args = mock_client.execute_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix is in the query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify result
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_get_content_embedding_with_schema():
|
||||
"""Test retrieving embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.query_raw.return_value = [
|
||||
{
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-id",
|
||||
"userId": None,
|
||||
"embedding": "[0.1, 0.2]",
|
||||
"searchableText": "test",
|
||||
"metadata": {},
|
||||
"createdAt": "2024-01-01",
|
||||
"updatedAt": "2024-01-01",
|
||||
}
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.get_content_embedding(
|
||||
ContentType.STORE_AGENT,
|
||||
"test-id",
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.query_raw.called
|
||||
|
||||
# Get the SQL query that was executed
|
||||
call_args = mock_client.query_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix is in the query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify result
|
||||
assert result is not None
|
||||
assert result["contentId"] == "test-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_delete_content_embedding_with_schema():
|
||||
"""Test deleting embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.delete_content_embedding(
|
||||
ContentType.STORE_AGENT,
|
||||
"test-id",
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.execute_raw.called
|
||||
|
||||
# Get the SQL query that was executed
|
||||
call_args = mock_client.execute_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix is in the query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify result
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_get_embedding_stats_with_schema():
|
||||
"""Test embedding statistics with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
# Mock both query results
|
||||
mock_client.query_raw.side_effect = [
|
||||
[{"count": 100}], # total_approved
|
||||
[{"count": 80}], # with_embeddings
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.get_embedding_stats()
|
||||
|
||||
# Verify both queries were called
|
||||
assert mock_client.query_raw.call_count == 2
|
||||
|
||||
# Get both SQL queries
|
||||
first_call = mock_client.query_raw.call_args_list[0]
|
||||
second_call = mock_client.query_raw.call_args_list[1]
|
||||
|
||||
first_sql = first_call[0][0]
|
||||
second_sql = second_call[0][0]
|
||||
|
||||
# Verify schema prefix in both queries
|
||||
assert '"platform"."StoreListingVersion"' in first_sql
|
||||
assert '"platform"."StoreListingVersion"' in second_sql
|
||||
assert '"platform"."UnifiedContentEmbedding"' in second_sql
|
||||
|
||||
# Verify results
|
||||
assert result["total_approved"] == 100
|
||||
assert result["with_embeddings"] == 80
|
||||
assert result["without_embeddings"] == 20
|
||||
assert result["coverage_percent"] == 80.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backfill_missing_embeddings_with_schema():
|
||||
"""Test backfilling embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
# Mock missing embeddings query
|
||||
mock_client.query_raw.return_value = [
|
||||
{
|
||||
"id": "version-1",
|
||||
"name": "Test Agent",
|
||||
"description": "Test description",
|
||||
"subHeading": "Test heading",
|
||||
"categories": ["test"],
|
||||
}
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.ensure_embedding"
|
||||
) as mock_ensure:
|
||||
mock_ensure.return_value = True
|
||||
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=10)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.query_raw.called
|
||||
|
||||
# Get the SQL query
|
||||
call_args = mock_client.query_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix in query
|
||||
assert '"platform"."StoreListingVersion"' in sql_query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify ensure_embedding was called
|
||||
assert mock_ensure.called
|
||||
|
||||
# Verify results
|
||||
assert result["processed"] == 1
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_ensure_content_embedding_with_schema():
|
||||
"""Test ensuring embeddings exist with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_content_embedding"
|
||||
) as mock_get:
|
||||
# Simulate no existing embedding
|
||||
mock_get.return_value = None
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding"
|
||||
) as mock_generate:
|
||||
mock_generate.return_value = [0.1] * 1536
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.store_content_embedding"
|
||||
) as mock_store:
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.ensure_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
searchable_text="test text",
|
||||
metadata={"test": "data"},
|
||||
user_id=None,
|
||||
force=False,
|
||||
)
|
||||
|
||||
# Verify the flow
|
||||
assert mock_get.called
|
||||
assert mock_generate.called
|
||||
assert mock_store.called
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backward_compatibility_store_embedding():
|
||||
"""Test backward compatibility wrapper for store_embedding."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.store_content_embedding"
|
||||
) as mock_store:
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id",
|
||||
embedding=[0.1] * 1536,
|
||||
tx=None,
|
||||
)
|
||||
|
||||
# Verify it calls the new function with correct parameters
|
||||
assert mock_store.called
|
||||
call_args = mock_store.call_args
|
||||
|
||||
assert call_args[1]["content_type"] == ContentType.STORE_AGENT
|
||||
assert call_args[1]["content_id"] == "test-version-id"
|
||||
assert call_args[1]["user_id"] is None
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backward_compatibility_get_embedding():
|
||||
"""Test backward compatibility wrapper for get_embedding."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_content_embedding"
|
||||
) as mock_get:
|
||||
mock_get.return_value = {
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-version-id",
|
||||
"embedding": "[0.1, 0.2]",
|
||||
"createdAt": "2024-01-01",
|
||||
"updatedAt": "2024-01-01",
|
||||
}
|
||||
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
# Verify it calls the new function
|
||||
assert mock_get.called
|
||||
|
||||
# Verify it transforms to old format
|
||||
assert result is not None
|
||||
assert result["storeListingVersionId"] == "test-version-id"
|
||||
assert "embedding" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_schema_handling_error_cases():
|
||||
"""Test error handling in schema-aware operations."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1] * 1536,
|
||||
searchable_text="test",
|
||||
metadata=None,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Should return False on error, not raise
|
||||
assert result is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -0,0 +1,387 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_prisma():
|
||||
"""Setup Prisma client for tests."""
|
||||
try:
|
||||
Prisma()
|
||||
except prisma.errors.ClientAlreadyRegisteredError:
|
||||
pass
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_build_searchable_text():
|
||||
"""Test searchable text building from listing fields."""
|
||||
result = embeddings.build_searchable_text(
|
||||
name="AI Assistant",
|
||||
description="A helpful AI assistant for productivity",
|
||||
sub_heading="Boost your productivity",
|
||||
categories=["AI", "Productivity"],
|
||||
)
|
||||
|
||||
expected = "AI Assistant Boost your productivity A helpful AI assistant for productivity AI Productivity"
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_build_searchable_text_empty_fields():
|
||||
"""Test searchable text building with empty fields."""
|
||||
result = embeddings.build_searchable_text(
|
||||
name="", description="Test description", sub_heading="", categories=[]
|
||||
)
|
||||
|
||||
assert result == "Test description"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_success():
|
||||
"""Test successful embedding generation."""
|
||||
# Mock OpenAI response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock()]
|
||||
mock_response.data[0].embedding = [0.1, 0.2, 0.3] * 512 # 1536 dimensions
|
||||
|
||||
# Use AsyncMock for async embeddings.create method
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1536
|
||||
assert result[0] == 0.1
|
||||
|
||||
mock_client.embeddings.create.assert_called_once_with(
|
||||
model="text-embedding-3-small", input="test text"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_no_api_key():
|
||||
"""Test embedding generation without API key."""
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = None
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_api_error():
|
||||
"""Test embedding generation with API error."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings.create = AsyncMock(side_effect=Exception("API Error"))
|
||||
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_text_truncation():
|
||||
"""Test that long text is properly truncated using tiktoken."""
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock()]
|
||||
mock_response.data[0].embedding = [0.1] * 1536
|
||||
|
||||
# Use AsyncMock for async embeddings.create method
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Create text that will exceed 8191 tokens
|
||||
# Use varied characters to ensure token-heavy text: each word is ~1 token
|
||||
words = [f"word{i}" for i in range(10000)]
|
||||
long_text = " ".join(words) # ~10000 tokens
|
||||
|
||||
await embeddings.generate_embedding(long_text)
|
||||
|
||||
# Verify text was truncated to 8191 tokens
|
||||
call_args = mock_client.embeddings.create.call_args
|
||||
truncated_text = call_args.kwargs["input"]
|
||||
|
||||
# Count actual tokens in truncated text
|
||||
enc = encoding_for_model("text-embedding-3-small")
|
||||
actual_tokens = len(enc.encode(truncated_text))
|
||||
|
||||
# Should be at or just under 8191 tokens
|
||||
assert actual_tokens <= 8191
|
||||
# Should be close to the limit (not over-truncated)
|
||||
assert actual_tokens >= 8100
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_embedding_success(mocker):
|
||||
"""Test successful embedding storage."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_client.execute_raw = mocker.AsyncMock()
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# execute_raw is called twice: once for SET search_path, once for INSERT
|
||||
assert mock_client.execute_raw.call_count == 2
|
||||
|
||||
# First call: SET search_path
|
||||
first_call_args = mock_client.execute_raw.call_args_list[0][0]
|
||||
assert "SET search_path" in first_call_args[0]
|
||||
|
||||
# Second call: INSERT query with the actual data
|
||||
second_call_args = mock_client.execute_raw.call_args_list[1][0]
|
||||
assert "test-version-id" in second_call_args
|
||||
assert "[0.1,0.2,0.3]" in second_call_args
|
||||
assert None in second_call_args # userId should be None for store agents
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_embedding_database_error(mocker):
|
||||
"""Test embedding storage with database error."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_success():
|
||||
"""Test successful embedding retrieval."""
|
||||
mock_result = [
|
||||
{
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-version-id",
|
||||
"userId": None,
|
||||
"embedding": "[0.1,0.2,0.3]",
|
||||
"searchableText": "Test text",
|
||||
"metadata": {},
|
||||
"createdAt": "2024-01-01T00:00:00Z",
|
||||
"updatedAt": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
assert result is not None
|
||||
assert result["storeListingVersionId"] == "test-version-id"
|
||||
assert result["embedding"] == "[0.1,0.2,0.3]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_not_found():
|
||||
"""Test embedding retrieval when not found."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.store_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_already_exists(mock_get, mock_store, mock_generate):
|
||||
"""Test ensure_embedding when embedding already exists."""
|
||||
mock_get.return_value = {"embedding": "[0.1,0.2,0.3]"}
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_generate.assert_not_called()
|
||||
mock_store.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.store_content_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
|
||||
"""Test ensure_embedding creating new embedding."""
|
||||
mock_get.return_value = None
|
||||
mock_generate.return_value = [0.1, 0.2, 0.3]
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_generate.assert_called_once_with("Test Test heading Test description test")
|
||||
mock_store.assert_called_once_with(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
searchable_text="Test Test heading Test description test",
|
||||
metadata={"name": "Test", "subHeading": "Test heading", "categories": ["test"]},
|
||||
user_id=None,
|
||||
tx=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
||||
"""Test ensure_embedding when generation fails."""
|
||||
mock_get.return_value = None
|
||||
mock_generate.return_value = None
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats():
|
||||
"""Test embedding statistics retrieval."""
|
||||
# Mock approved count query and embedded count query
|
||||
mock_approved_result = [{"count": 100}]
|
||||
mock_embedded_result = [{"count": 75}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
side_effect=[mock_approved_result, mock_embedded_result],
|
||||
):
|
||||
result = await embeddings.get_embedding_stats()
|
||||
|
||||
assert result["total_approved"] == 100
|
||||
assert result["with_embeddings"] == 75
|
||||
assert result["without_embeddings"] == 25
|
||||
assert result["coverage_percent"] == 75.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.ensure_embedding")
|
||||
async def test_backfill_missing_embeddings_success(mock_ensure):
|
||||
"""Test backfill with successful embedding generation."""
|
||||
# Mock missing embeddings query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "version-1",
|
||||
"name": "Agent 1",
|
||||
"description": "Description 1",
|
||||
"subHeading": "Heading 1",
|
||||
"categories": ["AI"],
|
||||
},
|
||||
{
|
||||
"id": "version-2",
|
||||
"name": "Agent 2",
|
||||
"description": "Description 2",
|
||||
"subHeading": "Heading 2",
|
||||
"categories": ["Productivity"],
|
||||
},
|
||||
]
|
||||
|
||||
# Mock ensure_embedding to succeed for first, fail for second
|
||||
mock_ensure.side_effect = [True, False]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_missing,
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
|
||||
assert result["processed"] == 2
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 1
|
||||
assert mock_ensure.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_backfill_missing_embeddings_no_missing():
|
||||
"""Test backfill when no embeddings are missing."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
|
||||
assert result["processed"] == 0
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["message"] == "No missing embeddings"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_embedding_to_vector_string():
|
||||
"""Test embedding to PostgreSQL vector string conversion."""
|
||||
embedding = [0.1, 0.2, 0.3, -0.4]
|
||||
result = embeddings.embedding_to_vector_string(embedding)
|
||||
assert result == "[0.1,0.2,0.3,-0.4]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_embed_query():
|
||||
"""Test embed_query function (alias for generate_embedding)."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding"
|
||||
) as mock_generate:
|
||||
mock_generate.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.embed_query("test query")
|
||||
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
mock_generate.assert_called_once_with("test query")
|
||||
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
Hybrid Search for Store Agents
|
||||
|
||||
Combines semantic (embedding) search with lexical (tsvector) search
|
||||
for improved relevance in marketplace agent discovery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from backend.api.features.store.embeddings import (
|
||||
embed_query,
|
||||
embedding_to_vector_string,
|
||||
)
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchWeights:
|
||||
"""Weights for combining search signals."""
|
||||
|
||||
semantic: float = 0.30 # Embedding cosine similarity
|
||||
lexical: float = 0.30 # tsvector ts_rank_cd score
|
||||
category: float = 0.20 # Category match boost
|
||||
recency: float = 0.10 # Newer agents ranked higher
|
||||
popularity: float = 0.10 # Agent usage/runs (PageRank-like)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate weights are non-negative and sum to approximately 1.0."""
|
||||
total = (
|
||||
self.semantic
|
||||
+ self.lexical
|
||||
+ self.category
|
||||
+ self.recency
|
||||
+ self.popularity
|
||||
)
|
||||
|
||||
if any(
|
||||
w < 0
|
||||
for w in [
|
||||
self.semantic,
|
||||
self.lexical,
|
||||
self.category,
|
||||
self.recency,
|
||||
self.popularity,
|
||||
]
|
||||
):
|
||||
raise ValueError("All weights must be non-negative")
|
||||
|
||||
if not (0.99 <= total <= 1.01):
|
||||
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
|
||||
|
||||
|
||||
DEFAULT_WEIGHTS = HybridSearchWeights()
|
||||
|
||||
# Minimum relevance score threshold - agents below this are filtered out
|
||||
# With weights (0.30 semantic + 0.30 lexical + 0.20 category + 0.10 recency + 0.10 popularity):
|
||||
# - 0.20 means at least ~60% semantic match OR strong lexical match required
|
||||
# - Ensures only genuinely relevant results are returned
|
||||
# - Recency/popularity alone (0.10 each) won't pass the threshold
|
||||
DEFAULT_MIN_SCORE = 0.20
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchResult:
|
||||
"""A single search result with score breakdown."""
|
||||
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_image: str
|
||||
creator_username: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
runs: int
|
||||
rating: float
|
||||
categories: list[str]
|
||||
featured: bool
|
||||
is_available: bool
|
||||
updated_at: datetime
|
||||
|
||||
# Score breakdown (for debugging/tuning)
|
||||
combined_score: float
|
||||
semantic_score: float = 0.0
|
||||
lexical_score: float = 0.0
|
||||
category_score: float = 0.0
|
||||
recency_score: float = 0.0
|
||||
popularity_score: float = 0.0
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
query: str,
|
||||
featured: bool = False,
|
||||
creators: list[str] | None = None,
|
||||
category: str | None = None,
|
||||
sorted_by: (
|
||||
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
|
||||
) = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
weights: HybridSearchWeights | None = None,
|
||||
min_score: float | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Perform hybrid search combining semantic and lexical signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
featured: Filter for featured agents only
|
||||
creators: Filter by creator usernames
|
||||
category: Filter by category
|
||||
sorted_by: Sort order (relevance uses hybrid scoring)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Results per page
|
||||
weights: Custom weights for search signals
|
||||
min_score: Minimum relevance score threshold (0-1). Results below
|
||||
this score are filtered out. Defaults to DEFAULT_MIN_SCORE.
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, total count). Returns empty list if no
|
||||
results meet the minimum relevance threshold.
|
||||
"""
|
||||
# Validate inputs
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return [], 0 # Empty query returns no results
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = 1
|
||||
if page_size > 100: # Cap at reasonable limit to prevent performance issues
|
||||
page_size = 100
|
||||
|
||||
if weights is None:
|
||||
weights = DEFAULT_WEIGHTS
|
||||
if min_score is None:
|
||||
min_score = DEFAULT_MIN_SCORE
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Build WHERE clause conditions
|
||||
where_parts: list[str] = ["sa.is_available = true"]
|
||||
params: list[Any] = []
|
||||
param_index = 1
|
||||
|
||||
# Add search query for lexical matching
|
||||
params.append(query)
|
||||
query_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Add lowercased query for category matching
|
||||
params.append(query.lower())
|
||||
query_lower_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
if featured:
|
||||
where_parts.append("sa.featured = true")
|
||||
|
||||
if creators:
|
||||
where_parts.append(f"sa.creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category:
|
||||
where_parts.append(f"${param_index} = ANY(sa.categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
# Safe: where_parts only contains hardcoded strings with $N parameter placeholders
|
||||
# No user input is concatenated directly into the SQL string
|
||||
where_clause = " AND ".join(where_parts)
|
||||
|
||||
# Embedding is required for hybrid search - fail fast if unavailable
|
||||
if query_embedding is None or not query_embedding:
|
||||
# Log detailed error server-side
|
||||
logger.error(
|
||||
"Failed to generate query embedding. "
|
||||
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
||||
)
|
||||
# Raise generic error to client
|
||||
raise ValueError("Search service temporarily unavailable")
|
||||
|
||||
# Add embedding parameter
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
params.append(embedding_str)
|
||||
embedding_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Add weight parameters for SQL calculation
|
||||
params.append(weights.semantic)
|
||||
weight_semantic_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
params.append(weights.lexical)
|
||||
weight_lexical_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
params.append(weights.category)
|
||||
weight_category_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
params.append(weights.recency)
|
||||
weight_recency_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
params.append(weights.popularity)
|
||||
weight_popularity_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Add min_score parameter
|
||||
params.append(min_score)
|
||||
min_score_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Optimized hybrid search query:
|
||||
# 1. Direct join to UnifiedContentEmbedding via contentId=storeListingVersionId (no redundant JOINs)
|
||||
# 2. UNION approach (deduplicates agents matching both branches)
|
||||
# 3. COUNT(*) OVER() to get total count in single query
|
||||
# 4. Optimized category matching with EXISTS + unnest
|
||||
# 5. Pre-calculated max values for lexical and popularity normalization
|
||||
# 6. Simplified recency calculation with linear decay
|
||||
# 7. Logarithmic popularity scaling to prevent viral agents from dominating
|
||||
sql_query = f"""
|
||||
WITH candidates AS (
|
||||
-- Lexical matches (uses GIN index on search column)
|
||||
SELECT sa."storeListingVersionId"
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
WHERE {where_clause}
|
||||
AND sa.search @@ plainto_tsquery('english', {query_param})
|
||||
|
||||
UNION
|
||||
|
||||
-- Semantic matches (uses HNSW index on embedding with KNN)
|
||||
SELECT "storeListingVersionId"
|
||||
FROM (
|
||||
SELECT sa."storeListingVersionId", uce.embedding
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
WHERE {where_clause}
|
||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||
LIMIT 200
|
||||
) semantic_results
|
||||
),
|
||||
search_scores AS (
|
||||
SELECT
|
||||
sa.slug,
|
||||
sa.agent_name,
|
||||
sa.agent_image,
|
||||
sa.creator_username,
|
||||
sa.creator_avatar,
|
||||
sa.sub_heading,
|
||||
sa.description,
|
||||
sa.runs,
|
||||
sa.rating,
|
||||
sa.categories,
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
-- Semantic score: cosine similarity (1 - distance)
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score: ts_rank_cd (will be normalized later)
|
||||
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match: optimized with unnest for better performance
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(sa.categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
|
||||
)
|
||||
THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
-- Recency score: linear decay over 90 days (simpler than exponential)
|
||||
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score,
|
||||
-- Popularity raw: agent runs count (will be normalized with log scaling)
|
||||
sa.runs as popularity_raw
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||
LEFT JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
),
|
||||
max_lexical AS (
|
||||
SELECT MAX(lexical_raw) as max_val FROM search_scores
|
||||
),
|
||||
max_popularity AS (
|
||||
SELECT MAX(popularity_raw) as max_val FROM search_scores
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
ss.*,
|
||||
-- Normalize lexical score by pre-calculated max
|
||||
CASE
|
||||
WHEN ml.max_val > 0
|
||||
THEN ss.lexical_raw / ml.max_val
|
||||
ELSE 0
|
||||
END as lexical_score,
|
||||
-- Normalize popularity with logarithmic scaling to prevent viral agents from dominating
|
||||
-- LOG(1 + runs) / LOG(1 + max_runs) ensures score is 0-1 range
|
||||
CASE
|
||||
WHEN mp.max_val > 0 AND ss.popularity_raw > 0
|
||||
THEN LN(1 + ss.popularity_raw) / LN(1 + mp.max_val)
|
||||
ELSE 0
|
||||
END as popularity_score
|
||||
FROM search_scores ss
|
||||
CROSS JOIN max_lexical ml
|
||||
CROSS JOIN max_popularity mp
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
popularity_score,
|
||||
(
|
||||
{weight_semantic_param} * semantic_score +
|
||||
{weight_lexical_param} * lexical_score +
|
||||
{weight_category_param} * category_score +
|
||||
{weight_recency_param} * recency_score +
|
||||
{weight_popularity_param} * popularity_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
),
|
||||
filtered AS (
|
||||
SELECT
|
||||
*,
|
||||
COUNT(*) OVER () as total_count
|
||||
FROM scored
|
||||
WHERE combined_score >= {min_score_param}
|
||||
)
|
||||
SELECT * FROM filtered
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||
"""
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
|
||||
# Execute search query - includes total_count via window function
|
||||
results = await query_raw_with_schema(
|
||||
sql_query, *params, set_public_search_path=True
|
||||
)
|
||||
|
||||
# Extract total count from first result (all rows have same count)
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
# Remove total_count from results before returning
|
||||
for result in results:
|
||||
result.pop("total_count", None)
|
||||
|
||||
# Log without sensitive query content
|
||||
logger.info(f"Hybrid search: {len(results)} results, {total} total")
|
||||
|
||||
return results, total
|
||||
|
||||
|
||||
async def hybrid_search_simple(
|
||||
query: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Simplified hybrid search for common use cases.
|
||||
|
||||
Uses default weights and no filters.
|
||||
"""
|
||||
return await hybrid_search(
|
||||
query=query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
Integration tests for hybrid search with schema handling.
|
||||
|
||||
These tests verify that hybrid search works correctly across different database schemas.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.hybrid_search import HybridSearchWeights, hybrid_search
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_schema_handling():
|
||||
"""Test that hybrid search correctly handles database schema prefixes."""
|
||||
# Test with a mock query to ensure schema handling works
|
||||
query = "test agent"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Mock the query result
|
||||
mock_query.return_value = [
|
||||
{
|
||||
"slug": "test/agent",
|
||||
"agent_name": "Test Agent",
|
||||
"agent_image": "test.png",
|
||||
"creator_username": "test",
|
||||
"creator_avatar": "avatar.png",
|
||||
"sub_heading": "Test sub-heading",
|
||||
"description": "Test description",
|
||||
"runs": 10,
|
||||
"rating": 4.5,
|
||||
"categories": ["test"],
|
||||
"featured": False,
|
||||
"is_available": True,
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"combined_score": 0.8,
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.6,
|
||||
"category_score": 0.5,
|
||||
"recency_score": 0.4,
|
||||
"total_count": 1,
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536 # Mock embedding
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query=query,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_query.called
|
||||
# Verify the SQL template uses schema_prefix placeholder
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
assert "{schema_prefix}" in sql_template
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert total == 1
|
||||
assert results[0]["slug"] == "test/agent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_public_schema():
|
||||
"""Test hybrid search when using public schema (no prefix needed)."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "public"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the mock was set up correctly
|
||||
assert mock_schema.return_value == "public"
|
||||
|
||||
# Results should work even with empty results
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_custom_schema():
|
||||
"""Test hybrid search when using custom schema (e.g., 'platform')."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the mock was set up correctly
|
||||
assert mock_schema.return_value == "platform"
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_without_embeddings():
|
||||
"""Test hybrid search fails fast when embeddings are unavailable."""
|
||||
# Patch where the function is used, not where it's defined
|
||||
with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed:
|
||||
# Simulate embedding failure
|
||||
mock_embed.return_value = None
|
||||
|
||||
# Should raise ValueError with helpful message
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify error message is generic (doesn't leak implementation details)
|
||||
assert "Search service temporarily unavailable" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_filters():
|
||||
"""Test hybrid search with various filters."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
# Test with featured filter
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
featured=True,
|
||||
creators=["user1", "user2"],
|
||||
category="productivity",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Verify filters were applied in the query
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0][1:] # Skip SQL template
|
||||
|
||||
# Should have query, query_lower, creators array, category
|
||||
assert len(params) >= 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_weights():
|
||||
"""Test hybrid search with custom weights."""
|
||||
custom_weights = HybridSearchWeights(
|
||||
semantic=0.5,
|
||||
lexical=0.3,
|
||||
category=0.1,
|
||||
recency=0.1,
|
||||
popularity=0.0,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
weights=custom_weights,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify custom weights were used in the query
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
params = call_args[0][1:] # Get all parameters passed
|
||||
|
||||
# Check that SQL uses parameterized weights (not f-string interpolation)
|
||||
assert "$" in sql_template # Verify parameterization is used
|
||||
|
||||
# Check that custom weights are in the params
|
||||
assert 0.5 in params # semantic weight
|
||||
assert 0.3 in params # lexical weight
|
||||
assert 0.1 in params # category and recency weights
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_min_score_filtering():
|
||||
"""Test hybrid search minimum score threshold."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Return results with varying scores
|
||||
mock_query.return_value = [
|
||||
{
|
||||
"slug": "high-score/agent",
|
||||
"agent_name": "High Score Agent",
|
||||
"combined_score": 0.8,
|
||||
"total_count": 1,
|
||||
# ... other fields
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
# Test with custom min_score
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
min_score=0.5, # High threshold
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify min_score was applied in query
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
params = call_args[0][1:] # Get all parameters
|
||||
|
||||
# Check that SQL uses parameterized min_score
|
||||
assert "combined_score >=" in sql_template
|
||||
assert "$" in sql_template # Verify parameterization
|
||||
|
||||
# Check that custom min_score is in the params
|
||||
assert 0.5 in params
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_pagination():
|
||||
"""Test hybrid search pagination."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
# Test page 2 with page_size 10
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=2,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Verify pagination parameters
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0]
|
||||
|
||||
# Last two params should be LIMIT and OFFSET
|
||||
limit = params[-2]
|
||||
offset = params[-1]
|
||||
|
||||
assert limit == 10 # page_size
|
||||
assert offset == 10 # (page - 1) * page_size = (2 - 1) * 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_error_handling():
|
||||
"""Test hybrid search error handling."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Simulate database error
|
||||
mock_query.side_effect = Exception("Database connection error")
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
# Should raise exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
assert "Database connection error" in str(exc_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -38,6 +38,20 @@ POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
|
||||
if POOL_TIMEOUT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
|
||||
|
||||
# Add public schema to search_path for pgvector type access
|
||||
# The vector extension is in public schema, but search_path is determined by schema parameter
|
||||
# Extract the schema from DATABASE_URL or default to 'platform'
|
||||
parsed_url = urlparse(DATABASE_URL)
|
||||
url_params = dict(parse_qsl(parsed_url.query))
|
||||
db_schema = url_params.get("schema", "platform")
|
||||
# Build search_path, avoiding duplicates if db_schema is already 'public'
|
||||
search_path_schemas = list(
|
||||
dict.fromkeys([db_schema, "public"])
|
||||
) # Preserves order, removes duplicates
|
||||
search_path = ",".join(search_path_schemas)
|
||||
# This allows using ::vector without schema qualification
|
||||
DATABASE_URL = add_param(DATABASE_URL, "options", f"-c search_path={search_path}")
|
||||
|
||||
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
|
||||
|
||||
prisma = Prisma(
|
||||
@@ -108,21 +122,102 @@ def get_database_schema() -> str:
|
||||
return query_params.get("schema", "public")
|
||||
|
||||
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
"""Execute raw SQL query with proper schema handling."""
|
||||
async def _raw_with_schema(
|
||||
query_template: str,
|
||||
*args,
|
||||
execute: bool = False,
|
||||
client: Prisma | None = None,
|
||||
set_public_search_path: bool = False,
|
||||
) -> list[dict] | int:
|
||||
"""Internal: Execute raw SQL with proper schema handling.
|
||||
|
||||
Use query_raw_with_schema() or execute_raw_with_schema() instead.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE.
|
||||
client: Optional Prisma client for transactions (only used when execute=True).
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
- list[dict] if execute=False (query results)
|
||||
- int if execute=True (number of affected rows)
|
||||
"""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
result = await prisma_module.get_client().query_raw(
|
||||
formatted_query, *args # type: ignore
|
||||
)
|
||||
db_client = client if client else prisma_module.get_client()
|
||||
|
||||
# Set search_path to include public schema if requested
|
||||
# Prisma doesn't support the 'options' connection parameter, so we set it per-session
|
||||
# This is idempotent and safe to call multiple times
|
||||
if set_public_search_path:
|
||||
await db_client.execute_raw(f"SET search_path = {schema}, public") # type: ignore
|
||||
|
||||
if execute:
|
||||
result = await db_client.execute_raw(formatted_query, *args) # type: ignore
|
||||
else:
|
||||
result = await db_client.query_raw(formatted_query, *args) # type: ignore
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def query_raw_with_schema(
|
||||
query_template: str, *args, set_public_search_path: bool = False
|
||||
) -> list[dict]:
|
||||
"""Execute raw SQL SELECT query with proper schema handling.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
List of result rows as dictionaries
|
||||
|
||||
Example:
|
||||
results = await query_raw_with_schema(
|
||||
'SELECT * FROM {schema_prefix}"User" WHERE id = $1',
|
||||
user_id
|
||||
)
|
||||
"""
|
||||
return await _raw_with_schema(query_template, *args, execute=False, set_public_search_path=set_public_search_path) # type: ignore
|
||||
|
||||
|
||||
async def execute_raw_with_schema(
|
||||
query_template: str,
|
||||
*args,
|
||||
client: Prisma | None = None,
|
||||
set_public_search_path: bool = False,
|
||||
) -> int:
|
||||
"""Execute raw SQL command (INSERT/UPDATE/DELETE) with proper schema handling.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
client: Optional Prisma client for transactions
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
Number of affected rows
|
||||
|
||||
Example:
|
||||
await execute_raw_with_schema(
|
||||
'INSERT INTO {schema_prefix}"User" (id, name) VALUES ($1, $2)',
|
||||
user_id, name,
|
||||
client=tx # Optional transaction client
|
||||
)
|
||||
"""
|
||||
return await _raw_with_schema(query_template, *args, execute=True, client=client, set_public_search_path=set_public_search_path) # type: ignore
|
||||
|
||||
|
||||
class BaseDbModel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import fastapi.exceptions
|
||||
@@ -18,6 +19,17 @@ from backend.usecases.sample import create_test_user
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot):
|
||||
"""
|
||||
|
||||
@@ -7,6 +7,10 @@ from backend.api.features.library.db import (
|
||||
list_library_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -208,6 +212,10 @@ class DatabaseManager(AppService):
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
|
||||
# Store Embeddings
|
||||
get_embedding_stats = _(get_embedding_stats)
|
||||
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
||||
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
@@ -259,6 +267,10 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
get_store_agents = _(d.get_store_agents)
|
||||
get_store_agent_details = _(d.get_store_agent_details)
|
||||
|
||||
# Store Embeddings
|
||||
get_embedding_stats = _(d.get_embedding_stats)
|
||||
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
@@ -19,6 +20,17 @@ from backend.util.test import SpinTestServer, wait_execution
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
|
||||
logger.info(f"Creating graph for user {u.id}")
|
||||
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
@@ -37,7 +38,7 @@ from backend.monitoring import (
|
||||
report_execution_accuracy_alerts,
|
||||
report_late_executions,
|
||||
)
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.clients import get_database_manager_client, get_scheduler_client
|
||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||
from backend.util.exceptions import (
|
||||
GraphNotFoundError,
|
||||
@@ -254,6 +255,88 @@ def execution_accuracy_alerts():
|
||||
return report_execution_accuracy_alerts()
|
||||
|
||||
|
||||
def ensure_embeddings_coverage():
|
||||
"""
|
||||
Ensure all content types (store agents, blocks, docs) have embeddings for search.
|
||||
|
||||
Processes ALL missing embeddings in batches of 10 per content type until 100% coverage.
|
||||
Missing embeddings = content invisible in hybrid search.
|
||||
|
||||
Schedule: Runs every 6 hours (balanced between coverage and API costs).
|
||||
- Catches new content added between scheduled runs
|
||||
- Batch size 10 per content type: gradual processing to avoid rate limits
|
||||
- Manual trigger available via execute_ensure_embeddings_coverage endpoint
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
stats = db_client.get_embedding_stats()
|
||||
|
||||
# Check for error from get_embedding_stats() first
|
||||
if "error" in stats:
|
||||
logger.error(
|
||||
f"Failed to get embedding stats: {stats['error']} - skipping backfill"
|
||||
)
|
||||
return {"processed": 0, "success": 0, "failed": 0, "error": stats["error"]}
|
||||
|
||||
# Extract totals from new stats structure
|
||||
totals = stats.get("totals", {})
|
||||
without_embeddings = totals.get("without_embeddings", 0)
|
||||
coverage_percent = totals.get("coverage_percent", 0)
|
||||
|
||||
if without_embeddings == 0:
|
||||
logger.info("All content has embeddings, skipping backfill")
|
||||
return {"processed": 0, "success": 0, "failed": 0}
|
||||
|
||||
# Log per-content-type stats for visibility
|
||||
by_type = stats.get("by_type", {})
|
||||
for content_type, type_stats in by_type.items():
|
||||
if type_stats.get("without_embeddings", 0) > 0:
|
||||
logger.info(
|
||||
f"{content_type}: {type_stats['without_embeddings']} items without embeddings "
|
||||
f"({type_stats['coverage_percent']}% coverage)"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Total: {without_embeddings} items without embeddings "
|
||||
f"({coverage_percent}% coverage) - processing all"
|
||||
)
|
||||
|
||||
total_processed = 0
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
|
||||
# Process in batches until no more missing embeddings
|
||||
while True:
|
||||
result = db_client.backfill_missing_embeddings(batch_size=10)
|
||||
|
||||
total_processed += result["processed"]
|
||||
total_success += result["success"]
|
||||
total_failed += result["failed"]
|
||||
|
||||
if result["processed"] == 0:
|
||||
# No more missing embeddings
|
||||
break
|
||||
|
||||
if result["success"] == 0 and result["processed"] > 0:
|
||||
# All attempts in this batch failed - stop to avoid infinite loop
|
||||
logger.error(
|
||||
f"All {result['processed']} embedding attempts failed - stopping backfill"
|
||||
)
|
||||
break
|
||||
|
||||
# Small delay between batches to avoid rate limits
|
||||
time.sleep(1)
|
||||
|
||||
logger.info(
|
||||
f"Embedding backfill completed: {total_success}/{total_processed} succeeded, "
|
||||
f"{total_failed} failed"
|
||||
)
|
||||
return {
|
||||
"processed": total_processed,
|
||||
"success": total_success,
|
||||
"failed": total_failed,
|
||||
}
|
||||
|
||||
|
||||
# Monitoring functions are now imported from monitoring module
|
||||
|
||||
|
||||
@@ -475,6 +558,19 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# Embedding Coverage - Every 6 hours
|
||||
# Ensures all approved agents have embeddings for hybrid search
|
||||
# Critical: missing embeddings = agents invisible in search
|
||||
self.scheduler.add_job(
|
||||
ensure_embeddings_coverage,
|
||||
id="ensure_embeddings_coverage",
|
||||
trigger="interval",
|
||||
hours=6,
|
||||
replace_existing=True,
|
||||
max_instances=1, # Prevent overlapping runs
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
|
||||
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
||||
@@ -632,6 +728,11 @@ class Scheduler(AppService):
|
||||
"""Manually trigger execution accuracy alert checking."""
|
||||
return execution_accuracy_alerts()
|
||||
|
||||
@expose
|
||||
def execute_ensure_embeddings_coverage(self):
|
||||
"""Manually trigger embedding backfill for approved store agents."""
|
||||
return ensure_embeddings_coverage()
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.util.settings import Settings
|
||||
settings = Settings()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
from supabase import AClient, Client
|
||||
|
||||
from backend.data.execution import (
|
||||
@@ -139,6 +140,24 @@ async def get_async_supabase() -> "AClient":
|
||||
)
|
||||
|
||||
|
||||
# ============ OpenAI Client ============ #
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_openai_client() -> "AsyncOpenAI | None":
|
||||
"""
|
||||
Get a process-cached async OpenAI client for embeddings.
|
||||
|
||||
Returns None if API key is not configured.
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
api_key = settings.secrets.openai_internal_api_key
|
||||
if not api_key:
|
||||
return None
|
||||
return AsyncOpenAI(api_key=api_key)
|
||||
|
||||
|
||||
# ============ Notification Queue Helpers ============ #
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
-- CreateExtension
|
||||
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
||||
-- Create in public schema so vector type is available across all schemas
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "vector" WITH SCHEMA "public";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
||||
END $$;
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "ContentType" AS ENUM ('STORE_AGENT', 'BLOCK', 'INTEGRATION', 'DOCUMENTATION', 'LIBRARY_AGENT');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "UnifiedContentEmbedding" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"contentType" "ContentType" NOT NULL,
|
||||
"contentId" TEXT NOT NULL,
|
||||
"userId" TEXT,
|
||||
"embedding" public.vector(1536) NOT NULL,
|
||||
"searchableText" TEXT NOT NULL,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "UnifiedContentEmbedding_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UnifiedContentEmbedding_contentType_idx" ON "UnifiedContentEmbedding"("contentType");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UnifiedContentEmbedding_userId_idx" ON "UnifiedContentEmbedding"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UnifiedContentEmbedding_contentType_userId_idx" ON "UnifiedContentEmbedding"("contentType", "userId");
|
||||
|
||||
-- CreateIndex
|
||||
-- NULLS NOT DISTINCT ensures only one public (NULL userId) embedding per contentType+contentId
|
||||
-- Requires PostgreSQL 15+. Supabase uses PostgreSQL 15+.
|
||||
CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" ON "UnifiedContentEmbedding"("contentType", "contentId", "userId") NULLS NOT DISTINCT;
|
||||
|
||||
-- CreateIndex
|
||||
-- HNSW index for fast vector similarity search on embeddings
|
||||
-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py
|
||||
CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops);
|
||||
@@ -0,0 +1,71 @@
|
||||
-- Acknowledge Supabase-managed extensions to prevent drift warnings
|
||||
-- These extensions are pre-installed by Supabase in specific schemas
|
||||
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
|
||||
|
||||
-- Create schemas (safe in both CI and Supabase)
|
||||
CREATE SCHEMA IF NOT EXISTS "extensions";
|
||||
|
||||
-- Extensions that exist in both CI and Supabase
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgcrypto extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'uuid-ossp extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
-- Supabase-specific extensions (skip gracefully in CI)
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_net extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgjwt extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "graphql";
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_graphql extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "pgsodium";
|
||||
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgsodium extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "vault";
|
||||
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'supabase_vault extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
|
||||
-- Return to platform
|
||||
CREATE SCHEMA IF NOT EXISTS "platform";
|
||||
@@ -1,14 +1,15 @@
|
||||
datasource db {
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
extensions = [pgvector(map: "vector")]
|
||||
}
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = -1
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views", "fullTextSearch"]
|
||||
previewFeatures = ["views", "fullTextSearch", "postgresqlExtensions"]
|
||||
partial_type_generator = "backend/data/partial_types.py"
|
||||
}
|
||||
|
||||
@@ -127,8 +128,8 @@ model BuilderSearchHistory {
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
searchQuery String
|
||||
filter String[] @default([])
|
||||
byCreator String[] @default([])
|
||||
filter String[] @default([])
|
||||
byCreator String[] @default([])
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
@@ -721,26 +722,25 @@ view StoreAgent {
|
||||
storeListingVersionId String
|
||||
updated_at DateTime
|
||||
|
||||
slug String
|
||||
agent_name String
|
||||
agent_video String?
|
||||
agent_output_demo String?
|
||||
agent_image String[]
|
||||
slug String
|
||||
agent_name String
|
||||
agent_video String?
|
||||
agent_output_demo String?
|
||||
agent_image String[]
|
||||
|
||||
featured Boolean @default(false)
|
||||
creator_username String?
|
||||
creator_avatar String?
|
||||
sub_heading String
|
||||
description String
|
||||
categories String[]
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
agentGraphVersions String[]
|
||||
agentGraphId String
|
||||
is_available Boolean @default(true)
|
||||
useForOnboarding Boolean @default(false)
|
||||
featured Boolean @default(false)
|
||||
creator_username String?
|
||||
creator_avatar String?
|
||||
sub_heading String
|
||||
description String
|
||||
categories String[]
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
agentGraphVersions String[]
|
||||
agentGraphId String
|
||||
is_available Boolean @default(true)
|
||||
useForOnboarding Boolean @default(false)
|
||||
|
||||
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
||||
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
||||
@@ -856,14 +856,14 @@ model StoreListingVersion {
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
|
||||
|
||||
// Content fields
|
||||
name String
|
||||
subHeading String
|
||||
videoUrl String?
|
||||
agentOutputDemoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
name String
|
||||
subHeading String
|
||||
videoUrl String?
|
||||
agentOutputDemoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
|
||||
isFeatured Boolean @default(false)
|
||||
|
||||
@@ -899,6 +899,9 @@ model StoreListingVersion {
|
||||
// Reviews for this specific version
|
||||
Reviews StoreListingReview[]
|
||||
|
||||
// Note: Embeddings now stored in UnifiedContentEmbedding table
|
||||
// Use contentType=STORE_AGENT and contentId=storeListingVersionId
|
||||
|
||||
@@unique([storeListingId, version])
|
||||
@@index([storeListingId, submissionStatus, isAvailable])
|
||||
@@index([submissionStatus])
|
||||
@@ -906,6 +909,42 @@ model StoreListingVersion {
|
||||
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
|
||||
}
|
||||
|
||||
// Content type enum for unified search across store agents, blocks, docs
|
||||
// Note: BLOCK/INTEGRATION are file-based (Python classes), not DB records
|
||||
// DOCUMENTATION are file-based (.md files), not DB records
|
||||
// Only STORE_AGENT and LIBRARY_AGENT are stored in database
|
||||
enum ContentType {
|
||||
STORE_AGENT // Database: StoreListingVersion
|
||||
BLOCK // File-based: Python classes in /backend/blocks/
|
||||
INTEGRATION // File-based: Python classes (blocks with credentials)
|
||||
DOCUMENTATION // File-based: .md/.mdx files
|
||||
LIBRARY_AGENT // Database: User's personal agents
|
||||
}
|
||||
|
||||
// Unified embeddings table for all searchable content types
|
||||
// Supports both public content (userId=null) and user-specific content (userId=userID)
|
||||
model UnifiedContentEmbedding {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Content identification
|
||||
contentType ContentType
|
||||
contentId String // DB ID (storeListingVersionId) or file identifier (block.id, file_path)
|
||||
userId String? // NULL for public content (store, blocks, docs), userId for private content (library agents)
|
||||
|
||||
// Search data
|
||||
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
|
||||
searchableText String // Combined text for search and fallback
|
||||
metadata Json @default("{}") // Content-specific metadata
|
||||
|
||||
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
|
||||
@@index([contentType])
|
||||
@@index([userId])
|
||||
@@index([contentType, userId])
|
||||
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
@@ -998,16 +1037,16 @@ model OAuthApplication {
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Application metadata
|
||||
name String
|
||||
description String?
|
||||
logoUrl String? // URL to app logo stored in GCS
|
||||
clientId String @unique
|
||||
clientSecret String // Hashed with Scrypt (same as API keys)
|
||||
clientSecretSalt String // Salt for Scrypt hashing
|
||||
name String
|
||||
description String?
|
||||
logoUrl String? // URL to app logo stored in GCS
|
||||
clientId String @unique
|
||||
clientSecret String // Hashed with Scrypt (same as API keys)
|
||||
clientSecretSalt String // Salt for Scrypt hashing
|
||||
|
||||
// OAuth configuration
|
||||
redirectUris String[] // Allowed callback URLs
|
||||
grantTypes String[] @default(["authorization_code", "refresh_token"])
|
||||
grantTypes String[] @default(["authorization_code", "refresh_token"])
|
||||
scopes APIKeyPermission[] // Which permissions the app can request
|
||||
|
||||
// Application management
|
||||
|
||||
@@ -81,16 +81,18 @@ export const RunInputDialog = ({
|
||||
Inputs
|
||||
</Text>
|
||||
</div>
|
||||
<FormRenderer
|
||||
jsonSchema={inputSchema as RJSFSchema}
|
||||
handleChange={(v) => handleInputChange(v.formData)}
|
||||
uiSchema={uiSchema}
|
||||
initialValues={{}}
|
||||
formContext={{
|
||||
showHandles: false,
|
||||
size: "large",
|
||||
}}
|
||||
/>
|
||||
<div className="px-2">
|
||||
<FormRenderer
|
||||
jsonSchema={inputSchema as RJSFSchema}
|
||||
handleChange={(v) => handleInputChange(v.formData)}
|
||||
uiSchema={uiSchema}
|
||||
initialValues={{}}
|
||||
formContext={{
|
||||
showHandles: false,
|
||||
size: "large",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import { useGetV2GetSpecificBlocks } from "@/app/api/__generated__/endpoints/def
|
||||
import {
|
||||
useGetV1GetExecutionDetails,
|
||||
useGetV1GetSpecificGraph,
|
||||
useGetV1ListUserGraphs,
|
||||
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||
@@ -18,7 +17,6 @@ import { useReactFlow } from "@xyflow/react";
|
||||
import { useControlPanelStore } from "../../../stores/controlPanelStore";
|
||||
import { useHistoryStore } from "../../../stores/historyStore";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
|
||||
export const useFlow = () => {
|
||||
const [isLocked, setIsLocked] = useState(false);
|
||||
@@ -38,9 +36,6 @@ export const useFlow = () => {
|
||||
const setGraphExecutionStatus = useGraphStore(
|
||||
useShallow((state) => state.setGraphExecutionStatus),
|
||||
);
|
||||
const setAvailableSubGraphs = useGraphStore(
|
||||
useShallow((state) => state.setAvailableSubGraphs),
|
||||
);
|
||||
const updateEdgeBeads = useEdgeStore(
|
||||
useShallow((state) => state.updateEdgeBeads),
|
||||
);
|
||||
@@ -67,11 +62,6 @@ export const useFlow = () => {
|
||||
},
|
||||
);
|
||||
|
||||
// Fetch all available graphs for sub-agent update detection
|
||||
const { data: availableGraphs } = useGetV1ListUserGraphs({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const { data: graph, isLoading: isGraphLoading } = useGetV1GetSpecificGraph(
|
||||
flowID ?? "",
|
||||
flowVersion !== null ? { version: flowVersion } : {},
|
||||
@@ -126,18 +116,10 @@ export const useFlow = () => {
|
||||
}
|
||||
}, [graph]);
|
||||
|
||||
// Update available sub-graphs in store for sub-agent update detection
|
||||
useEffect(() => {
|
||||
if (availableGraphs) {
|
||||
setAvailableSubGraphs(availableGraphs);
|
||||
}
|
||||
}, [availableGraphs, setAvailableSubGraphs]);
|
||||
|
||||
// adding nodes
|
||||
useEffect(() => {
|
||||
if (customNodes.length > 0) {
|
||||
useNodeStore.getState().setNodes([]);
|
||||
useNodeStore.getState().clearResolutionState();
|
||||
addNodes(customNodes);
|
||||
|
||||
// Sync hardcoded values with handle IDs.
|
||||
@@ -221,7 +203,6 @@ export const useFlow = () => {
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
useNodeStore.getState().setNodes([]);
|
||||
useNodeStore.getState().clearResolutionState();
|
||||
useEdgeStore.getState().setEdges([]);
|
||||
useGraphStore.getState().reset();
|
||||
useEdgeStore.getState().resetEdgeBeads();
|
||||
|
||||
@@ -8,7 +8,6 @@ import {
|
||||
getBezierPath,
|
||||
} from "@xyflow/react";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { XIcon } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { NodeExecutionResult } from "@/lib/autogpt-server-api";
|
||||
@@ -36,8 +35,6 @@ const CustomEdge = ({
|
||||
selected,
|
||||
}: EdgeProps<CustomEdge>) => {
|
||||
const removeConnection = useEdgeStore((state) => state.removeEdge);
|
||||
// Subscribe to the brokenEdgeIDs map and check if this edge is broken across any node
|
||||
const isBroken = useNodeStore((state) => state.isEdgeBroken(id));
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
@@ -53,12 +50,6 @@ const CustomEdge = ({
|
||||
const beadUp = data?.beadUp ?? 0;
|
||||
const beadDown = data?.beadDown ?? 0;
|
||||
|
||||
const handleRemoveEdge = () => {
|
||||
removeConnection(id);
|
||||
// Note: broken edge tracking is cleaned up automatically by useSubAgentUpdateState
|
||||
// when it detects the edge no longer exists
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<BaseEdge
|
||||
@@ -66,11 +57,9 @@ const CustomEdge = ({
|
||||
markerEnd={markerEnd}
|
||||
className={cn(
|
||||
isStatic && "!stroke-[1.5px] [stroke-dasharray:6]",
|
||||
isBroken
|
||||
? "!stroke-red-500 !stroke-[2px] [stroke-dasharray:4]"
|
||||
: selected
|
||||
? "stroke-zinc-800"
|
||||
: "stroke-zinc-500/50 hover:stroke-zinc-500",
|
||||
selected
|
||||
? "stroke-zinc-800"
|
||||
: "stroke-zinc-500/50 hover:stroke-zinc-500",
|
||||
)}
|
||||
/>
|
||||
<JSBeads
|
||||
@@ -81,16 +70,12 @@ const CustomEdge = ({
|
||||
/>
|
||||
<EdgeLabelRenderer>
|
||||
<Button
|
||||
onClick={handleRemoveEdge}
|
||||
onClick={() => removeConnection(id)}
|
||||
className={cn(
|
||||
"absolute h-fit min-w-0 p-1 transition-opacity",
|
||||
isBroken
|
||||
? "bg-red-500 opacity-100 hover:bg-red-600"
|
||||
: isHovered
|
||||
? "opacity-100"
|
||||
: "opacity-0",
|
||||
isHovered ? "opacity-100" : "opacity-0",
|
||||
)}
|
||||
variant={isBroken ? "primary" : "secondary"}
|
||||
variant="secondary"
|
||||
style={{
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
||||
pointerEvents: "all",
|
||||
|
||||
@@ -3,7 +3,6 @@ import { Handle, Position } from "@xyflow/react";
|
||||
import { useEdgeStore } from "../../../stores/edgeStore";
|
||||
import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
|
||||
const InputNodeHandle = ({
|
||||
handleId,
|
||||
@@ -16,9 +15,6 @@ const InputNodeHandle = ({
|
||||
const isInputConnected = useEdgeStore((state) =>
|
||||
state.isInputConnected(nodeId ?? "", cleanedHandleId),
|
||||
);
|
||||
const isInputBroken = useNodeStore((state) =>
|
||||
state.isInputBroken(nodeId, cleanedHandleId),
|
||||
);
|
||||
|
||||
return (
|
||||
<Handle
|
||||
@@ -31,10 +27,7 @@ const InputNodeHandle = ({
|
||||
<CircleIcon
|
||||
size={16}
|
||||
weight={isInputConnected ? "fill" : "duotone"}
|
||||
className={cn(
|
||||
"text-gray-400 opacity-100",
|
||||
isInputBroken && "text-red-500",
|
||||
)}
|
||||
className={"text-gray-400 opacity-100"}
|
||||
/>
|
||||
</div>
|
||||
</Handle>
|
||||
@@ -45,17 +38,14 @@ const OutputNodeHandle = ({
|
||||
field_name,
|
||||
nodeId,
|
||||
hexColor,
|
||||
isBroken,
|
||||
}: {
|
||||
field_name: string;
|
||||
nodeId: string;
|
||||
hexColor: string;
|
||||
isBroken: boolean;
|
||||
}) => {
|
||||
const isOutputConnected = useEdgeStore((state) =>
|
||||
state.isOutputConnected(nodeId, field_name),
|
||||
);
|
||||
|
||||
return (
|
||||
<Handle
|
||||
type={"source"}
|
||||
@@ -68,10 +58,7 @@ const OutputNodeHandle = ({
|
||||
size={16}
|
||||
weight={"duotone"}
|
||||
color={isOutputConnected ? hexColor : "gray"}
|
||||
className={cn(
|
||||
"text-gray-400 opacity-100",
|
||||
isBroken && "text-red-500",
|
||||
)}
|
||||
className={cn("text-gray-400 opacity-100")}
|
||||
/>
|
||||
</div>
|
||||
</Handle>
|
||||
|
||||
@@ -20,8 +20,6 @@ import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput";
|
||||
import { NodeRightClickMenu } from "./components/NodeRightClickMenu";
|
||||
import { StickyNoteBlock } from "./components/StickyNoteBlock";
|
||||
import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
|
||||
import { SubAgentUpdateFeature } from "./components/SubAgentUpdate/SubAgentUpdateFeature";
|
||||
import { useCustomNode } from "./useCustomNode";
|
||||
|
||||
export type CustomNodeData = {
|
||||
hardcodedValues: {
|
||||
@@ -47,10 +45,6 @@ export type CustomNode = XYNode<CustomNodeData, "custom">;
|
||||
|
||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
({ data, id: nodeId, selected }) => {
|
||||
const { inputSchema, outputSchema } = useCustomNode({ data, nodeId });
|
||||
|
||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||
|
||||
if (data.uiType === BlockUIType.NOTE) {
|
||||
return (
|
||||
<StickyNoteBlock data={data} selected={selected} nodeId={nodeId} />
|
||||
@@ -69,6 +63,16 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
|
||||
const isAyrshare = data.uiType === BlockUIType.AYRSHARE;
|
||||
|
||||
const inputSchema =
|
||||
data.uiType === BlockUIType.AGENT
|
||||
? (data.hardcodedValues.input_schema ?? {})
|
||||
: data.inputSchema;
|
||||
|
||||
const outputSchema =
|
||||
data.uiType === BlockUIType.AGENT
|
||||
? (data.hardcodedValues.output_schema ?? {})
|
||||
: data.outputSchema;
|
||||
|
||||
const hasConfigErrors =
|
||||
data.errors &&
|
||||
Object.values(data.errors).some(
|
||||
@@ -83,11 +87,12 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
|
||||
const hasErrors = hasConfigErrors || hasOutputError;
|
||||
|
||||
// Currently all blockTypes design are similar - that's why i am using the same component for all of them
|
||||
// If in future - if we need some drastic change in some blockTypes design - we can create separate components for them
|
||||
const node = (
|
||||
<NodeContainer selected={selected} nodeId={nodeId} hasErrors={hasErrors}>
|
||||
<div className="rounded-xlarge bg-white">
|
||||
<NodeHeader data={data} nodeId={nodeId} />
|
||||
{isAgent && <SubAgentUpdateFeature nodeID={nodeId} nodeData={data} />}
|
||||
{isWebhook && <WebhookDisclaimer nodeId={nodeId} />}
|
||||
{isAyrshare && <AyrshareConnectButton />}
|
||||
<FormCreator
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
import React from "react";
|
||||
import { ArrowUpIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { cn, beautifyString } from "@/lib/utils";
|
||||
import { CustomNodeData } from "../../CustomNode";
|
||||
import { useSubAgentUpdateState } from "./useSubAgentUpdateState";
|
||||
import { IncompatibleUpdateDialog } from "./components/IncompatibleUpdateDialog";
|
||||
import { ResolutionModeBar } from "./components/ResolutionModeBar";
|
||||
|
||||
/**
|
||||
* Inline component for the update bar that can be placed after the header.
|
||||
* Use this inside the node content where you want the bar to appear.
|
||||
*/
|
||||
type SubAgentUpdateFeatureProps = {
|
||||
nodeID: string;
|
||||
nodeData: CustomNodeData;
|
||||
};
|
||||
|
||||
export function SubAgentUpdateFeature({
|
||||
nodeID,
|
||||
nodeData,
|
||||
}: SubAgentUpdateFeatureProps) {
|
||||
const {
|
||||
updateInfo,
|
||||
isInResolutionMode,
|
||||
handleUpdateClick,
|
||||
showIncompatibilityDialog,
|
||||
setShowIncompatibilityDialog,
|
||||
handleConfirmIncompatibleUpdate,
|
||||
} = useSubAgentUpdateState({ nodeID: nodeID, nodeData: nodeData });
|
||||
|
||||
const agentName = nodeData.title || "Agent";
|
||||
|
||||
if (!updateInfo.hasUpdate && !isInResolutionMode) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{isInResolutionMode ? (
|
||||
<ResolutionModeBar incompatibilities={updateInfo.incompatibilities} />
|
||||
) : (
|
||||
<SubAgentUpdateAvailableBar
|
||||
currentVersion={updateInfo.currentVersion}
|
||||
latestVersion={updateInfo.latestVersion}
|
||||
isCompatible={updateInfo.isCompatible}
|
||||
onUpdate={handleUpdateClick}
|
||||
/>
|
||||
)}
|
||||
{/* Incompatibility dialog - rendered here since this component owns the state */}
|
||||
{updateInfo.incompatibilities && (
|
||||
<IncompatibleUpdateDialog
|
||||
isOpen={showIncompatibilityDialog}
|
||||
onClose={() => setShowIncompatibilityDialog(false)}
|
||||
onConfirm={handleConfirmIncompatibleUpdate}
|
||||
currentVersion={updateInfo.currentVersion}
|
||||
latestVersion={updateInfo.latestVersion}
|
||||
agentName={beautifyString(agentName)}
|
||||
incompatibilities={updateInfo.incompatibilities}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
type SubAgentUpdateAvailableBarProps = {
|
||||
currentVersion: number;
|
||||
latestVersion: number;
|
||||
isCompatible: boolean;
|
||||
onUpdate: () => void;
|
||||
};
|
||||
|
||||
function SubAgentUpdateAvailableBar({
|
||||
currentVersion,
|
||||
latestVersion,
|
||||
isCompatible,
|
||||
onUpdate,
|
||||
}: SubAgentUpdateAvailableBarProps): React.ReactElement {
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 rounded-t-xl bg-blue-50 px-3 py-2 dark:bg-blue-900/30">
|
||||
<div className="flex items-center gap-2">
|
||||
<ArrowUpIcon className="h-4 w-4 text-blue-600 dark:text-blue-400" />
|
||||
<span className="text-sm text-blue-700 dark:text-blue-300">
|
||||
Update available (v{currentVersion} → v{latestVersion})
|
||||
</span>
|
||||
{!isCompatible && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<WarningIcon className="h-4 w-4 text-amber-500" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="max-w-xs">
|
||||
<p className="font-medium">Incompatible changes detected</p>
|
||||
<p className="text-xs text-gray-400">
|
||||
Click Update to see details
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
size="small"
|
||||
variant={isCompatible ? "primary" : "outline"}
|
||||
onClick={onUpdate}
|
||||
className={cn(
|
||||
"h-7 text-xs",
|
||||
!isCompatible && "border-amber-500 text-amber-600 hover:bg-amber-50",
|
||||
)}
|
||||
>
|
||||
Update
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,274 +0,0 @@
|
||||
import React from "react";
|
||||
import {
|
||||
WarningIcon,
|
||||
XCircleIcon,
|
||||
PlusCircleIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import { IncompatibilityInfo } from "@/app/(platform)/build/hooks/useSubAgentUpdate/types";
|
||||
|
||||
type IncompatibleUpdateDialogProps = {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
onConfirm: () => void;
|
||||
currentVersion: number;
|
||||
latestVersion: number;
|
||||
agentName: string;
|
||||
incompatibilities: IncompatibilityInfo;
|
||||
};
|
||||
|
||||
export function IncompatibleUpdateDialog({
|
||||
isOpen,
|
||||
onClose,
|
||||
onConfirm,
|
||||
currentVersion,
|
||||
latestVersion,
|
||||
agentName,
|
||||
incompatibilities,
|
||||
}: IncompatibleUpdateDialogProps) {
|
||||
const hasMissingInputs = incompatibilities.missingInputs.length > 0;
|
||||
const hasMissingOutputs = incompatibilities.missingOutputs.length > 0;
|
||||
const hasNewInputs = incompatibilities.newInputs.length > 0;
|
||||
const hasNewOutputs = incompatibilities.newOutputs.length > 0;
|
||||
const hasNewRequired = incompatibilities.newRequiredInputs.length > 0;
|
||||
const hasTypeMismatches = incompatibilities.inputTypeMismatches.length > 0;
|
||||
|
||||
const hasInputChanges = hasMissingInputs || hasNewInputs;
|
||||
const hasOutputChanges = hasMissingOutputs || hasNewOutputs;
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title={
|
||||
<div className="flex items-center gap-2">
|
||||
<WarningIcon className="h-5 w-5 text-amber-500" weight="fill" />
|
||||
Incompatible Update
|
||||
</div>
|
||||
}
|
||||
controlled={{
|
||||
isOpen,
|
||||
set: async (open) => {
|
||||
if (!open) onClose();
|
||||
},
|
||||
}}
|
||||
onClose={onClose}
|
||||
styling={{ maxWidth: "32rem" }}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-4">
|
||||
<p className="text-sm text-gray-600 dark:text-gray-400">
|
||||
Updating <strong>{beautifyString(agentName)}</strong> from v
|
||||
{currentVersion} to v{latestVersion} will break some connections.
|
||||
</p>
|
||||
|
||||
{/* Input changes - two column layout */}
|
||||
{hasInputChanges && (
|
||||
<TwoColumnSection
|
||||
title="Input Changes"
|
||||
leftIcon={
|
||||
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
||||
}
|
||||
leftTitle="Removed"
|
||||
leftItems={incompatibilities.missingInputs}
|
||||
rightIcon={
|
||||
<PlusCircleIcon
|
||||
className="h-4 w-4 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
}
|
||||
rightTitle="Added"
|
||||
rightItems={incompatibilities.newInputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Output changes - two column layout */}
|
||||
{hasOutputChanges && (
|
||||
<TwoColumnSection
|
||||
title="Output Changes"
|
||||
leftIcon={
|
||||
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
||||
}
|
||||
leftTitle="Removed"
|
||||
leftItems={incompatibilities.missingOutputs}
|
||||
rightIcon={
|
||||
<PlusCircleIcon
|
||||
className="h-4 w-4 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
}
|
||||
rightTitle="Added"
|
||||
rightItems={incompatibilities.newOutputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasTypeMismatches && (
|
||||
<SingleColumnSection
|
||||
icon={
|
||||
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
||||
}
|
||||
title="Type Changed"
|
||||
description="These connected inputs have a different type:"
|
||||
items={incompatibilities.inputTypeMismatches.map(
|
||||
(m) => `${m.name} (${m.oldType} → ${m.newType})`,
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasNewRequired && (
|
||||
<SingleColumnSection
|
||||
icon={
|
||||
<PlusCircleIcon
|
||||
className="h-4 w-4 text-amber-500"
|
||||
weight="fill"
|
||||
/>
|
||||
}
|
||||
title="New Required Inputs"
|
||||
description="These inputs are now required:"
|
||||
items={incompatibilities.newRequiredInputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Alert variant="warning">
|
||||
<AlertDescription>
|
||||
If you proceed, you'll need to remove the broken connections
|
||||
before you can save or run your agent.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button variant="ghost" size="small" onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onConfirm}
|
||||
className="border-amber-700 bg-amber-600 hover:bg-amber-700"
|
||||
>
|
||||
Update Anyway
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
type TwoColumnSectionProps = {
|
||||
title: string;
|
||||
leftIcon: React.ReactNode;
|
||||
leftTitle: string;
|
||||
leftItems: string[];
|
||||
rightIcon: React.ReactNode;
|
||||
rightTitle: string;
|
||||
rightItems: string[];
|
||||
};
|
||||
|
||||
function TwoColumnSection({
|
||||
title,
|
||||
leftIcon,
|
||||
leftTitle,
|
||||
leftItems,
|
||||
rightIcon,
|
||||
rightTitle,
|
||||
rightItems,
|
||||
}: TwoColumnSectionProps) {
|
||||
return (
|
||||
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||
<span className="font-medium">{title}</span>
|
||||
<div className="mt-2 grid grid-cols-2 items-start gap-4">
|
||||
{/* Left column - Breaking changes */}
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||
{leftIcon}
|
||||
<span>{leftTitle}</span>
|
||||
</div>
|
||||
<ul className="mt-1.5 space-y-1">
|
||||
{leftItems.length > 0 ? (
|
||||
leftItems.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-red-50 px-1 py-0.5 font-mono text-xs text-red-700 dark:bg-red-900/30 dark:text-red-300">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))
|
||||
) : (
|
||||
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||
None
|
||||
</li>
|
||||
)}
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
{/* Right column - Possible solutions */}
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||
{rightIcon}
|
||||
<span>{rightTitle}</span>
|
||||
</div>
|
||||
<ul className="mt-1.5 space-y-1">
|
||||
{rightItems.length > 0 ? (
|
||||
rightItems.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-green-50 px-1 py-0.5 font-mono text-xs text-green-700 dark:bg-green-900/30 dark:text-green-300">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))
|
||||
) : (
|
||||
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||
None
|
||||
</li>
|
||||
)}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type SingleColumnSectionProps = {
|
||||
icon: React.ReactNode;
|
||||
title: string;
|
||||
description: string;
|
||||
items: string[];
|
||||
};
|
||||
|
||||
function SingleColumnSection({
|
||||
icon,
|
||||
title,
|
||||
description,
|
||||
items,
|
||||
}: SingleColumnSectionProps) {
|
||||
return (
|
||||
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||
<div className="flex items-center gap-2">
|
||||
{icon}
|
||||
<span className="font-medium">{title}</span>
|
||||
</div>
|
||||
<p className="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||
{description}
|
||||
</p>
|
||||
<ul className="mt-2 space-y-1">
|
||||
{items.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="ml-4 list-disc text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-gray-100 px-1 py-0.5 font-mono text-xs dark:bg-gray-800">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
import React from "react";
|
||||
import { InfoIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { IncompatibilityInfo } from "@/app/(platform)/build/hooks/useSubAgentUpdate/types";
|
||||
|
||||
type ResolutionModeBarProps = {
|
||||
incompatibilities: IncompatibilityInfo | null;
|
||||
};
|
||||
|
||||
export function ResolutionModeBar({
|
||||
incompatibilities,
|
||||
}: ResolutionModeBarProps): React.ReactElement {
|
||||
const renderIncompatibilities = () => {
|
||||
if (!incompatibilities) return <span>No incompatibilities</span>;
|
||||
|
||||
const sections: React.ReactNode[] = [];
|
||||
|
||||
if (incompatibilities.missingInputs.length > 0) {
|
||||
sections.push(
|
||||
<div key="missing-inputs" className="mb-1">
|
||||
<span className="font-semibold">Missing inputs: </span>
|
||||
{incompatibilities.missingInputs.map((name, i) => (
|
||||
<React.Fragment key={name}>
|
||||
<code className="font-mono">{name}</code>
|
||||
{i < incompatibilities.missingInputs.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.missingOutputs.length > 0) {
|
||||
sections.push(
|
||||
<div key="missing-outputs" className="mb-1">
|
||||
<span className="font-semibold">Missing outputs: </span>
|
||||
{incompatibilities.missingOutputs.map((name, i) => (
|
||||
<React.Fragment key={name}>
|
||||
<code className="font-mono">{name}</code>
|
||||
{i < incompatibilities.missingOutputs.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.newRequiredInputs.length > 0) {
|
||||
sections.push(
|
||||
<div key="new-required" className="mb-1">
|
||||
<span className="font-semibold">New required inputs: </span>
|
||||
{incompatibilities.newRequiredInputs.map((name, i) => (
|
||||
<React.Fragment key={name}>
|
||||
<code className="font-mono">{name}</code>
|
||||
{i < incompatibilities.newRequiredInputs.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.inputTypeMismatches.length > 0) {
|
||||
sections.push(
|
||||
<div key="type-mismatches" className="mb-1">
|
||||
<span className="font-semibold">Type changed: </span>
|
||||
{incompatibilities.inputTypeMismatches.map((m, i) => (
|
||||
<React.Fragment key={m.name}>
|
||||
<code className="font-mono">{m.name}</code>
|
||||
<span className="text-gray-400">
|
||||
{" "}
|
||||
({m.oldType} → {m.newType})
|
||||
</span>
|
||||
{i < incompatibilities.inputTypeMismatches.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
|
||||
return <>{sections}</>;
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 rounded-t-xl bg-amber-50 px-3 py-2 dark:bg-amber-900/30">
|
||||
<div className="flex items-center gap-2">
|
||||
<WarningIcon className="h-4 w-4 text-amber-600 dark:text-amber-400" />
|
||||
<span className="text-sm text-amber-700 dark:text-amber-300">
|
||||
Remove incompatible connections
|
||||
</span>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<InfoIcon className="h-4 w-4 cursor-help text-amber-500" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="max-w-sm">
|
||||
<p className="mb-2 font-semibold">Incompatible changes:</p>
|
||||
<div className="text-xs">{renderIncompatibilities()}</div>
|
||||
<p className="mt-2 text-xs text-gray-400">
|
||||
{(incompatibilities?.newRequiredInputs.length ?? 0) > 0
|
||||
? "Replace / delete"
|
||||
: "Delete"}{" "}
|
||||
the red connections to continue
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
import { useState, useCallback, useEffect } from "react";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import {
|
||||
useNodeStore,
|
||||
NodeResolutionData,
|
||||
} from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import {
|
||||
useSubAgentUpdate,
|
||||
createUpdatedAgentNodeInputs,
|
||||
getBrokenEdgeIDs,
|
||||
} from "@/app/(platform)/build/hooks/useSubAgentUpdate";
|
||||
import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api";
|
||||
import { CustomNodeData } from "../../CustomNode";
|
||||
|
||||
// Stable empty set to avoid creating new references in selectors
|
||||
const EMPTY_SET: Set<string> = new Set();
|
||||
|
||||
type UseSubAgentUpdateParams = {
|
||||
nodeID: string;
|
||||
nodeData: CustomNodeData;
|
||||
};
|
||||
|
||||
export function useSubAgentUpdateState({
|
||||
nodeID,
|
||||
nodeData,
|
||||
}: UseSubAgentUpdateParams) {
|
||||
const [showIncompatibilityDialog, setShowIncompatibilityDialog] =
|
||||
useState(false);
|
||||
|
||||
// Get store actions
|
||||
const updateNodeData = useNodeStore(
|
||||
useShallow((state) => state.updateNodeData),
|
||||
);
|
||||
const setNodeResolutionMode = useNodeStore(
|
||||
useShallow((state) => state.setNodeResolutionMode),
|
||||
);
|
||||
const isNodeInResolutionMode = useNodeStore(
|
||||
useShallow((state) => state.isNodeInResolutionMode),
|
||||
);
|
||||
const setBrokenEdgeIDs = useNodeStore(
|
||||
useShallow((state) => state.setBrokenEdgeIDs),
|
||||
);
|
||||
// Get this node's broken edge IDs from the per-node map
|
||||
// Use EMPTY_SET as fallback to maintain referential stability
|
||||
const brokenEdgeIDs = useNodeStore(
|
||||
(state) => state.brokenEdgeIDs.get(nodeID) || EMPTY_SET,
|
||||
);
|
||||
const getNodeResolutionData = useNodeStore(
|
||||
useShallow((state) => state.getNodeResolutionData),
|
||||
);
|
||||
const connectedEdges = useEdgeStore(
|
||||
useShallow((state) => state.getNodeEdges(nodeID)),
|
||||
);
|
||||
const availableSubGraphs = useGraphStore(
|
||||
useShallow((state) => state.availableSubGraphs),
|
||||
);
|
||||
|
||||
// Extract agent-specific data
|
||||
const graphID = nodeData.hardcodedValues?.graph_id as string | undefined;
|
||||
const graphVersion = nodeData.hardcodedValues?.graph_version as
|
||||
| number
|
||||
| undefined;
|
||||
const currentInputSchema = nodeData.hardcodedValues?.input_schema as
|
||||
| GraphInputSchema
|
||||
| undefined;
|
||||
const currentOutputSchema = nodeData.hardcodedValues?.output_schema as
|
||||
| GraphOutputSchema
|
||||
| undefined;
|
||||
|
||||
// Use the sub-agent update hook
|
||||
const updateInfo = useSubAgentUpdate(
|
||||
nodeID,
|
||||
graphID,
|
||||
graphVersion,
|
||||
currentInputSchema,
|
||||
currentOutputSchema,
|
||||
connectedEdges,
|
||||
availableSubGraphs,
|
||||
);
|
||||
|
||||
const isInResolutionMode = isNodeInResolutionMode(nodeID);
|
||||
|
||||
// Handle update button click
|
||||
const handleUpdateClick = useCallback(() => {
|
||||
if (!updateInfo.hasUpdate || !updateInfo.latestGraph) return;
|
||||
|
||||
if (updateInfo.isCompatible) {
|
||||
// Compatible update - apply directly
|
||||
const newHardcodedValues = createUpdatedAgentNodeInputs(
|
||||
nodeData.hardcodedValues,
|
||||
updateInfo.latestGraph,
|
||||
);
|
||||
updateNodeData(nodeID, { hardcodedValues: newHardcodedValues });
|
||||
} else {
|
||||
// Incompatible update - show dialog
|
||||
setShowIncompatibilityDialog(true);
|
||||
}
|
||||
}, [
|
||||
updateInfo.hasUpdate,
|
||||
updateInfo.latestGraph,
|
||||
updateInfo.isCompatible,
|
||||
nodeData.hardcodedValues,
|
||||
updateNodeData,
|
||||
nodeID,
|
||||
]);
|
||||
|
||||
// Handle confirming an incompatible update
|
||||
function handleConfirmIncompatibleUpdate() {
|
||||
if (!updateInfo.latestGraph || !updateInfo.incompatibilities) return;
|
||||
|
||||
const latestGraph = updateInfo.latestGraph;
|
||||
|
||||
// Get the new schemas from the latest graph version
|
||||
const newInputSchema =
|
||||
(latestGraph.input_schema as Record<string, unknown>) || {};
|
||||
const newOutputSchema =
|
||||
(latestGraph.output_schema as Record<string, unknown>) || {};
|
||||
|
||||
// Create the updated hardcoded values but DON'T apply them yet
|
||||
// We'll apply them when resolution is complete
|
||||
const pendingHardcodedValues = createUpdatedAgentNodeInputs(
|
||||
nodeData.hardcodedValues,
|
||||
latestGraph,
|
||||
);
|
||||
|
||||
// Get broken edge IDs and store them for this node
|
||||
const brokenIds = getBrokenEdgeIDs(
|
||||
connectedEdges,
|
||||
updateInfo.incompatibilities,
|
||||
nodeID,
|
||||
);
|
||||
setBrokenEdgeIDs(nodeID, brokenIds);
|
||||
|
||||
// Enter resolution mode with both old and new schemas
|
||||
// DON'T apply the update yet - keep old schema so connections remain visible
|
||||
const resolutionData: NodeResolutionData = {
|
||||
incompatibilities: updateInfo.incompatibilities,
|
||||
pendingUpdate: {
|
||||
input_schema: newInputSchema,
|
||||
output_schema: newOutputSchema,
|
||||
},
|
||||
currentSchema: {
|
||||
input_schema: (currentInputSchema as Record<string, unknown>) || {},
|
||||
output_schema: (currentOutputSchema as Record<string, unknown>) || {},
|
||||
},
|
||||
pendingHardcodedValues,
|
||||
};
|
||||
setNodeResolutionMode(nodeID, true, resolutionData);
|
||||
|
||||
setShowIncompatibilityDialog(false);
|
||||
}
|
||||
|
||||
// Check if resolution is complete (all broken edges removed)
|
||||
const resolutionData = getNodeResolutionData(nodeID);
|
||||
|
||||
// Auto-check resolution on edge changes
|
||||
useEffect(() => {
|
||||
if (!isInResolutionMode) return;
|
||||
|
||||
// Check if any broken edges still exist
|
||||
const remainingBroken = Array.from(brokenEdgeIDs).filter((edgeId) =>
|
||||
connectedEdges.some((e) => e.id === edgeId),
|
||||
);
|
||||
|
||||
if (remainingBroken.length === 0) {
|
||||
// Resolution complete - now apply the pending update
|
||||
if (resolutionData?.pendingHardcodedValues) {
|
||||
updateNodeData(nodeID, {
|
||||
hardcodedValues: resolutionData.pendingHardcodedValues,
|
||||
});
|
||||
}
|
||||
// setNodeResolutionMode will clean up this node's broken edges automatically
|
||||
setNodeResolutionMode(nodeID, false);
|
||||
}
|
||||
}, [
|
||||
isInResolutionMode,
|
||||
brokenEdgeIDs,
|
||||
connectedEdges,
|
||||
resolutionData,
|
||||
nodeID,
|
||||
]);
|
||||
|
||||
return {
|
||||
updateInfo,
|
||||
isInResolutionMode,
|
||||
resolutionData,
|
||||
showIncompatibilityDialog,
|
||||
setShowIncompatibilityDialog,
|
||||
handleUpdateClick,
|
||||
handleConfirmIncompatibleUpdate,
|
||||
};
|
||||
}
|
||||
@@ -1,6 +1,4 @@
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { NodeResolutionData } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
|
||||
export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
||||
INCOMPLETE: "ring-slate-300 bg-slate-300",
|
||||
@@ -11,48 +9,3 @@ export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
||||
TERMINATED: "ring-orange-300 bg-orange-300 ",
|
||||
FAILED: "ring-red-300 bg-red-300",
|
||||
};
|
||||
|
||||
/**
|
||||
* Merges schemas during resolution mode to include removed inputs/outputs
|
||||
* that still have connections, so users can see and delete them.
|
||||
*/
|
||||
export function mergeSchemaForResolution(
|
||||
currentSchema: Record<string, unknown>,
|
||||
newSchema: Record<string, unknown>,
|
||||
resolutionData: NodeResolutionData,
|
||||
type: "input" | "output",
|
||||
): Record<string, unknown> {
|
||||
const newProps = (newSchema.properties as RJSFSchema) || {};
|
||||
const currentProps = (currentSchema.properties as RJSFSchema) || {};
|
||||
const mergedProps = { ...newProps };
|
||||
const incomp = resolutionData.incompatibilities;
|
||||
|
||||
if (type === "input") {
|
||||
// Add back missing inputs that have connections
|
||||
incomp.missingInputs.forEach((inputName: string) => {
|
||||
if (currentProps[inputName]) {
|
||||
mergedProps[inputName] = currentProps[inputName];
|
||||
}
|
||||
});
|
||||
// Add back inputs with type mismatches (keep old type so connection works visually)
|
||||
incomp.inputTypeMismatches.forEach(
|
||||
(mismatch: { name: string; oldType: string; newType: string }) => {
|
||||
if (currentProps[mismatch.name]) {
|
||||
mergedProps[mismatch.name] = currentProps[mismatch.name];
|
||||
}
|
||||
},
|
||||
);
|
||||
} else {
|
||||
// Add back missing outputs that have connections
|
||||
incomp.missingOutputs.forEach((outputName: string) => {
|
||||
if (currentProps[outputName]) {
|
||||
mergedProps[outputName] = currentProps[outputName];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
...newSchema,
|
||||
properties: mergedProps,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { CustomNodeData } from "./CustomNode";
|
||||
import { BlockUIType } from "../../../types";
|
||||
import { useMemo } from "react";
|
||||
import { mergeSchemaForResolution } from "./helpers";
|
||||
|
||||
export const useCustomNode = ({
|
||||
data,
|
||||
nodeId,
|
||||
}: {
|
||||
data: CustomNodeData;
|
||||
nodeId: string;
|
||||
}) => {
|
||||
const isInResolutionMode = useNodeStore((state) =>
|
||||
state.nodesInResolutionMode.has(nodeId),
|
||||
);
|
||||
const resolutionData = useNodeStore((state) =>
|
||||
state.nodeResolutionData.get(nodeId),
|
||||
);
|
||||
|
||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||
|
||||
const currentInputSchema = isAgent
|
||||
? (data.hardcodedValues.input_schema ?? {})
|
||||
: data.inputSchema;
|
||||
const currentOutputSchema = isAgent
|
||||
? (data.hardcodedValues.output_schema ?? {})
|
||||
: data.outputSchema;
|
||||
|
||||
const inputSchema = useMemo(() => {
|
||||
if (isAgent && isInResolutionMode && resolutionData) {
|
||||
return mergeSchemaForResolution(
|
||||
resolutionData.currentSchema.input_schema,
|
||||
resolutionData.pendingUpdate.input_schema,
|
||||
resolutionData,
|
||||
"input",
|
||||
);
|
||||
}
|
||||
return currentInputSchema;
|
||||
}, [isAgent, isInResolutionMode, resolutionData, currentInputSchema]);
|
||||
|
||||
const outputSchema = useMemo(() => {
|
||||
if (isAgent && isInResolutionMode && resolutionData) {
|
||||
return mergeSchemaForResolution(
|
||||
resolutionData.currentSchema.output_schema,
|
||||
resolutionData.pendingUpdate.output_schema,
|
||||
resolutionData,
|
||||
"output",
|
||||
);
|
||||
}
|
||||
return currentOutputSchema;
|
||||
}, [isAgent, isInResolutionMode, resolutionData, currentOutputSchema]);
|
||||
|
||||
return {
|
||||
inputSchema,
|
||||
outputSchema,
|
||||
};
|
||||
};
|
||||
@@ -5,16 +5,20 @@ import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { BlockUIType } from "../../types";
|
||||
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
|
||||
|
||||
interface FormCreatorProps {
|
||||
jsonSchema: RJSFSchema;
|
||||
nodeId: string;
|
||||
uiType: BlockUIType;
|
||||
showHandles?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
|
||||
({ jsonSchema, nodeId, uiType, showHandles = true, className }) => {
|
||||
export const FormCreator = React.memo(
|
||||
({
|
||||
jsonSchema,
|
||||
nodeId,
|
||||
uiType,
|
||||
showHandles = true,
|
||||
className,
|
||||
}: {
|
||||
jsonSchema: RJSFSchema;
|
||||
nodeId: string;
|
||||
uiType: BlockUIType;
|
||||
showHandles?: boolean;
|
||||
className?: string;
|
||||
}) => {
|
||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||
|
||||
const getHardCodedValues = useNodeStore(
|
||||
|
||||
@@ -14,8 +14,6 @@ import {
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { getTypeDisplayInfo } from "./helpers";
|
||||
import { BlockUIType } from "../../types";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useBrokenOutputs } from "./useBrokenOutputs";
|
||||
|
||||
export const OutputHandler = ({
|
||||
outputSchema,
|
||||
@@ -29,9 +27,6 @@ export const OutputHandler = ({
|
||||
const { isOutputConnected } = useEdgeStore();
|
||||
const properties = outputSchema?.properties || {};
|
||||
const [isOutputVisible, setIsOutputVisible] = useState(true);
|
||||
const brokenOutputs = useBrokenOutputs(nodeId);
|
||||
|
||||
console.log("brokenOutputs", brokenOutputs);
|
||||
|
||||
const showHandles = uiType !== BlockUIType.OUTPUT;
|
||||
|
||||
@@ -49,7 +44,6 @@ export const OutputHandler = ({
|
||||
const shouldShow = isConnected || isOutputVisible;
|
||||
const { displayType, colorClass, hexColor } =
|
||||
getTypeDisplayInfo(fieldSchema);
|
||||
const isBroken = brokenOutputs.has(fullKey);
|
||||
|
||||
return shouldShow ? (
|
||||
<div key={fullKey} className="flex flex-col items-end gap-2">
|
||||
@@ -70,29 +64,15 @@ export const OutputHandler = ({
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"text-slate-700",
|
||||
isBroken && "text-red-500 line-through",
|
||||
)}
|
||||
>
|
||||
<Text variant="body" className="text-slate-700">
|
||||
{fieldTitle}
|
||||
</Text>
|
||||
<Text
|
||||
variant="small"
|
||||
as="span"
|
||||
className={cn(
|
||||
colorClass,
|
||||
isBroken && "!text-red-500 line-through",
|
||||
)}
|
||||
>
|
||||
<Text variant="small" as="span" className={colorClass}>
|
||||
({displayType})
|
||||
</Text>
|
||||
|
||||
{showHandles && (
|
||||
<OutputNodeHandle
|
||||
isBroken={isBroken}
|
||||
field_name={fullKey}
|
||||
nodeId={nodeId}
|
||||
hexColor={hexColor}
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import { useMemo } from "react";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
|
||||
/**
|
||||
* Hook to get the set of broken output names for a node in resolution mode.
|
||||
*/
|
||||
export function useBrokenOutputs(nodeID: string): Set<string> {
|
||||
// Subscribe to the actual state values, not just methods
|
||||
const isInResolution = useNodeStore((state) =>
|
||||
state.nodesInResolutionMode.has(nodeID),
|
||||
);
|
||||
const resolutionData = useNodeStore((state) =>
|
||||
state.nodeResolutionData.get(nodeID),
|
||||
);
|
||||
|
||||
return useMemo(() => {
|
||||
if (!isInResolution || !resolutionData) {
|
||||
return new Set<string>();
|
||||
}
|
||||
|
||||
return new Set(resolutionData.incompatibilities.missingOutputs);
|
||||
}, [isInResolution, resolutionData]);
|
||||
}
|
||||
@@ -25,7 +25,7 @@ export const RightSidebar = () => {
|
||||
>
|
||||
<div className="mb-4">
|
||||
<h2 className="text-lg font-semibold text-slate-800 dark:text-slate-200">
|
||||
Graph Debug Panel
|
||||
Flow Debug Panel
|
||||
</h2>
|
||||
</div>
|
||||
|
||||
@@ -65,7 +65,7 @@ export const RightSidebar = () => {
|
||||
{l.source_id}[{l.source_name}] → {l.sink_id}[{l.sink_name}]
|
||||
</div>
|
||||
<div className="mt-1 text-slate-500 dark:text-slate-400">
|
||||
edge.id: {l.id}
|
||||
edge_id: {l.id}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
|
||||
@@ -12,14 +12,7 @@ import {
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/__legacy__/ui/popover";
|
||||
import {
|
||||
Block,
|
||||
BlockIORootSchema,
|
||||
BlockUIType,
|
||||
GraphInputSchema,
|
||||
GraphOutputSchema,
|
||||
SpecialBlockID,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { Block, BlockUIType, SpecialBlockID } from "@/lib/autogpt-server-api";
|
||||
import { MagnifyingGlassIcon, PlusIcon } from "@radix-ui/react-icons";
|
||||
import { IconToyBrick } from "@/components/__legacy__/ui/icons";
|
||||
import { getPrimaryCategoryColor } from "@/lib/utils";
|
||||
@@ -31,10 +24,8 @@ import {
|
||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import jaro from "jaro-winkler";
|
||||
|
||||
type _Block = Omit<Block, "inputSchema" | "outputSchema"> & {
|
||||
type _Block = Block & {
|
||||
uiKey?: string;
|
||||
inputSchema: BlockIORootSchema | GraphInputSchema;
|
||||
outputSchema: BlockIORootSchema | GraphOutputSchema;
|
||||
hardcodedValues?: Record<string, any>;
|
||||
_cached?: {
|
||||
blockName: string;
|
||||
|
||||
@@ -2,7 +2,7 @@ import React from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { LogOut } from "lucide-react";
|
||||
import { ClockIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import { ClockIcon } from "@phosphor-icons/react";
|
||||
import { IconPlay, IconSquare } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
interface Props {
|
||||
@@ -13,7 +13,6 @@ interface Props {
|
||||
isRunning: boolean;
|
||||
isDisabled: boolean;
|
||||
className?: string;
|
||||
resolutionModeActive?: boolean;
|
||||
}
|
||||
|
||||
export const BuildActionBar: React.FC<Props> = ({
|
||||
@@ -24,30 +23,9 @@ export const BuildActionBar: React.FC<Props> = ({
|
||||
isRunning,
|
||||
isDisabled,
|
||||
className,
|
||||
resolutionModeActive = false,
|
||||
}) => {
|
||||
const buttonClasses =
|
||||
"flex items-center gap-2 text-sm font-medium md:text-lg";
|
||||
|
||||
// Show resolution mode message instead of action buttons
|
||||
if (resolutionModeActive) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex w-fit select-none items-center justify-center p-4",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-3 rounded-lg border border-amber-300 bg-amber-50 px-4 py-3 dark:border-amber-700 dark:bg-amber-900/30">
|
||||
<WarningIcon className="size-5 text-amber-600 dark:text-amber-400" />
|
||||
<span className="text-sm font-medium text-amber-800 dark:text-amber-200">
|
||||
Remove incompatible connections to continue
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
|
||||
@@ -60,16 +60,10 @@ export function CustomEdge({
|
||||
targetY - 5,
|
||||
);
|
||||
const { deleteElements } = useReactFlow<Node, CustomEdge>();
|
||||
const builderContext = useContext(BuilderContext);
|
||||
const { visualizeBeads } = builderContext ?? {
|
||||
const { visualizeBeads } = useContext(BuilderContext) ?? {
|
||||
visualizeBeads: "no",
|
||||
};
|
||||
|
||||
// Check if this edge is broken (during resolution mode)
|
||||
const isBroken =
|
||||
builderContext?.resolutionMode?.active &&
|
||||
builderContext?.resolutionMode?.brokenEdgeIds?.includes(id);
|
||||
|
||||
const onEdgeRemoveClick = () => {
|
||||
deleteElements({ edges: [{ id }] });
|
||||
};
|
||||
@@ -177,27 +171,12 @@ export function CustomEdge({
|
||||
|
||||
const middle = getPointForT(0.5);
|
||||
|
||||
// Determine edge color - red for broken edges
|
||||
const baseColor = data?.edgeColor ?? "#555555";
|
||||
const edgeColor = isBroken ? "#ef4444" : baseColor;
|
||||
// Add opacity to hex color (99 = 60% opacity, 80 = 50% opacity)
|
||||
const strokeColor = isBroken
|
||||
? `${edgeColor}99`
|
||||
: selected
|
||||
? edgeColor
|
||||
: `${edgeColor}80`;
|
||||
|
||||
return (
|
||||
<>
|
||||
<BaseEdge
|
||||
path={svgPath}
|
||||
markerEnd={markerEnd}
|
||||
style={{
|
||||
stroke: strokeColor,
|
||||
strokeWidth: data?.isStatic ? 2.5 : 2,
|
||||
strokeDasharray: data?.isStatic ? "5 3" : undefined,
|
||||
}}
|
||||
className="data-sentry-unmask transition-all duration-200"
|
||||
className={`data-sentry-unmask transition-all duration-200 ${data?.isStatic ? "[stroke-dasharray:5_3]" : "[stroke-dasharray:0]"} [stroke-width:${data?.isStatic ? 2.5 : 2}px] hover:[stroke-width:${data?.isStatic ? 3.5 : 3}px] ${selected ? `[stroke:${data?.edgeColor ?? "#555555"}]` : `[stroke:${data?.edgeColor ?? "#555555"}80] hover:[stroke:${data?.edgeColor ?? "#555555"}]`}`}
|
||||
/>
|
||||
<path
|
||||
d={svgPath}
|
||||
|
||||
@@ -18,8 +18,6 @@ import {
|
||||
BlockIOSubSchema,
|
||||
BlockUIType,
|
||||
Category,
|
||||
GraphInputSchema,
|
||||
GraphOutputSchema,
|
||||
NodeExecutionResult,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import {
|
||||
@@ -64,21 +62,14 @@ import { NodeGenericInputField, NodeTextBoxInput } from "../NodeInputs";
|
||||
import NodeOutputs from "../NodeOutputs";
|
||||
import OutputModalComponent from "../OutputModalComponent";
|
||||
import "./customnode.css";
|
||||
import { SubAgentUpdateBar } from "./SubAgentUpdateBar";
|
||||
import { IncompatibilityDialog } from "./IncompatibilityDialog";
|
||||
import {
|
||||
useSubAgentUpdate,
|
||||
createUpdatedAgentNodeInputs,
|
||||
getBrokenEdgeIDs,
|
||||
} from "../../../hooks/useSubAgentUpdate";
|
||||
|
||||
export type ConnectedEdge = {
|
||||
id: string;
|
||||
export type ConnectionData = Array<{
|
||||
edge_id: string;
|
||||
source: string;
|
||||
sourceHandle: string;
|
||||
target: string;
|
||||
targetHandle: string;
|
||||
};
|
||||
}>;
|
||||
|
||||
export type CustomNodeData = {
|
||||
blockType: string;
|
||||
@@ -89,7 +80,7 @@ export type CustomNodeData = {
|
||||
inputSchema: BlockIORootSchema;
|
||||
outputSchema: BlockIORootSchema;
|
||||
hardcodedValues: { [key: string]: any };
|
||||
connections: ConnectedEdge[];
|
||||
connections: ConnectionData;
|
||||
isOutputOpen: boolean;
|
||||
status?: NodeExecutionResult["status"];
|
||||
/** executionResults contains outputs across multiple executions
|
||||
@@ -136,199 +127,20 @@ export const CustomNode = React.memo(
|
||||
|
||||
let subGraphID = "";
|
||||
|
||||
if (data.uiType === BlockUIType.AGENT) {
|
||||
// Display the graph's schema instead AgentExecutorBlock's schema.
|
||||
data.inputSchema = data.hardcodedValues?.input_schema || {};
|
||||
data.outputSchema = data.hardcodedValues?.output_schema || {};
|
||||
subGraphID = data.hardcodedValues?.graph_id || subGraphID;
|
||||
}
|
||||
|
||||
if (!builderContext) {
|
||||
throw new Error(
|
||||
"BuilderContext consumer must be inside FlowEditor component",
|
||||
);
|
||||
}
|
||||
|
||||
const {
|
||||
libraryAgent,
|
||||
setIsAnyModalOpen,
|
||||
getNextNodeId,
|
||||
availableFlows,
|
||||
resolutionMode,
|
||||
enterResolutionMode,
|
||||
} = builderContext;
|
||||
|
||||
// Check if this node is in resolution mode (moved up for schema merge logic)
|
||||
const isInResolutionMode =
|
||||
resolutionMode.active && resolutionMode.nodeId === id;
|
||||
|
||||
if (data.uiType === BlockUIType.AGENT) {
|
||||
// Display the graph's schema instead AgentExecutorBlock's schema.
|
||||
const currentInputSchema = data.hardcodedValues?.input_schema || {};
|
||||
const currentOutputSchema = data.hardcodedValues?.output_schema || {};
|
||||
subGraphID = data.hardcodedValues?.graph_id || subGraphID;
|
||||
|
||||
// During resolution mode, merge old connected inputs/outputs with new schema
|
||||
if (isInResolutionMode && resolutionMode.pendingUpdate) {
|
||||
const newInputSchema =
|
||||
(resolutionMode.pendingUpdate.input_schema as BlockIORootSchema) ||
|
||||
{};
|
||||
const newOutputSchema =
|
||||
(resolutionMode.pendingUpdate.output_schema as BlockIORootSchema) ||
|
||||
{};
|
||||
|
||||
// Merge input schemas: start with new schema, add old connected inputs that are missing
|
||||
const mergedInputProps = { ...newInputSchema.properties };
|
||||
const incomp = resolutionMode.incompatibilities;
|
||||
if (incomp && currentInputSchema.properties) {
|
||||
// Add back missing inputs that have connections (so user can see/delete them)
|
||||
incomp.missingInputs.forEach((inputName) => {
|
||||
if (currentInputSchema.properties[inputName]) {
|
||||
mergedInputProps[inputName] =
|
||||
currentInputSchema.properties[inputName];
|
||||
}
|
||||
});
|
||||
// Add back inputs with type mismatches (keep old type so connection still works visually)
|
||||
incomp.inputTypeMismatches.forEach((mismatch) => {
|
||||
if (currentInputSchema.properties[mismatch.name]) {
|
||||
mergedInputProps[mismatch.name] =
|
||||
currentInputSchema.properties[mismatch.name];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Merge output schemas: start with new schema, add old connected outputs that are missing
|
||||
const mergedOutputProps = { ...newOutputSchema.properties };
|
||||
if (incomp && currentOutputSchema.properties) {
|
||||
incomp.missingOutputs.forEach((outputName) => {
|
||||
if (currentOutputSchema.properties[outputName]) {
|
||||
mergedOutputProps[outputName] =
|
||||
currentOutputSchema.properties[outputName];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
data.inputSchema = {
|
||||
...newInputSchema,
|
||||
properties: mergedInputProps,
|
||||
};
|
||||
data.outputSchema = {
|
||||
...newOutputSchema,
|
||||
properties: mergedOutputProps,
|
||||
};
|
||||
} else {
|
||||
data.inputSchema = currentInputSchema;
|
||||
data.outputSchema = currentOutputSchema;
|
||||
}
|
||||
}
|
||||
|
||||
const setHardcodedValues = useCallback(
|
||||
(values: any) => {
|
||||
updateNodeData(id, { hardcodedValues: values });
|
||||
},
|
||||
[id, updateNodeData],
|
||||
);
|
||||
|
||||
// Sub-agent update detection
|
||||
const isAgentBlock = data.uiType === BlockUIType.AGENT;
|
||||
const graphId = isAgentBlock ? data.hardcodedValues?.graph_id : undefined;
|
||||
const graphVersion = isAgentBlock
|
||||
? data.hardcodedValues?.graph_version
|
||||
: undefined;
|
||||
|
||||
const subAgentUpdate = useSubAgentUpdate(
|
||||
id,
|
||||
graphId,
|
||||
graphVersion,
|
||||
isAgentBlock
|
||||
? (data.hardcodedValues?.input_schema as GraphInputSchema)
|
||||
: undefined,
|
||||
isAgentBlock
|
||||
? (data.hardcodedValues?.output_schema as GraphOutputSchema)
|
||||
: undefined,
|
||||
data.connections,
|
||||
availableFlows,
|
||||
);
|
||||
|
||||
const [showIncompatibilityDialog, setShowIncompatibilityDialog] =
|
||||
useState(false);
|
||||
|
||||
// Helper to check if a handle is broken (for resolution mode)
|
||||
const isInputHandleBroken = useCallback(
|
||||
(handleName: string): boolean => {
|
||||
if (!isInResolutionMode || !resolutionMode.incompatibilities) {
|
||||
return false;
|
||||
}
|
||||
const incomp = resolutionMode.incompatibilities;
|
||||
return (
|
||||
incomp.missingInputs.includes(handleName) ||
|
||||
incomp.inputTypeMismatches.some((m) => m.name === handleName)
|
||||
);
|
||||
},
|
||||
[isInResolutionMode, resolutionMode.incompatibilities],
|
||||
);
|
||||
|
||||
const isOutputHandleBroken = useCallback(
|
||||
(handleName: string): boolean => {
|
||||
if (!isInResolutionMode || !resolutionMode.incompatibilities) {
|
||||
return false;
|
||||
}
|
||||
return resolutionMode.incompatibilities.missingOutputs.includes(
|
||||
handleName,
|
||||
);
|
||||
},
|
||||
[isInResolutionMode, resolutionMode.incompatibilities],
|
||||
);
|
||||
|
||||
// Handle update button click
|
||||
const handleUpdateClick = useCallback(() => {
|
||||
if (!subAgentUpdate.latestGraph) return;
|
||||
|
||||
if (subAgentUpdate.isCompatible) {
|
||||
// Compatible update - directly apply
|
||||
const updatedValues = createUpdatedAgentNodeInputs(
|
||||
data.hardcodedValues,
|
||||
subAgentUpdate.latestGraph,
|
||||
);
|
||||
setHardcodedValues(updatedValues);
|
||||
toast({
|
||||
title: "Agent updated",
|
||||
description: `Updated to version ${subAgentUpdate.latestVersion}`,
|
||||
});
|
||||
} else {
|
||||
// Incompatible update - show dialog
|
||||
setShowIncompatibilityDialog(true);
|
||||
}
|
||||
}, [subAgentUpdate, data.hardcodedValues, setHardcodedValues]);
|
||||
|
||||
// Handle confirm incompatible update
|
||||
const handleConfirmIncompatibleUpdate = useCallback(() => {
|
||||
if (!subAgentUpdate.latestGraph || !subAgentUpdate.incompatibilities) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Create the updated values but DON'T apply them yet
|
||||
const updatedValues = createUpdatedAgentNodeInputs(
|
||||
data.hardcodedValues,
|
||||
subAgentUpdate.latestGraph,
|
||||
);
|
||||
|
||||
// Get broken edge IDs
|
||||
const brokenEdgeIds = getBrokenEdgeIDs(
|
||||
data.connections,
|
||||
subAgentUpdate.incompatibilities,
|
||||
id,
|
||||
);
|
||||
|
||||
// Enter resolution mode with pending update (don't apply schema yet)
|
||||
enterResolutionMode(
|
||||
id,
|
||||
subAgentUpdate.incompatibilities,
|
||||
brokenEdgeIds,
|
||||
updatedValues,
|
||||
);
|
||||
|
||||
setShowIncompatibilityDialog(false);
|
||||
}, [
|
||||
subAgentUpdate,
|
||||
data.hardcodedValues,
|
||||
data.connections,
|
||||
id,
|
||||
enterResolutionMode,
|
||||
]);
|
||||
const { libraryAgent, setIsAnyModalOpen, getNextNodeId } = builderContext;
|
||||
|
||||
useEffect(() => {
|
||||
if (data.executionResults || data.status) {
|
||||
@@ -344,6 +156,13 @@ export const CustomNode = React.memo(
|
||||
setIsAnyModalOpen?.(isModalOpen || isOutputModalOpen);
|
||||
}, [isModalOpen, isOutputModalOpen, data, setIsAnyModalOpen]);
|
||||
|
||||
const setHardcodedValues = useCallback(
|
||||
(values: any) => {
|
||||
updateNodeData(id, { hardcodedValues: values });
|
||||
},
|
||||
[id, updateNodeData],
|
||||
);
|
||||
|
||||
const handleTitleEdit = useCallback(() => {
|
||||
setIsEditingTitle(true);
|
||||
setTimeout(() => {
|
||||
@@ -436,7 +255,6 @@ export const CustomNode = React.memo(
|
||||
isConnected={isOutputHandleConnected(propKey)}
|
||||
schema={fieldSchema}
|
||||
side="right"
|
||||
isBroken={isOutputHandleBroken(propKey)}
|
||||
/>
|
||||
{"properties" in fieldSchema &&
|
||||
renderHandles(
|
||||
@@ -567,7 +385,6 @@ export const CustomNode = React.memo(
|
||||
isRequired={isRequired}
|
||||
schema={propSchema}
|
||||
side="left"
|
||||
isBroken={isInputHandleBroken(propKey)}
|
||||
/>
|
||||
) : (
|
||||
propKey !== "credentials" &&
|
||||
@@ -1056,22 +873,6 @@ export const CustomNode = React.memo(
|
||||
<ContextMenuContent />
|
||||
</div>
|
||||
|
||||
{/* Sub-agent Update Bar - shown below header */}
|
||||
{isAgentBlock && (subAgentUpdate.hasUpdate || isInResolutionMode) && (
|
||||
<SubAgentUpdateBar
|
||||
currentVersion={subAgentUpdate.currentVersion}
|
||||
latestVersion={subAgentUpdate.latestVersion}
|
||||
isCompatible={subAgentUpdate.isCompatible}
|
||||
incompatibilities={
|
||||
isInResolutionMode
|
||||
? resolutionMode.incompatibilities
|
||||
: subAgentUpdate.incompatibilities
|
||||
}
|
||||
onUpdate={handleUpdateClick}
|
||||
isInResolutionMode={isInResolutionMode}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Body */}
|
||||
<div className="mx-5 my-6 rounded-b-xl">
|
||||
{/* Input Handles */}
|
||||
@@ -1243,24 +1044,9 @@ export const CustomNode = React.memo(
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<ContextMenu.Root>
|
||||
<ContextMenu.Trigger>{nodeContent()}</ContextMenu.Trigger>
|
||||
</ContextMenu.Root>
|
||||
|
||||
{/* Incompatibility Dialog for sub-agent updates */}
|
||||
{isAgentBlock && subAgentUpdate.incompatibilities && (
|
||||
<IncompatibilityDialog
|
||||
isOpen={showIncompatibilityDialog}
|
||||
onClose={() => setShowIncompatibilityDialog(false)}
|
||||
onConfirm={handleConfirmIncompatibleUpdate}
|
||||
currentVersion={subAgentUpdate.currentVersion}
|
||||
latestVersion={subAgentUpdate.latestVersion}
|
||||
agentName={data.blockType || "Agent"}
|
||||
incompatibilities={subAgentUpdate.incompatibilities}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
<ContextMenu.Root>
|
||||
<ContextMenu.Trigger>{nodeContent()}</ContextMenu.Trigger>
|
||||
</ContextMenu.Root>
|
||||
);
|
||||
},
|
||||
(prevProps, nextProps) => {
|
||||
|
||||
@@ -1,244 +0,0 @@
|
||||
import React from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/__legacy__/ui/dialog";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { AlertTriangle, XCircle, PlusCircle } from "lucide-react";
|
||||
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
|
||||
interface IncompatibilityDialogProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
onConfirm: () => void;
|
||||
currentVersion: number;
|
||||
latestVersion: number;
|
||||
agentName: string;
|
||||
incompatibilities: IncompatibilityInfo;
|
||||
}
|
||||
|
||||
export const IncompatibilityDialog: React.FC<IncompatibilityDialogProps> = ({
|
||||
isOpen,
|
||||
onClose,
|
||||
onConfirm,
|
||||
currentVersion,
|
||||
latestVersion,
|
||||
agentName,
|
||||
incompatibilities,
|
||||
}) => {
|
||||
const hasMissingInputs = incompatibilities.missingInputs.length > 0;
|
||||
const hasMissingOutputs = incompatibilities.missingOutputs.length > 0;
|
||||
const hasNewInputs = incompatibilities.newInputs.length > 0;
|
||||
const hasNewOutputs = incompatibilities.newOutputs.length > 0;
|
||||
const hasNewRequired = incompatibilities.newRequiredInputs.length > 0;
|
||||
const hasTypeMismatches = incompatibilities.inputTypeMismatches.length > 0;
|
||||
|
||||
const hasInputChanges = hasMissingInputs || hasNewInputs;
|
||||
const hasOutputChanges = hasMissingOutputs || hasNewOutputs;
|
||||
|
||||
return (
|
||||
<Dialog open={isOpen} onOpenChange={(open) => !open && onClose()}>
|
||||
<DialogContent className="max-w-lg">
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
<AlertTriangle className="h-5 w-5 text-amber-500" />
|
||||
Incompatible Update
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Updating <strong>{beautifyString(agentName)}</strong> from v
|
||||
{currentVersion} to v{latestVersion} will break some connections.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="space-y-4 py-2">
|
||||
{/* Input changes - two column layout */}
|
||||
{hasInputChanges && (
|
||||
<TwoColumnSection
|
||||
title="Input Changes"
|
||||
leftIcon={<XCircle className="h-4 w-4 text-red-500" />}
|
||||
leftTitle="Removed"
|
||||
leftItems={incompatibilities.missingInputs}
|
||||
rightIcon={<PlusCircle className="h-4 w-4 text-green-500" />}
|
||||
rightTitle="Added"
|
||||
rightItems={incompatibilities.newInputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Output changes - two column layout */}
|
||||
{hasOutputChanges && (
|
||||
<TwoColumnSection
|
||||
title="Output Changes"
|
||||
leftIcon={<XCircle className="h-4 w-4 text-red-500" />}
|
||||
leftTitle="Removed"
|
||||
leftItems={incompatibilities.missingOutputs}
|
||||
rightIcon={<PlusCircle className="h-4 w-4 text-green-500" />}
|
||||
rightTitle="Added"
|
||||
rightItems={incompatibilities.newOutputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasTypeMismatches && (
|
||||
<SingleColumnSection
|
||||
icon={<XCircle className="h-4 w-4 text-red-500" />}
|
||||
title="Type Changed"
|
||||
description="These connected inputs have a different type:"
|
||||
items={incompatibilities.inputTypeMismatches.map(
|
||||
(m) => `${m.name} (${m.oldType} → ${m.newType})`,
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasNewRequired && (
|
||||
<SingleColumnSection
|
||||
icon={<PlusCircle className="h-4 w-4 text-amber-500" />}
|
||||
title="New Required Inputs"
|
||||
description="These inputs are now required:"
|
||||
items={incompatibilities.newRequiredInputs}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Alert variant="warning">
|
||||
<AlertDescription>
|
||||
If you proceed, you'll need to remove the broken connections
|
||||
before you can save or run your agent.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<DialogFooter className="gap-2 sm:gap-0">
|
||||
<Button variant="outline" onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="destructive"
|
||||
onClick={onConfirm}
|
||||
className="bg-amber-600 hover:bg-amber-700"
|
||||
>
|
||||
Update Anyway
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
interface TwoColumnSectionProps {
|
||||
title: string;
|
||||
leftIcon: React.ReactNode;
|
||||
leftTitle: string;
|
||||
leftItems: string[];
|
||||
rightIcon: React.ReactNode;
|
||||
rightTitle: string;
|
||||
rightItems: string[];
|
||||
}
|
||||
|
||||
const TwoColumnSection: React.FC<TwoColumnSectionProps> = ({
|
||||
title,
|
||||
leftIcon,
|
||||
leftTitle,
|
||||
leftItems,
|
||||
rightIcon,
|
||||
rightTitle,
|
||||
rightItems,
|
||||
}) => (
|
||||
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||
<span className="font-medium">{title}</span>
|
||||
<div className="mt-2 grid grid-cols-2 items-start gap-4">
|
||||
{/* Left column - Breaking changes */}
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||
{leftIcon}
|
||||
<span>{leftTitle}</span>
|
||||
</div>
|
||||
<ul className="mt-1.5 space-y-1">
|
||||
{leftItems.length > 0 ? (
|
||||
leftItems.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-red-50 px-1 py-0.5 font-mono text-xs text-red-700 dark:bg-red-900/30 dark:text-red-300">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))
|
||||
) : (
|
||||
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||
None
|
||||
</li>
|
||||
)}
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
{/* Right column - Possible solutions */}
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||
{rightIcon}
|
||||
<span>{rightTitle}</span>
|
||||
</div>
|
||||
<ul className="mt-1.5 space-y-1">
|
||||
{rightItems.length > 0 ? (
|
||||
rightItems.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-green-50 px-1 py-0.5 font-mono text-xs text-green-700 dark:bg-green-900/30 dark:text-green-300">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))
|
||||
) : (
|
||||
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||
None
|
||||
</li>
|
||||
)}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
interface SingleColumnSectionProps {
|
||||
icon: React.ReactNode;
|
||||
title: string;
|
||||
description: string;
|
||||
items: string[];
|
||||
}
|
||||
|
||||
const SingleColumnSection: React.FC<SingleColumnSectionProps> = ({
|
||||
icon,
|
||||
title,
|
||||
description,
|
||||
items,
|
||||
}) => (
|
||||
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||
<div className="flex items-center gap-2">
|
||||
{icon}
|
||||
<span className="font-medium">{title}</span>
|
||||
</div>
|
||||
<p className="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||
{description}
|
||||
</p>
|
||||
<ul className="mt-2 space-y-1">
|
||||
{items.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="ml-4 list-disc text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-gray-100 px-1 py-0.5 font-mono text-xs dark:bg-gray-800">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
);
|
||||
|
||||
export default IncompatibilityDialog;
|
||||
@@ -1,130 +0,0 @@
|
||||
import React from "react";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { ArrowUp, AlertTriangle, Info } from "lucide-react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface SubAgentUpdateBarProps {
|
||||
currentVersion: number;
|
||||
latestVersion: number;
|
||||
isCompatible: boolean;
|
||||
incompatibilities: IncompatibilityInfo | null;
|
||||
onUpdate: () => void;
|
||||
isInResolutionMode?: boolean;
|
||||
}
|
||||
|
||||
export const SubAgentUpdateBar: React.FC<SubAgentUpdateBarProps> = ({
|
||||
currentVersion,
|
||||
latestVersion,
|
||||
isCompatible,
|
||||
incompatibilities,
|
||||
onUpdate,
|
||||
isInResolutionMode = false,
|
||||
}) => {
|
||||
if (isInResolutionMode) {
|
||||
return <ResolutionModeBar incompatibilities={incompatibilities} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 rounded-t-lg bg-blue-50 px-3 py-2 dark:bg-blue-900/30">
|
||||
<div className="flex items-center gap-2">
|
||||
<ArrowUp className="h-4 w-4 text-blue-600 dark:text-blue-400" />
|
||||
<span className="text-sm text-blue-700 dark:text-blue-300">
|
||||
Update available (v{currentVersion} → v{latestVersion})
|
||||
</span>
|
||||
{!isCompatible && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<AlertTriangle className="h-4 w-4 text-amber-500" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="max-w-xs">
|
||||
<p className="font-medium">Incompatible changes detected</p>
|
||||
<p className="text-xs text-gray-400">
|
||||
Click Update to see details
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
size="sm"
|
||||
variant={isCompatible ? "default" : "outline"}
|
||||
onClick={onUpdate}
|
||||
className={cn(
|
||||
"h-7 text-xs",
|
||||
!isCompatible && "border-amber-500 text-amber-600 hover:bg-amber-50",
|
||||
)}
|
||||
>
|
||||
Update
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface ResolutionModeBarProps {
|
||||
incompatibilities: IncompatibilityInfo | null;
|
||||
}
|
||||
|
||||
const ResolutionModeBar: React.FC<ResolutionModeBarProps> = ({
|
||||
incompatibilities,
|
||||
}) => {
|
||||
const formatIncompatibilities = () => {
|
||||
if (!incompatibilities) return "No incompatibilities";
|
||||
|
||||
const items: string[] = [];
|
||||
|
||||
if (incompatibilities.missingInputs.length > 0) {
|
||||
items.push(
|
||||
`Missing inputs: ${incompatibilities.missingInputs.join(", ")}`,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.missingOutputs.length > 0) {
|
||||
items.push(
|
||||
`Missing outputs: ${incompatibilities.missingOutputs.join(", ")}`,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.newRequiredInputs.length > 0) {
|
||||
items.push(
|
||||
`New required inputs: ${incompatibilities.newRequiredInputs.join(", ")}`,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.inputTypeMismatches.length > 0) {
|
||||
const mismatches = incompatibilities.inputTypeMismatches
|
||||
.map((m) => `${m.name} (${m.oldType} → ${m.newType})`)
|
||||
.join(", ");
|
||||
items.push(`Type changed: ${mismatches}`);
|
||||
}
|
||||
|
||||
return items.join("\n");
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 rounded-t-lg bg-amber-50 px-3 py-2 dark:bg-amber-900/30">
|
||||
<div className="flex items-center gap-2">
|
||||
<AlertTriangle className="h-4 w-4 text-amber-600 dark:text-amber-400" />
|
||||
<span className="text-sm text-amber-700 dark:text-amber-300">
|
||||
Remove incompatible connections
|
||||
</span>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Info className="h-4 w-4 cursor-help text-amber-500" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="max-w-sm whitespace-pre-line">
|
||||
<p className="font-medium">Incompatible changes:</p>
|
||||
<p className="mt-1 text-xs">{formatIncompatibilities()}</p>
|
||||
<p className="mt-2 text-xs text-gray-400">
|
||||
Delete the red connections to continue
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default SubAgentUpdateBar;
|
||||
@@ -26,17 +26,15 @@ import {
|
||||
applyNodeChanges,
|
||||
} from "@xyflow/react";
|
||||
import "@xyflow/react/dist/style.css";
|
||||
import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode";
|
||||
import { CustomNode } from "../CustomNode/CustomNode";
|
||||
import "./flow.css";
|
||||
import {
|
||||
BlockUIType,
|
||||
formatEdgeID,
|
||||
GraphExecutionID,
|
||||
GraphID,
|
||||
GraphMeta,
|
||||
LibraryAgent,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
|
||||
import { history } from "../history";
|
||||
@@ -74,30 +72,12 @@ import { FloatingSafeModeToggle } from "../../FloatingSafeModeToogle";
|
||||
// It helps to prevent spamming the history with small movements especially when pressing on a input in a block
|
||||
const MINIMUM_MOVE_BEFORE_LOG = 50;
|
||||
|
||||
export type ResolutionModeState = {
|
||||
active: boolean;
|
||||
nodeId: string | null;
|
||||
incompatibilities: IncompatibilityInfo | null;
|
||||
brokenEdgeIds: string[];
|
||||
pendingUpdate: Record<string, unknown> | null; // The hardcoded values to apply after resolution
|
||||
};
|
||||
|
||||
type BuilderContextType = {
|
||||
libraryAgent: LibraryAgent | null;
|
||||
visualizeBeads: "no" | "static" | "animate";
|
||||
setIsAnyModalOpen: (isOpen: boolean) => void;
|
||||
getNextNodeId: () => string;
|
||||
getNodeTitle: (nodeID: string) => string | null;
|
||||
availableFlows: GraphMeta[];
|
||||
resolutionMode: ResolutionModeState;
|
||||
enterResolutionMode: (
|
||||
nodeId: string,
|
||||
incompatibilities: IncompatibilityInfo,
|
||||
brokenEdgeIds: string[],
|
||||
pendingUpdate: Record<string, unknown>,
|
||||
) => void;
|
||||
exitResolutionMode: () => void;
|
||||
applyPendingUpdate: () => void;
|
||||
};
|
||||
|
||||
export type NodeDimension = {
|
||||
@@ -192,92 +172,6 @@ const FlowEditor: React.FC<{
|
||||
// It stores the dimension of all nodes with position as well
|
||||
const [nodeDimensions, setNodeDimensions] = useState<NodeDimension>({});
|
||||
|
||||
// Resolution mode state for sub-agent incompatible updates
|
||||
const [resolutionMode, setResolutionMode] = useState<ResolutionModeState>({
|
||||
active: false,
|
||||
nodeId: null,
|
||||
incompatibilities: null,
|
||||
brokenEdgeIds: [],
|
||||
pendingUpdate: null,
|
||||
});
|
||||
|
||||
const enterResolutionMode = useCallback(
|
||||
(
|
||||
nodeId: string,
|
||||
incompatibilities: IncompatibilityInfo,
|
||||
brokenEdgeIds: string[],
|
||||
pendingUpdate: Record<string, unknown>,
|
||||
) => {
|
||||
setResolutionMode({
|
||||
active: true,
|
||||
nodeId,
|
||||
incompatibilities,
|
||||
brokenEdgeIds,
|
||||
pendingUpdate,
|
||||
});
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const exitResolutionMode = useCallback(() => {
|
||||
setResolutionMode({
|
||||
active: false,
|
||||
nodeId: null,
|
||||
incompatibilities: null,
|
||||
brokenEdgeIds: [],
|
||||
pendingUpdate: null,
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Apply pending update after resolution mode completes
|
||||
const applyPendingUpdate = useCallback(() => {
|
||||
if (!resolutionMode.nodeId || !resolutionMode.pendingUpdate) return;
|
||||
|
||||
const node = nodes.find((n) => n.id === resolutionMode.nodeId);
|
||||
if (node) {
|
||||
const pendingUpdate = resolutionMode.pendingUpdate as {
|
||||
[key: string]: any;
|
||||
};
|
||||
setNodes((nds) =>
|
||||
nds.map((n) =>
|
||||
n.id === resolutionMode.nodeId
|
||||
? { ...n, data: { ...n.data, hardcodedValues: pendingUpdate } }
|
||||
: n,
|
||||
),
|
||||
);
|
||||
}
|
||||
exitResolutionMode();
|
||||
toast({
|
||||
title: "Update complete",
|
||||
description: "Agent has been updated to the new version.",
|
||||
});
|
||||
}, [resolutionMode, nodes, setNodes, exitResolutionMode, toast]);
|
||||
|
||||
// Check if all broken edges have been removed and auto-apply pending update
|
||||
useEffect(() => {
|
||||
if (!resolutionMode.active || resolutionMode.brokenEdgeIds.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentEdgeIds = new Set(edges.map((e) => e.id));
|
||||
const remainingBrokenEdges = resolutionMode.brokenEdgeIds.filter((id) =>
|
||||
currentEdgeIds.has(id),
|
||||
);
|
||||
|
||||
if (remainingBrokenEdges.length === 0) {
|
||||
// All broken edges have been removed, apply pending update
|
||||
applyPendingUpdate();
|
||||
} else if (
|
||||
remainingBrokenEdges.length !== resolutionMode.brokenEdgeIds.length
|
||||
) {
|
||||
// Update the list of broken edges
|
||||
setResolutionMode((prev) => ({
|
||||
...prev,
|
||||
brokenEdgeIds: remainingBrokenEdges,
|
||||
}));
|
||||
}
|
||||
}, [edges, resolutionMode, applyPendingUpdate]);
|
||||
|
||||
// Set page title with or without graph name
|
||||
useEffect(() => {
|
||||
document.title = savedAgent
|
||||
@@ -537,19 +431,17 @@ const FlowEditor: React.FC<{
|
||||
...node.data.connections.filter(
|
||||
(conn) =>
|
||||
!removedEdges.some(
|
||||
(removedEdge) => removedEdge.id === conn.id,
|
||||
(removedEdge) => removedEdge.id === conn.edge_id,
|
||||
),
|
||||
),
|
||||
// Add node connections for added edges
|
||||
...addedEdges.map(
|
||||
(addedEdge): ConnectedEdge => ({
|
||||
id: addedEdge.item.id,
|
||||
source: addedEdge.item.source,
|
||||
target: addedEdge.item.target,
|
||||
sourceHandle: addedEdge.item.sourceHandle!,
|
||||
targetHandle: addedEdge.item.targetHandle!,
|
||||
}),
|
||||
),
|
||||
...addedEdges.map((addedEdge) => ({
|
||||
edge_id: addedEdge.item.id,
|
||||
source: addedEdge.item.source,
|
||||
target: addedEdge.item.target,
|
||||
sourceHandle: addedEdge.item.sourceHandle!,
|
||||
targetHandle: addedEdge.item.targetHandle!,
|
||||
})),
|
||||
],
|
||||
},
|
||||
}));
|
||||
@@ -575,15 +467,13 @@ const FlowEditor: React.FC<{
|
||||
data: {
|
||||
...node.data,
|
||||
connections: [
|
||||
...replaceEdges.map(
|
||||
(replaceEdge): ConnectedEdge => ({
|
||||
id: replaceEdge.item.id,
|
||||
source: replaceEdge.item.source,
|
||||
target: replaceEdge.item.target,
|
||||
sourceHandle: replaceEdge.item.sourceHandle!,
|
||||
targetHandle: replaceEdge.item.targetHandle!,
|
||||
}),
|
||||
),
|
||||
...replaceEdges.map((replaceEdge) => ({
|
||||
edge_id: replaceEdge.item.id,
|
||||
source: replaceEdge.item.source,
|
||||
target: replaceEdge.item.target,
|
||||
sourceHandle: replaceEdge.item.sourceHandle!,
|
||||
targetHandle: replaceEdge.item.targetHandle!,
|
||||
})),
|
||||
],
|
||||
},
|
||||
})),
|
||||
@@ -1000,23 +890,8 @@ const FlowEditor: React.FC<{
|
||||
setIsAnyModalOpen,
|
||||
getNextNodeId,
|
||||
getNodeTitle,
|
||||
availableFlows,
|
||||
resolutionMode,
|
||||
enterResolutionMode,
|
||||
exitResolutionMode,
|
||||
applyPendingUpdate,
|
||||
}),
|
||||
[
|
||||
libraryAgent,
|
||||
visualizeBeads,
|
||||
getNextNodeId,
|
||||
getNodeTitle,
|
||||
availableFlows,
|
||||
resolutionMode,
|
||||
enterResolutionMode,
|
||||
applyPendingUpdate,
|
||||
exitResolutionMode,
|
||||
],
|
||||
[libraryAgent, visualizeBeads, getNextNodeId, getNodeTitle],
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -1116,7 +991,6 @@ const FlowEditor: React.FC<{
|
||||
onClickScheduleButton={handleScheduleButton}
|
||||
isDisabled={!savedAgent}
|
||||
isRunning={isRunning}
|
||||
resolutionModeActive={resolutionMode.active}
|
||||
/>
|
||||
) : (
|
||||
<Alert className="absolute bottom-4 left-1/2 z-20 w-auto -translate-x-1/2 select-none">
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
import { BlockIOSubSchema } from "@/lib/autogpt-server-api/types";
|
||||
import {
|
||||
cn,
|
||||
beautifyString,
|
||||
getTypeBgColor,
|
||||
getTypeTextColor,
|
||||
getEffectiveType,
|
||||
} from "@/lib/utils";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { beautifyString, getTypeBgColor, getTypeTextColor } from "@/lib/utils";
|
||||
import { FC, memo, useCallback } from "react";
|
||||
import { Handle, Position } from "@xyflow/react";
|
||||
import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip";
|
||||
@@ -18,7 +13,6 @@ type HandleProps = {
|
||||
side: "left" | "right";
|
||||
title?: string;
|
||||
className?: string;
|
||||
isBroken?: boolean;
|
||||
};
|
||||
|
||||
// Move the constant out of the component to avoid re-creation on every render.
|
||||
@@ -33,23 +27,18 @@ const TYPE_NAME: Record<string, string> = {
|
||||
};
|
||||
|
||||
// Extract and memoize the Dot component so that it doesn't re-render unnecessarily.
|
||||
const Dot: FC<{ isConnected: boolean; type?: string; isBroken?: boolean }> =
|
||||
memo(({ isConnected, type, isBroken }) => {
|
||||
const color = isBroken
|
||||
? "border-red-500 bg-red-100 dark:bg-red-900/30"
|
||||
: isConnected
|
||||
? getTypeBgColor(type || "any")
|
||||
: "border-gray-300 dark:border-gray-600";
|
||||
const Dot: FC<{ isConnected: boolean; type?: string }> = memo(
|
||||
({ isConnected, type }) => {
|
||||
const color = isConnected
|
||||
? getTypeBgColor(type || "any")
|
||||
: "border-gray-300 dark:border-gray-600";
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"m-1 h-4 w-4 rounded-full border-2 bg-white transition-colors duration-100 group-hover:bg-gray-300 dark:bg-slate-800 dark:group-hover:bg-gray-700",
|
||||
color,
|
||||
isBroken && "opacity-50",
|
||||
)}
|
||||
className={`${color} m-1 h-4 w-4 rounded-full border-2 bg-white transition-colors duration-100 group-hover:bg-gray-300 dark:bg-slate-800 dark:group-hover:bg-gray-700`}
|
||||
/>
|
||||
);
|
||||
});
|
||||
},
|
||||
);
|
||||
Dot.displayName = "Dot";
|
||||
|
||||
const NodeHandle: FC<HandleProps> = ({
|
||||
@@ -60,34 +49,24 @@ const NodeHandle: FC<HandleProps> = ({
|
||||
side,
|
||||
title,
|
||||
className,
|
||||
isBroken = false,
|
||||
}) => {
|
||||
// Extract effective type from schema (handles anyOf/oneOf/allOf wrappers)
|
||||
const effectiveType = getEffectiveType(schema);
|
||||
|
||||
const typeClass = `text-sm ${getTypeTextColor(effectiveType || "any")} ${
|
||||
const typeClass = `text-sm ${getTypeTextColor(schema.type || "any")} ${
|
||||
side === "left" ? "text-left" : "text-right"
|
||||
}`;
|
||||
|
||||
const label = (
|
||||
<div className={cn("flex flex-grow flex-row", isBroken && "opacity-50")}>
|
||||
<div className="flex flex-grow flex-row">
|
||||
<span
|
||||
className={cn(
|
||||
"data-sentry-unmask text-m green flex items-end pr-2 text-gray-900 dark:text-gray-100",
|
||||
className,
|
||||
isBroken && "text-red-500 line-through",
|
||||
)}
|
||||
>
|
||||
{title || schema.title || beautifyString(keyName.toLowerCase())}
|
||||
{isRequired ? "*" : ""}
|
||||
</span>
|
||||
<span
|
||||
className={cn(
|
||||
`${typeClass} data-sentry-unmask flex items-end`,
|
||||
isBroken && "text-red-400",
|
||||
)}
|
||||
>
|
||||
({TYPE_NAME[effectiveType as keyof typeof TYPE_NAME] || "any"})
|
||||
<span className={`${typeClass} data-sentry-unmask flex items-end`}>
|
||||
({TYPE_NAME[schema.type as keyof typeof TYPE_NAME] || "any"})
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
@@ -105,7 +84,7 @@ const NodeHandle: FC<HandleProps> = ({
|
||||
return (
|
||||
<div
|
||||
key={keyName}
|
||||
className={cn("handle-container", isBroken && "pointer-events-none")}
|
||||
className="handle-container"
|
||||
onContextMenu={handleContextMenu}
|
||||
>
|
||||
<Handle
|
||||
@@ -113,15 +92,10 @@ const NodeHandle: FC<HandleProps> = ({
|
||||
data-testid={`input-handle-${keyName}`}
|
||||
position={Position.Left}
|
||||
id={keyName}
|
||||
className={cn("group -ml-[38px]", isBroken && "cursor-not-allowed")}
|
||||
isConnectable={!isBroken}
|
||||
className="group -ml-[38px]"
|
||||
>
|
||||
<div className="pointer-events-none flex items-center">
|
||||
<Dot
|
||||
isConnected={isConnected}
|
||||
type={effectiveType}
|
||||
isBroken={isBroken}
|
||||
/>
|
||||
<Dot isConnected={isConnected} type={schema.type} />
|
||||
{label}
|
||||
</div>
|
||||
</Handle>
|
||||
@@ -132,10 +106,7 @@ const NodeHandle: FC<HandleProps> = ({
|
||||
return (
|
||||
<div
|
||||
key={keyName}
|
||||
className={cn(
|
||||
"handle-container justify-end",
|
||||
isBroken && "pointer-events-none",
|
||||
)}
|
||||
className="handle-container justify-end"
|
||||
onContextMenu={handleContextMenu}
|
||||
>
|
||||
<Handle
|
||||
@@ -143,16 +114,11 @@ const NodeHandle: FC<HandleProps> = ({
|
||||
data-testid={`output-handle-${keyName}`}
|
||||
position={Position.Right}
|
||||
id={keyName}
|
||||
className={cn("group -mr-[38px]", isBroken && "cursor-not-allowed")}
|
||||
isConnectable={!isBroken}
|
||||
className="group -mr-[38px]"
|
||||
>
|
||||
<div className="pointer-events-none flex items-center">
|
||||
{label}
|
||||
<Dot
|
||||
isConnected={isConnected}
|
||||
type={effectiveType}
|
||||
isBroken={isBroken}
|
||||
/>
|
||||
<Dot isConnected={isConnected} type={schema.type} />
|
||||
</div>
|
||||
</Handle>
|
||||
</div>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import {
|
||||
ConnectedEdge,
|
||||
ConnectionData,
|
||||
CustomNodeData,
|
||||
} from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
|
||||
import { NodeTableInput } from "@/app/(platform)/build/components/legacy-builder/NodeTableInput";
|
||||
@@ -65,7 +65,7 @@ type NodeObjectInputTreeProps = {
|
||||
selfKey?: string;
|
||||
schema: BlockIORootSchema | BlockIOObjectSubSchema;
|
||||
object?: { [key: string]: any };
|
||||
connections: ConnectedEdge[];
|
||||
connections: ConnectionData;
|
||||
handleInputClick: (key: string) => void;
|
||||
handleInputChange: (key: string, value: any) => void;
|
||||
errors: { [key: string]: string | undefined };
|
||||
@@ -585,7 +585,7 @@ const NodeOneOfDiscriminatorField: FC<{
|
||||
currentValue?: any;
|
||||
defaultValue?: any;
|
||||
errors: { [key: string]: string | undefined };
|
||||
connections: ConnectedEdge[];
|
||||
connections: ConnectionData;
|
||||
handleInputChange: (key: string, value: any) => void;
|
||||
handleInputClick: (key: string) => void;
|
||||
className?: string;
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import { FC, useCallback, useEffect, useState } from "react";
|
||||
|
||||
import NodeHandle from "@/app/(platform)/build/components/legacy-builder/NodeHandle";
|
||||
import type {
|
||||
import {
|
||||
BlockIOTableSubSchema,
|
||||
TableCellValue,
|
||||
TableRow,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import type { ConnectedEdge } from "./CustomNode/CustomNode";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { PlusIcon, XIcon } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Button } from "../../../../../components/atoms/Button/Button";
|
||||
import { Input } from "../../../../../components/atoms/Input/Input";
|
||||
|
||||
interface NodeTableInputProps {
|
||||
/** Unique identifier for the node in the builder graph */
|
||||
@@ -26,7 +25,13 @@ interface NodeTableInputProps {
|
||||
/** Validation errors mapped by field key */
|
||||
errors: { [key: string]: string | undefined };
|
||||
/** Graph connections between nodes in the builder */
|
||||
connections: ConnectedEdge[];
|
||||
connections: {
|
||||
edge_id: string;
|
||||
source: string;
|
||||
sourceHandle: string;
|
||||
target: string;
|
||||
targetHandle: string;
|
||||
}[];
|
||||
/** Callback when table data changes */
|
||||
handleInputChange: (key: string, value: TableRow[]) => void;
|
||||
/** Callback when input field is clicked (for builder selection) */
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { useCallback } from "react";
|
||||
import { Node, Edge, useReactFlow } from "@xyflow/react";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { ConnectedEdge } from "./CustomNode/CustomNode";
|
||||
|
||||
interface CopyableData {
|
||||
nodes: Node[];
|
||||
@@ -112,15 +111,13 @@ export function useCopyPaste(getNextNodeId: () => string) {
|
||||
(edge: Edge) =>
|
||||
edge.source === node.id || edge.target === node.id,
|
||||
)
|
||||
.map(
|
||||
(edge: Edge): ConnectedEdge => ({
|
||||
id: edge.id,
|
||||
source: edge.source,
|
||||
target: edge.target,
|
||||
sourceHandle: edge.sourceHandle!,
|
||||
targetHandle: edge.targetHandle!,
|
||||
}),
|
||||
);
|
||||
.map((edge: Edge) => ({
|
||||
edge_id: edge.id,
|
||||
source: edge.source,
|
||||
target: edge.target,
|
||||
sourceHandle: edge.sourceHandle,
|
||||
targetHandle: edge.targetHandle,
|
||||
}));
|
||||
|
||||
return {
|
||||
...node,
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
import { GraphInputSchema } from "@/lib/autogpt-server-api";
|
||||
import { GraphMetaLike, IncompatibilityInfo } from "./types";
|
||||
|
||||
// Helper type for schema properties - the generated types are too loose
|
||||
type SchemaProperties = Record<string, GraphInputSchema["properties"][string]>;
|
||||
type SchemaRequired = string[];
|
||||
|
||||
// Helper to safely extract schema properties
|
||||
export function getSchemaProperties(schema: unknown): SchemaProperties {
|
||||
if (
|
||||
schema &&
|
||||
typeof schema === "object" &&
|
||||
"properties" in schema &&
|
||||
typeof schema.properties === "object" &&
|
||||
schema.properties !== null
|
||||
) {
|
||||
return schema.properties as SchemaProperties;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
export function getSchemaRequired(schema: unknown): SchemaRequired {
|
||||
if (
|
||||
schema &&
|
||||
typeof schema === "object" &&
|
||||
"required" in schema &&
|
||||
Array.isArray(schema.required)
|
||||
) {
|
||||
return schema.required as SchemaRequired;
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the updated agent node inputs for a sub-agent node
|
||||
*/
|
||||
export function createUpdatedAgentNodeInputs(
|
||||
currentInputs: Record<string, unknown>,
|
||||
latestSubGraphVersion: GraphMetaLike,
|
||||
): Record<string, unknown> {
|
||||
return {
|
||||
...currentInputs,
|
||||
graph_version: latestSubGraphVersion.version,
|
||||
input_schema: latestSubGraphVersion.input_schema,
|
||||
output_schema: latestSubGraphVersion.output_schema,
|
||||
};
|
||||
}
|
||||
|
||||
/** Generic edge type that works with both builders:
|
||||
* - New builder uses CustomEdge with (formally) optional handles
|
||||
* - Legacy builder uses ConnectedEdge type with required handles */
|
||||
export type EdgeLike = {
|
||||
id: string;
|
||||
source: string;
|
||||
target: string;
|
||||
sourceHandle?: string | null;
|
||||
targetHandle?: string | null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Determines which edges are broken after an incompatible update.
|
||||
* Works with both legacy ConnectedEdge and new CustomEdge.
|
||||
*/
|
||||
export function getBrokenEdgeIDs(
|
||||
connections: EdgeLike[],
|
||||
incompatibilities: IncompatibilityInfo,
|
||||
nodeID: string,
|
||||
): string[] {
|
||||
const brokenEdgeIDs: string[] = [];
|
||||
const typeMismatchInputNames = new Set(
|
||||
incompatibilities.inputTypeMismatches.map((m) => m.name),
|
||||
);
|
||||
|
||||
connections.forEach((conn) => {
|
||||
// Check if this connection uses a missing input (node is target)
|
||||
if (
|
||||
conn.target === nodeID &&
|
||||
conn.targetHandle &&
|
||||
incompatibilities.missingInputs.includes(conn.targetHandle)
|
||||
) {
|
||||
brokenEdgeIDs.push(conn.id);
|
||||
}
|
||||
|
||||
// Check if this connection uses an input with a type mismatch (node is target)
|
||||
if (
|
||||
conn.target === nodeID &&
|
||||
conn.targetHandle &&
|
||||
typeMismatchInputNames.has(conn.targetHandle)
|
||||
) {
|
||||
brokenEdgeIDs.push(conn.id);
|
||||
}
|
||||
|
||||
// Check if this connection uses a missing output (node is source)
|
||||
if (
|
||||
conn.source === nodeID &&
|
||||
conn.sourceHandle &&
|
||||
incompatibilities.missingOutputs.includes(conn.sourceHandle)
|
||||
) {
|
||||
brokenEdgeIDs.push(conn.id);
|
||||
}
|
||||
});
|
||||
|
||||
return brokenEdgeIDs;
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
export { useSubAgentUpdate } from "./useSubAgentUpdate";
|
||||
export { createUpdatedAgentNodeInputs, getBrokenEdgeIDs } from "./helpers";
|
||||
@@ -1,27 +0,0 @@
|
||||
import type { GraphMeta as LegacyGraphMeta } from "@/lib/autogpt-server-api";
|
||||
import type { GraphMeta as GeneratedGraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||
|
||||
export type SubAgentUpdateInfo<T extends GraphMetaLike = GraphMetaLike> = {
|
||||
hasUpdate: boolean;
|
||||
currentVersion: number;
|
||||
latestVersion: number;
|
||||
latestGraph: T | null;
|
||||
isCompatible: boolean;
|
||||
incompatibilities: IncompatibilityInfo | null;
|
||||
};
|
||||
|
||||
// Union type for GraphMeta that works with both legacy and new builder
|
||||
export type GraphMetaLike = LegacyGraphMeta | GeneratedGraphMeta;
|
||||
|
||||
export type IncompatibilityInfo = {
|
||||
missingInputs: string[]; // Connected inputs that no longer exist
|
||||
missingOutputs: string[]; // Connected outputs that no longer exist
|
||||
newInputs: string[]; // Inputs that exist in new version but not in current
|
||||
newOutputs: string[]; // Outputs that exist in new version but not in current
|
||||
newRequiredInputs: string[]; // New required inputs not in current version or not required
|
||||
inputTypeMismatches: Array<{
|
||||
name: string;
|
||||
oldType: string;
|
||||
newType: string;
|
||||
}>; // Connected inputs where the type has changed
|
||||
};
|
||||
@@ -1,160 +0,0 @@
|
||||
import { useMemo } from "react";
|
||||
import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api";
|
||||
import { getEffectiveType } from "@/lib/utils";
|
||||
import { EdgeLike, getSchemaProperties, getSchemaRequired } from "./helpers";
|
||||
import {
|
||||
GraphMetaLike,
|
||||
IncompatibilityInfo,
|
||||
SubAgentUpdateInfo,
|
||||
} from "./types";
|
||||
|
||||
/**
|
||||
* Checks if a newer version of a sub-agent is available and determines compatibility
|
||||
*/
|
||||
export function useSubAgentUpdate<T extends GraphMetaLike>(
|
||||
nodeID: string,
|
||||
graphID: string | undefined,
|
||||
graphVersion: number | undefined,
|
||||
currentInputSchema: GraphInputSchema | undefined,
|
||||
currentOutputSchema: GraphOutputSchema | undefined,
|
||||
connections: EdgeLike[],
|
||||
availableGraphs: T[],
|
||||
): SubAgentUpdateInfo<T> {
|
||||
// Find the latest version of the same graph
|
||||
const latestGraph = useMemo(() => {
|
||||
if (!graphID) return null;
|
||||
return availableGraphs.find((graph) => graph.id === graphID) || null;
|
||||
}, [graphID, availableGraphs]);
|
||||
|
||||
// Check if there's an update available
|
||||
const hasUpdate = useMemo(() => {
|
||||
if (!latestGraph || graphVersion === undefined) return false;
|
||||
return latestGraph.version! > graphVersion;
|
||||
}, [latestGraph, graphVersion]);
|
||||
|
||||
// Get connected input and output handles for this specific node
|
||||
const connectedHandles = useMemo(() => {
|
||||
const inputHandles = new Set<string>();
|
||||
const outputHandles = new Set<string>();
|
||||
|
||||
connections.forEach((conn) => {
|
||||
// If this node is the target, the targetHandle is an input on this node
|
||||
if (conn.target === nodeID && conn.targetHandle) {
|
||||
inputHandles.add(conn.targetHandle);
|
||||
}
|
||||
// If this node is the source, the sourceHandle is an output on this node
|
||||
if (conn.source === nodeID && conn.sourceHandle) {
|
||||
outputHandles.add(conn.sourceHandle);
|
||||
}
|
||||
});
|
||||
|
||||
return { inputHandles, outputHandles };
|
||||
}, [connections, nodeID]);
|
||||
|
||||
// Check schema compatibility
|
||||
const compatibilityResult = useMemo((): {
|
||||
isCompatible: boolean;
|
||||
incompatibilities: IncompatibilityInfo | null;
|
||||
} => {
|
||||
if (!hasUpdate || !latestGraph) {
|
||||
return { isCompatible: true, incompatibilities: null };
|
||||
}
|
||||
|
||||
const newInputProps = getSchemaProperties(latestGraph.input_schema);
|
||||
const newOutputProps = getSchemaProperties(latestGraph.output_schema);
|
||||
const newRequiredInputs = getSchemaRequired(latestGraph.input_schema);
|
||||
|
||||
const currentInputProps = getSchemaProperties(currentInputSchema);
|
||||
const currentOutputProps = getSchemaProperties(currentOutputSchema);
|
||||
const currentRequiredInputs = getSchemaRequired(currentInputSchema);
|
||||
|
||||
const incompatibilities: IncompatibilityInfo = {
|
||||
missingInputs: [],
|
||||
missingOutputs: [],
|
||||
newInputs: [],
|
||||
newOutputs: [],
|
||||
newRequiredInputs: [],
|
||||
inputTypeMismatches: [],
|
||||
};
|
||||
|
||||
// Check for missing connected inputs and type mismatches
|
||||
connectedHandles.inputHandles.forEach((inputHandle) => {
|
||||
if (!(inputHandle in newInputProps)) {
|
||||
incompatibilities.missingInputs.push(inputHandle);
|
||||
} else {
|
||||
// Check for type mismatch on connected inputs
|
||||
const currentProp = currentInputProps[inputHandle];
|
||||
const newProp = newInputProps[inputHandle];
|
||||
const currentType = getEffectiveType(currentProp);
|
||||
const newType = getEffectiveType(newProp);
|
||||
|
||||
if (currentType && newType && currentType !== newType) {
|
||||
incompatibilities.inputTypeMismatches.push({
|
||||
name: inputHandle,
|
||||
oldType: currentType,
|
||||
newType: newType,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Check for missing connected outputs
|
||||
connectedHandles.outputHandles.forEach((outputHandle) => {
|
||||
if (!(outputHandle in newOutputProps)) {
|
||||
incompatibilities.missingOutputs.push(outputHandle);
|
||||
}
|
||||
});
|
||||
|
||||
// Check for new required inputs that didn't exist or weren't required before
|
||||
newRequiredInputs.forEach((requiredInput) => {
|
||||
const existedBefore = requiredInput in currentInputProps;
|
||||
const wasRequiredBefore = currentRequiredInputs.includes(
|
||||
requiredInput as string,
|
||||
);
|
||||
|
||||
if (!existedBefore || !wasRequiredBefore) {
|
||||
incompatibilities.newRequiredInputs.push(requiredInput as string);
|
||||
}
|
||||
});
|
||||
|
||||
// Check for new inputs that don't exist in the current version
|
||||
Object.keys(newInputProps).forEach((inputName) => {
|
||||
if (!(inputName in currentInputProps)) {
|
||||
incompatibilities.newInputs.push(inputName);
|
||||
}
|
||||
});
|
||||
|
||||
// Check for new outputs that don't exist in the current version
|
||||
Object.keys(newOutputProps).forEach((outputName) => {
|
||||
if (!(outputName in currentOutputProps)) {
|
||||
incompatibilities.newOutputs.push(outputName);
|
||||
}
|
||||
});
|
||||
|
||||
const hasIncompatibilities =
|
||||
incompatibilities.missingInputs.length > 0 ||
|
||||
incompatibilities.missingOutputs.length > 0 ||
|
||||
incompatibilities.newRequiredInputs.length > 0 ||
|
||||
incompatibilities.inputTypeMismatches.length > 0;
|
||||
|
||||
return {
|
||||
isCompatible: !hasIncompatibilities,
|
||||
incompatibilities: hasIncompatibilities ? incompatibilities : null,
|
||||
};
|
||||
}, [
|
||||
hasUpdate,
|
||||
latestGraph,
|
||||
currentInputSchema,
|
||||
currentOutputSchema,
|
||||
connectedHandles,
|
||||
]);
|
||||
|
||||
return {
|
||||
hasUpdate,
|
||||
currentVersion: graphVersion || 0,
|
||||
latestVersion: latestGraph?.version || 0,
|
||||
latestGraph,
|
||||
isCompatible: compatibilityResult.isCompatible,
|
||||
incompatibilities: compatibilityResult.incompatibilities,
|
||||
};
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
import { create } from "zustand";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||
|
||||
interface GraphStore {
|
||||
graphExecutionStatus: AgentExecutionStatus | undefined;
|
||||
@@ -18,10 +17,6 @@ interface GraphStore {
|
||||
outputSchema: Record<string, any> | null,
|
||||
) => void;
|
||||
|
||||
// Available graphs; used for sub-graph updates
|
||||
availableSubGraphs: GraphMeta[];
|
||||
setAvailableSubGraphs: (graphs: GraphMeta[]) => void;
|
||||
|
||||
hasInputs: () => boolean;
|
||||
hasCredentials: () => boolean;
|
||||
hasOutputs: () => boolean;
|
||||
@@ -34,7 +29,6 @@ export const useGraphStore = create<GraphStore>((set, get) => ({
|
||||
inputSchema: null,
|
||||
credentialsInputSchema: null,
|
||||
outputSchema: null,
|
||||
availableSubGraphs: [],
|
||||
|
||||
setGraphExecutionStatus: (status: AgentExecutionStatus | undefined) => {
|
||||
set({
|
||||
@@ -52,8 +46,6 @@ export const useGraphStore = create<GraphStore>((set, get) => ({
|
||||
setGraphSchemas: (inputSchema, credentialsInputSchema, outputSchema) =>
|
||||
set({ inputSchema, credentialsInputSchema, outputSchema }),
|
||||
|
||||
setAvailableSubGraphs: (graphs) => set({ availableSubGraphs: graphs }),
|
||||
|
||||
hasOutputs: () => {
|
||||
const { outputSchema } = get();
|
||||
return Object.keys(outputSchema?.properties ?? {}).length > 0;
|
||||
|
||||
@@ -17,25 +17,6 @@ import {
|
||||
ensurePathExists,
|
||||
parseHandleIdToPath,
|
||||
} from "@/components/renderers/InputRenderer/helpers";
|
||||
import { IncompatibilityInfo } from "../hooks/useSubAgentUpdate/types";
|
||||
|
||||
// Resolution mode data stored per node
|
||||
export type NodeResolutionData = {
|
||||
incompatibilities: IncompatibilityInfo;
|
||||
// The NEW schema from the update (what we're updating TO)
|
||||
pendingUpdate: {
|
||||
input_schema: Record<string, unknown>;
|
||||
output_schema: Record<string, unknown>;
|
||||
};
|
||||
// The OLD schema before the update (what we're updating FROM)
|
||||
// Needed to merge and show removed inputs during resolution
|
||||
currentSchema: {
|
||||
input_schema: Record<string, unknown>;
|
||||
output_schema: Record<string, unknown>;
|
||||
};
|
||||
// The full updated hardcoded values to apply when resolution completes
|
||||
pendingHardcodedValues: Record<string, unknown>;
|
||||
};
|
||||
|
||||
// Minimum movement (in pixels) required before logging position change to history
|
||||
// Prevents spamming history with small movements when clicking on inputs inside blocks
|
||||
@@ -84,32 +65,12 @@ type NodeStore = {
|
||||
backendId: string,
|
||||
errors: { [key: string]: string },
|
||||
) => void;
|
||||
clearAllNodeErrors: () => void; // Add this
|
||||
|
||||
syncHardcodedValuesWithHandleIds: (nodeId: string) => void;
|
||||
|
||||
// Credentials optional helpers
|
||||
setCredentialsOptional: (nodeId: string, optional: boolean) => void;
|
||||
clearAllNodeErrors: () => void;
|
||||
|
||||
nodesInResolutionMode: Set<string>;
|
||||
brokenEdgeIDs: Map<string, Set<string>>;
|
||||
nodeResolutionData: Map<string, NodeResolutionData>;
|
||||
setNodeResolutionMode: (
|
||||
nodeID: string,
|
||||
inResolution: boolean,
|
||||
resolutionData?: NodeResolutionData,
|
||||
) => void;
|
||||
isNodeInResolutionMode: (nodeID: string) => boolean;
|
||||
getNodeResolutionData: (nodeID: string) => NodeResolutionData | undefined;
|
||||
setBrokenEdgeIDs: (nodeID: string, edgeIDs: string[]) => void;
|
||||
removeBrokenEdgeID: (nodeID: string, edgeID: string) => void;
|
||||
isEdgeBroken: (edgeID: string) => boolean;
|
||||
clearResolutionState: () => void;
|
||||
|
||||
isInputBroken: (nodeID: string, handleID: string) => boolean;
|
||||
getInputTypeMismatch: (
|
||||
nodeID: string,
|
||||
handleID: string,
|
||||
) => string | undefined;
|
||||
};
|
||||
|
||||
export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
@@ -413,99 +374,4 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
|
||||
useHistoryStore.getState().pushState(newState);
|
||||
},
|
||||
|
||||
// Sub-agent resolution mode state
|
||||
nodesInResolutionMode: new Set<string>(),
|
||||
brokenEdgeIDs: new Map<string, Set<string>>(),
|
||||
nodeResolutionData: new Map<string, NodeResolutionData>(),
|
||||
|
||||
setNodeResolutionMode: (
|
||||
nodeID: string,
|
||||
inResolution: boolean,
|
||||
resolutionData?: NodeResolutionData,
|
||||
) => {
|
||||
set((state) => {
|
||||
const newNodesSet = new Set(state.nodesInResolutionMode);
|
||||
const newResolutionDataMap = new Map(state.nodeResolutionData);
|
||||
const newBrokenEdgeIDs = new Map(state.brokenEdgeIDs);
|
||||
|
||||
if (inResolution) {
|
||||
newNodesSet.add(nodeID);
|
||||
if (resolutionData) {
|
||||
newResolutionDataMap.set(nodeID, resolutionData);
|
||||
}
|
||||
} else {
|
||||
newNodesSet.delete(nodeID);
|
||||
newResolutionDataMap.delete(nodeID);
|
||||
newBrokenEdgeIDs.delete(nodeID); // Clean up broken edges when exiting resolution mode
|
||||
}
|
||||
|
||||
return {
|
||||
nodesInResolutionMode: newNodesSet,
|
||||
nodeResolutionData: newResolutionDataMap,
|
||||
brokenEdgeIDs: newBrokenEdgeIDs,
|
||||
};
|
||||
});
|
||||
},
|
||||
|
||||
isNodeInResolutionMode: (nodeID: string) => {
|
||||
return get().nodesInResolutionMode.has(nodeID);
|
||||
},
|
||||
|
||||
getNodeResolutionData: (nodeID: string) => {
|
||||
return get().nodeResolutionData.get(nodeID);
|
||||
},
|
||||
|
||||
setBrokenEdgeIDs: (nodeID: string, edgeIDs: string[]) => {
|
||||
set((state) => {
|
||||
const newMap = new Map(state.brokenEdgeIDs);
|
||||
newMap.set(nodeID, new Set(edgeIDs));
|
||||
return { brokenEdgeIDs: newMap };
|
||||
});
|
||||
},
|
||||
|
||||
removeBrokenEdgeID: (nodeID: string, edgeID: string) => {
|
||||
set((state) => {
|
||||
const newMap = new Map(state.brokenEdgeIDs);
|
||||
const nodeSet = new Set(newMap.get(nodeID) || []);
|
||||
nodeSet.delete(edgeID);
|
||||
newMap.set(nodeID, nodeSet);
|
||||
return { brokenEdgeIDs: newMap };
|
||||
});
|
||||
},
|
||||
|
||||
isEdgeBroken: (edgeID: string) => {
|
||||
// Check across all nodes
|
||||
const brokenEdgeIDs = get().brokenEdgeIDs;
|
||||
for (const edgeSet of brokenEdgeIDs.values()) {
|
||||
if (edgeSet.has(edgeID)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
},
|
||||
|
||||
clearResolutionState: () => {
|
||||
set({
|
||||
nodesInResolutionMode: new Set<string>(),
|
||||
brokenEdgeIDs: new Map<string, Set<string>>(),
|
||||
nodeResolutionData: new Map<string, NodeResolutionData>(),
|
||||
});
|
||||
},
|
||||
|
||||
// Helper functions for input renderers
|
||||
isInputBroken: (nodeID: string, handleID: string) => {
|
||||
const resolutionData = get().nodeResolutionData.get(nodeID);
|
||||
if (!resolutionData) return false;
|
||||
return resolutionData.incompatibilities.missingInputs.includes(handleID);
|
||||
},
|
||||
|
||||
getInputTypeMismatch: (nodeID: string, handleID: string) => {
|
||||
const resolutionData = get().nodeResolutionData.get(nodeID);
|
||||
if (!resolutionData) return undefined;
|
||||
const mismatch = resolutionData.incompatibilities.inputTypeMismatches.find(
|
||||
(m) => m.name === handleID,
|
||||
);
|
||||
return mismatch?.newType;
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -18,7 +18,7 @@ function ErrorPageContent() {
|
||||
) {
|
||||
window.location.href = "/login";
|
||||
} else {
|
||||
window.document.location.reload();
|
||||
window.location.href = "/marketplace";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ export const extendedButtonVariants = cva(
|
||||
),
|
||||
},
|
||||
size: {
|
||||
small: "px-3 py-2 text-sm gap-1.5 h-[2.25rem] min-w-[5.5rem]",
|
||||
small: "px-3 py-2 text-sm gap-1.5 h-[2.25rem]",
|
||||
large: "px-4 py-3 text-sm gap-2 h-[3.25rem]",
|
||||
icon: "p-3 !min-w-0",
|
||||
},
|
||||
|
||||
@@ -30,6 +30,8 @@ export const FormRenderer = ({
|
||||
return generateUiSchemaForCustomFields(preprocessedSchema, uiSchema);
|
||||
}, [preprocessedSchema, uiSchema]);
|
||||
|
||||
console.log("preprocessedSchema", preprocessedSchema);
|
||||
|
||||
return (
|
||||
<div className={"mb-6 mt-4"}>
|
||||
<Form
|
||||
|
||||
@@ -2,12 +2,10 @@ import { FieldProps, getUiOptions, getWidget } from "@rjsf/utils";
|
||||
import { AnyOfFieldTitle } from "./components/AnyOfFieldTitle";
|
||||
import { isEmpty } from "lodash";
|
||||
import { useAnyOfField } from "./useAnyOfField";
|
||||
import { cleanUpHandleId, getHandleId, updateUiOption } from "../../helpers";
|
||||
import { getHandleId, updateUiOption } from "../../helpers";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { ANY_OF_FLAG } from "../../constants";
|
||||
import { findCustomFieldId } from "../../registry";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export const AnyOfField = (props: FieldProps) => {
|
||||
const { registry, schema } = props;
|
||||
@@ -23,8 +21,6 @@ export const AnyOfField = (props: FieldProps) => {
|
||||
field_id,
|
||||
} = useAnyOfField(props);
|
||||
|
||||
const isInputBroken = useNodeStore((state) => state.isInputBroken);
|
||||
|
||||
const parentCustomFieldId = findCustomFieldId(schema);
|
||||
if (parentCustomFieldId) {
|
||||
return null;
|
||||
@@ -47,7 +43,6 @@ export const AnyOfField = (props: FieldProps) => {
|
||||
});
|
||||
|
||||
const isHandleConnected = isInputConnected(nodeId, handleId);
|
||||
const isAnyOfInputBroken = isInputBroken(nodeId, cleanUpHandleId(handleId));
|
||||
|
||||
// Now anyOf can render - custom fields if the option schema matches a custom field
|
||||
const optionCustomFieldId = optionSchema
|
||||
@@ -83,11 +78,7 @@ export const AnyOfField = (props: FieldProps) => {
|
||||
registry={registry}
|
||||
placeholder={props.placeholder}
|
||||
autocomplete={props.autocomplete}
|
||||
className={cn(
|
||||
"-ml-1 h-[22px] w-fit gap-1 px-1 pl-2 text-xs font-medium",
|
||||
isAnyOfInputBroken &&
|
||||
"border-red-500 bg-red-100 text-red-600 line-through",
|
||||
)}
|
||||
className="-ml-1 h-[22px] w-fit gap-1 px-1 pl-2 text-xs font-medium"
|
||||
autofocus={props.autofocus}
|
||||
label=""
|
||||
hideLabel={true}
|
||||
@@ -102,7 +93,7 @@ export const AnyOfField = (props: FieldProps) => {
|
||||
selector={selector}
|
||||
uiSchema={updatedUiSchema}
|
||||
/>
|
||||
{!isHandleConnected && !isAnyOfInputBroken && optionsSchemaField}
|
||||
{!isHandleConnected && optionsSchemaField}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -13,7 +13,6 @@ import { Text } from "@/components/atoms/Text/Text";
|
||||
import { isOptionalType } from "../../../utils/schema-utils";
|
||||
import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
|
||||
interface customFieldProps extends FieldProps {
|
||||
selector: JSX.Element;
|
||||
@@ -52,13 +51,6 @@ export const AnyOfFieldTitle = (props: customFieldProps) => {
|
||||
shouldShowTypeSelector(schema) && !isArrayItem && !isHandleConnected;
|
||||
const shoudlShowType = isHandleConnected || (isOptional && type);
|
||||
|
||||
const isInputBroken = useNodeStore((state) =>
|
||||
state.isInputBroken(nodeId, cleanUpHandleId(uiOptions.handleId)),
|
||||
);
|
||||
const inputMismatch = useNodeStore((state) =>
|
||||
state.getInputTypeMismatch(nodeId, cleanUpHandleId(uiOptions.handleId)),
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-2">
|
||||
<TitleFieldTemplate
|
||||
@@ -70,16 +62,8 @@ export const AnyOfFieldTitle = (props: customFieldProps) => {
|
||||
uiSchema={uiSchema}
|
||||
/>
|
||||
{shoudlShowType && (
|
||||
<Text
|
||||
variant="small"
|
||||
className={cn(
|
||||
"text-zinc-700",
|
||||
isInputBroken && "line-through",
|
||||
colorClass,
|
||||
inputMismatch && "rounded-md bg-red-100 px-1 !text-red-500",
|
||||
)}
|
||||
>
|
||||
{isOptional ? `(${inputMismatch || displayType})` : "(any)"}
|
||||
<Text variant="small" className={cn("text-zinc-700", colorClass)}>
|
||||
{isOptional ? `(${displayType})` : "(any)"}
|
||||
</Text>
|
||||
)}
|
||||
{shouldShowSelector && selector}
|
||||
|
||||
@@ -9,9 +9,8 @@ import { Text } from "@/components/atoms/Text/Text";
|
||||
import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers";
|
||||
import { isAnyOfSchema } from "../../utils/schema-utils";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { cleanUpHandleId, isArrayItem } from "../../helpers";
|
||||
import { isArrayItem } from "../../helpers";
|
||||
import { InputNodeHandle } from "@/app/(platform)/build/components/FlowEditor/handlers/NodeHandle";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
|
||||
export default function TitleField(props: TitleFieldProps) {
|
||||
const { id, title, required, schema, registry, uiSchema } = props;
|
||||
@@ -27,11 +26,6 @@ export default function TitleField(props: TitleFieldProps) {
|
||||
const smallText = isArrayItemFlag || additional;
|
||||
|
||||
const showHandle = uiOptions.showHandles ?? showHandles;
|
||||
|
||||
const isInputBroken = useNodeStore((state) =>
|
||||
state.isInputBroken(nodeId, cleanUpHandleId(uiOptions.handleId)),
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex items-center">
|
||||
{showHandle !== false && (
|
||||
@@ -40,11 +34,7 @@ export default function TitleField(props: TitleFieldProps) {
|
||||
<Text
|
||||
variant={isArrayItemFlag ? "small" : "body"}
|
||||
id={id}
|
||||
className={cn(
|
||||
"line-clamp-1",
|
||||
smallText && "text-sm text-zinc-700",
|
||||
isInputBroken && "text-red-500 line-through",
|
||||
)}
|
||||
className={cn("line-clamp-1", smallText && "text-sm text-zinc-700")}
|
||||
>
|
||||
{title}
|
||||
</Text>
|
||||
@@ -54,7 +44,7 @@ export default function TitleField(props: TitleFieldProps) {
|
||||
{!isAnyOf && (
|
||||
<Text
|
||||
variant="small"
|
||||
className={cn("ml-2", isInputBroken && "line-through", colorClass)}
|
||||
className={cn("ml-2", colorClass)}
|
||||
id={description_id}
|
||||
>
|
||||
({displayType})
|
||||
|
||||
@@ -30,8 +30,6 @@ export function updateUiOption<T extends Record<string, any>>(
|
||||
}
|
||||
|
||||
export const cleanUpHandleId = (handleId: string) => {
|
||||
if (!handleId) return "";
|
||||
|
||||
let newHandleId = handleId;
|
||||
if (handleId.includes(ANY_OF_FLAG)) {
|
||||
newHandleId = newHandleId.replace(ANY_OF_FLAG, "");
|
||||
|
||||
@@ -233,14 +233,13 @@ export default function useAgentGraph(
|
||||
title: `${block.name} ${node.id}`,
|
||||
inputSchema: block.inputSchema,
|
||||
outputSchema: block.outputSchema,
|
||||
isOutputStatic: block.staticOutput,
|
||||
hardcodedValues: node.input_default,
|
||||
uiType: block.uiType,
|
||||
metadata: metadata,
|
||||
connections: graph.links
|
||||
.filter((l) => [l.source_id, l.sink_id].includes(node.id))
|
||||
.map((link) => ({
|
||||
id: formatEdgeID(link),
|
||||
edge_id: formatEdgeID(link),
|
||||
source: link.source_id,
|
||||
sourceHandle: link.source_name,
|
||||
target: link.sink_id,
|
||||
|
||||
@@ -245,8 +245,8 @@ export type BlockIONullSubSchema = BlockIOSubSchemaMeta & {
|
||||
// At the time of writing, combined schemas only occur on the first nested level in a
|
||||
// block schema. It is typed this way to make the use of these objects less tedious.
|
||||
type BlockIOCombinedTypeSubSchema = BlockIOSubSchemaMeta & {
|
||||
type?: never;
|
||||
const?: never;
|
||||
type: never;
|
||||
const: never;
|
||||
} & (
|
||||
| {
|
||||
allOf: [BlockIOSimpleTypeSubSchema];
|
||||
@@ -368,8 +368,8 @@ export type GraphMeta = {
|
||||
recommended_schedule_cron: string | null;
|
||||
forked_from_id?: GraphID | null;
|
||||
forked_from_version?: number | null;
|
||||
input_schema: GraphInputSchema;
|
||||
output_schema: GraphOutputSchema;
|
||||
input_schema: GraphIOSchema;
|
||||
output_schema: GraphIOSchema;
|
||||
credentials_input_schema: CredentialsInputSchema;
|
||||
} & (
|
||||
| {
|
||||
@@ -385,51 +385,19 @@ export type GraphMeta = {
|
||||
export type GraphID = Brand<string, "GraphID">;
|
||||
|
||||
/* Derived from backend/data/graph.py:Graph._generate_schema() */
|
||||
export type GraphInputSchema = {
|
||||
export type GraphIOSchema = {
|
||||
type: "object";
|
||||
properties: Record<string, GraphInputSubSchema>;
|
||||
required: (keyof GraphInputSchema["properties"])[];
|
||||
properties: Record<string, GraphIOSubSchema>;
|
||||
required: (keyof BlockIORootSchema["properties"])[];
|
||||
};
|
||||
export type GraphInputSubSchema = GraphOutputSubSchema &
|
||||
(
|
||||
| { type?: never; default: any | null } // AgentInputBlock (generic Any type)
|
||||
| { type: "string"; format: "short-text"; default: string | null } // AgentShortTextInputBlock
|
||||
| { type: "string"; format: "long-text"; default: string | null } // AgentLongTextInputBlock
|
||||
| { type: "integer"; default: number | null } // AgentNumberInputBlock
|
||||
| { type: "string"; format: "date"; default: string | null } // AgentDateInputBlock
|
||||
| { type: "string"; format: "time"; default: string | null } // AgentTimeInputBlock
|
||||
| { type: "string"; format: "file"; default: string | null } // AgentFileInputBlock
|
||||
| { type: "string"; enum: string[]; default: string | null } // AgentDropdownInputBlock
|
||||
| { type: "boolean"; default: boolean } // AgentToggleInputBlock
|
||||
| {
|
||||
// AgentTableInputBlock
|
||||
type: "array";
|
||||
format: "table";
|
||||
items: {
|
||||
type: "object";
|
||||
properties: Record<string, { type: "string" }>;
|
||||
};
|
||||
default: Array<Record<string, string>> | null;
|
||||
}
|
||||
| {
|
||||
// AgentGoogleDriveFileInputBlock
|
||||
type: "object";
|
||||
format: "google-drive-picker";
|
||||
google_drive_picker_config?: GoogleDrivePickerConfig;
|
||||
default: GoogleDriveFile | null;
|
||||
}
|
||||
);
|
||||
export type GraphOutputSchema = {
|
||||
type: "object";
|
||||
properties: Record<string, GraphOutputSubSchema>;
|
||||
required: (keyof GraphOutputSchema["properties"])[];
|
||||
};
|
||||
export type GraphOutputSubSchema = {
|
||||
// TODO: typed outputs based on the incoming edges?
|
||||
title: string;
|
||||
description?: string;
|
||||
advanced: boolean;
|
||||
export type GraphIOSubSchema = Omit<
|
||||
BlockIOSubSchemaMeta,
|
||||
"placeholder" | "depends_on" | "hidden"
|
||||
> & {
|
||||
type: never; // bodge to avoid type checking hell; doesn't exist at runtime
|
||||
default?: string;
|
||||
secret: boolean;
|
||||
metadata?: any;
|
||||
};
|
||||
|
||||
export type CredentialsInputSchema = {
|
||||
@@ -472,8 +440,8 @@ export type GraphUpdateable = Omit<
|
||||
is_active?: boolean;
|
||||
nodes: NodeCreatable[];
|
||||
links: LinkCreatable[];
|
||||
input_schema?: GraphInputSchema;
|
||||
output_schema?: GraphOutputSchema;
|
||||
input_schema?: GraphIOSchema;
|
||||
output_schema?: GraphIOSchema;
|
||||
};
|
||||
|
||||
export type GraphCreatable = _GraphCreatableInner & {
|
||||
@@ -529,8 +497,8 @@ export type LibraryAgent = {
|
||||
name: string;
|
||||
description: string;
|
||||
instructions?: string | null;
|
||||
input_schema: GraphInputSchema;
|
||||
output_schema: GraphOutputSchema;
|
||||
input_schema: GraphIOSchema;
|
||||
output_schema: GraphIOSchema;
|
||||
credentials_input_schema: CredentialsInputSchema;
|
||||
new_output: boolean;
|
||||
can_access_graph: boolean;
|
||||
|
||||
@@ -6,10 +6,7 @@ import { NodeDimension } from "@/app/(platform)/build/components/legacy-builder/
|
||||
import {
|
||||
BlockIOObjectSubSchema,
|
||||
BlockIORootSchema,
|
||||
BlockIOSubSchema,
|
||||
Category,
|
||||
GraphInputSubSchema,
|
||||
GraphOutputSubSchema,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
|
||||
export function cn(...inputs: ClassValue[]) {
|
||||
@@ -79,8 +76,8 @@ export function getTypeBgColor(type: string | null): string {
|
||||
);
|
||||
}
|
||||
|
||||
export function getTypeColor(type: string | undefined): string {
|
||||
if (!type) return "#6b7280";
|
||||
export function getTypeColor(type: string | null): string {
|
||||
if (type === null) return "#6b7280";
|
||||
return (
|
||||
{
|
||||
string: "#22c55e",
|
||||
@@ -91,59 +88,11 @@ export function getTypeColor(type: string | undefined): string {
|
||||
array: "#6366f1",
|
||||
null: "#6b7280",
|
||||
any: "#6b7280",
|
||||
"": "#6b7280",
|
||||
}[type] || "#6b7280"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the effective type from a JSON schema, handling anyOf/oneOf/allOf wrappers.
|
||||
* Returns the first non-null type found in the schema structure.
|
||||
*/
|
||||
export function getEffectiveType(
|
||||
schema:
|
||||
| BlockIOSubSchema
|
||||
| GraphInputSubSchema
|
||||
| GraphOutputSubSchema
|
||||
| null
|
||||
| undefined,
|
||||
): string | undefined {
|
||||
if (!schema) return undefined;
|
||||
|
||||
// Direct type property
|
||||
if ("type" in schema && schema.type) {
|
||||
return String(schema.type);
|
||||
}
|
||||
|
||||
// Handle allOf - typically a single-item wrapper
|
||||
if (
|
||||
"allOf" in schema &&
|
||||
Array.isArray(schema.allOf) &&
|
||||
schema.allOf.length > 0
|
||||
) {
|
||||
return getEffectiveType(schema.allOf[0]);
|
||||
}
|
||||
|
||||
// Handle anyOf - e.g. [{ type: "string" }, { type: "null" }]
|
||||
if ("anyOf" in schema && Array.isArray(schema.anyOf)) {
|
||||
for (const item of schema.anyOf) {
|
||||
if ("type" in item && item.type !== "null") {
|
||||
return String(item.type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle oneOf
|
||||
if ("oneOf" in schema && Array.isArray(schema.oneOf)) {
|
||||
for (const item of schema.oneOf) {
|
||||
if ("type" in item && item.type !== "null") {
|
||||
return String(item.type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export function beautifyString(name: string): string {
|
||||
// Regular expression to identify places to split, considering acronyms
|
||||
const result = name
|
||||
|
||||
Reference in New Issue
Block a user