mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-15 01:58:23 -05:00
Compare commits
3 Commits
fix/creden
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ac941fe2f | ||
|
|
b01ea3fcbd | ||
|
|
3b09a94e3f |
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -176,7 +176,7 @@ jobs:
|
|||||||
}
|
}
|
||||||
|
|
||||||
- name: Run Database Migrations
|
- name: Run Database Migrations
|
||||||
run: poetry run prisma migrate dev --name updates
|
run: poetry run prisma migrate deploy
|
||||||
env:
|
env:
|
||||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||||
|
|||||||
9
.github/workflows/platform-frontend-ci.yml
vendored
9
.github/workflows/platform-frontend-ci.yml
vendored
@@ -11,6 +11,7 @@ on:
|
|||||||
- ".github/workflows/platform-frontend-ci.yml"
|
- ".github/workflows/platform-frontend-ci.yml"
|
||||||
- "autogpt_platform/frontend/**"
|
- "autogpt_platform/frontend/**"
|
||||||
merge_group:
|
merge_group:
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }}
|
group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }}
|
||||||
@@ -151,6 +152,14 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
cp ../.env.default ../.env
|
cp ../.env.default ../.env
|
||||||
|
|
||||||
|
- name: Copy backend .env and set OpenAI API key
|
||||||
|
run: |
|
||||||
|
cp ../backend/.env.default ../backend/.env
|
||||||
|
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||||
|
env:
|
||||||
|
# Used by E2E test data script to generate embeddings for approved store agents
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
|||||||
1
autogpt_platform/backend/.gitignore
vendored
1
autogpt_platform/backend/.gitignore
vendored
@@ -18,3 +18,4 @@ load-tests/results/
|
|||||||
load-tests/*.json
|
load-tests/*.json
|
||||||
load-tests/*.log
|
load-tests/*.log
|
||||||
load-tests/node_modules/*
|
load-tests/node_modules/*
|
||||||
|
migrations/*/rollback*.sql
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
import pytest
|
import pytest
|
||||||
@@ -17,6 +18,17 @@ setup_test_data = setup_test_data
|
|||||||
setup_firecrawl_test_data = setup_firecrawl_test_data
|
setup_firecrawl_test_data = setup_firecrawl_test_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def mock_embedding_functions():
|
||||||
|
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.db.ensure_embedding",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(scope="session")
|
||||||
async def test_run_agent(setup_test_data):
|
async def test_run_agent(setup_test_data):
|
||||||
"""Test that the run_agent tool successfully executes an approved agent"""
|
"""Test that the run_agent tool successfully executes an approved agent"""
|
||||||
|
|||||||
@@ -35,11 +35,7 @@ from backend.data.model import (
|
|||||||
OAuth2Credentials,
|
OAuth2Credentials,
|
||||||
UserIntegrations,
|
UserIntegrations,
|
||||||
)
|
)
|
||||||
from backend.data.onboarding import (
|
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||||
OnboardingStep,
|
|
||||||
complete_onboarding_step,
|
|
||||||
increment_runs,
|
|
||||||
)
|
|
||||||
from backend.data.user import get_user_integrations
|
from backend.data.user import get_user_integrations
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||||
@@ -378,7 +374,6 @@ async def webhook_ingress_generic(
|
|||||||
return
|
return
|
||||||
|
|
||||||
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
|
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
|
||||||
await increment_runs(user_id)
|
|
||||||
|
|
||||||
# Execute all triggers concurrently for better performance
|
# Execute all triggers concurrently for better performance
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from backend.data.execution import GraphExecutionMeta
|
|||||||
from backend.data.graph import get_graph
|
from backend.data.graph import get_graph
|
||||||
from backend.data.integrations import get_webhook
|
from backend.data.integrations import get_webhook
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.data.onboarding import increment_runs
|
|
||||||
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks import get_webhook_manager
|
from backend.integrations.webhooks import get_webhook_manager
|
||||||
@@ -403,8 +402,6 @@ async def execute_preset(
|
|||||||
merged_node_input = preset.inputs | inputs
|
merged_node_input = preset.inputs | inputs
|
||||||
merged_credential_inputs = preset.credentials | credential_inputs
|
merged_credential_inputs = preset.credentials | credential_inputs
|
||||||
|
|
||||||
await increment_runs(user_id)
|
|
||||||
|
|
||||||
return await add_graph_execution(
|
return await add_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
graph_id=preset.graph_id,
|
graph_id=preset.graph_id,
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import typing
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
@@ -10,7 +9,7 @@ import prisma.errors
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import prisma.types
|
import prisma.types
|
||||||
|
|
||||||
from backend.data.db import query_raw_with_schema, transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.graph import (
|
from backend.data.graph import (
|
||||||
GraphMeta,
|
GraphMeta,
|
||||||
GraphModel,
|
GraphModel,
|
||||||
@@ -30,6 +29,8 @@ from backend.util.settings import Settings
|
|||||||
|
|
||||||
from . import exceptions as store_exceptions
|
from . import exceptions as store_exceptions
|
||||||
from . import model as store_model
|
from . import model as store_model
|
||||||
|
from .embeddings import ensure_embedding
|
||||||
|
from .hybrid_search import hybrid_search
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
@@ -50,128 +51,77 @@ async def get_store_agents(
|
|||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
) -> store_model.StoreAgentsResponse:
|
) -> store_model.StoreAgentsResponse:
|
||||||
"""
|
"""
|
||||||
Get PUBLIC store agents from the StoreAgent view
|
Get PUBLIC store agents from the StoreAgent view.
|
||||||
|
|
||||||
|
Search behavior:
|
||||||
|
- With search_query: Uses hybrid search (semantic + lexical)
|
||||||
|
- Fallback: If embeddings unavailable, gracefully degrades to lexical-only
|
||||||
|
- Rationale: User-facing endpoint prioritizes availability over accuracy
|
||||||
|
|
||||||
|
Note: Admin operations (approval) use fail-fast to prevent inconsistent state.
|
||||||
"""
|
"""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
search_used_hybrid = False
|
||||||
|
store_agents: list[store_model.StoreAgent] = []
|
||||||
|
agents: list[dict[str, Any]] = []
|
||||||
|
total = 0
|
||||||
|
total_pages = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# If search_query is provided, use full-text search
|
# If search_query is provided, use hybrid search (embeddings + tsvector)
|
||||||
if search_query:
|
if search_query:
|
||||||
offset = (page - 1) * page_size
|
# Try hybrid search combining semantic and lexical signals
|
||||||
|
# Falls back to lexical-only if OpenAI unavailable (user-facing, high SLA)
|
||||||
|
try:
|
||||||
|
agents, total = await hybrid_search(
|
||||||
|
query=search_query,
|
||||||
|
featured=featured,
|
||||||
|
creators=creators,
|
||||||
|
category=category,
|
||||||
|
sorted_by="relevance", # Use hybrid scoring for relevance
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
search_used_hybrid = True
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but fall back to lexical search for better UX
|
||||||
|
logger.error(
|
||||||
|
f"Hybrid search failed (likely OpenAI unavailable), "
|
||||||
|
f"falling back to lexical search: {e}"
|
||||||
|
)
|
||||||
|
# search_used_hybrid remains False, will use fallback path below
|
||||||
|
|
||||||
# Whitelist allowed order_by columns
|
# Convert hybrid search results (dict format) if hybrid succeeded
|
||||||
ALLOWED_ORDER_BY = {
|
if search_used_hybrid:
|
||||||
"rating": "rating DESC, rank DESC",
|
total_pages = (total + page_size - 1) // page_size
|
||||||
"runs": "runs DESC, rank DESC",
|
store_agents: list[store_model.StoreAgent] = []
|
||||||
"name": "agent_name ASC, rank ASC",
|
for agent in agents:
|
||||||
"updated_at": "updated_at DESC, rank DESC",
|
try:
|
||||||
}
|
store_agent = store_model.StoreAgent(
|
||||||
|
slug=agent["slug"],
|
||||||
|
agent_name=agent["agent_name"],
|
||||||
|
agent_image=(
|
||||||
|
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||||
|
),
|
||||||
|
creator=agent["creator_username"] or "Needs Profile",
|
||||||
|
creator_avatar=agent["creator_avatar"] or "",
|
||||||
|
sub_heading=agent["sub_heading"],
|
||||||
|
description=agent["description"],
|
||||||
|
runs=agent["runs"],
|
||||||
|
rating=agent["rating"],
|
||||||
|
)
|
||||||
|
store_agents.append(store_agent)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error parsing Store agent from hybrid search results: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
# Validate and get order clause
|
if not search_used_hybrid:
|
||||||
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
# Fallback path - use basic search or no search
|
||||||
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
|
||||||
else:
|
|
||||||
order_by_clause = "updated_at DESC, rank DESC"
|
|
||||||
|
|
||||||
# Build WHERE conditions and parameters list
|
|
||||||
where_parts: list[str] = []
|
|
||||||
params: list[typing.Any] = [search_query] # $1 - search term
|
|
||||||
param_index = 2 # Start at $2 for next parameter
|
|
||||||
|
|
||||||
# Always filter for available agents
|
|
||||||
where_parts.append("is_available = true")
|
|
||||||
|
|
||||||
if featured:
|
|
||||||
where_parts.append("featured = true")
|
|
||||||
|
|
||||||
if creators and creators:
|
|
||||||
# Use ANY with array parameter
|
|
||||||
where_parts.append(f"creator_username = ANY(${param_index})")
|
|
||||||
params.append(creators)
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
if category and category:
|
|
||||||
where_parts.append(f"${param_index} = ANY(categories)")
|
|
||||||
params.append(category)
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
|
||||||
|
|
||||||
# Add pagination params
|
|
||||||
params.extend([page_size, offset])
|
|
||||||
limit_param = f"${param_index}"
|
|
||||||
offset_param = f"${param_index + 1}"
|
|
||||||
|
|
||||||
# Execute full-text search query with parameterized values
|
|
||||||
sql_query = f"""
|
|
||||||
SELECT
|
|
||||||
slug,
|
|
||||||
agent_name,
|
|
||||||
agent_image,
|
|
||||||
creator_username,
|
|
||||||
creator_avatar,
|
|
||||||
sub_heading,
|
|
||||||
description,
|
|
||||||
runs,
|
|
||||||
rating,
|
|
||||||
categories,
|
|
||||||
featured,
|
|
||||||
is_available,
|
|
||||||
updated_at,
|
|
||||||
ts_rank_cd(search, query) AS rank
|
|
||||||
FROM {{schema_prefix}}"StoreAgent",
|
|
||||||
plainto_tsquery('english', $1) AS query
|
|
||||||
WHERE {sql_where_clause}
|
|
||||||
AND search @@ query
|
|
||||||
ORDER BY {order_by_clause}
|
|
||||||
LIMIT {limit_param} OFFSET {offset_param}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Count query for pagination - only uses search term parameter
|
|
||||||
count_query = f"""
|
|
||||||
SELECT COUNT(*) as count
|
|
||||||
FROM {{schema_prefix}}"StoreAgent",
|
|
||||||
plainto_tsquery('english', $1) AS query
|
|
||||||
WHERE {sql_where_clause}
|
|
||||||
AND search @@ query
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Execute both queries with parameters
|
|
||||||
agents = await query_raw_with_schema(sql_query, *params)
|
|
||||||
|
|
||||||
# For count, use params without pagination (last 2 params)
|
|
||||||
count_params = params[:-2]
|
|
||||||
count_result = await query_raw_with_schema(count_query, *count_params)
|
|
||||||
|
|
||||||
total = count_result[0]["count"] if count_result else 0
|
|
||||||
total_pages = (total + page_size - 1) // page_size
|
|
||||||
|
|
||||||
# Convert raw results to StoreAgent models
|
|
||||||
store_agents: list[store_model.StoreAgent] = []
|
|
||||||
for agent in agents:
|
|
||||||
try:
|
|
||||||
store_agent = store_model.StoreAgent(
|
|
||||||
slug=agent["slug"],
|
|
||||||
agent_name=agent["agent_name"],
|
|
||||||
agent_image=(
|
|
||||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
|
||||||
),
|
|
||||||
creator=agent["creator_username"] or "Needs Profile",
|
|
||||||
creator_avatar=agent["creator_avatar"] or "",
|
|
||||||
sub_heading=agent["sub_heading"],
|
|
||||||
description=agent["description"],
|
|
||||||
runs=agent["runs"],
|
|
||||||
rating=agent["rating"],
|
|
||||||
)
|
|
||||||
store_agents.append(store_agent)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error parsing Store agent from search results: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Non-search query path (original logic)
|
|
||||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||||
if featured:
|
if featured:
|
||||||
where_clause["featured"] = featured
|
where_clause["featured"] = featured
|
||||||
@@ -180,6 +130,14 @@ async def get_store_agents(
|
|||||||
if category:
|
if category:
|
||||||
where_clause["categories"] = {"has": category}
|
where_clause["categories"] = {"has": category}
|
||||||
|
|
||||||
|
# Add basic text search if search_query provided but hybrid failed
|
||||||
|
if search_query:
|
||||||
|
where_clause["OR"] = [
|
||||||
|
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||||
|
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
|
||||||
|
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||||
|
]
|
||||||
|
|
||||||
order_by = []
|
order_by = []
|
||||||
if sorted_by == "rating":
|
if sorted_by == "rating":
|
||||||
order_by.append({"rating": "desc"})
|
order_by.append({"rating": "desc"})
|
||||||
@@ -188,7 +146,7 @@ async def get_store_agents(
|
|||||||
elif sorted_by == "name":
|
elif sorted_by == "name":
|
||||||
order_by.append({"agent_name": "asc"})
|
order_by.append({"agent_name": "asc"})
|
||||||
|
|
||||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||||
where=where_clause,
|
where=where_clause,
|
||||||
order=order_by,
|
order=order_by,
|
||||||
skip=(page - 1) * page_size,
|
skip=(page - 1) * page_size,
|
||||||
@@ -199,7 +157,7 @@ async def get_store_agents(
|
|||||||
total_pages = (total + page_size - 1) // page_size
|
total_pages = (total + page_size - 1) // page_size
|
||||||
|
|
||||||
store_agents: list[store_model.StoreAgent] = []
|
store_agents: list[store_model.StoreAgent] = []
|
||||||
for agent in agents:
|
for agent in db_agents:
|
||||||
try:
|
try:
|
||||||
# Create the StoreAgent object safely
|
# Create the StoreAgent object safely
|
||||||
store_agent = store_model.StoreAgent(
|
store_agent = store_model.StoreAgent(
|
||||||
@@ -1577,7 +1535,7 @@ async def review_store_submission(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update the AgentGraph with store listing data
|
# Update the AgentGraph with store listing data
|
||||||
await prisma.models.AgentGraph.prisma().update(
|
await prisma.models.AgentGraph.prisma(tx).update(
|
||||||
where={
|
where={
|
||||||
"graphVersionId": {
|
"graphVersionId": {
|
||||||
"id": store_listing_version.agentGraphId,
|
"id": store_listing_version.agentGraphId,
|
||||||
@@ -1592,6 +1550,23 @@ async def review_store_submission(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Generate embedding for approved listing (blocking - admin operation)
|
||||||
|
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||||
|
embedding_success = await ensure_embedding(
|
||||||
|
version_id=store_listing_version_id,
|
||||||
|
name=store_listing_version.name,
|
||||||
|
description=store_listing_version.description,
|
||||||
|
sub_heading=store_listing_version.subHeading,
|
||||||
|
categories=store_listing_version.categories or [],
|
||||||
|
tx=tx,
|
||||||
|
)
|
||||||
|
if not embedding_success:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to generate embedding for listing {store_listing_version_id}. "
|
||||||
|
"This is likely due to OpenAI API being unavailable. "
|
||||||
|
"Please try again later or contact support if the issue persists."
|
||||||
|
)
|
||||||
|
|
||||||
await prisma.models.StoreListing.prisma(tx).update(
|
await prisma.models.StoreListing.prisma(tx).update(
|
||||||
where={"id": store_listing_version.StoreListing.id},
|
where={"id": store_listing_version.StoreListing.id},
|
||||||
data={
|
data={
|
||||||
|
|||||||
@@ -0,0 +1,568 @@
|
|||||||
|
"""
|
||||||
|
Unified Content Embeddings Service
|
||||||
|
|
||||||
|
Handles generation and storage of OpenAI embeddings for all content types
|
||||||
|
(store listings, blocks, documentation, library agents) to enable semantic/hybrid search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import prisma
|
||||||
|
from prisma.enums import ContentType
|
||||||
|
from tiktoken import encoding_for_model
|
||||||
|
|
||||||
|
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||||
|
from backend.util.clients import get_openai_client
|
||||||
|
from backend.util.json import dumps
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# OpenAI embedding model configuration
|
||||||
|
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||||
|
# OpenAI embedding token limit (8,191 with 1 token buffer for safety)
|
||||||
|
EMBEDDING_MAX_TOKENS = 8191
|
||||||
|
|
||||||
|
|
||||||
|
def build_searchable_text(
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
sub_heading: str,
|
||||||
|
categories: list[str],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Build searchable text from listing version fields.
|
||||||
|
|
||||||
|
Combines relevant fields into a single string for embedding.
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
# Name is important - include it
|
||||||
|
if name:
|
||||||
|
parts.append(name)
|
||||||
|
|
||||||
|
# Sub-heading provides context
|
||||||
|
if sub_heading:
|
||||||
|
parts.append(sub_heading)
|
||||||
|
|
||||||
|
# Description is the main content
|
||||||
|
if description:
|
||||||
|
parts.append(description)
|
||||||
|
|
||||||
|
# Categories help with semantic matching
|
||||||
|
if categories:
|
||||||
|
parts.append(" ".join(categories))
|
||||||
|
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_embedding(text: str) -> list[float] | None:
|
||||||
|
"""
|
||||||
|
Generate embedding for text using OpenAI API.
|
||||||
|
|
||||||
|
Returns None if embedding generation fails.
|
||||||
|
Fail-fast: no retries to maintain consistency with approval flow.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = get_openai_client()
|
||||||
|
if not client:
|
||||||
|
logger.error("openai_internal_api_key not set, cannot generate embedding")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Truncate text to token limit using tiktoken
|
||||||
|
# Character-based truncation is insufficient because token ratios vary by content type
|
||||||
|
enc = encoding_for_model(EMBEDDING_MODEL)
|
||||||
|
tokens = enc.encode(text)
|
||||||
|
if len(tokens) > EMBEDDING_MAX_TOKENS:
|
||||||
|
tokens = tokens[:EMBEDDING_MAX_TOKENS]
|
||||||
|
truncated_text = enc.decode(tokens)
|
||||||
|
logger.info(
|
||||||
|
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
truncated_text = text
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = await client.embeddings.create(
|
||||||
|
model=EMBEDDING_MODEL,
|
||||||
|
input=truncated_text,
|
||||||
|
)
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
embedding = response.data[0].embedding
|
||||||
|
logger.info(
|
||||||
|
f"Generated embedding: {len(embedding)} dims, "
|
||||||
|
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate embedding: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def store_embedding(
|
||||||
|
version_id: str,
|
||||||
|
embedding: list[float],
|
||||||
|
tx: prisma.Prisma | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Store embedding in the database.
|
||||||
|
|
||||||
|
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||||
|
DEPRECATED: Use ensure_embedding() instead (includes searchable_text).
|
||||||
|
"""
|
||||||
|
return await store_content_embedding(
|
||||||
|
content_type=ContentType.STORE_AGENT,
|
||||||
|
content_id=version_id,
|
||||||
|
embedding=embedding,
|
||||||
|
searchable_text="", # Empty for backward compat; ensure_embedding() populates this
|
||||||
|
metadata=None,
|
||||||
|
user_id=None, # Store agents are public
|
||||||
|
tx=tx,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def store_content_embedding(
|
||||||
|
content_type: ContentType,
|
||||||
|
content_id: str,
|
||||||
|
embedding: list[float],
|
||||||
|
searchable_text: str,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
tx: prisma.Prisma | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Store embedding in the unified content embeddings table.
|
||||||
|
|
||||||
|
New function for unified content embedding storage.
|
||||||
|
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = tx if tx else prisma.get_client()
|
||||||
|
|
||||||
|
# Convert embedding to PostgreSQL vector format
|
||||||
|
embedding_str = embedding_to_vector_string(embedding)
|
||||||
|
metadata_json = dumps(metadata or {})
|
||||||
|
|
||||||
|
# Upsert the embedding
|
||||||
|
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
||||||
|
await execute_raw_with_schema(
|
||||||
|
"""
|
||||||
|
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
||||||
|
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||||
|
)
|
||||||
|
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
||||||
|
ON CONFLICT ("contentType", "contentId", "userId")
|
||||||
|
DO UPDATE SET
|
||||||
|
"embedding" = $4::vector,
|
||||||
|
"searchableText" = $5,
|
||||||
|
"metadata" = $6::jsonb,
|
||||||
|
"updatedAt" = NOW()
|
||||||
|
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
|
||||||
|
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
|
||||||
|
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
|
||||||
|
""",
|
||||||
|
content_type,
|
||||||
|
content_id,
|
||||||
|
user_id,
|
||||||
|
embedding_str,
|
||||||
|
searchable_text,
|
||||||
|
metadata_json,
|
||||||
|
client=client,
|
||||||
|
set_public_search_path=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||||
|
"""
|
||||||
|
Retrieve embedding record for a listing version.
|
||||||
|
|
||||||
|
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||||
|
Returns dict with storeListingVersionId, embedding, timestamps or None if not found.
|
||||||
|
"""
|
||||||
|
result = await get_content_embedding(
|
||||||
|
ContentType.STORE_AGENT, version_id, user_id=None
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
# Transform to old format for backward compatibility
|
||||||
|
return {
|
||||||
|
"storeListingVersionId": result["contentId"],
|
||||||
|
"embedding": result["embedding"],
|
||||||
|
"createdAt": result["createdAt"],
|
||||||
|
"updatedAt": result["updatedAt"],
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_content_embedding(
|
||||||
|
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""
|
||||||
|
Retrieve embedding record for any content type.
|
||||||
|
|
||||||
|
New function for unified content embedding retrieval.
|
||||||
|
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await query_raw_with_schema(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
"contentType",
|
||||||
|
"contentId",
|
||||||
|
"userId",
|
||||||
|
"embedding"::text as "embedding",
|
||||||
|
"searchableText",
|
||||||
|
"metadata",
|
||||||
|
"createdAt",
|
||||||
|
"updatedAt"
|
||||||
|
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||||
|
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||||
|
""",
|
||||||
|
content_type,
|
||||||
|
content_id,
|
||||||
|
user_id,
|
||||||
|
set_public_search_path=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result and len(result) > 0:
|
||||||
|
return result[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_embedding(
|
||||||
|
version_id: str,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
sub_heading: str,
|
||||||
|
categories: list[str],
|
||||||
|
force: bool = False,
|
||||||
|
tx: prisma.Prisma | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Ensure an embedding exists for the listing version.
|
||||||
|
|
||||||
|
Creates embedding if missing. Use force=True to regenerate.
|
||||||
|
Backward-compatible wrapper for store listings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version_id: The StoreListingVersion ID
|
||||||
|
name: Agent name
|
||||||
|
description: Agent description
|
||||||
|
sub_heading: Agent sub-heading
|
||||||
|
categories: Agent categories
|
||||||
|
force: Force regeneration even if embedding exists
|
||||||
|
tx: Optional transaction client
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if embedding exists/was created, False on failure
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if embedding already exists
|
||||||
|
if not force:
|
||||||
|
existing = await get_embedding(version_id)
|
||||||
|
if existing and existing.get("embedding"):
|
||||||
|
logger.debug(f"Embedding for version {version_id} already exists")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Build searchable text for embedding
|
||||||
|
searchable_text = build_searchable_text(
|
||||||
|
name, description, sub_heading, categories
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate new embedding
|
||||||
|
embedding = await generate_embedding(searchable_text)
|
||||||
|
if embedding is None:
|
||||||
|
logger.warning(f"Could not generate embedding for version {version_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Store the embedding with metadata using new function
|
||||||
|
metadata = {
|
||||||
|
"name": name,
|
||||||
|
"subHeading": sub_heading,
|
||||||
|
"categories": categories,
|
||||||
|
}
|
||||||
|
return await store_content_embedding(
|
||||||
|
content_type=ContentType.STORE_AGENT,
|
||||||
|
content_id=version_id,
|
||||||
|
embedding=embedding,
|
||||||
|
searchable_text=searchable_text,
|
||||||
|
metadata=metadata,
|
||||||
|
user_id=None, # Store agents are public
|
||||||
|
tx=tx,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_embedding(version_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete embedding for a listing version.
|
||||||
|
|
||||||
|
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||||
|
Note: This is usually handled automatically by CASCADE delete,
|
||||||
|
but provided for manual cleanup if needed.
|
||||||
|
"""
|
||||||
|
return await delete_content_embedding(ContentType.STORE_AGENT, version_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_content_embedding(
|
||||||
|
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Delete embedding for any content type.
|
||||||
|
|
||||||
|
New function for unified content embedding deletion.
|
||||||
|
Note: This is usually handled automatically by CASCADE delete,
|
||||||
|
but provided for manual cleanup if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_type: The type of content (STORE_AGENT, LIBRARY_AGENT, etc.)
|
||||||
|
content_id: The unique identifier for the content
|
||||||
|
user_id: Optional user ID. For public content (STORE_AGENT, BLOCK), pass None.
|
||||||
|
For user-scoped content (LIBRARY_AGENT), pass the user's ID to avoid
|
||||||
|
deleting embeddings belonging to other users.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deletion succeeded, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = prisma.get_client()
|
||||||
|
|
||||||
|
await execute_raw_with_schema(
|
||||||
|
"""
|
||||||
|
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||||
|
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
||||||
|
AND "contentId" = $2
|
||||||
|
AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||||
|
""",
|
||||||
|
content_type,
|
||||||
|
content_id,
|
||||||
|
user_id,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_str = f" (user: {user_id})" if user_id else ""
|
||||||
|
logger.info(f"Deleted embedding for {content_type}:{content_id}{user_str}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete embedding for {content_type}:{content_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embedding_stats() -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get statistics about embedding coverage.
|
||||||
|
|
||||||
|
Returns counts of:
|
||||||
|
- Total approved listing versions
|
||||||
|
- Versions with embeddings
|
||||||
|
- Versions without embeddings
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Count approved versions
|
||||||
|
approved_result = await query_raw_with_schema(
|
||||||
|
"""
|
||||||
|
SELECT COUNT(*) as count
|
||||||
|
FROM {schema_prefix}"StoreListingVersion"
|
||||||
|
WHERE "submissionStatus" = 'APPROVED'
|
||||||
|
AND "isDeleted" = false
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||||
|
|
||||||
|
# Count versions with embeddings
|
||||||
|
embedded_result = await query_raw_with_schema(
|
||||||
|
"""
|
||||||
|
SELECT COUNT(*) as count
|
||||||
|
FROM {schema_prefix}"StoreListingVersion" slv
|
||||||
|
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||||
|
WHERE slv."submissionStatus" = 'APPROVED'
|
||||||
|
AND slv."isDeleted" = false
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_approved": total_approved,
|
||||||
|
"with_embeddings": with_embeddings,
|
||||||
|
"without_embeddings": total_approved - with_embeddings,
|
||||||
|
"coverage_percent": (
|
||||||
|
round(with_embeddings / total_approved * 100, 1)
|
||||||
|
if total_approved > 0
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get embedding stats: {e}")
|
||||||
|
return {
|
||||||
|
"total_approved": 0,
|
||||||
|
"with_embeddings": 0,
|
||||||
|
"without_embeddings": 0,
|
||||||
|
"coverage_percent": 0,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for approved listings that don't have them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Number of embeddings to generate in one call
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with success/failure counts
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Find approved versions without embeddings
|
||||||
|
missing = await query_raw_with_schema(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
slv.id,
|
||||||
|
slv.name,
|
||||||
|
slv.description,
|
||||||
|
slv."subHeading",
|
||||||
|
slv.categories
|
||||||
|
FROM {schema_prefix}"StoreListingVersion" slv
|
||||||
|
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
|
||||||
|
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||||
|
WHERE slv."submissionStatus" = 'APPROVED'
|
||||||
|
AND slv."isDeleted" = false
|
||||||
|
AND uce."contentId" IS NULL
|
||||||
|
LIMIT $1
|
||||||
|
""",
|
||||||
|
batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not missing:
|
||||||
|
return {
|
||||||
|
"processed": 0,
|
||||||
|
"success": 0,
|
||||||
|
"failed": 0,
|
||||||
|
"message": "No missing embeddings",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process embeddings concurrently for better performance
|
||||||
|
embedding_tasks = [
|
||||||
|
ensure_embedding(
|
||||||
|
version_id=row["id"],
|
||||||
|
name=row["name"],
|
||||||
|
description=row["description"],
|
||||||
|
sub_heading=row["subHeading"],
|
||||||
|
categories=row["categories"] or [],
|
||||||
|
)
|
||||||
|
for row in missing
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
success = sum(1 for result in results if result is True)
|
||||||
|
failed = len(results) - success
|
||||||
|
|
||||||
|
return {
|
||||||
|
"processed": len(missing),
|
||||||
|
"success": success,
|
||||||
|
"failed": failed,
|
||||||
|
"message": f"Backfilled {success} embeddings, {failed} failed",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to backfill embeddings: {e}")
|
||||||
|
return {
|
||||||
|
"processed": 0,
|
||||||
|
"success": 0,
|
||||||
|
"failed": 0,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def embed_query(query: str) -> list[float] | None:
|
||||||
|
"""
|
||||||
|
Generate embedding for a search query.
|
||||||
|
|
||||||
|
Same as generate_embedding but with clearer intent.
|
||||||
|
"""
|
||||||
|
return await generate_embedding(query)
|
||||||
|
|
||||||
|
|
||||||
|
def embedding_to_vector_string(embedding: list[float]) -> str:
|
||||||
|
"""Convert embedding list to PostgreSQL vector string format."""
|
||||||
|
return "[" + ",".join(str(x) for x in embedding) + "]"
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_content_embedding(
|
||||||
|
content_type: ContentType,
|
||||||
|
content_id: str,
|
||||||
|
searchable_text: str,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
force: bool = False,
|
||||||
|
tx: prisma.Prisma | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Ensure an embedding exists for any content type.
|
||||||
|
|
||||||
|
Generic function for creating embeddings for store agents, blocks, docs, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_type: ContentType enum value (STORE_AGENT, BLOCK, etc.)
|
||||||
|
content_id: Unique identifier for the content
|
||||||
|
searchable_text: Combined text for embedding generation
|
||||||
|
metadata: Optional metadata to store with embedding
|
||||||
|
force: Force regeneration even if embedding exists
|
||||||
|
tx: Optional transaction client
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if embedding exists/was created, False on failure
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if embedding already exists
|
||||||
|
if not force:
|
||||||
|
existing = await get_content_embedding(content_type, content_id, user_id)
|
||||||
|
if existing and existing.get("embedding"):
|
||||||
|
logger.debug(
|
||||||
|
f"Embedding for {content_type}:{content_id} already exists"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Generate new embedding
|
||||||
|
embedding = await generate_embedding(searchable_text)
|
||||||
|
if embedding is None:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not generate embedding for {content_type}:{content_id}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Store the embedding
|
||||||
|
return await store_content_embedding(
|
||||||
|
content_type=content_type,
|
||||||
|
content_id=content_id,
|
||||||
|
embedding=embedding,
|
||||||
|
searchable_text=searchable_text,
|
||||||
|
metadata=metadata or {},
|
||||||
|
user_id=user_id,
|
||||||
|
tx=tx,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
|
||||||
|
return False
|
||||||
@@ -0,0 +1,329 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for embeddings with schema handling.
|
||||||
|
|
||||||
|
These tests verify that embeddings operations work correctly across different database schemas.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from prisma.enums import ContentType
|
||||||
|
|
||||||
|
from backend.api.features.store import embeddings
|
||||||
|
|
||||||
|
# Schema prefix tests removed - functionality moved to db.raw_with_schema() helper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_store_content_embedding_with_schema():
|
||||||
|
"""Test storing embeddings with proper schema handling."""
|
||||||
|
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||||
|
mock_schema.return_value = "platform"
|
||||||
|
|
||||||
|
with patch("prisma.get_client") as mock_get_client:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
result = await embeddings.store_content_embedding(
|
||||||
|
content_type=ContentType.STORE_AGENT,
|
||||||
|
content_id="test-id",
|
||||||
|
embedding=[0.1] * 1536,
|
||||||
|
searchable_text="test text",
|
||||||
|
metadata={"test": "data"},
|
||||||
|
user_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the query was called
|
||||||
|
assert mock_client.execute_raw.called
|
||||||
|
|
||||||
|
# Get the SQL query that was executed
|
||||||
|
call_args = mock_client.execute_raw.call_args
|
||||||
|
sql_query = call_args[0][0]
|
||||||
|
|
||||||
|
# Verify schema prefix is in the query
|
||||||
|
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||||
|
|
||||||
|
# Verify result
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_get_content_embedding_with_schema():
|
||||||
|
"""Test retrieving embeddings with proper schema handling."""
|
||||||
|
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||||
|
mock_schema.return_value = "platform"
|
||||||
|
|
||||||
|
with patch("prisma.get_client") as mock_get_client:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.query_raw.return_value = [
|
||||||
|
{
|
||||||
|
"contentType": "STORE_AGENT",
|
||||||
|
"contentId": "test-id",
|
||||||
|
"userId": None,
|
||||||
|
"embedding": "[0.1, 0.2]",
|
||||||
|
"searchableText": "test",
|
||||||
|
"metadata": {},
|
||||||
|
"createdAt": "2024-01-01",
|
||||||
|
"updatedAt": "2024-01-01",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
result = await embeddings.get_content_embedding(
|
||||||
|
ContentType.STORE_AGENT,
|
||||||
|
"test-id",
|
||||||
|
user_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the query was called
|
||||||
|
assert mock_client.query_raw.called
|
||||||
|
|
||||||
|
# Get the SQL query that was executed
|
||||||
|
call_args = mock_client.query_raw.call_args
|
||||||
|
sql_query = call_args[0][0]
|
||||||
|
|
||||||
|
# Verify schema prefix is in the query
|
||||||
|
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||||
|
|
||||||
|
# Verify result
|
||||||
|
assert result is not None
|
||||||
|
assert result["contentId"] == "test-id"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_delete_content_embedding_with_schema():
|
||||||
|
"""Test deleting embeddings with proper schema handling."""
|
||||||
|
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||||
|
mock_schema.return_value = "platform"
|
||||||
|
|
||||||
|
with patch("prisma.get_client") as mock_get_client:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
result = await embeddings.delete_content_embedding(
|
||||||
|
ContentType.STORE_AGENT,
|
||||||
|
"test-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the query was called
|
||||||
|
assert mock_client.execute_raw.called
|
||||||
|
|
||||||
|
# Get the SQL query that was executed
|
||||||
|
call_args = mock_client.execute_raw.call_args
|
||||||
|
sql_query = call_args[0][0]
|
||||||
|
|
||||||
|
# Verify schema prefix is in the query
|
||||||
|
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||||
|
|
||||||
|
# Verify result
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_get_embedding_stats_with_schema():
|
||||||
|
"""Test embedding statistics with proper schema handling."""
|
||||||
|
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||||
|
mock_schema.return_value = "platform"
|
||||||
|
|
||||||
|
with patch("prisma.get_client") as mock_get_client:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
# Mock both query results
|
||||||
|
mock_client.query_raw.side_effect = [
|
||||||
|
[{"count": 100}], # total_approved
|
||||||
|
[{"count": 80}], # with_embeddings
|
||||||
|
]
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
result = await embeddings.get_embedding_stats()
|
||||||
|
|
||||||
|
# Verify both queries were called
|
||||||
|
assert mock_client.query_raw.call_count == 2
|
||||||
|
|
||||||
|
# Get both SQL queries
|
||||||
|
first_call = mock_client.query_raw.call_args_list[0]
|
||||||
|
second_call = mock_client.query_raw.call_args_list[1]
|
||||||
|
|
||||||
|
first_sql = first_call[0][0]
|
||||||
|
second_sql = second_call[0][0]
|
||||||
|
|
||||||
|
# Verify schema prefix in both queries
|
||||||
|
assert '"platform"."StoreListingVersion"' in first_sql
|
||||||
|
assert '"platform"."StoreListingVersion"' in second_sql
|
||||||
|
assert '"platform"."UnifiedContentEmbedding"' in second_sql
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
assert result["total_approved"] == 100
|
||||||
|
assert result["with_embeddings"] == 80
|
||||||
|
assert result["without_embeddings"] == 20
|
||||||
|
assert result["coverage_percent"] == 80.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_backfill_missing_embeddings_with_schema():
|
||||||
|
"""Test backfilling embeddings with proper schema handling."""
|
||||||
|
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||||
|
mock_schema.return_value = "platform"
|
||||||
|
|
||||||
|
with patch("prisma.get_client") as mock_get_client:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
# Mock missing embeddings query
|
||||||
|
mock_client.query_raw.return_value = [
|
||||||
|
{
|
||||||
|
"id": "version-1",
|
||||||
|
"name": "Test Agent",
|
||||||
|
"description": "Test description",
|
||||||
|
"subHeading": "Test heading",
|
||||||
|
"categories": ["test"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.ensure_embedding"
|
||||||
|
) as mock_ensure:
|
||||||
|
mock_ensure.return_value = True
|
||||||
|
|
||||||
|
result = await embeddings.backfill_missing_embeddings(batch_size=10)
|
||||||
|
|
||||||
|
# Verify the query was called
|
||||||
|
assert mock_client.query_raw.called
|
||||||
|
|
||||||
|
# Get the SQL query
|
||||||
|
call_args = mock_client.query_raw.call_args
|
||||||
|
sql_query = call_args[0][0]
|
||||||
|
|
||||||
|
# Verify schema prefix in query
|
||||||
|
assert '"platform"."StoreListingVersion"' in sql_query
|
||||||
|
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||||
|
|
||||||
|
# Verify ensure_embedding was called
|
||||||
|
assert mock_ensure.called
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
assert result["processed"] == 1
|
||||||
|
assert result["success"] == 1
|
||||||
|
assert result["failed"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_ensure_content_embedding_with_schema():
|
||||||
|
"""Test ensuring embeddings exist with proper schema handling."""
|
||||||
|
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||||
|
mock_schema.return_value = "platform"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.get_content_embedding"
|
||||||
|
) as mock_get:
|
||||||
|
# Simulate no existing embedding
|
||||||
|
mock_get.return_value = None
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.generate_embedding"
|
||||||
|
) as mock_generate:
|
||||||
|
mock_generate.return_value = [0.1] * 1536
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.store_content_embedding"
|
||||||
|
) as mock_store:
|
||||||
|
mock_store.return_value = True
|
||||||
|
|
||||||
|
result = await embeddings.ensure_content_embedding(
|
||||||
|
content_type=ContentType.STORE_AGENT,
|
||||||
|
content_id="test-id",
|
||||||
|
searchable_text="test text",
|
||||||
|
metadata={"test": "data"},
|
||||||
|
user_id=None,
|
||||||
|
force=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the flow
|
||||||
|
assert mock_get.called
|
||||||
|
assert mock_generate.called
|
||||||
|
assert mock_store.called
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_backward_compatibility_store_embedding():
|
||||||
|
"""Test backward compatibility wrapper for store_embedding."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.store_content_embedding"
|
||||||
|
) as mock_store:
|
||||||
|
mock_store.return_value = True
|
||||||
|
|
||||||
|
result = await embeddings.store_embedding(
|
||||||
|
version_id="test-version-id",
|
||||||
|
embedding=[0.1] * 1536,
|
||||||
|
tx=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify it calls the new function with correct parameters
|
||||||
|
assert mock_store.called
|
||||||
|
call_args = mock_store.call_args
|
||||||
|
|
||||||
|
assert call_args[1]["content_type"] == ContentType.STORE_AGENT
|
||||||
|
assert call_args[1]["content_id"] == "test-version-id"
|
||||||
|
assert call_args[1]["user_id"] is None
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_backward_compatibility_get_embedding():
|
||||||
|
"""Test backward compatibility wrapper for get_embedding."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.get_content_embedding"
|
||||||
|
) as mock_get:
|
||||||
|
mock_get.return_value = {
|
||||||
|
"contentType": "STORE_AGENT",
|
||||||
|
"contentId": "test-version-id",
|
||||||
|
"embedding": "[0.1, 0.2]",
|
||||||
|
"createdAt": "2024-01-01",
|
||||||
|
"updatedAt": "2024-01-01",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await embeddings.get_embedding("test-version-id")
|
||||||
|
|
||||||
|
# Verify it calls the new function
|
||||||
|
assert mock_get.called
|
||||||
|
|
||||||
|
# Verify it transforms to old format
|
||||||
|
assert result is not None
|
||||||
|
assert result["storeListingVersionId"] == "test-version-id"
|
||||||
|
assert "embedding" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_schema_handling_error_cases():
|
||||||
|
"""Test error handling in schema-aware operations."""
|
||||||
|
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||||
|
mock_schema.return_value = "platform"
|
||||||
|
|
||||||
|
with patch("prisma.get_client") as mock_get_client:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
result = await embeddings.store_content_embedding(
|
||||||
|
content_type=ContentType.STORE_AGENT,
|
||||||
|
content_id="test-id",
|
||||||
|
embedding=[0.1] * 1536,
|
||||||
|
searchable_text="test",
|
||||||
|
metadata=None,
|
||||||
|
user_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return False on error, not raise
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v", "-s"])
|
||||||
@@ -0,0 +1,387 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import prisma
|
||||||
|
import pytest
|
||||||
|
from prisma import Prisma
|
||||||
|
from prisma.enums import ContentType
|
||||||
|
|
||||||
|
from backend.api.features.store import embeddings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def setup_prisma():
|
||||||
|
"""Setup Prisma client for tests."""
|
||||||
|
try:
|
||||||
|
Prisma()
|
||||||
|
except prisma.errors.ClientAlreadyRegisteredError:
|
||||||
|
pass
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_build_searchable_text():
|
||||||
|
"""Test searchable text building from listing fields."""
|
||||||
|
result = embeddings.build_searchable_text(
|
||||||
|
name="AI Assistant",
|
||||||
|
description="A helpful AI assistant for productivity",
|
||||||
|
sub_heading="Boost your productivity",
|
||||||
|
categories=["AI", "Productivity"],
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = "AI Assistant Boost your productivity A helpful AI assistant for productivity AI Productivity"
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_build_searchable_text_empty_fields():
|
||||||
|
"""Test searchable text building with empty fields."""
|
||||||
|
result = embeddings.build_searchable_text(
|
||||||
|
name="", description="Test description", sub_heading="", categories=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "Test description"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_generate_embedding_success():
|
||||||
|
"""Test successful embedding generation."""
|
||||||
|
# Mock OpenAI response
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data = [MagicMock()]
|
||||||
|
mock_response.data[0].embedding = [0.1, 0.2, 0.3] * 512 # 1536 dimensions
|
||||||
|
|
||||||
|
# Use AsyncMock for async embeddings.create method
|
||||||
|
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Patch at the point of use in embeddings.py
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.get_openai_client"
|
||||||
|
) as mock_get_client:
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
result = await embeddings.generate_embedding("test text")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) == 1536
|
||||||
|
assert result[0] == 0.1
|
||||||
|
|
||||||
|
mock_client.embeddings.create.assert_called_once_with(
|
||||||
|
model="text-embedding-3-small", input="test text"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_generate_embedding_no_api_key():
|
||||||
|
"""Test embedding generation without API key."""
|
||||||
|
# Patch at the point of use in embeddings.py
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.get_openai_client"
|
||||||
|
) as mock_get_client:
|
||||||
|
mock_get_client.return_value = None
|
||||||
|
|
||||||
|
result = await embeddings.generate_embedding("test text")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_generate_embedding_api_error():
|
||||||
|
"""Test embedding generation with API error."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.embeddings.create = AsyncMock(side_effect=Exception("API Error"))
|
||||||
|
|
||||||
|
# Patch at the point of use in embeddings.py
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.get_openai_client"
|
||||||
|
) as mock_get_client:
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
result = await embeddings.generate_embedding("test text")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_generate_embedding_text_truncation():
|
||||||
|
"""Test that long text is properly truncated using tiktoken."""
|
||||||
|
from tiktoken import encoding_for_model
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data = [MagicMock()]
|
||||||
|
mock_response.data[0].embedding = [0.1] * 1536
|
||||||
|
|
||||||
|
# Use AsyncMock for async embeddings.create method
|
||||||
|
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Patch at the point of use in embeddings.py
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.get_openai_client"
|
||||||
|
) as mock_get_client:
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
# Create text that will exceed 8191 tokens
|
||||||
|
# Use varied characters to ensure token-heavy text: each word is ~1 token
|
||||||
|
words = [f"word{i}" for i in range(10000)]
|
||||||
|
long_text = " ".join(words) # ~10000 tokens
|
||||||
|
|
||||||
|
await embeddings.generate_embedding(long_text)
|
||||||
|
|
||||||
|
# Verify text was truncated to 8191 tokens
|
||||||
|
call_args = mock_client.embeddings.create.call_args
|
||||||
|
truncated_text = call_args.kwargs["input"]
|
||||||
|
|
||||||
|
# Count actual tokens in truncated text
|
||||||
|
enc = encoding_for_model("text-embedding-3-small")
|
||||||
|
actual_tokens = len(enc.encode(truncated_text))
|
||||||
|
|
||||||
|
# Should be at or just under 8191 tokens
|
||||||
|
assert actual_tokens <= 8191
|
||||||
|
# Should be close to the limit (not over-truncated)
|
||||||
|
assert actual_tokens >= 8100
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_store_embedding_success(mocker):
|
||||||
|
"""Test successful embedding storage."""
|
||||||
|
mock_client = mocker.AsyncMock()
|
||||||
|
mock_client.execute_raw = mocker.AsyncMock()
|
||||||
|
|
||||||
|
embedding = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
result = await embeddings.store_embedding(
|
||||||
|
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
# execute_raw is called twice: once for SET search_path, once for INSERT
|
||||||
|
assert mock_client.execute_raw.call_count == 2
|
||||||
|
|
||||||
|
# First call: SET search_path
|
||||||
|
first_call_args = mock_client.execute_raw.call_args_list[0][0]
|
||||||
|
assert "SET search_path" in first_call_args[0]
|
||||||
|
|
||||||
|
# Second call: INSERT query with the actual data
|
||||||
|
second_call_args = mock_client.execute_raw.call_args_list[1][0]
|
||||||
|
assert "test-version-id" in second_call_args
|
||||||
|
assert "[0.1,0.2,0.3]" in second_call_args
|
||||||
|
assert None in second_call_args # userId should be None for store agents
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_store_embedding_database_error(mocker):
|
||||||
|
"""Test embedding storage with database error."""
|
||||||
|
mock_client = mocker.AsyncMock()
|
||||||
|
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||||
|
|
||||||
|
embedding = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
result = await embeddings.store_embedding(
|
||||||
|
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_embedding_success():
|
||||||
|
"""Test successful embedding retrieval."""
|
||||||
|
mock_result = [
|
||||||
|
{
|
||||||
|
"contentType": "STORE_AGENT",
|
||||||
|
"contentId": "test-version-id",
|
||||||
|
"userId": None,
|
||||||
|
"embedding": "[0.1,0.2,0.3]",
|
||||||
|
"searchableText": "Test text",
|
||||||
|
"metadata": {},
|
||||||
|
"createdAt": "2024-01-01T00:00:00Z",
|
||||||
|
"updatedAt": "2024-01-01T00:00:00Z",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||||
|
return_value=mock_result,
|
||||||
|
):
|
||||||
|
result = await embeddings.get_embedding("test-version-id")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["storeListingVersionId"] == "test-version-id"
|
||||||
|
assert result["embedding"] == "[0.1,0.2,0.3]"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_embedding_not_found():
|
||||||
|
"""Test embedding retrieval when not found."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||||
|
return_value=[],
|
||||||
|
):
|
||||||
|
result = await embeddings.get_embedding("test-version-id")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||||
|
@patch("backend.api.features.store.embeddings.store_embedding")
|
||||||
|
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||||
|
async def test_ensure_embedding_already_exists(mock_get, mock_store, mock_generate):
|
||||||
|
"""Test ensure_embedding when embedding already exists."""
|
||||||
|
mock_get.return_value = {"embedding": "[0.1,0.2,0.3]"}
|
||||||
|
|
||||||
|
result = await embeddings.ensure_embedding(
|
||||||
|
version_id="test-id",
|
||||||
|
name="Test",
|
||||||
|
description="Test description",
|
||||||
|
sub_heading="Test heading",
|
||||||
|
categories=["test"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_generate.assert_not_called()
|
||||||
|
mock_store.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||||
|
@patch("backend.api.features.store.embeddings.store_content_embedding")
|
||||||
|
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||||
|
async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
|
||||||
|
"""Test ensure_embedding creating new embedding."""
|
||||||
|
mock_get.return_value = None
|
||||||
|
mock_generate.return_value = [0.1, 0.2, 0.3]
|
||||||
|
mock_store.return_value = True
|
||||||
|
|
||||||
|
result = await embeddings.ensure_embedding(
|
||||||
|
version_id="test-id",
|
||||||
|
name="Test",
|
||||||
|
description="Test description",
|
||||||
|
sub_heading="Test heading",
|
||||||
|
categories=["test"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_generate.assert_called_once_with("Test Test heading Test description test")
|
||||||
|
mock_store.assert_called_once_with(
|
||||||
|
content_type=ContentType.STORE_AGENT,
|
||||||
|
content_id="test-id",
|
||||||
|
embedding=[0.1, 0.2, 0.3],
|
||||||
|
searchable_text="Test Test heading Test description test",
|
||||||
|
metadata={"name": "Test", "subHeading": "Test heading", "categories": ["test"]},
|
||||||
|
user_id=None,
|
||||||
|
tx=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||||
|
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||||
|
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
||||||
|
"""Test ensure_embedding when generation fails."""
|
||||||
|
mock_get.return_value = None
|
||||||
|
mock_generate.return_value = None
|
||||||
|
|
||||||
|
result = await embeddings.ensure_embedding(
|
||||||
|
version_id="test-id",
|
||||||
|
name="Test",
|
||||||
|
description="Test description",
|
||||||
|
sub_heading="Test heading",
|
||||||
|
categories=["test"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_embedding_stats():
|
||||||
|
"""Test embedding statistics retrieval."""
|
||||||
|
# Mock approved count query and embedded count query
|
||||||
|
mock_approved_result = [{"count": 100}]
|
||||||
|
mock_embedded_result = [{"count": 75}]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||||
|
side_effect=[mock_approved_result, mock_embedded_result],
|
||||||
|
):
|
||||||
|
result = await embeddings.get_embedding_stats()
|
||||||
|
|
||||||
|
assert result["total_approved"] == 100
|
||||||
|
assert result["with_embeddings"] == 75
|
||||||
|
assert result["without_embeddings"] == 25
|
||||||
|
assert result["coverage_percent"] == 75.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@patch("backend.api.features.store.embeddings.ensure_embedding")
|
||||||
|
async def test_backfill_missing_embeddings_success(mock_ensure):
|
||||||
|
"""Test backfill with successful embedding generation."""
|
||||||
|
# Mock missing embeddings query
|
||||||
|
mock_missing = [
|
||||||
|
{
|
||||||
|
"id": "version-1",
|
||||||
|
"name": "Agent 1",
|
||||||
|
"description": "Description 1",
|
||||||
|
"subHeading": "Heading 1",
|
||||||
|
"categories": ["AI"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "version-2",
|
||||||
|
"name": "Agent 2",
|
||||||
|
"description": "Description 2",
|
||||||
|
"subHeading": "Heading 2",
|
||||||
|
"categories": ["Productivity"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock ensure_embedding to succeed for first, fail for second
|
||||||
|
mock_ensure.side_effect = [True, False]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||||
|
return_value=mock_missing,
|
||||||
|
):
|
||||||
|
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||||
|
|
||||||
|
assert result["processed"] == 2
|
||||||
|
assert result["success"] == 1
|
||||||
|
assert result["failed"] == 1
|
||||||
|
assert mock_ensure.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_backfill_missing_embeddings_no_missing():
|
||||||
|
"""Test backfill when no embeddings are missing."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||||
|
return_value=[],
|
||||||
|
):
|
||||||
|
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||||
|
|
||||||
|
assert result["processed"] == 0
|
||||||
|
assert result["success"] == 0
|
||||||
|
assert result["failed"] == 0
|
||||||
|
assert result["message"] == "No missing embeddings"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_embedding_to_vector_string():
|
||||||
|
"""Test embedding to PostgreSQL vector string conversion."""
|
||||||
|
embedding = [0.1, 0.2, 0.3, -0.4]
|
||||||
|
result = embeddings.embedding_to_vector_string(embedding)
|
||||||
|
assert result == "[0.1,0.2,0.3,-0.4]"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_embed_query():
|
||||||
|
"""Test embed_query function (alias for generate_embedding)."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.embeddings.generate_embedding"
|
||||||
|
) as mock_generate:
|
||||||
|
mock_generate.return_value = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
result = await embeddings.embed_query("test query")
|
||||||
|
|
||||||
|
assert result == [0.1, 0.2, 0.3]
|
||||||
|
mock_generate.assert_called_once_with("test query")
|
||||||
@@ -0,0 +1,393 @@
|
|||||||
|
"""
|
||||||
|
Hybrid Search for Store Agents
|
||||||
|
|
||||||
|
Combines semantic (embedding) search with lexical (tsvector) search
|
||||||
|
for improved relevance in marketplace agent discovery.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from backend.api.features.store.embeddings import (
|
||||||
|
embed_query,
|
||||||
|
embedding_to_vector_string,
|
||||||
|
)
|
||||||
|
from backend.data.db import query_raw_with_schema
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HybridSearchWeights:
|
||||||
|
"""Weights for combining search signals."""
|
||||||
|
|
||||||
|
semantic: float = 0.30 # Embedding cosine similarity
|
||||||
|
lexical: float = 0.30 # tsvector ts_rank_cd score
|
||||||
|
category: float = 0.20 # Category match boost
|
||||||
|
recency: float = 0.10 # Newer agents ranked higher
|
||||||
|
popularity: float = 0.10 # Agent usage/runs (PageRank-like)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Validate weights are non-negative and sum to approximately 1.0."""
|
||||||
|
total = (
|
||||||
|
self.semantic
|
||||||
|
+ self.lexical
|
||||||
|
+ self.category
|
||||||
|
+ self.recency
|
||||||
|
+ self.popularity
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(
|
||||||
|
w < 0
|
||||||
|
for w in [
|
||||||
|
self.semantic,
|
||||||
|
self.lexical,
|
||||||
|
self.category,
|
||||||
|
self.recency,
|
||||||
|
self.popularity,
|
||||||
|
]
|
||||||
|
):
|
||||||
|
raise ValueError("All weights must be non-negative")
|
||||||
|
|
||||||
|
if not (0.99 <= total <= 1.01):
|
||||||
|
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_WEIGHTS = HybridSearchWeights()
|
||||||
|
|
||||||
|
# Minimum relevance score threshold - agents below this are filtered out
|
||||||
|
# With weights (0.30 semantic + 0.30 lexical + 0.20 category + 0.10 recency + 0.10 popularity):
|
||||||
|
# - 0.20 means at least ~60% semantic match OR strong lexical match required
|
||||||
|
# - Ensures only genuinely relevant results are returned
|
||||||
|
# - Recency/popularity alone (0.10 each) won't pass the threshold
|
||||||
|
DEFAULT_MIN_SCORE = 0.20
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HybridSearchResult:
|
||||||
|
"""A single search result with score breakdown."""
|
||||||
|
|
||||||
|
slug: str
|
||||||
|
agent_name: str
|
||||||
|
agent_image: str
|
||||||
|
creator_username: str
|
||||||
|
creator_avatar: str
|
||||||
|
sub_heading: str
|
||||||
|
description: str
|
||||||
|
runs: int
|
||||||
|
rating: float
|
||||||
|
categories: list[str]
|
||||||
|
featured: bool
|
||||||
|
is_available: bool
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
# Score breakdown (for debugging/tuning)
|
||||||
|
combined_score: float
|
||||||
|
semantic_score: float = 0.0
|
||||||
|
lexical_score: float = 0.0
|
||||||
|
category_score: float = 0.0
|
||||||
|
recency_score: float = 0.0
|
||||||
|
popularity_score: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
async def hybrid_search(
|
||||||
|
query: str,
|
||||||
|
featured: bool = False,
|
||||||
|
creators: list[str] | None = None,
|
||||||
|
category: str | None = None,
|
||||||
|
sorted_by: (
|
||||||
|
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
|
||||||
|
) = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
weights: HybridSearchWeights | None = None,
|
||||||
|
min_score: float | None = None,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
"""
|
||||||
|
Perform hybrid search combining semantic and lexical signals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query string
|
||||||
|
featured: Filter for featured agents only
|
||||||
|
creators: Filter by creator usernames
|
||||||
|
category: Filter by category
|
||||||
|
sorted_by: Sort order (relevance uses hybrid scoring)
|
||||||
|
page: Page number (1-indexed)
|
||||||
|
page_size: Results per page
|
||||||
|
weights: Custom weights for search signals
|
||||||
|
min_score: Minimum relevance score threshold (0-1). Results below
|
||||||
|
this score are filtered out. Defaults to DEFAULT_MIN_SCORE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (results list, total count). Returns empty list if no
|
||||||
|
results meet the minimum relevance threshold.
|
||||||
|
"""
|
||||||
|
# Validate inputs
|
||||||
|
query = query.strip()
|
||||||
|
if not query:
|
||||||
|
return [], 0 # Empty query returns no results
|
||||||
|
|
||||||
|
if page < 1:
|
||||||
|
page = 1
|
||||||
|
if page_size < 1:
|
||||||
|
page_size = 1
|
||||||
|
if page_size > 100: # Cap at reasonable limit to prevent performance issues
|
||||||
|
page_size = 100
|
||||||
|
|
||||||
|
if weights is None:
|
||||||
|
weights = DEFAULT_WEIGHTS
|
||||||
|
if min_score is None:
|
||||||
|
min_score = DEFAULT_MIN_SCORE
|
||||||
|
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# Generate query embedding
|
||||||
|
query_embedding = await embed_query(query)
|
||||||
|
|
||||||
|
# Build WHERE clause conditions
|
||||||
|
where_parts: list[str] = ["sa.is_available = true"]
|
||||||
|
params: list[Any] = []
|
||||||
|
param_index = 1
|
||||||
|
|
||||||
|
# Add search query for lexical matching
|
||||||
|
params.append(query)
|
||||||
|
query_param = f"${param_index}"
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
# Add lowercased query for category matching
|
||||||
|
params.append(query.lower())
|
||||||
|
query_lower_param = f"${param_index}"
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
if featured:
|
||||||
|
where_parts.append("sa.featured = true")
|
||||||
|
|
||||||
|
if creators:
|
||||||
|
where_parts.append(f"sa.creator_username = ANY(${param_index})")
|
||||||
|
params.append(creators)
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
if category:
|
||||||
|
where_parts.append(f"${param_index} = ANY(sa.categories)")
|
||||||
|
params.append(category)
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
# Safe: where_parts only contains hardcoded strings with $N parameter placeholders
|
||||||
|
# No user input is concatenated directly into the SQL string
|
||||||
|
where_clause = " AND ".join(where_parts)
|
||||||
|
|
||||||
|
# Embedding is required for hybrid search - fail fast if unavailable
|
||||||
|
if query_embedding is None or not query_embedding:
|
||||||
|
# Log detailed error server-side
|
||||||
|
logger.error(
|
||||||
|
"Failed to generate query embedding. "
|
||||||
|
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
||||||
|
)
|
||||||
|
# Raise generic error to client
|
||||||
|
raise ValueError("Search service temporarily unavailable")
|
||||||
|
|
||||||
|
# Add embedding parameter
|
||||||
|
embedding_str = embedding_to_vector_string(query_embedding)
|
||||||
|
params.append(embedding_str)
|
||||||
|
embedding_param = f"${param_index}"
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
# Add weight parameters for SQL calculation
|
||||||
|
params.append(weights.semantic)
|
||||||
|
weight_semantic_param = f"${param_index}"
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
params.append(weights.lexical)
|
||||||
|
weight_lexical_param = f"${param_index}"
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
params.append(weights.category)
|
||||||
|
weight_category_param = f"${param_index}"
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
params.append(weights.recency)
|
||||||
|
weight_recency_param = f"${param_index}"
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
params.append(weights.popularity)
|
||||||
|
weight_popularity_param = f"${param_index}"
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
# Add min_score parameter
|
||||||
|
params.append(min_score)
|
||||||
|
min_score_param = f"${param_index}"
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
# Optimized hybrid search query:
|
||||||
|
# 1. Direct join to UnifiedContentEmbedding via contentId=storeListingVersionId (no redundant JOINs)
|
||||||
|
# 2. UNION approach (deduplicates agents matching both branches)
|
||||||
|
# 3. COUNT(*) OVER() to get total count in single query
|
||||||
|
# 4. Optimized category matching with EXISTS + unnest
|
||||||
|
# 5. Pre-calculated max values for lexical and popularity normalization
|
||||||
|
# 6. Simplified recency calculation with linear decay
|
||||||
|
# 7. Logarithmic popularity scaling to prevent viral agents from dominating
|
||||||
|
sql_query = f"""
|
||||||
|
WITH candidates AS (
|
||||||
|
-- Lexical matches (uses GIN index on search column)
|
||||||
|
SELECT sa."storeListingVersionId"
|
||||||
|
FROM {{schema_prefix}}"StoreAgent" sa
|
||||||
|
WHERE {where_clause}
|
||||||
|
AND sa.search @@ plainto_tsquery('english', {query_param})
|
||||||
|
|
||||||
|
UNION
|
||||||
|
|
||||||
|
-- Semantic matches (uses HNSW index on embedding with KNN)
|
||||||
|
SELECT "storeListingVersionId"
|
||||||
|
FROM (
|
||||||
|
SELECT sa."storeListingVersionId", uce.embedding
|
||||||
|
FROM {{schema_prefix}}"StoreAgent" sa
|
||||||
|
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||||
|
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||||
|
WHERE {where_clause}
|
||||||
|
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||||
|
LIMIT 200
|
||||||
|
) semantic_results
|
||||||
|
),
|
||||||
|
search_scores AS (
|
||||||
|
SELECT
|
||||||
|
sa.slug,
|
||||||
|
sa.agent_name,
|
||||||
|
sa.agent_image,
|
||||||
|
sa.creator_username,
|
||||||
|
sa.creator_avatar,
|
||||||
|
sa.sub_heading,
|
||||||
|
sa.description,
|
||||||
|
sa.runs,
|
||||||
|
sa.rating,
|
||||||
|
sa.categories,
|
||||||
|
sa.featured,
|
||||||
|
sa.is_available,
|
||||||
|
sa.updated_at,
|
||||||
|
-- Semantic score: cosine similarity (1 - distance)
|
||||||
|
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||||
|
-- Lexical score: ts_rank_cd (will be normalized later)
|
||||||
|
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||||
|
-- Category match: optimized with unnest for better performance
|
||||||
|
CASE
|
||||||
|
WHEN EXISTS (
|
||||||
|
SELECT 1 FROM unnest(sa.categories) cat
|
||||||
|
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
|
||||||
|
)
|
||||||
|
THEN 1.0
|
||||||
|
ELSE 0.0
|
||||||
|
END as category_score,
|
||||||
|
-- Recency score: linear decay over 90 days (simpler than exponential)
|
||||||
|
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score,
|
||||||
|
-- Popularity raw: agent runs count (will be normalized with log scaling)
|
||||||
|
sa.runs as popularity_raw
|
||||||
|
FROM candidates c
|
||||||
|
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||||
|
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||||
|
LEFT JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||||
|
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||||
|
),
|
||||||
|
max_lexical AS (
|
||||||
|
SELECT MAX(lexical_raw) as max_val FROM search_scores
|
||||||
|
),
|
||||||
|
max_popularity AS (
|
||||||
|
SELECT MAX(popularity_raw) as max_val FROM search_scores
|
||||||
|
),
|
||||||
|
normalized AS (
|
||||||
|
SELECT
|
||||||
|
ss.*,
|
||||||
|
-- Normalize lexical score by pre-calculated max
|
||||||
|
CASE
|
||||||
|
WHEN ml.max_val > 0
|
||||||
|
THEN ss.lexical_raw / ml.max_val
|
||||||
|
ELSE 0
|
||||||
|
END as lexical_score,
|
||||||
|
-- Normalize popularity with logarithmic scaling to prevent viral agents from dominating
|
||||||
|
-- LOG(1 + runs) / LOG(1 + max_runs) ensures score is 0-1 range
|
||||||
|
CASE
|
||||||
|
WHEN mp.max_val > 0 AND ss.popularity_raw > 0
|
||||||
|
THEN LN(1 + ss.popularity_raw) / LN(1 + mp.max_val)
|
||||||
|
ELSE 0
|
||||||
|
END as popularity_score
|
||||||
|
FROM search_scores ss
|
||||||
|
CROSS JOIN max_lexical ml
|
||||||
|
CROSS JOIN max_popularity mp
|
||||||
|
),
|
||||||
|
scored AS (
|
||||||
|
SELECT
|
||||||
|
slug,
|
||||||
|
agent_name,
|
||||||
|
agent_image,
|
||||||
|
creator_username,
|
||||||
|
creator_avatar,
|
||||||
|
sub_heading,
|
||||||
|
description,
|
||||||
|
runs,
|
||||||
|
rating,
|
||||||
|
categories,
|
||||||
|
featured,
|
||||||
|
is_available,
|
||||||
|
updated_at,
|
||||||
|
semantic_score,
|
||||||
|
lexical_score,
|
||||||
|
category_score,
|
||||||
|
recency_score,
|
||||||
|
popularity_score,
|
||||||
|
(
|
||||||
|
{weight_semantic_param} * semantic_score +
|
||||||
|
{weight_lexical_param} * lexical_score +
|
||||||
|
{weight_category_param} * category_score +
|
||||||
|
{weight_recency_param} * recency_score +
|
||||||
|
{weight_popularity_param} * popularity_score
|
||||||
|
) as combined_score
|
||||||
|
FROM normalized
|
||||||
|
),
|
||||||
|
filtered AS (
|
||||||
|
SELECT
|
||||||
|
*,
|
||||||
|
COUNT(*) OVER () as total_count
|
||||||
|
FROM scored
|
||||||
|
WHERE combined_score >= {min_score_param}
|
||||||
|
)
|
||||||
|
SELECT * FROM filtered
|
||||||
|
ORDER BY combined_score DESC
|
||||||
|
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Add pagination params
|
||||||
|
params.extend([page_size, offset])
|
||||||
|
|
||||||
|
# Execute search query - includes total_count via window function
|
||||||
|
results = await query_raw_with_schema(
|
||||||
|
sql_query, *params, set_public_search_path=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract total count from first result (all rows have same count)
|
||||||
|
total = results[0]["total_count"] if results else 0
|
||||||
|
|
||||||
|
# Remove total_count from results before returning
|
||||||
|
for result in results:
|
||||||
|
result.pop("total_count", None)
|
||||||
|
|
||||||
|
# Log without sensitive query content
|
||||||
|
logger.info(f"Hybrid search: {len(results)} results, {total} total")
|
||||||
|
|
||||||
|
return results, total
|
||||||
|
|
||||||
|
|
||||||
|
async def hybrid_search_simple(
|
||||||
|
query: str,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
"""
|
||||||
|
Simplified hybrid search for common use cases.
|
||||||
|
|
||||||
|
Uses default weights and no filters.
|
||||||
|
"""
|
||||||
|
return await hybrid_search(
|
||||||
|
query=query,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
@@ -0,0 +1,334 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for hybrid search with schema handling.
|
||||||
|
|
||||||
|
These tests verify that hybrid search works correctly across different database schemas.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.store.hybrid_search import HybridSearchWeights, hybrid_search
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_hybrid_search_with_schema_handling():
|
||||||
|
"""Test that hybrid search correctly handles database schema prefixes."""
|
||||||
|
# Test with a mock query to ensure schema handling works
|
||||||
|
query = "test agent"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||||
|
) as mock_query:
|
||||||
|
# Mock the query result
|
||||||
|
mock_query.return_value = [
|
||||||
|
{
|
||||||
|
"slug": "test/agent",
|
||||||
|
"agent_name": "Test Agent",
|
||||||
|
"agent_image": "test.png",
|
||||||
|
"creator_username": "test",
|
||||||
|
"creator_avatar": "avatar.png",
|
||||||
|
"sub_heading": "Test sub-heading",
|
||||||
|
"description": "Test description",
|
||||||
|
"runs": 10,
|
||||||
|
"rating": 4.5,
|
||||||
|
"categories": ["test"],
|
||||||
|
"featured": False,
|
||||||
|
"is_available": True,
|
||||||
|
"updated_at": "2024-01-01T00:00:00Z",
|
||||||
|
"combined_score": 0.8,
|
||||||
|
"semantic_score": 0.7,
|
||||||
|
"lexical_score": 0.6,
|
||||||
|
"category_score": 0.5,
|
||||||
|
"recency_score": 0.4,
|
||||||
|
"total_count": 1,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.embed_query"
|
||||||
|
) as mock_embed:
|
||||||
|
mock_embed.return_value = [0.1] * 1536 # Mock embedding
|
||||||
|
|
||||||
|
results, total = await hybrid_search(
|
||||||
|
query=query,
|
||||||
|
page=1,
|
||||||
|
page_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the query was called
|
||||||
|
assert mock_query.called
|
||||||
|
# Verify the SQL template uses schema_prefix placeholder
|
||||||
|
call_args = mock_query.call_args
|
||||||
|
sql_template = call_args[0][0]
|
||||||
|
assert "{schema_prefix}" in sql_template
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
assert len(results) == 1
|
||||||
|
assert total == 1
|
||||||
|
assert results[0]["slug"] == "test/agent"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_hybrid_search_with_public_schema():
|
||||||
|
"""Test hybrid search when using public schema (no prefix needed)."""
|
||||||
|
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||||
|
mock_schema.return_value = "public"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||||
|
) as mock_query:
|
||||||
|
mock_query.return_value = []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.embed_query"
|
||||||
|
) as mock_embed:
|
||||||
|
mock_embed.return_value = [0.1] * 1536
|
||||||
|
|
||||||
|
results, total = await hybrid_search(
|
||||||
|
query="test",
|
||||||
|
page=1,
|
||||||
|
page_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the mock was set up correctly
|
||||||
|
assert mock_schema.return_value == "public"
|
||||||
|
|
||||||
|
# Results should work even with empty results
|
||||||
|
assert results == []
|
||||||
|
assert total == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_hybrid_search_with_custom_schema():
|
||||||
|
"""Test hybrid search when using custom schema (e.g., 'platform')."""
|
||||||
|
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||||
|
mock_schema.return_value = "platform"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||||
|
) as mock_query:
|
||||||
|
mock_query.return_value = []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.embed_query"
|
||||||
|
) as mock_embed:
|
||||||
|
mock_embed.return_value = [0.1] * 1536
|
||||||
|
|
||||||
|
results, total = await hybrid_search(
|
||||||
|
query="test",
|
||||||
|
page=1,
|
||||||
|
page_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the mock was set up correctly
|
||||||
|
assert mock_schema.return_value == "platform"
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
assert total == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_hybrid_search_without_embeddings():
|
||||||
|
"""Test hybrid search fails fast when embeddings are unavailable."""
|
||||||
|
# Patch where the function is used, not where it's defined
|
||||||
|
with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed:
|
||||||
|
# Simulate embedding failure
|
||||||
|
mock_embed.return_value = None
|
||||||
|
|
||||||
|
# Should raise ValueError with helpful message
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
await hybrid_search(
|
||||||
|
query="test",
|
||||||
|
page=1,
|
||||||
|
page_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error message is generic (doesn't leak implementation details)
|
||||||
|
assert "Search service temporarily unavailable" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_hybrid_search_with_filters():
|
||||||
|
"""Test hybrid search with various filters."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||||
|
) as mock_query:
|
||||||
|
mock_query.return_value = []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.embed_query"
|
||||||
|
) as mock_embed:
|
||||||
|
mock_embed.return_value = [0.1] * 1536
|
||||||
|
|
||||||
|
# Test with featured filter
|
||||||
|
results, total = await hybrid_search(
|
||||||
|
query="test",
|
||||||
|
featured=True,
|
||||||
|
creators=["user1", "user2"],
|
||||||
|
category="productivity",
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify filters were applied in the query
|
||||||
|
call_args = mock_query.call_args
|
||||||
|
params = call_args[0][1:] # Skip SQL template
|
||||||
|
|
||||||
|
# Should have query, query_lower, creators array, category
|
||||||
|
assert len(params) >= 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_hybrid_search_weights():
|
||||||
|
"""Test hybrid search with custom weights."""
|
||||||
|
custom_weights = HybridSearchWeights(
|
||||||
|
semantic=0.5,
|
||||||
|
lexical=0.3,
|
||||||
|
category=0.1,
|
||||||
|
recency=0.1,
|
||||||
|
popularity=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||||
|
) as mock_query:
|
||||||
|
mock_query.return_value = []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.embed_query"
|
||||||
|
) as mock_embed:
|
||||||
|
mock_embed.return_value = [0.1] * 1536
|
||||||
|
|
||||||
|
results, total = await hybrid_search(
|
||||||
|
query="test",
|
||||||
|
weights=custom_weights,
|
||||||
|
page=1,
|
||||||
|
page_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify custom weights were used in the query
|
||||||
|
call_args = mock_query.call_args
|
||||||
|
sql_template = call_args[0][0]
|
||||||
|
params = call_args[0][1:] # Get all parameters passed
|
||||||
|
|
||||||
|
# Check that SQL uses parameterized weights (not f-string interpolation)
|
||||||
|
assert "$" in sql_template # Verify parameterization is used
|
||||||
|
|
||||||
|
# Check that custom weights are in the params
|
||||||
|
assert 0.5 in params # semantic weight
|
||||||
|
assert 0.3 in params # lexical weight
|
||||||
|
assert 0.1 in params # category and recency weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_hybrid_search_min_score_filtering():
|
||||||
|
"""Test hybrid search minimum score threshold."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||||
|
) as mock_query:
|
||||||
|
# Return results with varying scores
|
||||||
|
mock_query.return_value = [
|
||||||
|
{
|
||||||
|
"slug": "high-score/agent",
|
||||||
|
"agent_name": "High Score Agent",
|
||||||
|
"combined_score": 0.8,
|
||||||
|
"total_count": 1,
|
||||||
|
# ... other fields
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.embed_query"
|
||||||
|
) as mock_embed:
|
||||||
|
mock_embed.return_value = [0.1] * 1536
|
||||||
|
|
||||||
|
# Test with custom min_score
|
||||||
|
results, total = await hybrid_search(
|
||||||
|
query="test",
|
||||||
|
min_score=0.5, # High threshold
|
||||||
|
page=1,
|
||||||
|
page_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify min_score was applied in query
|
||||||
|
call_args = mock_query.call_args
|
||||||
|
sql_template = call_args[0][0]
|
||||||
|
params = call_args[0][1:] # Get all parameters
|
||||||
|
|
||||||
|
# Check that SQL uses parameterized min_score
|
||||||
|
assert "combined_score >=" in sql_template
|
||||||
|
assert "$" in sql_template # Verify parameterization
|
||||||
|
|
||||||
|
# Check that custom min_score is in the params
|
||||||
|
assert 0.5 in params
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_hybrid_search_pagination():
|
||||||
|
"""Test hybrid search pagination."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||||
|
) as mock_query:
|
||||||
|
mock_query.return_value = []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.embed_query"
|
||||||
|
) as mock_embed:
|
||||||
|
mock_embed.return_value = [0.1] * 1536
|
||||||
|
|
||||||
|
# Test page 2 with page_size 10
|
||||||
|
results, total = await hybrid_search(
|
||||||
|
query="test",
|
||||||
|
page=2,
|
||||||
|
page_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify pagination parameters
|
||||||
|
call_args = mock_query.call_args
|
||||||
|
params = call_args[0]
|
||||||
|
|
||||||
|
# Last two params should be LIMIT and OFFSET
|
||||||
|
limit = params[-2]
|
||||||
|
offset = params[-1]
|
||||||
|
|
||||||
|
assert limit == 10 # page_size
|
||||||
|
assert offset == 10 # (page - 1) * page_size = (2 - 1) * 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_hybrid_search_error_handling():
|
||||||
|
"""Test hybrid search error handling."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||||
|
) as mock_query:
|
||||||
|
# Simulate database error
|
||||||
|
mock_query.side_effect = Exception("Database connection error")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.hybrid_search.embed_query"
|
||||||
|
) as mock_embed:
|
||||||
|
mock_embed.return_value = [0.1] * 1536
|
||||||
|
|
||||||
|
# Should raise exception
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await hybrid_search(
|
||||||
|
query="test",
|
||||||
|
page=1,
|
||||||
|
page_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Database connection error" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v", "-s"])
|
||||||
@@ -64,7 +64,6 @@ from backend.data.onboarding import (
|
|||||||
complete_re_run_agent,
|
complete_re_run_agent,
|
||||||
get_recommended_agents,
|
get_recommended_agents,
|
||||||
get_user_onboarding,
|
get_user_onboarding,
|
||||||
increment_runs,
|
|
||||||
onboarding_enabled,
|
onboarding_enabled,
|
||||||
reset_user_onboarding,
|
reset_user_onboarding,
|
||||||
update_user_onboarding,
|
update_user_onboarding,
|
||||||
@@ -975,7 +974,6 @@ async def execute_graph(
|
|||||||
# Record successful graph execution
|
# Record successful graph execution
|
||||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||||
record_graph_operation(operation="execute", status="success")
|
record_graph_operation(operation="execute", status="success")
|
||||||
await increment_runs(user_id)
|
|
||||||
await complete_re_run_agent(user_id, graph_id)
|
await complete_re_run_agent(user_id, graph_id)
|
||||||
if source == "library":
|
if source == "library":
|
||||||
await complete_onboarding_step(
|
await complete_onboarding_step(
|
||||||
|
|||||||
@@ -38,6 +38,20 @@ POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
|
|||||||
if POOL_TIMEOUT:
|
if POOL_TIMEOUT:
|
||||||
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
|
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
|
||||||
|
|
||||||
|
# Add public schema to search_path for pgvector type access
|
||||||
|
# The vector extension is in public schema, but search_path is determined by schema parameter
|
||||||
|
# Extract the schema from DATABASE_URL or default to 'public' (matching get_database_schema())
|
||||||
|
parsed_url = urlparse(DATABASE_URL)
|
||||||
|
url_params = dict(parse_qsl(parsed_url.query))
|
||||||
|
db_schema = url_params.get("schema", "public")
|
||||||
|
# Build search_path, avoiding duplicates if db_schema is already 'public'
|
||||||
|
search_path_schemas = list(
|
||||||
|
dict.fromkeys([db_schema, "public"])
|
||||||
|
) # Preserves order, removes duplicates
|
||||||
|
search_path = ",".join(search_path_schemas)
|
||||||
|
# This allows using ::vector without schema qualification
|
||||||
|
DATABASE_URL = add_param(DATABASE_URL, "options", f"-c search_path={search_path}")
|
||||||
|
|
||||||
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
|
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
|
||||||
|
|
||||||
prisma = Prisma(
|
prisma = Prisma(
|
||||||
@@ -108,21 +122,102 @@ def get_database_schema() -> str:
|
|||||||
return query_params.get("schema", "public")
|
return query_params.get("schema", "public")
|
||||||
|
|
||||||
|
|
||||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
async def _raw_with_schema(
|
||||||
"""Execute raw SQL query with proper schema handling."""
|
query_template: str,
|
||||||
|
*args,
|
||||||
|
execute: bool = False,
|
||||||
|
client: Prisma | None = None,
|
||||||
|
set_public_search_path: bool = False,
|
||||||
|
) -> list[dict] | int:
|
||||||
|
"""Internal: Execute raw SQL with proper schema handling.
|
||||||
|
|
||||||
|
Use query_raw_with_schema() or execute_raw_with_schema() instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_template: SQL query with {schema_prefix} placeholder
|
||||||
|
*args: Query parameters
|
||||||
|
execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE.
|
||||||
|
client: Optional Prisma client for transactions (only used when execute=True).
|
||||||
|
set_public_search_path: If True, sets search_path to include public schema.
|
||||||
|
Needed for pgvector types and other public schema objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- list[dict] if execute=False (query results)
|
||||||
|
- int if execute=True (number of affected rows)
|
||||||
|
"""
|
||||||
schema = get_database_schema()
|
schema = get_database_schema()
|
||||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||||
|
|
||||||
import prisma as prisma_module
|
import prisma as prisma_module
|
||||||
|
|
||||||
result = await prisma_module.get_client().query_raw(
|
db_client = client if client else prisma_module.get_client()
|
||||||
formatted_query, *args # type: ignore
|
|
||||||
)
|
# Set search_path to include public schema if requested
|
||||||
|
# Prisma doesn't support the 'options' connection parameter, so we set it per-session
|
||||||
|
# This is idempotent and safe to call multiple times
|
||||||
|
if set_public_search_path:
|
||||||
|
await db_client.execute_raw(f"SET search_path = {schema}, public") # type: ignore
|
||||||
|
|
||||||
|
if execute:
|
||||||
|
result = await db_client.execute_raw(formatted_query, *args) # type: ignore
|
||||||
|
else:
|
||||||
|
result = await db_client.query_raw(formatted_query, *args) # type: ignore
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def query_raw_with_schema(
|
||||||
|
query_template: str, *args, set_public_search_path: bool = False
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Execute raw SQL SELECT query with proper schema handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_template: SQL query with {schema_prefix} placeholder
|
||||||
|
*args: Query parameters
|
||||||
|
set_public_search_path: If True, sets search_path to include public schema.
|
||||||
|
Needed for pgvector types and other public schema objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of result rows as dictionaries
|
||||||
|
|
||||||
|
Example:
|
||||||
|
results = await query_raw_with_schema(
|
||||||
|
'SELECT * FROM {schema_prefix}"User" WHERE id = $1',
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
return await _raw_with_schema(query_template, *args, execute=False, set_public_search_path=set_public_search_path) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_raw_with_schema(
|
||||||
|
query_template: str,
|
||||||
|
*args,
|
||||||
|
client: Prisma | None = None,
|
||||||
|
set_public_search_path: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""Execute raw SQL command (INSERT/UPDATE/DELETE) with proper schema handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_template: SQL query with {schema_prefix} placeholder
|
||||||
|
*args: Query parameters
|
||||||
|
client: Optional Prisma client for transactions
|
||||||
|
set_public_search_path: If True, sets search_path to include public schema.
|
||||||
|
Needed for pgvector types and other public schema objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of affected rows
|
||||||
|
|
||||||
|
Example:
|
||||||
|
await execute_raw_with_schema(
|
||||||
|
'INSERT INTO {schema_prefix}"User" (id, name) VALUES ($1, $2)',
|
||||||
|
user_id, name,
|
||||||
|
client=tx # Optional transaction client
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
return await _raw_with_schema(query_template, *args, execute=True, client=client, set_public_search_path=set_public_search_path) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class BaseDbModel(BaseModel):
|
class BaseDbModel(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import fastapi.exceptions
|
import fastapi.exceptions
|
||||||
@@ -18,6 +19,17 @@ from backend.usecases.sample import create_test_user
|
|||||||
from backend.util.test import SpinTestServer
|
from backend.util.test import SpinTestServer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def mock_embedding_functions():
|
||||||
|
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.db.ensure_embedding",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot):
|
async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -334,7 +334,7 @@ async def _get_user_timezone(user_id: str) -> str:
|
|||||||
return get_user_timezone_or_utc(user.timezone if user else None)
|
return get_user_timezone_or_utc(user.timezone if user else None)
|
||||||
|
|
||||||
|
|
||||||
async def increment_runs(user_id: str):
|
async def increment_onboarding_runs(user_id: str):
|
||||||
"""
|
"""
|
||||||
Increment a user's run counters and trigger any onboarding milestones.
|
Increment a user's run counters and trigger any onboarding milestones.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ from backend.api.features.library.db import (
|
|||||||
list_library_agents,
|
list_library_agents,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||||
|
from backend.api.features.store.embeddings import (
|
||||||
|
backfill_missing_embeddings,
|
||||||
|
get_embedding_stats,
|
||||||
|
)
|
||||||
from backend.data import db
|
from backend.data import db
|
||||||
from backend.data.analytics import (
|
from backend.data.analytics import (
|
||||||
get_accuracy_trends_and_alerts,
|
get_accuracy_trends_and_alerts,
|
||||||
@@ -20,6 +24,7 @@ from backend.data.execution import (
|
|||||||
get_execution_kv_data,
|
get_execution_kv_data,
|
||||||
get_execution_outputs_by_node_exec_id,
|
get_execution_outputs_by_node_exec_id,
|
||||||
get_frequently_executed_graphs,
|
get_frequently_executed_graphs,
|
||||||
|
get_graph_execution,
|
||||||
get_graph_execution_meta,
|
get_graph_execution_meta,
|
||||||
get_graph_executions,
|
get_graph_executions,
|
||||||
get_graph_executions_count,
|
get_graph_executions_count,
|
||||||
@@ -57,6 +62,7 @@ from backend.data.notifications import (
|
|||||||
get_user_notification_oldest_message_in_batch,
|
get_user_notification_oldest_message_in_batch,
|
||||||
remove_notifications_from_batch,
|
remove_notifications_from_batch,
|
||||||
)
|
)
|
||||||
|
from backend.data.onboarding import increment_onboarding_runs
|
||||||
from backend.data.user import (
|
from backend.data.user import (
|
||||||
get_active_user_ids_in_timerange,
|
get_active_user_ids_in_timerange,
|
||||||
get_user_by_id,
|
get_user_by_id,
|
||||||
@@ -140,6 +146,7 @@ class DatabaseManager(AppService):
|
|||||||
get_child_graph_executions = _(get_child_graph_executions)
|
get_child_graph_executions = _(get_child_graph_executions)
|
||||||
get_graph_executions = _(get_graph_executions)
|
get_graph_executions = _(get_graph_executions)
|
||||||
get_graph_executions_count = _(get_graph_executions_count)
|
get_graph_executions_count = _(get_graph_executions_count)
|
||||||
|
get_graph_execution = _(get_graph_execution)
|
||||||
get_graph_execution_meta = _(get_graph_execution_meta)
|
get_graph_execution_meta = _(get_graph_execution_meta)
|
||||||
create_graph_execution = _(create_graph_execution)
|
create_graph_execution = _(create_graph_execution)
|
||||||
get_node_execution = _(get_node_execution)
|
get_node_execution = _(get_node_execution)
|
||||||
@@ -204,10 +211,17 @@ class DatabaseManager(AppService):
|
|||||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||||
|
|
||||||
|
# Onboarding
|
||||||
|
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||||
|
|
||||||
# Store
|
# Store
|
||||||
get_store_agents = _(get_store_agents)
|
get_store_agents = _(get_store_agents)
|
||||||
get_store_agent_details = _(get_store_agent_details)
|
get_store_agent_details = _(get_store_agent_details)
|
||||||
|
|
||||||
|
# Store Embeddings
|
||||||
|
get_embedding_stats = _(get_embedding_stats)
|
||||||
|
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
||||||
|
|
||||||
# Summary data - async
|
# Summary data - async
|
||||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||||
|
|
||||||
@@ -259,6 +273,10 @@ class DatabaseManagerClient(AppServiceClient):
|
|||||||
get_store_agents = _(d.get_store_agents)
|
get_store_agents = _(d.get_store_agents)
|
||||||
get_store_agent_details = _(d.get_store_agent_details)
|
get_store_agent_details = _(d.get_store_agent_details)
|
||||||
|
|
||||||
|
# Store Embeddings
|
||||||
|
get_embedding_stats = _(d.get_embedding_stats)
|
||||||
|
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||||
d = DatabaseManager
|
d = DatabaseManager
|
||||||
@@ -274,6 +292,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
get_graph = d.get_graph
|
get_graph = d.get_graph
|
||||||
get_graph_metadata = d.get_graph_metadata
|
get_graph_metadata = d.get_graph_metadata
|
||||||
get_graph_settings = d.get_graph_settings
|
get_graph_settings = d.get_graph_settings
|
||||||
|
get_graph_execution = d.get_graph_execution
|
||||||
get_graph_execution_meta = d.get_graph_execution_meta
|
get_graph_execution_meta = d.get_graph_execution_meta
|
||||||
get_node = d.get_node
|
get_node = d.get_node
|
||||||
get_node_execution = d.get_node_execution
|
get_node_execution = d.get_node_execution
|
||||||
@@ -318,6 +337,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
add_store_agent_to_library = d.add_store_agent_to_library
|
add_store_agent_to_library = d.add_store_agent_to_library
|
||||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||||
|
|
||||||
|
# Onboarding
|
||||||
|
increment_onboarding_runs = d.increment_onboarding_runs
|
||||||
|
|
||||||
# Store
|
# Store
|
||||||
get_store_agents = d.get_store_agents
|
get_store_agents = d.get_store_agents
|
||||||
get_store_agent_details = d.get_store_agent_details
|
get_store_agent_details = d.get_store_agent_details
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import fastapi.responses
|
import fastapi.responses
|
||||||
import pytest
|
import pytest
|
||||||
@@ -19,6 +20,17 @@ from backend.util.test import SpinTestServer, wait_execution
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def mock_embedding_functions():
|
||||||
|
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.db.ensure_embedding",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
|
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
|
||||||
logger.info(f"Creating graph for user {u.id}")
|
logger.info(f"Creating graph for user {u.id}")
|
||||||
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)
|
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -27,7 +28,6 @@ from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
|||||||
from backend.data.block import BlockInput
|
from backend.data.block import BlockInput
|
||||||
from backend.data.execution import GraphExecutionWithNodes
|
from backend.data.execution import GraphExecutionWithNodes
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.data.onboarding import increment_runs
|
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
from backend.monitoring import (
|
from backend.monitoring import (
|
||||||
NotificationJobArgs,
|
NotificationJobArgs,
|
||||||
@@ -37,7 +37,7 @@ from backend.monitoring import (
|
|||||||
report_execution_accuracy_alerts,
|
report_execution_accuracy_alerts,
|
||||||
report_late_executions,
|
report_late_executions,
|
||||||
)
|
)
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_database_manager_client, get_scheduler_client
|
||||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||||
from backend.util.exceptions import (
|
from backend.util.exceptions import (
|
||||||
GraphNotFoundError,
|
GraphNotFoundError,
|
||||||
@@ -156,7 +156,6 @@ async def _execute_graph(**kwargs):
|
|||||||
inputs=args.input_data,
|
inputs=args.input_data,
|
||||||
graph_credentials_inputs=args.input_credentials,
|
graph_credentials_inputs=args.input_credentials,
|
||||||
)
|
)
|
||||||
await increment_runs(args.user_id)
|
|
||||||
elapsed = asyncio.get_event_loop().time() - start_time
|
elapsed = asyncio.get_event_loop().time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||||
@@ -254,6 +253,74 @@ def execution_accuracy_alerts():
|
|||||||
return report_execution_accuracy_alerts()
|
return report_execution_accuracy_alerts()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_embeddings_coverage():
|
||||||
|
"""
|
||||||
|
Ensure approved store agents have embeddings for hybrid search.
|
||||||
|
|
||||||
|
Processes ALL missing embeddings in batches of 10 until 100% coverage.
|
||||||
|
Missing embeddings = agents invisible in hybrid search.
|
||||||
|
|
||||||
|
Schedule: Runs every 6 hours (balanced between coverage and API costs).
|
||||||
|
- Catches agents approved between scheduled runs
|
||||||
|
- Batch size 10: gradual processing to avoid rate limits
|
||||||
|
- Manual trigger available via execute_ensure_embeddings_coverage endpoint
|
||||||
|
"""
|
||||||
|
db_client = get_database_manager_client()
|
||||||
|
stats = db_client.get_embedding_stats()
|
||||||
|
|
||||||
|
# Check for error from get_embedding_stats() first
|
||||||
|
if "error" in stats:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get embedding stats: {stats['error']} - skipping backfill"
|
||||||
|
)
|
||||||
|
return {"processed": 0, "success": 0, "failed": 0, "error": stats["error"]}
|
||||||
|
|
||||||
|
if stats["without_embeddings"] == 0:
|
||||||
|
logger.info("All approved agents have embeddings, skipping backfill")
|
||||||
|
return {"processed": 0, "success": 0, "failed": 0}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Found {stats['without_embeddings']} agents without embeddings "
|
||||||
|
f"({stats['coverage_percent']}% coverage) - processing all"
|
||||||
|
)
|
||||||
|
|
||||||
|
total_processed = 0
|
||||||
|
total_success = 0
|
||||||
|
total_failed = 0
|
||||||
|
|
||||||
|
# Process in batches until no more missing embeddings
|
||||||
|
while True:
|
||||||
|
result = db_client.backfill_missing_embeddings(batch_size=10)
|
||||||
|
|
||||||
|
total_processed += result["processed"]
|
||||||
|
total_success += result["success"]
|
||||||
|
total_failed += result["failed"]
|
||||||
|
|
||||||
|
if result["processed"] == 0:
|
||||||
|
# No more missing embeddings
|
||||||
|
break
|
||||||
|
|
||||||
|
if result["success"] == 0 and result["processed"] > 0:
|
||||||
|
# All attempts in this batch failed - stop to avoid infinite loop
|
||||||
|
logger.error(
|
||||||
|
f"All {result['processed']} embedding attempts failed - stopping backfill"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Small delay between batches to avoid rate limits
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Embedding backfill completed: {total_success}/{total_processed} succeeded, "
|
||||||
|
f"{total_failed} failed"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"processed": total_processed,
|
||||||
|
"success": total_success,
|
||||||
|
"failed": total_failed,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Monitoring functions are now imported from monitoring module
|
# Monitoring functions are now imported from monitoring module
|
||||||
|
|
||||||
|
|
||||||
@@ -475,6 +542,19 @@ class Scheduler(AppService):
|
|||||||
jobstore=Jobstores.EXECUTION.value,
|
jobstore=Jobstores.EXECUTION.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Embedding Coverage - Every 6 hours
|
||||||
|
# Ensures all approved agents have embeddings for hybrid search
|
||||||
|
# Critical: missing embeddings = agents invisible in search
|
||||||
|
self.scheduler.add_job(
|
||||||
|
ensure_embeddings_coverage,
|
||||||
|
id="ensure_embeddings_coverage",
|
||||||
|
trigger="interval",
|
||||||
|
hours=6,
|
||||||
|
replace_existing=True,
|
||||||
|
max_instances=1, # Prevent overlapping runs
|
||||||
|
jobstore=Jobstores.EXECUTION.value,
|
||||||
|
)
|
||||||
|
|
||||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||||
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
|
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
|
||||||
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
||||||
@@ -632,6 +712,11 @@ class Scheduler(AppService):
|
|||||||
"""Manually trigger execution accuracy alert checking."""
|
"""Manually trigger execution accuracy alert checking."""
|
||||||
return execution_accuracy_alerts()
|
return execution_accuracy_alerts()
|
||||||
|
|
||||||
|
@expose
|
||||||
|
def execute_ensure_embeddings_coverage(self):
|
||||||
|
"""Manually trigger embedding backfill for approved store agents."""
|
||||||
|
return ensure_embeddings_coverage()
|
||||||
|
|
||||||
|
|
||||||
class SchedulerClient(AppServiceClient):
|
class SchedulerClient(AppServiceClient):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError
|
|||||||
|
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
|
from backend.data import onboarding as onboarding_db
|
||||||
from backend.data import user as user_db
|
from backend.data import user as user_db
|
||||||
from backend.data.block import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
@@ -31,7 +32,6 @@ from backend.data.execution import (
|
|||||||
GraphExecutionStats,
|
GraphExecutionStats,
|
||||||
GraphExecutionWithNodes,
|
GraphExecutionWithNodes,
|
||||||
NodesInputMasks,
|
NodesInputMasks,
|
||||||
get_graph_execution,
|
|
||||||
)
|
)
|
||||||
from backend.data.graph import GraphModel, Node
|
from backend.data.graph import GraphModel, Node
|
||||||
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput
|
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput
|
||||||
@@ -809,13 +809,14 @@ async def add_graph_execution(
|
|||||||
edb = execution_db
|
edb = execution_db
|
||||||
udb = user_db
|
udb = user_db
|
||||||
gdb = graph_db
|
gdb = graph_db
|
||||||
|
odb = onboarding_db
|
||||||
else:
|
else:
|
||||||
edb = udb = gdb = get_database_manager_async_client()
|
edb = udb = gdb = odb = get_database_manager_async_client()
|
||||||
|
|
||||||
# Get or create the graph execution
|
# Get or create the graph execution
|
||||||
if graph_exec_id:
|
if graph_exec_id:
|
||||||
# Resume existing execution
|
# Resume existing execution
|
||||||
graph_exec = await get_graph_execution(
|
graph_exec = await edb.get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=graph_exec_id,
|
execution_id=graph_exec_id,
|
||||||
include_node_executions=True,
|
include_node_executions=True,
|
||||||
@@ -891,6 +892,7 @@ async def add_graph_execution(
|
|||||||
)
|
)
|
||||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||||
|
|
||||||
|
# Publish to execution queue for executor to pick up
|
||||||
exec_queue = await get_async_execution_queue()
|
exec_queue = await get_async_execution_queue()
|
||||||
await exec_queue.publish_message(
|
await exec_queue.publish_message(
|
||||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||||
@@ -899,14 +901,12 @@ async def add_graph_execution(
|
|||||||
)
|
)
|
||||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||||
|
|
||||||
|
# Update execution status to QUEUED
|
||||||
graph_exec.status = ExecutionStatus.QUEUED
|
graph_exec.status = ExecutionStatus.QUEUED
|
||||||
await edb.update_graph_execution_stats(
|
await edb.update_graph_execution_stats(
|
||||||
graph_exec_id=graph_exec.id,
|
graph_exec_id=graph_exec.id,
|
||||||
status=graph_exec.status,
|
status=graph_exec.status,
|
||||||
)
|
)
|
||||||
await get_async_execution_event_bus().publish(graph_exec)
|
|
||||||
|
|
||||||
return graph_exec
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
err = str(e) or type(e).__name__
|
err = str(e) or type(e).__name__
|
||||||
if not graph_exec:
|
if not graph_exec:
|
||||||
@@ -927,6 +927,24 @@ async def add_graph_execution(
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
await get_async_execution_event_bus().publish(graph_exec)
|
||||||
|
logger.info(f"Published update for execution #{graph_exec.id} to event bus")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to publish execution event for graph exec #{graph_exec.id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await odb.increment_onboarding_runs(user_id)
|
||||||
|
logger.info(
|
||||||
|
f"Incremented user #{user_id} onboarding runs for exec #{graph_exec.id}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to increment onboarding runs for user #{user_id}: {e}")
|
||||||
|
|
||||||
|
return graph_exec
|
||||||
|
|
||||||
|
|
||||||
# ============ Execution Output Helpers ============ #
|
# ============ Execution Output Helpers ============ #
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from backend.util.settings import Settings
|
|||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
from supabase import AClient, Client
|
from supabase import AClient, Client
|
||||||
|
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
@@ -139,6 +140,24 @@ async def get_async_supabase() -> "AClient":
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ OpenAI Client ============ #
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_openai_client() -> "AsyncOpenAI | None":
|
||||||
|
"""
|
||||||
|
Get a process-cached async OpenAI client for embeddings.
|
||||||
|
|
||||||
|
Returns None if API key is not configured.
|
||||||
|
"""
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
api_key = settings.secrets.openai_internal_api_key
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
return AsyncOpenAI(api_key=api_key)
|
||||||
|
|
||||||
|
|
||||||
# ============ Notification Queue Helpers ============ #
|
# ============ Notification Queue Helpers ============ #
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
-- CreateExtension
|
||||||
|
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
||||||
|
-- Create in public schema so vector type is available across all schemas
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "vector" WITH SCHEMA "public";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
-- CreateEnum
|
||||||
|
CREATE TYPE "ContentType" AS ENUM ('STORE_AGENT', 'BLOCK', 'INTEGRATION', 'DOCUMENTATION', 'LIBRARY_AGENT');
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "UnifiedContentEmbedding" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
"contentType" "ContentType" NOT NULL,
|
||||||
|
"contentId" TEXT NOT NULL,
|
||||||
|
"userId" TEXT,
|
||||||
|
"embedding" public.vector(1536) NOT NULL,
|
||||||
|
"searchableText" TEXT NOT NULL,
|
||||||
|
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||||
|
|
||||||
|
CONSTRAINT "UnifiedContentEmbedding_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "UnifiedContentEmbedding_contentType_idx" ON "UnifiedContentEmbedding"("contentType");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "UnifiedContentEmbedding_userId_idx" ON "UnifiedContentEmbedding"("userId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "UnifiedContentEmbedding_contentType_userId_idx" ON "UnifiedContentEmbedding"("contentType", "userId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
-- NULLS NOT DISTINCT ensures only one public (NULL userId) embedding per contentType+contentId
|
||||||
|
-- Requires PostgreSQL 15+. Supabase uses PostgreSQL 15+.
|
||||||
|
CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" ON "UnifiedContentEmbedding"("contentType", "contentId", "userId") NULLS NOT DISTINCT;
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
-- HNSW index for fast vector similarity search on embeddings
|
||||||
|
-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py
|
||||||
|
CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops);
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
-- Acknowledge Supabase-managed extensions to prevent drift warnings
|
||||||
|
-- These extensions are pre-installed by Supabase in specific schemas
|
||||||
|
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
|
||||||
|
|
||||||
|
-- Create schemas (safe in both CI and Supabase)
|
||||||
|
CREATE SCHEMA IF NOT EXISTS "extensions";
|
||||||
|
|
||||||
|
-- Extensions that exist in both CI and Supabase
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pgcrypto extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'uuid-ossp extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
-- Supabase-specific extensions (skip gracefully in CI)
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pg_net extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pgjwt extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE SCHEMA IF NOT EXISTS "graphql";
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pg_graphql extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE SCHEMA IF NOT EXISTS "pgsodium";
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pgsodium extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE SCHEMA IF NOT EXISTS "vault";
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'supabase_vault extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
|
||||||
|
-- Return to platform
|
||||||
|
CREATE SCHEMA IF NOT EXISTS "platform";
|
||||||
@@ -1,14 +1,15 @@
|
|||||||
datasource db {
|
datasource db {
|
||||||
provider = "postgresql"
|
provider = "postgresql"
|
||||||
url = env("DATABASE_URL")
|
url = env("DATABASE_URL")
|
||||||
directUrl = env("DIRECT_URL")
|
directUrl = env("DIRECT_URL")
|
||||||
|
extensions = [pgvector(map: "vector")]
|
||||||
}
|
}
|
||||||
|
|
||||||
generator client {
|
generator client {
|
||||||
provider = "prisma-client-py"
|
provider = "prisma-client-py"
|
||||||
recursive_type_depth = -1
|
recursive_type_depth = -1
|
||||||
interface = "asyncio"
|
interface = "asyncio"
|
||||||
previewFeatures = ["views", "fullTextSearch"]
|
previewFeatures = ["views", "fullTextSearch", "postgresqlExtensions"]
|
||||||
partial_type_generator = "backend/data/partial_types.py"
|
partial_type_generator = "backend/data/partial_types.py"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,8 +128,8 @@ model BuilderSearchHistory {
|
|||||||
updatedAt DateTime @default(now()) @updatedAt
|
updatedAt DateTime @default(now()) @updatedAt
|
||||||
|
|
||||||
searchQuery String
|
searchQuery String
|
||||||
filter String[] @default([])
|
filter String[] @default([])
|
||||||
byCreator String[] @default([])
|
byCreator String[] @default([])
|
||||||
|
|
||||||
userId String
|
userId String
|
||||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
@@ -721,26 +722,25 @@ view StoreAgent {
|
|||||||
storeListingVersionId String
|
storeListingVersionId String
|
||||||
updated_at DateTime
|
updated_at DateTime
|
||||||
|
|
||||||
slug String
|
slug String
|
||||||
agent_name String
|
agent_name String
|
||||||
agent_video String?
|
agent_video String?
|
||||||
agent_output_demo String?
|
agent_output_demo String?
|
||||||
agent_image String[]
|
agent_image String[]
|
||||||
|
|
||||||
featured Boolean @default(false)
|
featured Boolean @default(false)
|
||||||
creator_username String?
|
creator_username String?
|
||||||
creator_avatar String?
|
creator_avatar String?
|
||||||
sub_heading String
|
sub_heading String
|
||||||
description String
|
description String
|
||||||
categories String[]
|
categories String[]
|
||||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
runs Int
|
||||||
runs Int
|
rating Float
|
||||||
rating Float
|
versions String[]
|
||||||
versions String[]
|
agentGraphVersions String[]
|
||||||
agentGraphVersions String[]
|
agentGraphId String
|
||||||
agentGraphId String
|
is_available Boolean @default(true)
|
||||||
is_available Boolean @default(true)
|
useForOnboarding Boolean @default(false)
|
||||||
useForOnboarding Boolean @default(false)
|
|
||||||
|
|
||||||
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
||||||
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
||||||
@@ -856,14 +856,14 @@ model StoreListingVersion {
|
|||||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
|
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
|
||||||
|
|
||||||
// Content fields
|
// Content fields
|
||||||
name String
|
name String
|
||||||
subHeading String
|
subHeading String
|
||||||
videoUrl String?
|
videoUrl String?
|
||||||
agentOutputDemoUrl String?
|
agentOutputDemoUrl String?
|
||||||
imageUrls String[]
|
imageUrls String[]
|
||||||
description String
|
description String
|
||||||
instructions String?
|
instructions String?
|
||||||
categories String[]
|
categories String[]
|
||||||
|
|
||||||
isFeatured Boolean @default(false)
|
isFeatured Boolean @default(false)
|
||||||
|
|
||||||
@@ -899,6 +899,9 @@ model StoreListingVersion {
|
|||||||
// Reviews for this specific version
|
// Reviews for this specific version
|
||||||
Reviews StoreListingReview[]
|
Reviews StoreListingReview[]
|
||||||
|
|
||||||
|
// Note: Embeddings now stored in UnifiedContentEmbedding table
|
||||||
|
// Use contentType=STORE_AGENT and contentId=storeListingVersionId
|
||||||
|
|
||||||
@@unique([storeListingId, version])
|
@@unique([storeListingId, version])
|
||||||
@@index([storeListingId, submissionStatus, isAvailable])
|
@@index([storeListingId, submissionStatus, isAvailable])
|
||||||
@@index([submissionStatus])
|
@@index([submissionStatus])
|
||||||
@@ -906,6 +909,42 @@ model StoreListingVersion {
|
|||||||
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
|
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Content type enum for unified search across store agents, blocks, docs
|
||||||
|
// Note: BLOCK/INTEGRATION are file-based (Python classes), not DB records
|
||||||
|
// DOCUMENTATION are file-based (.md files), not DB records
|
||||||
|
// Only STORE_AGENT and LIBRARY_AGENT are stored in database
|
||||||
|
enum ContentType {
|
||||||
|
STORE_AGENT // Database: StoreListingVersion
|
||||||
|
BLOCK // File-based: Python classes in /backend/blocks/
|
||||||
|
INTEGRATION // File-based: Python classes (blocks with credentials)
|
||||||
|
DOCUMENTATION // File-based: .md/.mdx files
|
||||||
|
LIBRARY_AGENT // Database: User's personal agents
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unified embeddings table for all searchable content types
|
||||||
|
// Supports both public content (userId=null) and user-specific content (userId=userID)
|
||||||
|
model UnifiedContentEmbedding {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
// Content identification
|
||||||
|
contentType ContentType
|
||||||
|
contentId String // DB ID (storeListingVersionId) or file identifier (block.id, file_path)
|
||||||
|
userId String? // NULL for public content (store, blocks, docs), userId for private content (library agents)
|
||||||
|
|
||||||
|
// Search data
|
||||||
|
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
|
||||||
|
searchableText String // Combined text for search and fallback
|
||||||
|
metadata Json @default("{}") // Content-specific metadata
|
||||||
|
|
||||||
|
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
|
||||||
|
@@index([contentType])
|
||||||
|
@@index([userId])
|
||||||
|
@@index([contentType, userId])
|
||||||
|
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
|
||||||
|
}
|
||||||
|
|
||||||
model StoreListingReview {
|
model StoreListingReview {
|
||||||
id String @id @default(uuid())
|
id String @id @default(uuid())
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
@@ -998,16 +1037,16 @@ model OAuthApplication {
|
|||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
// Application metadata
|
// Application metadata
|
||||||
name String
|
name String
|
||||||
description String?
|
description String?
|
||||||
logoUrl String? // URL to app logo stored in GCS
|
logoUrl String? // URL to app logo stored in GCS
|
||||||
clientId String @unique
|
clientId String @unique
|
||||||
clientSecret String // Hashed with Scrypt (same as API keys)
|
clientSecret String // Hashed with Scrypt (same as API keys)
|
||||||
clientSecretSalt String // Salt for Scrypt hashing
|
clientSecretSalt String // Salt for Scrypt hashing
|
||||||
|
|
||||||
// OAuth configuration
|
// OAuth configuration
|
||||||
redirectUris String[] // Allowed callback URLs
|
redirectUris String[] // Allowed callback URLs
|
||||||
grantTypes String[] @default(["authorization_code", "refresh_token"])
|
grantTypes String[] @default(["authorization_code", "refresh_token"])
|
||||||
scopes APIKeyPermission[] // Which permissions the app can request
|
scopes APIKeyPermission[] // Which permissions the app can request
|
||||||
|
|
||||||
// Application management
|
// Application management
|
||||||
|
|||||||
@@ -81,18 +81,16 @@ export const RunInputDialog = ({
|
|||||||
Inputs
|
Inputs
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
<div className="px-2">
|
<FormRenderer
|
||||||
<FormRenderer
|
jsonSchema={inputSchema as RJSFSchema}
|
||||||
jsonSchema={inputSchema as RJSFSchema}
|
handleChange={(v) => handleInputChange(v.formData)}
|
||||||
handleChange={(v) => handleInputChange(v.formData)}
|
uiSchema={uiSchema}
|
||||||
uiSchema={uiSchema}
|
initialValues={{}}
|
||||||
initialValues={{}}
|
formContext={{
|
||||||
formContext={{
|
showHandles: false,
|
||||||
showHandles: false,
|
size: "large",
|
||||||
size: "large",
|
}}
|
||||||
}}
|
/>
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import { useGetV2GetSpecificBlocks } from "@/app/api/__generated__/endpoints/def
|
|||||||
import {
|
import {
|
||||||
useGetV1GetExecutionDetails,
|
useGetV1GetExecutionDetails,
|
||||||
useGetV1GetSpecificGraph,
|
useGetV1GetSpecificGraph,
|
||||||
|
useGetV1ListUserGraphs,
|
||||||
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||||
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||||
@@ -17,6 +18,7 @@ import { useReactFlow } from "@xyflow/react";
|
|||||||
import { useControlPanelStore } from "../../../stores/controlPanelStore";
|
import { useControlPanelStore } from "../../../stores/controlPanelStore";
|
||||||
import { useHistoryStore } from "../../../stores/historyStore";
|
import { useHistoryStore } from "../../../stores/historyStore";
|
||||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||||
|
import { okData } from "@/app/api/helpers";
|
||||||
|
|
||||||
export const useFlow = () => {
|
export const useFlow = () => {
|
||||||
const [isLocked, setIsLocked] = useState(false);
|
const [isLocked, setIsLocked] = useState(false);
|
||||||
@@ -36,6 +38,9 @@ export const useFlow = () => {
|
|||||||
const setGraphExecutionStatus = useGraphStore(
|
const setGraphExecutionStatus = useGraphStore(
|
||||||
useShallow((state) => state.setGraphExecutionStatus),
|
useShallow((state) => state.setGraphExecutionStatus),
|
||||||
);
|
);
|
||||||
|
const setAvailableSubGraphs = useGraphStore(
|
||||||
|
useShallow((state) => state.setAvailableSubGraphs),
|
||||||
|
);
|
||||||
const updateEdgeBeads = useEdgeStore(
|
const updateEdgeBeads = useEdgeStore(
|
||||||
useShallow((state) => state.updateEdgeBeads),
|
useShallow((state) => state.updateEdgeBeads),
|
||||||
);
|
);
|
||||||
@@ -62,6 +67,11 @@ export const useFlow = () => {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Fetch all available graphs for sub-agent update detection
|
||||||
|
const { data: availableGraphs } = useGetV1ListUserGraphs({
|
||||||
|
query: { select: okData },
|
||||||
|
});
|
||||||
|
|
||||||
const { data: graph, isLoading: isGraphLoading } = useGetV1GetSpecificGraph(
|
const { data: graph, isLoading: isGraphLoading } = useGetV1GetSpecificGraph(
|
||||||
flowID ?? "",
|
flowID ?? "",
|
||||||
flowVersion !== null ? { version: flowVersion } : {},
|
flowVersion !== null ? { version: flowVersion } : {},
|
||||||
@@ -116,10 +126,18 @@ export const useFlow = () => {
|
|||||||
}
|
}
|
||||||
}, [graph]);
|
}, [graph]);
|
||||||
|
|
||||||
|
// Update available sub-graphs in store for sub-agent update detection
|
||||||
|
useEffect(() => {
|
||||||
|
if (availableGraphs) {
|
||||||
|
setAvailableSubGraphs(availableGraphs);
|
||||||
|
}
|
||||||
|
}, [availableGraphs, setAvailableSubGraphs]);
|
||||||
|
|
||||||
// adding nodes
|
// adding nodes
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (customNodes.length > 0) {
|
if (customNodes.length > 0) {
|
||||||
useNodeStore.getState().setNodes([]);
|
useNodeStore.getState().setNodes([]);
|
||||||
|
useNodeStore.getState().clearResolutionState();
|
||||||
addNodes(customNodes);
|
addNodes(customNodes);
|
||||||
|
|
||||||
// Sync hardcoded values with handle IDs.
|
// Sync hardcoded values with handle IDs.
|
||||||
@@ -203,6 +221,7 @@ export const useFlow = () => {
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
return () => {
|
return () => {
|
||||||
useNodeStore.getState().setNodes([]);
|
useNodeStore.getState().setNodes([]);
|
||||||
|
useNodeStore.getState().clearResolutionState();
|
||||||
useEdgeStore.getState().setEdges([]);
|
useEdgeStore.getState().setEdges([]);
|
||||||
useGraphStore.getState().reset();
|
useGraphStore.getState().reset();
|
||||||
useEdgeStore.getState().resetEdgeBeads();
|
useEdgeStore.getState().resetEdgeBeads();
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import {
|
|||||||
getBezierPath,
|
getBezierPath,
|
||||||
} from "@xyflow/react";
|
} from "@xyflow/react";
|
||||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||||
|
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||||
import { XIcon } from "@phosphor-icons/react";
|
import { XIcon } from "@phosphor-icons/react";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { NodeExecutionResult } from "@/lib/autogpt-server-api";
|
import { NodeExecutionResult } from "@/lib/autogpt-server-api";
|
||||||
@@ -35,6 +36,8 @@ const CustomEdge = ({
|
|||||||
selected,
|
selected,
|
||||||
}: EdgeProps<CustomEdge>) => {
|
}: EdgeProps<CustomEdge>) => {
|
||||||
const removeConnection = useEdgeStore((state) => state.removeEdge);
|
const removeConnection = useEdgeStore((state) => state.removeEdge);
|
||||||
|
// Subscribe to the brokenEdgeIDs map and check if this edge is broken across any node
|
||||||
|
const isBroken = useNodeStore((state) => state.isEdgeBroken(id));
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
const [edgePath, labelX, labelY] = getBezierPath({
|
const [edgePath, labelX, labelY] = getBezierPath({
|
||||||
@@ -50,6 +53,12 @@ const CustomEdge = ({
|
|||||||
const beadUp = data?.beadUp ?? 0;
|
const beadUp = data?.beadUp ?? 0;
|
||||||
const beadDown = data?.beadDown ?? 0;
|
const beadDown = data?.beadDown ?? 0;
|
||||||
|
|
||||||
|
const handleRemoveEdge = () => {
|
||||||
|
removeConnection(id);
|
||||||
|
// Note: broken edge tracking is cleaned up automatically by useSubAgentUpdateState
|
||||||
|
// when it detects the edge no longer exists
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<BaseEdge
|
<BaseEdge
|
||||||
@@ -57,9 +66,11 @@ const CustomEdge = ({
|
|||||||
markerEnd={markerEnd}
|
markerEnd={markerEnd}
|
||||||
className={cn(
|
className={cn(
|
||||||
isStatic && "!stroke-[1.5px] [stroke-dasharray:6]",
|
isStatic && "!stroke-[1.5px] [stroke-dasharray:6]",
|
||||||
selected
|
isBroken
|
||||||
? "stroke-zinc-800"
|
? "!stroke-red-500 !stroke-[2px] [stroke-dasharray:4]"
|
||||||
: "stroke-zinc-500/50 hover:stroke-zinc-500",
|
: selected
|
||||||
|
? "stroke-zinc-800"
|
||||||
|
: "stroke-zinc-500/50 hover:stroke-zinc-500",
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
<JSBeads
|
<JSBeads
|
||||||
@@ -70,12 +81,16 @@ const CustomEdge = ({
|
|||||||
/>
|
/>
|
||||||
<EdgeLabelRenderer>
|
<EdgeLabelRenderer>
|
||||||
<Button
|
<Button
|
||||||
onClick={() => removeConnection(id)}
|
onClick={handleRemoveEdge}
|
||||||
className={cn(
|
className={cn(
|
||||||
"absolute h-fit min-w-0 p-1 transition-opacity",
|
"absolute h-fit min-w-0 p-1 transition-opacity",
|
||||||
isHovered ? "opacity-100" : "opacity-0",
|
isBroken
|
||||||
|
? "bg-red-500 opacity-100 hover:bg-red-600"
|
||||||
|
: isHovered
|
||||||
|
? "opacity-100"
|
||||||
|
: "opacity-0",
|
||||||
)}
|
)}
|
||||||
variant="secondary"
|
variant={isBroken ? "primary" : "secondary"}
|
||||||
style={{
|
style={{
|
||||||
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
||||||
pointerEvents: "all",
|
pointerEvents: "all",
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import { Handle, Position } from "@xyflow/react";
|
|||||||
import { useEdgeStore } from "../../../stores/edgeStore";
|
import { useEdgeStore } from "../../../stores/edgeStore";
|
||||||
import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers";
|
import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
import { useNodeStore } from "../../../stores/nodeStore";
|
||||||
|
|
||||||
const InputNodeHandle = ({
|
const InputNodeHandle = ({
|
||||||
handleId,
|
handleId,
|
||||||
@@ -15,6 +16,9 @@ const InputNodeHandle = ({
|
|||||||
const isInputConnected = useEdgeStore((state) =>
|
const isInputConnected = useEdgeStore((state) =>
|
||||||
state.isInputConnected(nodeId ?? "", cleanedHandleId),
|
state.isInputConnected(nodeId ?? "", cleanedHandleId),
|
||||||
);
|
);
|
||||||
|
const isInputBroken = useNodeStore((state) =>
|
||||||
|
state.isInputBroken(nodeId, cleanedHandleId),
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Handle
|
<Handle
|
||||||
@@ -27,7 +31,10 @@ const InputNodeHandle = ({
|
|||||||
<CircleIcon
|
<CircleIcon
|
||||||
size={16}
|
size={16}
|
||||||
weight={isInputConnected ? "fill" : "duotone"}
|
weight={isInputConnected ? "fill" : "duotone"}
|
||||||
className={"text-gray-400 opacity-100"}
|
className={cn(
|
||||||
|
"text-gray-400 opacity-100",
|
||||||
|
isInputBroken && "text-red-500",
|
||||||
|
)}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</Handle>
|
</Handle>
|
||||||
@@ -38,14 +45,17 @@ const OutputNodeHandle = ({
|
|||||||
field_name,
|
field_name,
|
||||||
nodeId,
|
nodeId,
|
||||||
hexColor,
|
hexColor,
|
||||||
|
isBroken,
|
||||||
}: {
|
}: {
|
||||||
field_name: string;
|
field_name: string;
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
hexColor: string;
|
hexColor: string;
|
||||||
|
isBroken: boolean;
|
||||||
}) => {
|
}) => {
|
||||||
const isOutputConnected = useEdgeStore((state) =>
|
const isOutputConnected = useEdgeStore((state) =>
|
||||||
state.isOutputConnected(nodeId, field_name),
|
state.isOutputConnected(nodeId, field_name),
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Handle
|
<Handle
|
||||||
type={"source"}
|
type={"source"}
|
||||||
@@ -58,7 +68,10 @@ const OutputNodeHandle = ({
|
|||||||
size={16}
|
size={16}
|
||||||
weight={"duotone"}
|
weight={"duotone"}
|
||||||
color={isOutputConnected ? hexColor : "gray"}
|
color={isOutputConnected ? hexColor : "gray"}
|
||||||
className={cn("text-gray-400 opacity-100")}
|
className={cn(
|
||||||
|
"text-gray-400 opacity-100",
|
||||||
|
isBroken && "text-red-500",
|
||||||
|
)}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</Handle>
|
</Handle>
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput";
|
|||||||
import { NodeRightClickMenu } from "./components/NodeRightClickMenu";
|
import { NodeRightClickMenu } from "./components/NodeRightClickMenu";
|
||||||
import { StickyNoteBlock } from "./components/StickyNoteBlock";
|
import { StickyNoteBlock } from "./components/StickyNoteBlock";
|
||||||
import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
|
import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
|
||||||
|
import { SubAgentUpdateFeature } from "./components/SubAgentUpdate/SubAgentUpdateFeature";
|
||||||
|
import { useCustomNode } from "./useCustomNode";
|
||||||
|
|
||||||
export type CustomNodeData = {
|
export type CustomNodeData = {
|
||||||
hardcodedValues: {
|
hardcodedValues: {
|
||||||
@@ -45,6 +47,10 @@ export type CustomNode = XYNode<CustomNodeData, "custom">;
|
|||||||
|
|
||||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||||
({ data, id: nodeId, selected }) => {
|
({ data, id: nodeId, selected }) => {
|
||||||
|
const { inputSchema, outputSchema } = useCustomNode({ data, nodeId });
|
||||||
|
|
||||||
|
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||||
|
|
||||||
if (data.uiType === BlockUIType.NOTE) {
|
if (data.uiType === BlockUIType.NOTE) {
|
||||||
return (
|
return (
|
||||||
<StickyNoteBlock data={data} selected={selected} nodeId={nodeId} />
|
<StickyNoteBlock data={data} selected={selected} nodeId={nodeId} />
|
||||||
@@ -63,16 +69,6 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
|||||||
|
|
||||||
const isAyrshare = data.uiType === BlockUIType.AYRSHARE;
|
const isAyrshare = data.uiType === BlockUIType.AYRSHARE;
|
||||||
|
|
||||||
const inputSchema =
|
|
||||||
data.uiType === BlockUIType.AGENT
|
|
||||||
? (data.hardcodedValues.input_schema ?? {})
|
|
||||||
: data.inputSchema;
|
|
||||||
|
|
||||||
const outputSchema =
|
|
||||||
data.uiType === BlockUIType.AGENT
|
|
||||||
? (data.hardcodedValues.output_schema ?? {})
|
|
||||||
: data.outputSchema;
|
|
||||||
|
|
||||||
const hasConfigErrors =
|
const hasConfigErrors =
|
||||||
data.errors &&
|
data.errors &&
|
||||||
Object.values(data.errors).some(
|
Object.values(data.errors).some(
|
||||||
@@ -87,12 +83,11 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
|||||||
|
|
||||||
const hasErrors = hasConfigErrors || hasOutputError;
|
const hasErrors = hasConfigErrors || hasOutputError;
|
||||||
|
|
||||||
// Currently all blockTypes design are similar - that's why i am using the same component for all of them
|
|
||||||
// If in future - if we need some drastic change in some blockTypes design - we can create separate components for them
|
|
||||||
const node = (
|
const node = (
|
||||||
<NodeContainer selected={selected} nodeId={nodeId} hasErrors={hasErrors}>
|
<NodeContainer selected={selected} nodeId={nodeId} hasErrors={hasErrors}>
|
||||||
<div className="rounded-xlarge bg-white">
|
<div className="rounded-xlarge bg-white">
|
||||||
<NodeHeader data={data} nodeId={nodeId} />
|
<NodeHeader data={data} nodeId={nodeId} />
|
||||||
|
{isAgent && <SubAgentUpdateFeature nodeID={nodeId} nodeData={data} />}
|
||||||
{isWebhook && <WebhookDisclaimer nodeId={nodeId} />}
|
{isWebhook && <WebhookDisclaimer nodeId={nodeId} />}
|
||||||
{isAyrshare && <AyrshareConnectButton />}
|
{isAyrshare && <AyrshareConnectButton />}
|
||||||
<FormCreator
|
<FormCreator
|
||||||
|
|||||||
@@ -0,0 +1,118 @@
|
|||||||
|
import React from "react";
|
||||||
|
import { ArrowUpIcon, WarningIcon } from "@phosphor-icons/react";
|
||||||
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
import {
|
||||||
|
Tooltip,
|
||||||
|
TooltipContent,
|
||||||
|
TooltipTrigger,
|
||||||
|
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||||
|
import { cn, beautifyString } from "@/lib/utils";
|
||||||
|
import { CustomNodeData } from "../../CustomNode";
|
||||||
|
import { useSubAgentUpdateState } from "./useSubAgentUpdateState";
|
||||||
|
import { IncompatibleUpdateDialog } from "./components/IncompatibleUpdateDialog";
|
||||||
|
import { ResolutionModeBar } from "./components/ResolutionModeBar";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inline component for the update bar that can be placed after the header.
|
||||||
|
* Use this inside the node content where you want the bar to appear.
|
||||||
|
*/
|
||||||
|
type SubAgentUpdateFeatureProps = {
|
||||||
|
nodeID: string;
|
||||||
|
nodeData: CustomNodeData;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function SubAgentUpdateFeature({
|
||||||
|
nodeID,
|
||||||
|
nodeData,
|
||||||
|
}: SubAgentUpdateFeatureProps) {
|
||||||
|
const {
|
||||||
|
updateInfo,
|
||||||
|
isInResolutionMode,
|
||||||
|
handleUpdateClick,
|
||||||
|
showIncompatibilityDialog,
|
||||||
|
setShowIncompatibilityDialog,
|
||||||
|
handleConfirmIncompatibleUpdate,
|
||||||
|
} = useSubAgentUpdateState({ nodeID: nodeID, nodeData: nodeData });
|
||||||
|
|
||||||
|
const agentName = nodeData.title || "Agent";
|
||||||
|
|
||||||
|
if (!updateInfo.hasUpdate && !isInResolutionMode) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{isInResolutionMode ? (
|
||||||
|
<ResolutionModeBar incompatibilities={updateInfo.incompatibilities} />
|
||||||
|
) : (
|
||||||
|
<SubAgentUpdateAvailableBar
|
||||||
|
currentVersion={updateInfo.currentVersion}
|
||||||
|
latestVersion={updateInfo.latestVersion}
|
||||||
|
isCompatible={updateInfo.isCompatible}
|
||||||
|
onUpdate={handleUpdateClick}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{/* Incompatibility dialog - rendered here since this component owns the state */}
|
||||||
|
{updateInfo.incompatibilities && (
|
||||||
|
<IncompatibleUpdateDialog
|
||||||
|
isOpen={showIncompatibilityDialog}
|
||||||
|
onClose={() => setShowIncompatibilityDialog(false)}
|
||||||
|
onConfirm={handleConfirmIncompatibleUpdate}
|
||||||
|
currentVersion={updateInfo.currentVersion}
|
||||||
|
latestVersion={updateInfo.latestVersion}
|
||||||
|
agentName={beautifyString(agentName)}
|
||||||
|
incompatibilities={updateInfo.incompatibilities}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
type SubAgentUpdateAvailableBarProps = {
|
||||||
|
currentVersion: number;
|
||||||
|
latestVersion: number;
|
||||||
|
isCompatible: boolean;
|
||||||
|
onUpdate: () => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
function SubAgentUpdateAvailableBar({
|
||||||
|
currentVersion,
|
||||||
|
latestVersion,
|
||||||
|
isCompatible,
|
||||||
|
onUpdate,
|
||||||
|
}: SubAgentUpdateAvailableBarProps): React.ReactElement {
|
||||||
|
return (
|
||||||
|
<div className="flex items-center justify-between gap-2 rounded-t-xl bg-blue-50 px-3 py-2 dark:bg-blue-900/30">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<ArrowUpIcon className="h-4 w-4 text-blue-600 dark:text-blue-400" />
|
||||||
|
<span className="text-sm text-blue-700 dark:text-blue-300">
|
||||||
|
Update available (v{currentVersion} → v{latestVersion})
|
||||||
|
</span>
|
||||||
|
{!isCompatible && (
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<WarningIcon className="h-4 w-4 text-amber-500" />
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent className="max-w-xs">
|
||||||
|
<p className="font-medium">Incompatible changes detected</p>
|
||||||
|
<p className="text-xs text-gray-400">
|
||||||
|
Click Update to see details
|
||||||
|
</p>
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<Button
|
||||||
|
size="small"
|
||||||
|
variant={isCompatible ? "primary" : "outline"}
|
||||||
|
onClick={onUpdate}
|
||||||
|
className={cn(
|
||||||
|
"h-7 text-xs",
|
||||||
|
!isCompatible && "border-amber-500 text-amber-600 hover:bg-amber-50",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
Update
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,274 @@
|
|||||||
|
import React from "react";
|
||||||
|
import {
|
||||||
|
WarningIcon,
|
||||||
|
XCircleIcon,
|
||||||
|
PlusCircleIcon,
|
||||||
|
} from "@phosphor-icons/react";
|
||||||
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||||
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
|
import { beautifyString } from "@/lib/utils";
|
||||||
|
import { IncompatibilityInfo } from "@/app/(platform)/build/hooks/useSubAgentUpdate/types";
|
||||||
|
|
||||||
|
type IncompatibleUpdateDialogProps = {
|
||||||
|
isOpen: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
onConfirm: () => void;
|
||||||
|
currentVersion: number;
|
||||||
|
latestVersion: number;
|
||||||
|
agentName: string;
|
||||||
|
incompatibilities: IncompatibilityInfo;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function IncompatibleUpdateDialog({
|
||||||
|
isOpen,
|
||||||
|
onClose,
|
||||||
|
onConfirm,
|
||||||
|
currentVersion,
|
||||||
|
latestVersion,
|
||||||
|
agentName,
|
||||||
|
incompatibilities,
|
||||||
|
}: IncompatibleUpdateDialogProps) {
|
||||||
|
const hasMissingInputs = incompatibilities.missingInputs.length > 0;
|
||||||
|
const hasMissingOutputs = incompatibilities.missingOutputs.length > 0;
|
||||||
|
const hasNewInputs = incompatibilities.newInputs.length > 0;
|
||||||
|
const hasNewOutputs = incompatibilities.newOutputs.length > 0;
|
||||||
|
const hasNewRequired = incompatibilities.newRequiredInputs.length > 0;
|
||||||
|
const hasTypeMismatches = incompatibilities.inputTypeMismatches.length > 0;
|
||||||
|
|
||||||
|
const hasInputChanges = hasMissingInputs || hasNewInputs;
|
||||||
|
const hasOutputChanges = hasMissingOutputs || hasNewOutputs;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Dialog
|
||||||
|
title={
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<WarningIcon className="h-5 w-5 text-amber-500" weight="fill" />
|
||||||
|
Incompatible Update
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
controlled={{
|
||||||
|
isOpen,
|
||||||
|
set: async (open) => {
|
||||||
|
if (!open) onClose();
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
onClose={onClose}
|
||||||
|
styling={{ maxWidth: "32rem" }}
|
||||||
|
>
|
||||||
|
<Dialog.Content>
|
||||||
|
<div className="space-y-4">
|
||||||
|
<p className="text-sm text-gray-600 dark:text-gray-400">
|
||||||
|
Updating <strong>{beautifyString(agentName)}</strong> from v
|
||||||
|
{currentVersion} to v{latestVersion} will break some connections.
|
||||||
|
</p>
|
||||||
|
|
||||||
|
{/* Input changes - two column layout */}
|
||||||
|
{hasInputChanges && (
|
||||||
|
<TwoColumnSection
|
||||||
|
title="Input Changes"
|
||||||
|
leftIcon={
|
||||||
|
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
||||||
|
}
|
||||||
|
leftTitle="Removed"
|
||||||
|
leftItems={incompatibilities.missingInputs}
|
||||||
|
rightIcon={
|
||||||
|
<PlusCircleIcon
|
||||||
|
className="h-4 w-4 text-green-500"
|
||||||
|
weight="fill"
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
rightTitle="Added"
|
||||||
|
rightItems={incompatibilities.newInputs}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Output changes - two column layout */}
|
||||||
|
{hasOutputChanges && (
|
||||||
|
<TwoColumnSection
|
||||||
|
title="Output Changes"
|
||||||
|
leftIcon={
|
||||||
|
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
||||||
|
}
|
||||||
|
leftTitle="Removed"
|
||||||
|
leftItems={incompatibilities.missingOutputs}
|
||||||
|
rightIcon={
|
||||||
|
<PlusCircleIcon
|
||||||
|
className="h-4 w-4 text-green-500"
|
||||||
|
weight="fill"
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
rightTitle="Added"
|
||||||
|
rightItems={incompatibilities.newOutputs}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{hasTypeMismatches && (
|
||||||
|
<SingleColumnSection
|
||||||
|
icon={
|
||||||
|
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
||||||
|
}
|
||||||
|
title="Type Changed"
|
||||||
|
description="These connected inputs have a different type:"
|
||||||
|
items={incompatibilities.inputTypeMismatches.map(
|
||||||
|
(m) => `${m.name} (${m.oldType} → ${m.newType})`,
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{hasNewRequired && (
|
||||||
|
<SingleColumnSection
|
||||||
|
icon={
|
||||||
|
<PlusCircleIcon
|
||||||
|
className="h-4 w-4 text-amber-500"
|
||||||
|
weight="fill"
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
title="New Required Inputs"
|
||||||
|
description="These inputs are now required:"
|
||||||
|
items={incompatibilities.newRequiredInputs}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<Alert variant="warning">
|
||||||
|
<AlertDescription>
|
||||||
|
If you proceed, you'll need to remove the broken connections
|
||||||
|
before you can save or run your agent.
|
||||||
|
</AlertDescription>
|
||||||
|
</Alert>
|
||||||
|
|
||||||
|
<Dialog.Footer>
|
||||||
|
<Button variant="ghost" size="small" onClick={onClose}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="primary"
|
||||||
|
size="small"
|
||||||
|
onClick={onConfirm}
|
||||||
|
className="border-amber-700 bg-amber-600 hover:bg-amber-700"
|
||||||
|
>
|
||||||
|
Update Anyway
|
||||||
|
</Button>
|
||||||
|
</Dialog.Footer>
|
||||||
|
</div>
|
||||||
|
</Dialog.Content>
|
||||||
|
</Dialog>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
type TwoColumnSectionProps = {
|
||||||
|
title: string;
|
||||||
|
leftIcon: React.ReactNode;
|
||||||
|
leftTitle: string;
|
||||||
|
leftItems: string[];
|
||||||
|
rightIcon: React.ReactNode;
|
||||||
|
rightTitle: string;
|
||||||
|
rightItems: string[];
|
||||||
|
};
|
||||||
|
|
||||||
|
function TwoColumnSection({
|
||||||
|
title,
|
||||||
|
leftIcon,
|
||||||
|
leftTitle,
|
||||||
|
leftItems,
|
||||||
|
rightIcon,
|
||||||
|
rightTitle,
|
||||||
|
rightItems,
|
||||||
|
}: TwoColumnSectionProps) {
|
||||||
|
return (
|
||||||
|
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||||
|
<span className="font-medium">{title}</span>
|
||||||
|
<div className="mt-2 grid grid-cols-2 items-start gap-4">
|
||||||
|
{/* Left column - Breaking changes */}
|
||||||
|
<div className="min-w-0">
|
||||||
|
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{leftIcon}
|
||||||
|
<span>{leftTitle}</span>
|
||||||
|
</div>
|
||||||
|
<ul className="mt-1.5 space-y-1">
|
||||||
|
{leftItems.length > 0 ? (
|
||||||
|
leftItems.map((item) => (
|
||||||
|
<li
|
||||||
|
key={item}
|
||||||
|
className="text-sm text-gray-700 dark:text-gray-300"
|
||||||
|
>
|
||||||
|
<code className="rounded bg-red-50 px-1 py-0.5 font-mono text-xs text-red-700 dark:bg-red-900/30 dark:text-red-300">
|
||||||
|
{item}
|
||||||
|
</code>
|
||||||
|
</li>
|
||||||
|
))
|
||||||
|
) : (
|
||||||
|
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||||
|
None
|
||||||
|
</li>
|
||||||
|
)}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Right column - Possible solutions */}
|
||||||
|
<div className="min-w-0">
|
||||||
|
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{rightIcon}
|
||||||
|
<span>{rightTitle}</span>
|
||||||
|
</div>
|
||||||
|
<ul className="mt-1.5 space-y-1">
|
||||||
|
{rightItems.length > 0 ? (
|
||||||
|
rightItems.map((item) => (
|
||||||
|
<li
|
||||||
|
key={item}
|
||||||
|
className="text-sm text-gray-700 dark:text-gray-300"
|
||||||
|
>
|
||||||
|
<code className="rounded bg-green-50 px-1 py-0.5 font-mono text-xs text-green-700 dark:bg-green-900/30 dark:text-green-300">
|
||||||
|
{item}
|
||||||
|
</code>
|
||||||
|
</li>
|
||||||
|
))
|
||||||
|
) : (
|
||||||
|
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||||
|
None
|
||||||
|
</li>
|
||||||
|
)}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
type SingleColumnSectionProps = {
|
||||||
|
icon: React.ReactNode;
|
||||||
|
title: string;
|
||||||
|
description: string;
|
||||||
|
items: string[];
|
||||||
|
};
|
||||||
|
|
||||||
|
function SingleColumnSection({
|
||||||
|
icon,
|
||||||
|
title,
|
||||||
|
description,
|
||||||
|
items,
|
||||||
|
}: SingleColumnSectionProps) {
|
||||||
|
return (
|
||||||
|
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
{icon}
|
||||||
|
<span className="font-medium">{title}</span>
|
||||||
|
</div>
|
||||||
|
<p className="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{description}
|
||||||
|
</p>
|
||||||
|
<ul className="mt-2 space-y-1">
|
||||||
|
{items.map((item) => (
|
||||||
|
<li
|
||||||
|
key={item}
|
||||||
|
className="ml-4 list-disc text-sm text-gray-700 dark:text-gray-300"
|
||||||
|
>
|
||||||
|
<code className="rounded bg-gray-100 px-1 py-0.5 font-mono text-xs dark:bg-gray-800">
|
||||||
|
{item}
|
||||||
|
</code>
|
||||||
|
</li>
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
import React from "react";
|
||||||
|
import { InfoIcon, WarningIcon } from "@phosphor-icons/react";
|
||||||
|
import {
|
||||||
|
Tooltip,
|
||||||
|
TooltipContent,
|
||||||
|
TooltipTrigger,
|
||||||
|
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||||
|
import { IncompatibilityInfo } from "@/app/(platform)/build/hooks/useSubAgentUpdate/types";
|
||||||
|
|
||||||
|
type ResolutionModeBarProps = {
|
||||||
|
incompatibilities: IncompatibilityInfo | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function ResolutionModeBar({
|
||||||
|
incompatibilities,
|
||||||
|
}: ResolutionModeBarProps): React.ReactElement {
|
||||||
|
const renderIncompatibilities = () => {
|
||||||
|
if (!incompatibilities) return <span>No incompatibilities</span>;
|
||||||
|
|
||||||
|
const sections: React.ReactNode[] = [];
|
||||||
|
|
||||||
|
if (incompatibilities.missingInputs.length > 0) {
|
||||||
|
sections.push(
|
||||||
|
<div key="missing-inputs" className="mb-1">
|
||||||
|
<span className="font-semibold">Missing inputs: </span>
|
||||||
|
{incompatibilities.missingInputs.map((name, i) => (
|
||||||
|
<React.Fragment key={name}>
|
||||||
|
<code className="font-mono">{name}</code>
|
||||||
|
{i < incompatibilities.missingInputs.length - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
))}
|
||||||
|
</div>,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (incompatibilities.missingOutputs.length > 0) {
|
||||||
|
sections.push(
|
||||||
|
<div key="missing-outputs" className="mb-1">
|
||||||
|
<span className="font-semibold">Missing outputs: </span>
|
||||||
|
{incompatibilities.missingOutputs.map((name, i) => (
|
||||||
|
<React.Fragment key={name}>
|
||||||
|
<code className="font-mono">{name}</code>
|
||||||
|
{i < incompatibilities.missingOutputs.length - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
))}
|
||||||
|
</div>,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (incompatibilities.newRequiredInputs.length > 0) {
|
||||||
|
sections.push(
|
||||||
|
<div key="new-required" className="mb-1">
|
||||||
|
<span className="font-semibold">New required inputs: </span>
|
||||||
|
{incompatibilities.newRequiredInputs.map((name, i) => (
|
||||||
|
<React.Fragment key={name}>
|
||||||
|
<code className="font-mono">{name}</code>
|
||||||
|
{i < incompatibilities.newRequiredInputs.length - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
))}
|
||||||
|
</div>,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (incompatibilities.inputTypeMismatches.length > 0) {
|
||||||
|
sections.push(
|
||||||
|
<div key="type-mismatches" className="mb-1">
|
||||||
|
<span className="font-semibold">Type changed: </span>
|
||||||
|
{incompatibilities.inputTypeMismatches.map((m, i) => (
|
||||||
|
<React.Fragment key={m.name}>
|
||||||
|
<code className="font-mono">{m.name}</code>
|
||||||
|
<span className="text-gray-400">
|
||||||
|
{" "}
|
||||||
|
({m.oldType} → {m.newType})
|
||||||
|
</span>
|
||||||
|
{i < incompatibilities.inputTypeMismatches.length - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
))}
|
||||||
|
</div>,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return <>{sections}</>;
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex items-center justify-between gap-2 rounded-t-xl bg-amber-50 px-3 py-2 dark:bg-amber-900/30">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<WarningIcon className="h-4 w-4 text-amber-600 dark:text-amber-400" />
|
||||||
|
<span className="text-sm text-amber-700 dark:text-amber-300">
|
||||||
|
Remove incompatible connections
|
||||||
|
</span>
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<InfoIcon className="h-4 w-4 cursor-help text-amber-500" />
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent className="max-w-sm">
|
||||||
|
<p className="mb-2 font-semibold">Incompatible changes:</p>
|
||||||
|
<div className="text-xs">{renderIncompatibilities()}</div>
|
||||||
|
<p className="mt-2 text-xs text-gray-400">
|
||||||
|
{(incompatibilities?.newRequiredInputs.length ?? 0) > 0
|
||||||
|
? "Replace / delete"
|
||||||
|
: "Delete"}{" "}
|
||||||
|
the red connections to continue
|
||||||
|
</p>
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,194 @@
|
|||||||
|
import { useState, useCallback, useEffect } from "react";
|
||||||
|
import { useShallow } from "zustand/react/shallow";
|
||||||
|
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||||
|
import {
|
||||||
|
useNodeStore,
|
||||||
|
NodeResolutionData,
|
||||||
|
} from "@/app/(platform)/build/stores/nodeStore";
|
||||||
|
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||||
|
import {
|
||||||
|
useSubAgentUpdate,
|
||||||
|
createUpdatedAgentNodeInputs,
|
||||||
|
getBrokenEdgeIDs,
|
||||||
|
} from "@/app/(platform)/build/hooks/useSubAgentUpdate";
|
||||||
|
import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api";
|
||||||
|
import { CustomNodeData } from "../../CustomNode";
|
||||||
|
|
||||||
|
// Stable empty set to avoid creating new references in selectors
|
||||||
|
const EMPTY_SET: Set<string> = new Set();
|
||||||
|
|
||||||
|
type UseSubAgentUpdateParams = {
|
||||||
|
nodeID: string;
|
||||||
|
nodeData: CustomNodeData;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function useSubAgentUpdateState({
|
||||||
|
nodeID,
|
||||||
|
nodeData,
|
||||||
|
}: UseSubAgentUpdateParams) {
|
||||||
|
const [showIncompatibilityDialog, setShowIncompatibilityDialog] =
|
||||||
|
useState(false);
|
||||||
|
|
||||||
|
// Get store actions
|
||||||
|
const updateNodeData = useNodeStore(
|
||||||
|
useShallow((state) => state.updateNodeData),
|
||||||
|
);
|
||||||
|
const setNodeResolutionMode = useNodeStore(
|
||||||
|
useShallow((state) => state.setNodeResolutionMode),
|
||||||
|
);
|
||||||
|
const isNodeInResolutionMode = useNodeStore(
|
||||||
|
useShallow((state) => state.isNodeInResolutionMode),
|
||||||
|
);
|
||||||
|
const setBrokenEdgeIDs = useNodeStore(
|
||||||
|
useShallow((state) => state.setBrokenEdgeIDs),
|
||||||
|
);
|
||||||
|
// Get this node's broken edge IDs from the per-node map
|
||||||
|
// Use EMPTY_SET as fallback to maintain referential stability
|
||||||
|
const brokenEdgeIDs = useNodeStore(
|
||||||
|
(state) => state.brokenEdgeIDs.get(nodeID) || EMPTY_SET,
|
||||||
|
);
|
||||||
|
const getNodeResolutionData = useNodeStore(
|
||||||
|
useShallow((state) => state.getNodeResolutionData),
|
||||||
|
);
|
||||||
|
const connectedEdges = useEdgeStore(
|
||||||
|
useShallow((state) => state.getNodeEdges(nodeID)),
|
||||||
|
);
|
||||||
|
const availableSubGraphs = useGraphStore(
|
||||||
|
useShallow((state) => state.availableSubGraphs),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Extract agent-specific data
|
||||||
|
const graphID = nodeData.hardcodedValues?.graph_id as string | undefined;
|
||||||
|
const graphVersion = nodeData.hardcodedValues?.graph_version as
|
||||||
|
| number
|
||||||
|
| undefined;
|
||||||
|
const currentInputSchema = nodeData.hardcodedValues?.input_schema as
|
||||||
|
| GraphInputSchema
|
||||||
|
| undefined;
|
||||||
|
const currentOutputSchema = nodeData.hardcodedValues?.output_schema as
|
||||||
|
| GraphOutputSchema
|
||||||
|
| undefined;
|
||||||
|
|
||||||
|
// Use the sub-agent update hook
|
||||||
|
const updateInfo = useSubAgentUpdate(
|
||||||
|
nodeID,
|
||||||
|
graphID,
|
||||||
|
graphVersion,
|
||||||
|
currentInputSchema,
|
||||||
|
currentOutputSchema,
|
||||||
|
connectedEdges,
|
||||||
|
availableSubGraphs,
|
||||||
|
);
|
||||||
|
|
||||||
|
const isInResolutionMode = isNodeInResolutionMode(nodeID);
|
||||||
|
|
||||||
|
// Handle update button click
|
||||||
|
const handleUpdateClick = useCallback(() => {
|
||||||
|
if (!updateInfo.hasUpdate || !updateInfo.latestGraph) return;
|
||||||
|
|
||||||
|
if (updateInfo.isCompatible) {
|
||||||
|
// Compatible update - apply directly
|
||||||
|
const newHardcodedValues = createUpdatedAgentNodeInputs(
|
||||||
|
nodeData.hardcodedValues,
|
||||||
|
updateInfo.latestGraph,
|
||||||
|
);
|
||||||
|
updateNodeData(nodeID, { hardcodedValues: newHardcodedValues });
|
||||||
|
} else {
|
||||||
|
// Incompatible update - show dialog
|
||||||
|
setShowIncompatibilityDialog(true);
|
||||||
|
}
|
||||||
|
}, [
|
||||||
|
updateInfo.hasUpdate,
|
||||||
|
updateInfo.latestGraph,
|
||||||
|
updateInfo.isCompatible,
|
||||||
|
nodeData.hardcodedValues,
|
||||||
|
updateNodeData,
|
||||||
|
nodeID,
|
||||||
|
]);
|
||||||
|
|
||||||
|
// Handle confirming an incompatible update
|
||||||
|
function handleConfirmIncompatibleUpdate() {
|
||||||
|
if (!updateInfo.latestGraph || !updateInfo.incompatibilities) return;
|
||||||
|
|
||||||
|
const latestGraph = updateInfo.latestGraph;
|
||||||
|
|
||||||
|
// Get the new schemas from the latest graph version
|
||||||
|
const newInputSchema =
|
||||||
|
(latestGraph.input_schema as Record<string, unknown>) || {};
|
||||||
|
const newOutputSchema =
|
||||||
|
(latestGraph.output_schema as Record<string, unknown>) || {};
|
||||||
|
|
||||||
|
// Create the updated hardcoded values but DON'T apply them yet
|
||||||
|
// We'll apply them when resolution is complete
|
||||||
|
const pendingHardcodedValues = createUpdatedAgentNodeInputs(
|
||||||
|
nodeData.hardcodedValues,
|
||||||
|
latestGraph,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Get broken edge IDs and store them for this node
|
||||||
|
const brokenIds = getBrokenEdgeIDs(
|
||||||
|
connectedEdges,
|
||||||
|
updateInfo.incompatibilities,
|
||||||
|
nodeID,
|
||||||
|
);
|
||||||
|
setBrokenEdgeIDs(nodeID, brokenIds);
|
||||||
|
|
||||||
|
// Enter resolution mode with both old and new schemas
|
||||||
|
// DON'T apply the update yet - keep old schema so connections remain visible
|
||||||
|
const resolutionData: NodeResolutionData = {
|
||||||
|
incompatibilities: updateInfo.incompatibilities,
|
||||||
|
pendingUpdate: {
|
||||||
|
input_schema: newInputSchema,
|
||||||
|
output_schema: newOutputSchema,
|
||||||
|
},
|
||||||
|
currentSchema: {
|
||||||
|
input_schema: (currentInputSchema as Record<string, unknown>) || {},
|
||||||
|
output_schema: (currentOutputSchema as Record<string, unknown>) || {},
|
||||||
|
},
|
||||||
|
pendingHardcodedValues,
|
||||||
|
};
|
||||||
|
setNodeResolutionMode(nodeID, true, resolutionData);
|
||||||
|
|
||||||
|
setShowIncompatibilityDialog(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if resolution is complete (all broken edges removed)
|
||||||
|
const resolutionData = getNodeResolutionData(nodeID);
|
||||||
|
|
||||||
|
// Auto-check resolution on edge changes
|
||||||
|
useEffect(() => {
|
||||||
|
if (!isInResolutionMode) return;
|
||||||
|
|
||||||
|
// Check if any broken edges still exist
|
||||||
|
const remainingBroken = Array.from(brokenEdgeIDs).filter((edgeId) =>
|
||||||
|
connectedEdges.some((e) => e.id === edgeId),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (remainingBroken.length === 0) {
|
||||||
|
// Resolution complete - now apply the pending update
|
||||||
|
if (resolutionData?.pendingHardcodedValues) {
|
||||||
|
updateNodeData(nodeID, {
|
||||||
|
hardcodedValues: resolutionData.pendingHardcodedValues,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// setNodeResolutionMode will clean up this node's broken edges automatically
|
||||||
|
setNodeResolutionMode(nodeID, false);
|
||||||
|
}
|
||||||
|
}, [
|
||||||
|
isInResolutionMode,
|
||||||
|
brokenEdgeIDs,
|
||||||
|
connectedEdges,
|
||||||
|
resolutionData,
|
||||||
|
nodeID,
|
||||||
|
]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
updateInfo,
|
||||||
|
isInResolutionMode,
|
||||||
|
resolutionData,
|
||||||
|
showIncompatibilityDialog,
|
||||||
|
setShowIncompatibilityDialog,
|
||||||
|
handleUpdateClick,
|
||||||
|
handleConfirmIncompatibleUpdate,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -1,4 +1,6 @@
|
|||||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||||
|
import { NodeResolutionData } from "@/app/(platform)/build/stores/nodeStore";
|
||||||
|
import { RJSFSchema } from "@rjsf/utils";
|
||||||
|
|
||||||
export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
||||||
INCOMPLETE: "ring-slate-300 bg-slate-300",
|
INCOMPLETE: "ring-slate-300 bg-slate-300",
|
||||||
@@ -9,3 +11,48 @@ export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
|||||||
TERMINATED: "ring-orange-300 bg-orange-300 ",
|
TERMINATED: "ring-orange-300 bg-orange-300 ",
|
||||||
FAILED: "ring-red-300 bg-red-300",
|
FAILED: "ring-red-300 bg-red-300",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Merges schemas during resolution mode to include removed inputs/outputs
|
||||||
|
* that still have connections, so users can see and delete them.
|
||||||
|
*/
|
||||||
|
export function mergeSchemaForResolution(
|
||||||
|
currentSchema: Record<string, unknown>,
|
||||||
|
newSchema: Record<string, unknown>,
|
||||||
|
resolutionData: NodeResolutionData,
|
||||||
|
type: "input" | "output",
|
||||||
|
): Record<string, unknown> {
|
||||||
|
const newProps = (newSchema.properties as RJSFSchema) || {};
|
||||||
|
const currentProps = (currentSchema.properties as RJSFSchema) || {};
|
||||||
|
const mergedProps = { ...newProps };
|
||||||
|
const incomp = resolutionData.incompatibilities;
|
||||||
|
|
||||||
|
if (type === "input") {
|
||||||
|
// Add back missing inputs that have connections
|
||||||
|
incomp.missingInputs.forEach((inputName: string) => {
|
||||||
|
if (currentProps[inputName]) {
|
||||||
|
mergedProps[inputName] = currentProps[inputName];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// Add back inputs with type mismatches (keep old type so connection works visually)
|
||||||
|
incomp.inputTypeMismatches.forEach(
|
||||||
|
(mismatch: { name: string; oldType: string; newType: string }) => {
|
||||||
|
if (currentProps[mismatch.name]) {
|
||||||
|
mergedProps[mismatch.name] = currentProps[mismatch.name];
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Add back missing outputs that have connections
|
||||||
|
incomp.missingOutputs.forEach((outputName: string) => {
|
||||||
|
if (currentProps[outputName]) {
|
||||||
|
mergedProps[outputName] = currentProps[outputName];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
...newSchema,
|
||||||
|
properties: mergedProps,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,58 @@
|
|||||||
|
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||||
|
import { CustomNodeData } from "./CustomNode";
|
||||||
|
import { BlockUIType } from "../../../types";
|
||||||
|
import { useMemo } from "react";
|
||||||
|
import { mergeSchemaForResolution } from "./helpers";
|
||||||
|
|
||||||
|
export const useCustomNode = ({
|
||||||
|
data,
|
||||||
|
nodeId,
|
||||||
|
}: {
|
||||||
|
data: CustomNodeData;
|
||||||
|
nodeId: string;
|
||||||
|
}) => {
|
||||||
|
const isInResolutionMode = useNodeStore((state) =>
|
||||||
|
state.nodesInResolutionMode.has(nodeId),
|
||||||
|
);
|
||||||
|
const resolutionData = useNodeStore((state) =>
|
||||||
|
state.nodeResolutionData.get(nodeId),
|
||||||
|
);
|
||||||
|
|
||||||
|
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||||
|
|
||||||
|
const currentInputSchema = isAgent
|
||||||
|
? (data.hardcodedValues.input_schema ?? {})
|
||||||
|
: data.inputSchema;
|
||||||
|
const currentOutputSchema = isAgent
|
||||||
|
? (data.hardcodedValues.output_schema ?? {})
|
||||||
|
: data.outputSchema;
|
||||||
|
|
||||||
|
const inputSchema = useMemo(() => {
|
||||||
|
if (isAgent && isInResolutionMode && resolutionData) {
|
||||||
|
return mergeSchemaForResolution(
|
||||||
|
resolutionData.currentSchema.input_schema,
|
||||||
|
resolutionData.pendingUpdate.input_schema,
|
||||||
|
resolutionData,
|
||||||
|
"input",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return currentInputSchema;
|
||||||
|
}, [isAgent, isInResolutionMode, resolutionData, currentInputSchema]);
|
||||||
|
|
||||||
|
const outputSchema = useMemo(() => {
|
||||||
|
if (isAgent && isInResolutionMode && resolutionData) {
|
||||||
|
return mergeSchemaForResolution(
|
||||||
|
resolutionData.currentSchema.output_schema,
|
||||||
|
resolutionData.pendingUpdate.output_schema,
|
||||||
|
resolutionData,
|
||||||
|
"output",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return currentOutputSchema;
|
||||||
|
}, [isAgent, isInResolutionMode, resolutionData, currentOutputSchema]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
inputSchema,
|
||||||
|
outputSchema,
|
||||||
|
};
|
||||||
|
};
|
||||||
@@ -5,20 +5,16 @@ import { useNodeStore } from "../../../stores/nodeStore";
|
|||||||
import { BlockUIType } from "../../types";
|
import { BlockUIType } from "../../types";
|
||||||
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
|
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
|
||||||
|
|
||||||
export const FormCreator = React.memo(
|
interface FormCreatorProps {
|
||||||
({
|
jsonSchema: RJSFSchema;
|
||||||
jsonSchema,
|
nodeId: string;
|
||||||
nodeId,
|
uiType: BlockUIType;
|
||||||
uiType,
|
showHandles?: boolean;
|
||||||
showHandles = true,
|
className?: string;
|
||||||
className,
|
}
|
||||||
}: {
|
|
||||||
jsonSchema: RJSFSchema;
|
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
|
||||||
nodeId: string;
|
({ jsonSchema, nodeId, uiType, showHandles = true, className }) => {
|
||||||
uiType: BlockUIType;
|
|
||||||
showHandles?: boolean;
|
|
||||||
className?: string;
|
|
||||||
}) => {
|
|
||||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||||
|
|
||||||
const getHardCodedValues = useNodeStore(
|
const getHardCodedValues = useNodeStore(
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import {
|
|||||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||||
import { getTypeDisplayInfo } from "./helpers";
|
import { getTypeDisplayInfo } from "./helpers";
|
||||||
import { BlockUIType } from "../../types";
|
import { BlockUIType } from "../../types";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { useBrokenOutputs } from "./useBrokenOutputs";
|
||||||
|
|
||||||
export const OutputHandler = ({
|
export const OutputHandler = ({
|
||||||
outputSchema,
|
outputSchema,
|
||||||
@@ -27,6 +29,9 @@ export const OutputHandler = ({
|
|||||||
const { isOutputConnected } = useEdgeStore();
|
const { isOutputConnected } = useEdgeStore();
|
||||||
const properties = outputSchema?.properties || {};
|
const properties = outputSchema?.properties || {};
|
||||||
const [isOutputVisible, setIsOutputVisible] = useState(true);
|
const [isOutputVisible, setIsOutputVisible] = useState(true);
|
||||||
|
const brokenOutputs = useBrokenOutputs(nodeId);
|
||||||
|
|
||||||
|
console.log("brokenOutputs", brokenOutputs);
|
||||||
|
|
||||||
const showHandles = uiType !== BlockUIType.OUTPUT;
|
const showHandles = uiType !== BlockUIType.OUTPUT;
|
||||||
|
|
||||||
@@ -44,6 +49,7 @@ export const OutputHandler = ({
|
|||||||
const shouldShow = isConnected || isOutputVisible;
|
const shouldShow = isConnected || isOutputVisible;
|
||||||
const { displayType, colorClass, hexColor } =
|
const { displayType, colorClass, hexColor } =
|
||||||
getTypeDisplayInfo(fieldSchema);
|
getTypeDisplayInfo(fieldSchema);
|
||||||
|
const isBroken = brokenOutputs.has(fullKey);
|
||||||
|
|
||||||
return shouldShow ? (
|
return shouldShow ? (
|
||||||
<div key={fullKey} className="flex flex-col items-end gap-2">
|
<div key={fullKey} className="flex flex-col items-end gap-2">
|
||||||
@@ -64,15 +70,29 @@ export const OutputHandler = ({
|
|||||||
</Tooltip>
|
</Tooltip>
|
||||||
</TooltipProvider>
|
</TooltipProvider>
|
||||||
)}
|
)}
|
||||||
<Text variant="body" className="text-slate-700">
|
<Text
|
||||||
|
variant="body"
|
||||||
|
className={cn(
|
||||||
|
"text-slate-700",
|
||||||
|
isBroken && "text-red-500 line-through",
|
||||||
|
)}
|
||||||
|
>
|
||||||
{fieldTitle}
|
{fieldTitle}
|
||||||
</Text>
|
</Text>
|
||||||
<Text variant="small" as="span" className={colorClass}>
|
<Text
|
||||||
|
variant="small"
|
||||||
|
as="span"
|
||||||
|
className={cn(
|
||||||
|
colorClass,
|
||||||
|
isBroken && "!text-red-500 line-through",
|
||||||
|
)}
|
||||||
|
>
|
||||||
({displayType})
|
({displayType})
|
||||||
</Text>
|
</Text>
|
||||||
|
|
||||||
{showHandles && (
|
{showHandles && (
|
||||||
<OutputNodeHandle
|
<OutputNodeHandle
|
||||||
|
isBroken={isBroken}
|
||||||
field_name={fullKey}
|
field_name={fullKey}
|
||||||
nodeId={nodeId}
|
nodeId={nodeId}
|
||||||
hexColor={hexColor}
|
hexColor={hexColor}
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
import { useMemo } from "react";
|
||||||
|
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hook to get the set of broken output names for a node in resolution mode.
|
||||||
|
*/
|
||||||
|
export function useBrokenOutputs(nodeID: string): Set<string> {
|
||||||
|
// Subscribe to the actual state values, not just methods
|
||||||
|
const isInResolution = useNodeStore((state) =>
|
||||||
|
state.nodesInResolutionMode.has(nodeID),
|
||||||
|
);
|
||||||
|
const resolutionData = useNodeStore((state) =>
|
||||||
|
state.nodeResolutionData.get(nodeID),
|
||||||
|
);
|
||||||
|
|
||||||
|
return useMemo(() => {
|
||||||
|
if (!isInResolution || !resolutionData) {
|
||||||
|
return new Set<string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Set(resolutionData.incompatibilities.missingOutputs);
|
||||||
|
}, [isInResolution, resolutionData]);
|
||||||
|
}
|
||||||
@@ -25,7 +25,7 @@ export const RightSidebar = () => {
|
|||||||
>
|
>
|
||||||
<div className="mb-4">
|
<div className="mb-4">
|
||||||
<h2 className="text-lg font-semibold text-slate-800 dark:text-slate-200">
|
<h2 className="text-lg font-semibold text-slate-800 dark:text-slate-200">
|
||||||
Flow Debug Panel
|
Graph Debug Panel
|
||||||
</h2>
|
</h2>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ export const RightSidebar = () => {
|
|||||||
{l.source_id}[{l.source_name}] → {l.sink_id}[{l.sink_name}]
|
{l.source_id}[{l.source_name}] → {l.sink_id}[{l.sink_name}]
|
||||||
</div>
|
</div>
|
||||||
<div className="mt-1 text-slate-500 dark:text-slate-400">
|
<div className="mt-1 text-slate-500 dark:text-slate-400">
|
||||||
edge_id: {l.id}
|
edge.id: {l.id}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
|
|||||||
@@ -12,7 +12,14 @@ import {
|
|||||||
PopoverContent,
|
PopoverContent,
|
||||||
PopoverTrigger,
|
PopoverTrigger,
|
||||||
} from "@/components/__legacy__/ui/popover";
|
} from "@/components/__legacy__/ui/popover";
|
||||||
import { Block, BlockUIType, SpecialBlockID } from "@/lib/autogpt-server-api";
|
import {
|
||||||
|
Block,
|
||||||
|
BlockIORootSchema,
|
||||||
|
BlockUIType,
|
||||||
|
GraphInputSchema,
|
||||||
|
GraphOutputSchema,
|
||||||
|
SpecialBlockID,
|
||||||
|
} from "@/lib/autogpt-server-api";
|
||||||
import { MagnifyingGlassIcon, PlusIcon } from "@radix-ui/react-icons";
|
import { MagnifyingGlassIcon, PlusIcon } from "@radix-ui/react-icons";
|
||||||
import { IconToyBrick } from "@/components/__legacy__/ui/icons";
|
import { IconToyBrick } from "@/components/__legacy__/ui/icons";
|
||||||
import { getPrimaryCategoryColor } from "@/lib/utils";
|
import { getPrimaryCategoryColor } from "@/lib/utils";
|
||||||
@@ -24,8 +31,10 @@ import {
|
|||||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||||
import jaro from "jaro-winkler";
|
import jaro from "jaro-winkler";
|
||||||
|
|
||||||
type _Block = Block & {
|
type _Block = Omit<Block, "inputSchema" | "outputSchema"> & {
|
||||||
uiKey?: string;
|
uiKey?: string;
|
||||||
|
inputSchema: BlockIORootSchema | GraphInputSchema;
|
||||||
|
outputSchema: BlockIORootSchema | GraphOutputSchema;
|
||||||
hardcodedValues?: Record<string, any>;
|
hardcodedValues?: Record<string, any>;
|
||||||
_cached?: {
|
_cached?: {
|
||||||
blockName: string;
|
blockName: string;
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import React from "react";
|
|||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { Button } from "@/components/__legacy__/ui/button";
|
import { Button } from "@/components/__legacy__/ui/button";
|
||||||
import { LogOut } from "lucide-react";
|
import { LogOut } from "lucide-react";
|
||||||
import { ClockIcon } from "@phosphor-icons/react";
|
import { ClockIcon, WarningIcon } from "@phosphor-icons/react";
|
||||||
import { IconPlay, IconSquare } from "@/components/__legacy__/ui/icons";
|
import { IconPlay, IconSquare } from "@/components/__legacy__/ui/icons";
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
@@ -13,6 +13,7 @@ interface Props {
|
|||||||
isRunning: boolean;
|
isRunning: boolean;
|
||||||
isDisabled: boolean;
|
isDisabled: boolean;
|
||||||
className?: string;
|
className?: string;
|
||||||
|
resolutionModeActive?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const BuildActionBar: React.FC<Props> = ({
|
export const BuildActionBar: React.FC<Props> = ({
|
||||||
@@ -23,9 +24,30 @@ export const BuildActionBar: React.FC<Props> = ({
|
|||||||
isRunning,
|
isRunning,
|
||||||
isDisabled,
|
isDisabled,
|
||||||
className,
|
className,
|
||||||
|
resolutionModeActive = false,
|
||||||
}) => {
|
}) => {
|
||||||
const buttonClasses =
|
const buttonClasses =
|
||||||
"flex items-center gap-2 text-sm font-medium md:text-lg";
|
"flex items-center gap-2 text-sm font-medium md:text-lg";
|
||||||
|
|
||||||
|
// Show resolution mode message instead of action buttons
|
||||||
|
if (resolutionModeActive) {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex w-fit select-none items-center justify-center p-4",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div className="flex items-center gap-3 rounded-lg border border-amber-300 bg-amber-50 px-4 py-3 dark:border-amber-700 dark:bg-amber-900/30">
|
||||||
|
<WarningIcon className="size-5 text-amber-600 dark:text-amber-400" />
|
||||||
|
<span className="text-sm font-medium text-amber-800 dark:text-amber-200">
|
||||||
|
Remove incompatible connections to continue
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
|
|||||||
@@ -60,10 +60,16 @@ export function CustomEdge({
|
|||||||
targetY - 5,
|
targetY - 5,
|
||||||
);
|
);
|
||||||
const { deleteElements } = useReactFlow<Node, CustomEdge>();
|
const { deleteElements } = useReactFlow<Node, CustomEdge>();
|
||||||
const { visualizeBeads } = useContext(BuilderContext) ?? {
|
const builderContext = useContext(BuilderContext);
|
||||||
|
const { visualizeBeads } = builderContext ?? {
|
||||||
visualizeBeads: "no",
|
visualizeBeads: "no",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Check if this edge is broken (during resolution mode)
|
||||||
|
const isBroken =
|
||||||
|
builderContext?.resolutionMode?.active &&
|
||||||
|
builderContext?.resolutionMode?.brokenEdgeIds?.includes(id);
|
||||||
|
|
||||||
const onEdgeRemoveClick = () => {
|
const onEdgeRemoveClick = () => {
|
||||||
deleteElements({ edges: [{ id }] });
|
deleteElements({ edges: [{ id }] });
|
||||||
};
|
};
|
||||||
@@ -171,12 +177,27 @@ export function CustomEdge({
|
|||||||
|
|
||||||
const middle = getPointForT(0.5);
|
const middle = getPointForT(0.5);
|
||||||
|
|
||||||
|
// Determine edge color - red for broken edges
|
||||||
|
const baseColor = data?.edgeColor ?? "#555555";
|
||||||
|
const edgeColor = isBroken ? "#ef4444" : baseColor;
|
||||||
|
// Add opacity to hex color (99 = 60% opacity, 80 = 50% opacity)
|
||||||
|
const strokeColor = isBroken
|
||||||
|
? `${edgeColor}99`
|
||||||
|
: selected
|
||||||
|
? edgeColor
|
||||||
|
: `${edgeColor}80`;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<BaseEdge
|
<BaseEdge
|
||||||
path={svgPath}
|
path={svgPath}
|
||||||
markerEnd={markerEnd}
|
markerEnd={markerEnd}
|
||||||
className={`data-sentry-unmask transition-all duration-200 ${data?.isStatic ? "[stroke-dasharray:5_3]" : "[stroke-dasharray:0]"} [stroke-width:${data?.isStatic ? 2.5 : 2}px] hover:[stroke-width:${data?.isStatic ? 3.5 : 3}px] ${selected ? `[stroke:${data?.edgeColor ?? "#555555"}]` : `[stroke:${data?.edgeColor ?? "#555555"}80] hover:[stroke:${data?.edgeColor ?? "#555555"}]`}`}
|
style={{
|
||||||
|
stroke: strokeColor,
|
||||||
|
strokeWidth: data?.isStatic ? 2.5 : 2,
|
||||||
|
strokeDasharray: data?.isStatic ? "5 3" : undefined,
|
||||||
|
}}
|
||||||
|
className="data-sentry-unmask transition-all duration-200"
|
||||||
/>
|
/>
|
||||||
<path
|
<path
|
||||||
d={svgPath}
|
d={svgPath}
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ import {
|
|||||||
BlockIOSubSchema,
|
BlockIOSubSchema,
|
||||||
BlockUIType,
|
BlockUIType,
|
||||||
Category,
|
Category,
|
||||||
|
GraphInputSchema,
|
||||||
|
GraphOutputSchema,
|
||||||
NodeExecutionResult,
|
NodeExecutionResult,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
import {
|
import {
|
||||||
@@ -62,14 +64,21 @@ import { NodeGenericInputField, NodeTextBoxInput } from "../NodeInputs";
|
|||||||
import NodeOutputs from "../NodeOutputs";
|
import NodeOutputs from "../NodeOutputs";
|
||||||
import OutputModalComponent from "../OutputModalComponent";
|
import OutputModalComponent from "../OutputModalComponent";
|
||||||
import "./customnode.css";
|
import "./customnode.css";
|
||||||
|
import { SubAgentUpdateBar } from "./SubAgentUpdateBar";
|
||||||
|
import { IncompatibilityDialog } from "./IncompatibilityDialog";
|
||||||
|
import {
|
||||||
|
useSubAgentUpdate,
|
||||||
|
createUpdatedAgentNodeInputs,
|
||||||
|
getBrokenEdgeIDs,
|
||||||
|
} from "../../../hooks/useSubAgentUpdate";
|
||||||
|
|
||||||
export type ConnectionData = Array<{
|
export type ConnectedEdge = {
|
||||||
edge_id: string;
|
id: string;
|
||||||
source: string;
|
source: string;
|
||||||
sourceHandle: string;
|
sourceHandle: string;
|
||||||
target: string;
|
target: string;
|
||||||
targetHandle: string;
|
targetHandle: string;
|
||||||
}>;
|
};
|
||||||
|
|
||||||
export type CustomNodeData = {
|
export type CustomNodeData = {
|
||||||
blockType: string;
|
blockType: string;
|
||||||
@@ -80,7 +89,7 @@ export type CustomNodeData = {
|
|||||||
inputSchema: BlockIORootSchema;
|
inputSchema: BlockIORootSchema;
|
||||||
outputSchema: BlockIORootSchema;
|
outputSchema: BlockIORootSchema;
|
||||||
hardcodedValues: { [key: string]: any };
|
hardcodedValues: { [key: string]: any };
|
||||||
connections: ConnectionData;
|
connections: ConnectedEdge[];
|
||||||
isOutputOpen: boolean;
|
isOutputOpen: boolean;
|
||||||
status?: NodeExecutionResult["status"];
|
status?: NodeExecutionResult["status"];
|
||||||
/** executionResults contains outputs across multiple executions
|
/** executionResults contains outputs across multiple executions
|
||||||
@@ -127,20 +136,199 @@ export const CustomNode = React.memo(
|
|||||||
|
|
||||||
let subGraphID = "";
|
let subGraphID = "";
|
||||||
|
|
||||||
if (data.uiType === BlockUIType.AGENT) {
|
|
||||||
// Display the graph's schema instead AgentExecutorBlock's schema.
|
|
||||||
data.inputSchema = data.hardcodedValues?.input_schema || {};
|
|
||||||
data.outputSchema = data.hardcodedValues?.output_schema || {};
|
|
||||||
subGraphID = data.hardcodedValues?.graph_id || subGraphID;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!builderContext) {
|
if (!builderContext) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
"BuilderContext consumer must be inside FlowEditor component",
|
"BuilderContext consumer must be inside FlowEditor component",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const { libraryAgent, setIsAnyModalOpen, getNextNodeId } = builderContext;
|
const {
|
||||||
|
libraryAgent,
|
||||||
|
setIsAnyModalOpen,
|
||||||
|
getNextNodeId,
|
||||||
|
availableFlows,
|
||||||
|
resolutionMode,
|
||||||
|
enterResolutionMode,
|
||||||
|
} = builderContext;
|
||||||
|
|
||||||
|
// Check if this node is in resolution mode (moved up for schema merge logic)
|
||||||
|
const isInResolutionMode =
|
||||||
|
resolutionMode.active && resolutionMode.nodeId === id;
|
||||||
|
|
||||||
|
if (data.uiType === BlockUIType.AGENT) {
|
||||||
|
// Display the graph's schema instead AgentExecutorBlock's schema.
|
||||||
|
const currentInputSchema = data.hardcodedValues?.input_schema || {};
|
||||||
|
const currentOutputSchema = data.hardcodedValues?.output_schema || {};
|
||||||
|
subGraphID = data.hardcodedValues?.graph_id || subGraphID;
|
||||||
|
|
||||||
|
// During resolution mode, merge old connected inputs/outputs with new schema
|
||||||
|
if (isInResolutionMode && resolutionMode.pendingUpdate) {
|
||||||
|
const newInputSchema =
|
||||||
|
(resolutionMode.pendingUpdate.input_schema as BlockIORootSchema) ||
|
||||||
|
{};
|
||||||
|
const newOutputSchema =
|
||||||
|
(resolutionMode.pendingUpdate.output_schema as BlockIORootSchema) ||
|
||||||
|
{};
|
||||||
|
|
||||||
|
// Merge input schemas: start with new schema, add old connected inputs that are missing
|
||||||
|
const mergedInputProps = { ...newInputSchema.properties };
|
||||||
|
const incomp = resolutionMode.incompatibilities;
|
||||||
|
if (incomp && currentInputSchema.properties) {
|
||||||
|
// Add back missing inputs that have connections (so user can see/delete them)
|
||||||
|
incomp.missingInputs.forEach((inputName) => {
|
||||||
|
if (currentInputSchema.properties[inputName]) {
|
||||||
|
mergedInputProps[inputName] =
|
||||||
|
currentInputSchema.properties[inputName];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// Add back inputs with type mismatches (keep old type so connection still works visually)
|
||||||
|
incomp.inputTypeMismatches.forEach((mismatch) => {
|
||||||
|
if (currentInputSchema.properties[mismatch.name]) {
|
||||||
|
mergedInputProps[mismatch.name] =
|
||||||
|
currentInputSchema.properties[mismatch.name];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge output schemas: start with new schema, add old connected outputs that are missing
|
||||||
|
const mergedOutputProps = { ...newOutputSchema.properties };
|
||||||
|
if (incomp && currentOutputSchema.properties) {
|
||||||
|
incomp.missingOutputs.forEach((outputName) => {
|
||||||
|
if (currentOutputSchema.properties[outputName]) {
|
||||||
|
mergedOutputProps[outputName] =
|
||||||
|
currentOutputSchema.properties[outputName];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
data.inputSchema = {
|
||||||
|
...newInputSchema,
|
||||||
|
properties: mergedInputProps,
|
||||||
|
};
|
||||||
|
data.outputSchema = {
|
||||||
|
...newOutputSchema,
|
||||||
|
properties: mergedOutputProps,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
data.inputSchema = currentInputSchema;
|
||||||
|
data.outputSchema = currentOutputSchema;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const setHardcodedValues = useCallback(
|
||||||
|
(values: any) => {
|
||||||
|
updateNodeData(id, { hardcodedValues: values });
|
||||||
|
},
|
||||||
|
[id, updateNodeData],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Sub-agent update detection
|
||||||
|
const isAgentBlock = data.uiType === BlockUIType.AGENT;
|
||||||
|
const graphId = isAgentBlock ? data.hardcodedValues?.graph_id : undefined;
|
||||||
|
const graphVersion = isAgentBlock
|
||||||
|
? data.hardcodedValues?.graph_version
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
const subAgentUpdate = useSubAgentUpdate(
|
||||||
|
id,
|
||||||
|
graphId,
|
||||||
|
graphVersion,
|
||||||
|
isAgentBlock
|
||||||
|
? (data.hardcodedValues?.input_schema as GraphInputSchema)
|
||||||
|
: undefined,
|
||||||
|
isAgentBlock
|
||||||
|
? (data.hardcodedValues?.output_schema as GraphOutputSchema)
|
||||||
|
: undefined,
|
||||||
|
data.connections,
|
||||||
|
availableFlows,
|
||||||
|
);
|
||||||
|
|
||||||
|
const [showIncompatibilityDialog, setShowIncompatibilityDialog] =
|
||||||
|
useState(false);
|
||||||
|
|
||||||
|
// Helper to check if a handle is broken (for resolution mode)
|
||||||
|
const isInputHandleBroken = useCallback(
|
||||||
|
(handleName: string): boolean => {
|
||||||
|
if (!isInResolutionMode || !resolutionMode.incompatibilities) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const incomp = resolutionMode.incompatibilities;
|
||||||
|
return (
|
||||||
|
incomp.missingInputs.includes(handleName) ||
|
||||||
|
incomp.inputTypeMismatches.some((m) => m.name === handleName)
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[isInResolutionMode, resolutionMode.incompatibilities],
|
||||||
|
);
|
||||||
|
|
||||||
|
const isOutputHandleBroken = useCallback(
|
||||||
|
(handleName: string): boolean => {
|
||||||
|
if (!isInResolutionMode || !resolutionMode.incompatibilities) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return resolutionMode.incompatibilities.missingOutputs.includes(
|
||||||
|
handleName,
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[isInResolutionMode, resolutionMode.incompatibilities],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Handle update button click
|
||||||
|
const handleUpdateClick = useCallback(() => {
|
||||||
|
if (!subAgentUpdate.latestGraph) return;
|
||||||
|
|
||||||
|
if (subAgentUpdate.isCompatible) {
|
||||||
|
// Compatible update - directly apply
|
||||||
|
const updatedValues = createUpdatedAgentNodeInputs(
|
||||||
|
data.hardcodedValues,
|
||||||
|
subAgentUpdate.latestGraph,
|
||||||
|
);
|
||||||
|
setHardcodedValues(updatedValues);
|
||||||
|
toast({
|
||||||
|
title: "Agent updated",
|
||||||
|
description: `Updated to version ${subAgentUpdate.latestVersion}`,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Incompatible update - show dialog
|
||||||
|
setShowIncompatibilityDialog(true);
|
||||||
|
}
|
||||||
|
}, [subAgentUpdate, data.hardcodedValues, setHardcodedValues]);
|
||||||
|
|
||||||
|
// Handle confirm incompatible update
|
||||||
|
const handleConfirmIncompatibleUpdate = useCallback(() => {
|
||||||
|
if (!subAgentUpdate.latestGraph || !subAgentUpdate.incompatibilities) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the updated values but DON'T apply them yet
|
||||||
|
const updatedValues = createUpdatedAgentNodeInputs(
|
||||||
|
data.hardcodedValues,
|
||||||
|
subAgentUpdate.latestGraph,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Get broken edge IDs
|
||||||
|
const brokenEdgeIds = getBrokenEdgeIDs(
|
||||||
|
data.connections,
|
||||||
|
subAgentUpdate.incompatibilities,
|
||||||
|
id,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Enter resolution mode with pending update (don't apply schema yet)
|
||||||
|
enterResolutionMode(
|
||||||
|
id,
|
||||||
|
subAgentUpdate.incompatibilities,
|
||||||
|
brokenEdgeIds,
|
||||||
|
updatedValues,
|
||||||
|
);
|
||||||
|
|
||||||
|
setShowIncompatibilityDialog(false);
|
||||||
|
}, [
|
||||||
|
subAgentUpdate,
|
||||||
|
data.hardcodedValues,
|
||||||
|
data.connections,
|
||||||
|
id,
|
||||||
|
enterResolutionMode,
|
||||||
|
]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (data.executionResults || data.status) {
|
if (data.executionResults || data.status) {
|
||||||
@@ -156,13 +344,6 @@ export const CustomNode = React.memo(
|
|||||||
setIsAnyModalOpen?.(isModalOpen || isOutputModalOpen);
|
setIsAnyModalOpen?.(isModalOpen || isOutputModalOpen);
|
||||||
}, [isModalOpen, isOutputModalOpen, data, setIsAnyModalOpen]);
|
}, [isModalOpen, isOutputModalOpen, data, setIsAnyModalOpen]);
|
||||||
|
|
||||||
const setHardcodedValues = useCallback(
|
|
||||||
(values: any) => {
|
|
||||||
updateNodeData(id, { hardcodedValues: values });
|
|
||||||
},
|
|
||||||
[id, updateNodeData],
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleTitleEdit = useCallback(() => {
|
const handleTitleEdit = useCallback(() => {
|
||||||
setIsEditingTitle(true);
|
setIsEditingTitle(true);
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
@@ -255,6 +436,7 @@ export const CustomNode = React.memo(
|
|||||||
isConnected={isOutputHandleConnected(propKey)}
|
isConnected={isOutputHandleConnected(propKey)}
|
||||||
schema={fieldSchema}
|
schema={fieldSchema}
|
||||||
side="right"
|
side="right"
|
||||||
|
isBroken={isOutputHandleBroken(propKey)}
|
||||||
/>
|
/>
|
||||||
{"properties" in fieldSchema &&
|
{"properties" in fieldSchema &&
|
||||||
renderHandles(
|
renderHandles(
|
||||||
@@ -385,6 +567,7 @@ export const CustomNode = React.memo(
|
|||||||
isRequired={isRequired}
|
isRequired={isRequired}
|
||||||
schema={propSchema}
|
schema={propSchema}
|
||||||
side="left"
|
side="left"
|
||||||
|
isBroken={isInputHandleBroken(propKey)}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
propKey !== "credentials" &&
|
propKey !== "credentials" &&
|
||||||
@@ -873,6 +1056,22 @@ export const CustomNode = React.memo(
|
|||||||
<ContextMenuContent />
|
<ContextMenuContent />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* Sub-agent Update Bar - shown below header */}
|
||||||
|
{isAgentBlock && (subAgentUpdate.hasUpdate || isInResolutionMode) && (
|
||||||
|
<SubAgentUpdateBar
|
||||||
|
currentVersion={subAgentUpdate.currentVersion}
|
||||||
|
latestVersion={subAgentUpdate.latestVersion}
|
||||||
|
isCompatible={subAgentUpdate.isCompatible}
|
||||||
|
incompatibilities={
|
||||||
|
isInResolutionMode
|
||||||
|
? resolutionMode.incompatibilities
|
||||||
|
: subAgentUpdate.incompatibilities
|
||||||
|
}
|
||||||
|
onUpdate={handleUpdateClick}
|
||||||
|
isInResolutionMode={isInResolutionMode}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Body */}
|
{/* Body */}
|
||||||
<div className="mx-5 my-6 rounded-b-xl">
|
<div className="mx-5 my-6 rounded-b-xl">
|
||||||
{/* Input Handles */}
|
{/* Input Handles */}
|
||||||
@@ -1044,9 +1243,24 @@ export const CustomNode = React.memo(
|
|||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ContextMenu.Root>
|
<>
|
||||||
<ContextMenu.Trigger>{nodeContent()}</ContextMenu.Trigger>
|
<ContextMenu.Root>
|
||||||
</ContextMenu.Root>
|
<ContextMenu.Trigger>{nodeContent()}</ContextMenu.Trigger>
|
||||||
|
</ContextMenu.Root>
|
||||||
|
|
||||||
|
{/* Incompatibility Dialog for sub-agent updates */}
|
||||||
|
{isAgentBlock && subAgentUpdate.incompatibilities && (
|
||||||
|
<IncompatibilityDialog
|
||||||
|
isOpen={showIncompatibilityDialog}
|
||||||
|
onClose={() => setShowIncompatibilityDialog(false)}
|
||||||
|
onConfirm={handleConfirmIncompatibleUpdate}
|
||||||
|
currentVersion={subAgentUpdate.currentVersion}
|
||||||
|
latestVersion={subAgentUpdate.latestVersion}
|
||||||
|
agentName={data.blockType || "Agent"}
|
||||||
|
incompatibilities={subAgentUpdate.incompatibilities}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
(prevProps, nextProps) => {
|
(prevProps, nextProps) => {
|
||||||
|
|||||||
@@ -0,0 +1,244 @@
|
|||||||
|
import React from "react";
|
||||||
|
import {
|
||||||
|
Dialog,
|
||||||
|
DialogContent,
|
||||||
|
DialogDescription,
|
||||||
|
DialogFooter,
|
||||||
|
DialogHeader,
|
||||||
|
DialogTitle,
|
||||||
|
} from "@/components/__legacy__/ui/dialog";
|
||||||
|
import { Button } from "@/components/__legacy__/ui/button";
|
||||||
|
import { AlertTriangle, XCircle, PlusCircle } from "lucide-react";
|
||||||
|
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
||||||
|
import { beautifyString } from "@/lib/utils";
|
||||||
|
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||||
|
|
||||||
|
interface IncompatibilityDialogProps {
|
||||||
|
isOpen: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
onConfirm: () => void;
|
||||||
|
currentVersion: number;
|
||||||
|
latestVersion: number;
|
||||||
|
agentName: string;
|
||||||
|
incompatibilities: IncompatibilityInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const IncompatibilityDialog: React.FC<IncompatibilityDialogProps> = ({
|
||||||
|
isOpen,
|
||||||
|
onClose,
|
||||||
|
onConfirm,
|
||||||
|
currentVersion,
|
||||||
|
latestVersion,
|
||||||
|
agentName,
|
||||||
|
incompatibilities,
|
||||||
|
}) => {
|
||||||
|
const hasMissingInputs = incompatibilities.missingInputs.length > 0;
|
||||||
|
const hasMissingOutputs = incompatibilities.missingOutputs.length > 0;
|
||||||
|
const hasNewInputs = incompatibilities.newInputs.length > 0;
|
||||||
|
const hasNewOutputs = incompatibilities.newOutputs.length > 0;
|
||||||
|
const hasNewRequired = incompatibilities.newRequiredInputs.length > 0;
|
||||||
|
const hasTypeMismatches = incompatibilities.inputTypeMismatches.length > 0;
|
||||||
|
|
||||||
|
const hasInputChanges = hasMissingInputs || hasNewInputs;
|
||||||
|
const hasOutputChanges = hasMissingOutputs || hasNewOutputs;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Dialog open={isOpen} onOpenChange={(open) => !open && onClose()}>
|
||||||
|
<DialogContent className="max-w-lg">
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle className="flex items-center gap-2">
|
||||||
|
<AlertTriangle className="h-5 w-5 text-amber-500" />
|
||||||
|
Incompatible Update
|
||||||
|
</DialogTitle>
|
||||||
|
<DialogDescription>
|
||||||
|
Updating <strong>{beautifyString(agentName)}</strong> from v
|
||||||
|
{currentVersion} to v{latestVersion} will break some connections.
|
||||||
|
</DialogDescription>
|
||||||
|
</DialogHeader>
|
||||||
|
|
||||||
|
<div className="space-y-4 py-2">
|
||||||
|
{/* Input changes - two column layout */}
|
||||||
|
{hasInputChanges && (
|
||||||
|
<TwoColumnSection
|
||||||
|
title="Input Changes"
|
||||||
|
leftIcon={<XCircle className="h-4 w-4 text-red-500" />}
|
||||||
|
leftTitle="Removed"
|
||||||
|
leftItems={incompatibilities.missingInputs}
|
||||||
|
rightIcon={<PlusCircle className="h-4 w-4 text-green-500" />}
|
||||||
|
rightTitle="Added"
|
||||||
|
rightItems={incompatibilities.newInputs}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Output changes - two column layout */}
|
||||||
|
{hasOutputChanges && (
|
||||||
|
<TwoColumnSection
|
||||||
|
title="Output Changes"
|
||||||
|
leftIcon={<XCircle className="h-4 w-4 text-red-500" />}
|
||||||
|
leftTitle="Removed"
|
||||||
|
leftItems={incompatibilities.missingOutputs}
|
||||||
|
rightIcon={<PlusCircle className="h-4 w-4 text-green-500" />}
|
||||||
|
rightTitle="Added"
|
||||||
|
rightItems={incompatibilities.newOutputs}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{hasTypeMismatches && (
|
||||||
|
<SingleColumnSection
|
||||||
|
icon={<XCircle className="h-4 w-4 text-red-500" />}
|
||||||
|
title="Type Changed"
|
||||||
|
description="These connected inputs have a different type:"
|
||||||
|
items={incompatibilities.inputTypeMismatches.map(
|
||||||
|
(m) => `${m.name} (${m.oldType} → ${m.newType})`,
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{hasNewRequired && (
|
||||||
|
<SingleColumnSection
|
||||||
|
icon={<PlusCircle className="h-4 w-4 text-amber-500" />}
|
||||||
|
title="New Required Inputs"
|
||||||
|
description="These inputs are now required:"
|
||||||
|
items={incompatibilities.newRequiredInputs}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Alert variant="warning">
|
||||||
|
<AlertDescription>
|
||||||
|
If you proceed, you'll need to remove the broken connections
|
||||||
|
before you can save or run your agent.
|
||||||
|
</AlertDescription>
|
||||||
|
</Alert>
|
||||||
|
|
||||||
|
<DialogFooter className="gap-2 sm:gap-0">
|
||||||
|
<Button variant="outline" onClick={onClose}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="destructive"
|
||||||
|
onClick={onConfirm}
|
||||||
|
className="bg-amber-600 hover:bg-amber-700"
|
||||||
|
>
|
||||||
|
Update Anyway
|
||||||
|
</Button>
|
||||||
|
</DialogFooter>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
interface TwoColumnSectionProps {
|
||||||
|
title: string;
|
||||||
|
leftIcon: React.ReactNode;
|
||||||
|
leftTitle: string;
|
||||||
|
leftItems: string[];
|
||||||
|
rightIcon: React.ReactNode;
|
||||||
|
rightTitle: string;
|
||||||
|
rightItems: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
const TwoColumnSection: React.FC<TwoColumnSectionProps> = ({
|
||||||
|
title,
|
||||||
|
leftIcon,
|
||||||
|
leftTitle,
|
||||||
|
leftItems,
|
||||||
|
rightIcon,
|
||||||
|
rightTitle,
|
||||||
|
rightItems,
|
||||||
|
}) => (
|
||||||
|
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||||
|
<span className="font-medium">{title}</span>
|
||||||
|
<div className="mt-2 grid grid-cols-2 items-start gap-4">
|
||||||
|
{/* Left column - Breaking changes */}
|
||||||
|
<div className="min-w-0">
|
||||||
|
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{leftIcon}
|
||||||
|
<span>{leftTitle}</span>
|
||||||
|
</div>
|
||||||
|
<ul className="mt-1.5 space-y-1">
|
||||||
|
{leftItems.length > 0 ? (
|
||||||
|
leftItems.map((item) => (
|
||||||
|
<li
|
||||||
|
key={item}
|
||||||
|
className="text-sm text-gray-700 dark:text-gray-300"
|
||||||
|
>
|
||||||
|
<code className="rounded bg-red-50 px-1 py-0.5 font-mono text-xs text-red-700 dark:bg-red-900/30 dark:text-red-300">
|
||||||
|
{item}
|
||||||
|
</code>
|
||||||
|
</li>
|
||||||
|
))
|
||||||
|
) : (
|
||||||
|
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||||
|
None
|
||||||
|
</li>
|
||||||
|
)}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Right column - Possible solutions */}
|
||||||
|
<div className="min-w-0">
|
||||||
|
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{rightIcon}
|
||||||
|
<span>{rightTitle}</span>
|
||||||
|
</div>
|
||||||
|
<ul className="mt-1.5 space-y-1">
|
||||||
|
{rightItems.length > 0 ? (
|
||||||
|
rightItems.map((item) => (
|
||||||
|
<li
|
||||||
|
key={item}
|
||||||
|
className="text-sm text-gray-700 dark:text-gray-300"
|
||||||
|
>
|
||||||
|
<code className="rounded bg-green-50 px-1 py-0.5 font-mono text-xs text-green-700 dark:bg-green-900/30 dark:text-green-300">
|
||||||
|
{item}
|
||||||
|
</code>
|
||||||
|
</li>
|
||||||
|
))
|
||||||
|
) : (
|
||||||
|
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||||
|
None
|
||||||
|
</li>
|
||||||
|
)}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
|
interface SingleColumnSectionProps {
|
||||||
|
icon: React.ReactNode;
|
||||||
|
title: string;
|
||||||
|
description: string;
|
||||||
|
items: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
const SingleColumnSection: React.FC<SingleColumnSectionProps> = ({
|
||||||
|
icon,
|
||||||
|
title,
|
||||||
|
description,
|
||||||
|
items,
|
||||||
|
}) => (
|
||||||
|
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
{icon}
|
||||||
|
<span className="font-medium">{title}</span>
|
||||||
|
</div>
|
||||||
|
<p className="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{description}
|
||||||
|
</p>
|
||||||
|
<ul className="mt-2 space-y-1">
|
||||||
|
{items.map((item) => (
|
||||||
|
<li
|
||||||
|
key={item}
|
||||||
|
className="ml-4 list-disc text-sm text-gray-700 dark:text-gray-300"
|
||||||
|
>
|
||||||
|
<code className="rounded bg-gray-100 px-1 py-0.5 font-mono text-xs dark:bg-gray-800">
|
||||||
|
{item}
|
||||||
|
</code>
|
||||||
|
</li>
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
|
export default IncompatibilityDialog;
|
||||||
@@ -0,0 +1,130 @@
|
|||||||
|
import React from "react";
|
||||||
|
import { Button } from "@/components/__legacy__/ui/button";
|
||||||
|
import { ArrowUp, AlertTriangle, Info } from "lucide-react";
|
||||||
|
import {
|
||||||
|
Tooltip,
|
||||||
|
TooltipContent,
|
||||||
|
TooltipTrigger,
|
||||||
|
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||||
|
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
interface SubAgentUpdateBarProps {
|
||||||
|
currentVersion: number;
|
||||||
|
latestVersion: number;
|
||||||
|
isCompatible: boolean;
|
||||||
|
incompatibilities: IncompatibilityInfo | null;
|
||||||
|
onUpdate: () => void;
|
||||||
|
isInResolutionMode?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const SubAgentUpdateBar: React.FC<SubAgentUpdateBarProps> = ({
|
||||||
|
currentVersion,
|
||||||
|
latestVersion,
|
||||||
|
isCompatible,
|
||||||
|
incompatibilities,
|
||||||
|
onUpdate,
|
||||||
|
isInResolutionMode = false,
|
||||||
|
}) => {
|
||||||
|
if (isInResolutionMode) {
|
||||||
|
return <ResolutionModeBar incompatibilities={incompatibilities} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex items-center justify-between gap-2 rounded-t-lg bg-blue-50 px-3 py-2 dark:bg-blue-900/30">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<ArrowUp className="h-4 w-4 text-blue-600 dark:text-blue-400" />
|
||||||
|
<span className="text-sm text-blue-700 dark:text-blue-300">
|
||||||
|
Update available (v{currentVersion} → v{latestVersion})
|
||||||
|
</span>
|
||||||
|
{!isCompatible && (
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<AlertTriangle className="h-4 w-4 text-amber-500" />
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent className="max-w-xs">
|
||||||
|
<p className="font-medium">Incompatible changes detected</p>
|
||||||
|
<p className="text-xs text-gray-400">
|
||||||
|
Click Update to see details
|
||||||
|
</p>
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant={isCompatible ? "default" : "outline"}
|
||||||
|
onClick={onUpdate}
|
||||||
|
className={cn(
|
||||||
|
"h-7 text-xs",
|
||||||
|
!isCompatible && "border-amber-500 text-amber-600 hover:bg-amber-50",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
Update
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
interface ResolutionModeBarProps {
|
||||||
|
incompatibilities: IncompatibilityInfo | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ResolutionModeBar: React.FC<ResolutionModeBarProps> = ({
|
||||||
|
incompatibilities,
|
||||||
|
}) => {
|
||||||
|
const formatIncompatibilities = () => {
|
||||||
|
if (!incompatibilities) return "No incompatibilities";
|
||||||
|
|
||||||
|
const items: string[] = [];
|
||||||
|
|
||||||
|
if (incompatibilities.missingInputs.length > 0) {
|
||||||
|
items.push(
|
||||||
|
`Missing inputs: ${incompatibilities.missingInputs.join(", ")}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (incompatibilities.missingOutputs.length > 0) {
|
||||||
|
items.push(
|
||||||
|
`Missing outputs: ${incompatibilities.missingOutputs.join(", ")}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (incompatibilities.newRequiredInputs.length > 0) {
|
||||||
|
items.push(
|
||||||
|
`New required inputs: ${incompatibilities.newRequiredInputs.join(", ")}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (incompatibilities.inputTypeMismatches.length > 0) {
|
||||||
|
const mismatches = incompatibilities.inputTypeMismatches
|
||||||
|
.map((m) => `${m.name} (${m.oldType} → ${m.newType})`)
|
||||||
|
.join(", ");
|
||||||
|
items.push(`Type changed: ${mismatches}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return items.join("\n");
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex items-center justify-between gap-2 rounded-t-lg bg-amber-50 px-3 py-2 dark:bg-amber-900/30">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<AlertTriangle className="h-4 w-4 text-amber-600 dark:text-amber-400" />
|
||||||
|
<span className="text-sm text-amber-700 dark:text-amber-300">
|
||||||
|
Remove incompatible connections
|
||||||
|
</span>
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<Info className="h-4 w-4 cursor-help text-amber-500" />
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent className="max-w-sm whitespace-pre-line">
|
||||||
|
<p className="font-medium">Incompatible changes:</p>
|
||||||
|
<p className="mt-1 text-xs">{formatIncompatibilities()}</p>
|
||||||
|
<p className="mt-2 text-xs text-gray-400">
|
||||||
|
Delete the red connections to continue
|
||||||
|
</p>
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SubAgentUpdateBar;
|
||||||
@@ -26,15 +26,17 @@ import {
|
|||||||
applyNodeChanges,
|
applyNodeChanges,
|
||||||
} from "@xyflow/react";
|
} from "@xyflow/react";
|
||||||
import "@xyflow/react/dist/style.css";
|
import "@xyflow/react/dist/style.css";
|
||||||
import { CustomNode } from "../CustomNode/CustomNode";
|
import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode";
|
||||||
import "./flow.css";
|
import "./flow.css";
|
||||||
import {
|
import {
|
||||||
BlockUIType,
|
BlockUIType,
|
||||||
formatEdgeID,
|
formatEdgeID,
|
||||||
GraphExecutionID,
|
GraphExecutionID,
|
||||||
GraphID,
|
GraphID,
|
||||||
|
GraphMeta,
|
||||||
LibraryAgent,
|
LibraryAgent,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
|
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
||||||
import { Key, storage } from "@/services/storage/local-storage";
|
import { Key, storage } from "@/services/storage/local-storage";
|
||||||
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
|
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
|
||||||
import { history } from "../history";
|
import { history } from "../history";
|
||||||
@@ -72,12 +74,30 @@ import { FloatingSafeModeToggle } from "../../FloatingSafeModeToogle";
|
|||||||
// It helps to prevent spamming the history with small movements especially when pressing on a input in a block
|
// It helps to prevent spamming the history with small movements especially when pressing on a input in a block
|
||||||
const MINIMUM_MOVE_BEFORE_LOG = 50;
|
const MINIMUM_MOVE_BEFORE_LOG = 50;
|
||||||
|
|
||||||
|
export type ResolutionModeState = {
|
||||||
|
active: boolean;
|
||||||
|
nodeId: string | null;
|
||||||
|
incompatibilities: IncompatibilityInfo | null;
|
||||||
|
brokenEdgeIds: string[];
|
||||||
|
pendingUpdate: Record<string, unknown> | null; // The hardcoded values to apply after resolution
|
||||||
|
};
|
||||||
|
|
||||||
type BuilderContextType = {
|
type BuilderContextType = {
|
||||||
libraryAgent: LibraryAgent | null;
|
libraryAgent: LibraryAgent | null;
|
||||||
visualizeBeads: "no" | "static" | "animate";
|
visualizeBeads: "no" | "static" | "animate";
|
||||||
setIsAnyModalOpen: (isOpen: boolean) => void;
|
setIsAnyModalOpen: (isOpen: boolean) => void;
|
||||||
getNextNodeId: () => string;
|
getNextNodeId: () => string;
|
||||||
getNodeTitle: (nodeID: string) => string | null;
|
getNodeTitle: (nodeID: string) => string | null;
|
||||||
|
availableFlows: GraphMeta[];
|
||||||
|
resolutionMode: ResolutionModeState;
|
||||||
|
enterResolutionMode: (
|
||||||
|
nodeId: string,
|
||||||
|
incompatibilities: IncompatibilityInfo,
|
||||||
|
brokenEdgeIds: string[],
|
||||||
|
pendingUpdate: Record<string, unknown>,
|
||||||
|
) => void;
|
||||||
|
exitResolutionMode: () => void;
|
||||||
|
applyPendingUpdate: () => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type NodeDimension = {
|
export type NodeDimension = {
|
||||||
@@ -172,6 +192,92 @@ const FlowEditor: React.FC<{
|
|||||||
// It stores the dimension of all nodes with position as well
|
// It stores the dimension of all nodes with position as well
|
||||||
const [nodeDimensions, setNodeDimensions] = useState<NodeDimension>({});
|
const [nodeDimensions, setNodeDimensions] = useState<NodeDimension>({});
|
||||||
|
|
||||||
|
// Resolution mode state for sub-agent incompatible updates
|
||||||
|
const [resolutionMode, setResolutionMode] = useState<ResolutionModeState>({
|
||||||
|
active: false,
|
||||||
|
nodeId: null,
|
||||||
|
incompatibilities: null,
|
||||||
|
brokenEdgeIds: [],
|
||||||
|
pendingUpdate: null,
|
||||||
|
});
|
||||||
|
|
||||||
|
const enterResolutionMode = useCallback(
|
||||||
|
(
|
||||||
|
nodeId: string,
|
||||||
|
incompatibilities: IncompatibilityInfo,
|
||||||
|
brokenEdgeIds: string[],
|
||||||
|
pendingUpdate: Record<string, unknown>,
|
||||||
|
) => {
|
||||||
|
setResolutionMode({
|
||||||
|
active: true,
|
||||||
|
nodeId,
|
||||||
|
incompatibilities,
|
||||||
|
brokenEdgeIds,
|
||||||
|
pendingUpdate,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
);
|
||||||
|
|
||||||
|
const exitResolutionMode = useCallback(() => {
|
||||||
|
setResolutionMode({
|
||||||
|
active: false,
|
||||||
|
nodeId: null,
|
||||||
|
incompatibilities: null,
|
||||||
|
brokenEdgeIds: [],
|
||||||
|
pendingUpdate: null,
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Apply pending update after resolution mode completes
|
||||||
|
const applyPendingUpdate = useCallback(() => {
|
||||||
|
if (!resolutionMode.nodeId || !resolutionMode.pendingUpdate) return;
|
||||||
|
|
||||||
|
const node = nodes.find((n) => n.id === resolutionMode.nodeId);
|
||||||
|
if (node) {
|
||||||
|
const pendingUpdate = resolutionMode.pendingUpdate as {
|
||||||
|
[key: string]: any;
|
||||||
|
};
|
||||||
|
setNodes((nds) =>
|
||||||
|
nds.map((n) =>
|
||||||
|
n.id === resolutionMode.nodeId
|
||||||
|
? { ...n, data: { ...n.data, hardcodedValues: pendingUpdate } }
|
||||||
|
: n,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
exitResolutionMode();
|
||||||
|
toast({
|
||||||
|
title: "Update complete",
|
||||||
|
description: "Agent has been updated to the new version.",
|
||||||
|
});
|
||||||
|
}, [resolutionMode, nodes, setNodes, exitResolutionMode, toast]);
|
||||||
|
|
||||||
|
// Check if all broken edges have been removed and auto-apply pending update
|
||||||
|
useEffect(() => {
|
||||||
|
if (!resolutionMode.active || resolutionMode.brokenEdgeIds.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const currentEdgeIds = new Set(edges.map((e) => e.id));
|
||||||
|
const remainingBrokenEdges = resolutionMode.brokenEdgeIds.filter((id) =>
|
||||||
|
currentEdgeIds.has(id),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (remainingBrokenEdges.length === 0) {
|
||||||
|
// All broken edges have been removed, apply pending update
|
||||||
|
applyPendingUpdate();
|
||||||
|
} else if (
|
||||||
|
remainingBrokenEdges.length !== resolutionMode.brokenEdgeIds.length
|
||||||
|
) {
|
||||||
|
// Update the list of broken edges
|
||||||
|
setResolutionMode((prev) => ({
|
||||||
|
...prev,
|
||||||
|
brokenEdgeIds: remainingBrokenEdges,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}, [edges, resolutionMode, applyPendingUpdate]);
|
||||||
|
|
||||||
// Set page title with or without graph name
|
// Set page title with or without graph name
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
document.title = savedAgent
|
document.title = savedAgent
|
||||||
@@ -431,17 +537,19 @@ const FlowEditor: React.FC<{
|
|||||||
...node.data.connections.filter(
|
...node.data.connections.filter(
|
||||||
(conn) =>
|
(conn) =>
|
||||||
!removedEdges.some(
|
!removedEdges.some(
|
||||||
(removedEdge) => removedEdge.id === conn.edge_id,
|
(removedEdge) => removedEdge.id === conn.id,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
// Add node connections for added edges
|
// Add node connections for added edges
|
||||||
...addedEdges.map((addedEdge) => ({
|
...addedEdges.map(
|
||||||
edge_id: addedEdge.item.id,
|
(addedEdge): ConnectedEdge => ({
|
||||||
source: addedEdge.item.source,
|
id: addedEdge.item.id,
|
||||||
target: addedEdge.item.target,
|
source: addedEdge.item.source,
|
||||||
sourceHandle: addedEdge.item.sourceHandle!,
|
target: addedEdge.item.target,
|
||||||
targetHandle: addedEdge.item.targetHandle!,
|
sourceHandle: addedEdge.item.sourceHandle!,
|
||||||
})),
|
targetHandle: addedEdge.item.targetHandle!,
|
||||||
|
}),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
@@ -467,13 +575,15 @@ const FlowEditor: React.FC<{
|
|||||||
data: {
|
data: {
|
||||||
...node.data,
|
...node.data,
|
||||||
connections: [
|
connections: [
|
||||||
...replaceEdges.map((replaceEdge) => ({
|
...replaceEdges.map(
|
||||||
edge_id: replaceEdge.item.id,
|
(replaceEdge): ConnectedEdge => ({
|
||||||
source: replaceEdge.item.source,
|
id: replaceEdge.item.id,
|
||||||
target: replaceEdge.item.target,
|
source: replaceEdge.item.source,
|
||||||
sourceHandle: replaceEdge.item.sourceHandle!,
|
target: replaceEdge.item.target,
|
||||||
targetHandle: replaceEdge.item.targetHandle!,
|
sourceHandle: replaceEdge.item.sourceHandle!,
|
||||||
})),
|
targetHandle: replaceEdge.item.targetHandle!,
|
||||||
|
}),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
})),
|
})),
|
||||||
@@ -890,8 +1000,23 @@ const FlowEditor: React.FC<{
|
|||||||
setIsAnyModalOpen,
|
setIsAnyModalOpen,
|
||||||
getNextNodeId,
|
getNextNodeId,
|
||||||
getNodeTitle,
|
getNodeTitle,
|
||||||
|
availableFlows,
|
||||||
|
resolutionMode,
|
||||||
|
enterResolutionMode,
|
||||||
|
exitResolutionMode,
|
||||||
|
applyPendingUpdate,
|
||||||
}),
|
}),
|
||||||
[libraryAgent, visualizeBeads, getNextNodeId, getNodeTitle],
|
[
|
||||||
|
libraryAgent,
|
||||||
|
visualizeBeads,
|
||||||
|
getNextNodeId,
|
||||||
|
getNodeTitle,
|
||||||
|
availableFlows,
|
||||||
|
resolutionMode,
|
||||||
|
enterResolutionMode,
|
||||||
|
applyPendingUpdate,
|
||||||
|
exitResolutionMode,
|
||||||
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -991,6 +1116,7 @@ const FlowEditor: React.FC<{
|
|||||||
onClickScheduleButton={handleScheduleButton}
|
onClickScheduleButton={handleScheduleButton}
|
||||||
isDisabled={!savedAgent}
|
isDisabled={!savedAgent}
|
||||||
isRunning={isRunning}
|
isRunning={isRunning}
|
||||||
|
resolutionModeActive={resolutionMode.active}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<Alert className="absolute bottom-4 left-1/2 z-20 w-auto -translate-x-1/2 select-none">
|
<Alert className="absolute bottom-4 left-1/2 z-20 w-auto -translate-x-1/2 select-none">
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
import { BlockIOSubSchema } from "@/lib/autogpt-server-api/types";
|
import { BlockIOSubSchema } from "@/lib/autogpt-server-api/types";
|
||||||
import { cn } from "@/lib/utils";
|
import {
|
||||||
import { beautifyString, getTypeBgColor, getTypeTextColor } from "@/lib/utils";
|
cn,
|
||||||
|
beautifyString,
|
||||||
|
getTypeBgColor,
|
||||||
|
getTypeTextColor,
|
||||||
|
getEffectiveType,
|
||||||
|
} from "@/lib/utils";
|
||||||
import { FC, memo, useCallback } from "react";
|
import { FC, memo, useCallback } from "react";
|
||||||
import { Handle, Position } from "@xyflow/react";
|
import { Handle, Position } from "@xyflow/react";
|
||||||
import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip";
|
import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip";
|
||||||
@@ -13,6 +18,7 @@ type HandleProps = {
|
|||||||
side: "left" | "right";
|
side: "left" | "right";
|
||||||
title?: string;
|
title?: string;
|
||||||
className?: string;
|
className?: string;
|
||||||
|
isBroken?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Move the constant out of the component to avoid re-creation on every render.
|
// Move the constant out of the component to avoid re-creation on every render.
|
||||||
@@ -27,18 +33,23 @@ const TYPE_NAME: Record<string, string> = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Extract and memoize the Dot component so that it doesn't re-render unnecessarily.
|
// Extract and memoize the Dot component so that it doesn't re-render unnecessarily.
|
||||||
const Dot: FC<{ isConnected: boolean; type?: string }> = memo(
|
const Dot: FC<{ isConnected: boolean; type?: string; isBroken?: boolean }> =
|
||||||
({ isConnected, type }) => {
|
memo(({ isConnected, type, isBroken }) => {
|
||||||
const color = isConnected
|
const color = isBroken
|
||||||
? getTypeBgColor(type || "any")
|
? "border-red-500 bg-red-100 dark:bg-red-900/30"
|
||||||
: "border-gray-300 dark:border-gray-600";
|
: isConnected
|
||||||
|
? getTypeBgColor(type || "any")
|
||||||
|
: "border-gray-300 dark:border-gray-600";
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={`${color} m-1 h-4 w-4 rounded-full border-2 bg-white transition-colors duration-100 group-hover:bg-gray-300 dark:bg-slate-800 dark:group-hover:bg-gray-700`}
|
className={cn(
|
||||||
|
"m-1 h-4 w-4 rounded-full border-2 bg-white transition-colors duration-100 group-hover:bg-gray-300 dark:bg-slate-800 dark:group-hover:bg-gray-700",
|
||||||
|
color,
|
||||||
|
isBroken && "opacity-50",
|
||||||
|
)}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
},
|
});
|
||||||
);
|
|
||||||
Dot.displayName = "Dot";
|
Dot.displayName = "Dot";
|
||||||
|
|
||||||
const NodeHandle: FC<HandleProps> = ({
|
const NodeHandle: FC<HandleProps> = ({
|
||||||
@@ -49,24 +60,34 @@ const NodeHandle: FC<HandleProps> = ({
|
|||||||
side,
|
side,
|
||||||
title,
|
title,
|
||||||
className,
|
className,
|
||||||
|
isBroken = false,
|
||||||
}) => {
|
}) => {
|
||||||
const typeClass = `text-sm ${getTypeTextColor(schema.type || "any")} ${
|
// Extract effective type from schema (handles anyOf/oneOf/allOf wrappers)
|
||||||
|
const effectiveType = getEffectiveType(schema);
|
||||||
|
|
||||||
|
const typeClass = `text-sm ${getTypeTextColor(effectiveType || "any")} ${
|
||||||
side === "left" ? "text-left" : "text-right"
|
side === "left" ? "text-left" : "text-right"
|
||||||
}`;
|
}`;
|
||||||
|
|
||||||
const label = (
|
const label = (
|
||||||
<div className="flex flex-grow flex-row">
|
<div className={cn("flex flex-grow flex-row", isBroken && "opacity-50")}>
|
||||||
<span
|
<span
|
||||||
className={cn(
|
className={cn(
|
||||||
"data-sentry-unmask text-m green flex items-end pr-2 text-gray-900 dark:text-gray-100",
|
"data-sentry-unmask text-m green flex items-end pr-2 text-gray-900 dark:text-gray-100",
|
||||||
className,
|
className,
|
||||||
|
isBroken && "text-red-500 line-through",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{title || schema.title || beautifyString(keyName.toLowerCase())}
|
{title || schema.title || beautifyString(keyName.toLowerCase())}
|
||||||
{isRequired ? "*" : ""}
|
{isRequired ? "*" : ""}
|
||||||
</span>
|
</span>
|
||||||
<span className={`${typeClass} data-sentry-unmask flex items-end`}>
|
<span
|
||||||
({TYPE_NAME[schema.type as keyof typeof TYPE_NAME] || "any"})
|
className={cn(
|
||||||
|
`${typeClass} data-sentry-unmask flex items-end`,
|
||||||
|
isBroken && "text-red-400",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
({TYPE_NAME[effectiveType as keyof typeof TYPE_NAME] || "any"})
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
@@ -84,7 +105,7 @@ const NodeHandle: FC<HandleProps> = ({
|
|||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
key={keyName}
|
key={keyName}
|
||||||
className="handle-container"
|
className={cn("handle-container", isBroken && "pointer-events-none")}
|
||||||
onContextMenu={handleContextMenu}
|
onContextMenu={handleContextMenu}
|
||||||
>
|
>
|
||||||
<Handle
|
<Handle
|
||||||
@@ -92,10 +113,15 @@ const NodeHandle: FC<HandleProps> = ({
|
|||||||
data-testid={`input-handle-${keyName}`}
|
data-testid={`input-handle-${keyName}`}
|
||||||
position={Position.Left}
|
position={Position.Left}
|
||||||
id={keyName}
|
id={keyName}
|
||||||
className="group -ml-[38px]"
|
className={cn("group -ml-[38px]", isBroken && "cursor-not-allowed")}
|
||||||
|
isConnectable={!isBroken}
|
||||||
>
|
>
|
||||||
<div className="pointer-events-none flex items-center">
|
<div className="pointer-events-none flex items-center">
|
||||||
<Dot isConnected={isConnected} type={schema.type} />
|
<Dot
|
||||||
|
isConnected={isConnected}
|
||||||
|
type={effectiveType}
|
||||||
|
isBroken={isBroken}
|
||||||
|
/>
|
||||||
{label}
|
{label}
|
||||||
</div>
|
</div>
|
||||||
</Handle>
|
</Handle>
|
||||||
@@ -106,7 +132,10 @@ const NodeHandle: FC<HandleProps> = ({
|
|||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
key={keyName}
|
key={keyName}
|
||||||
className="handle-container justify-end"
|
className={cn(
|
||||||
|
"handle-container justify-end",
|
||||||
|
isBroken && "pointer-events-none",
|
||||||
|
)}
|
||||||
onContextMenu={handleContextMenu}
|
onContextMenu={handleContextMenu}
|
||||||
>
|
>
|
||||||
<Handle
|
<Handle
|
||||||
@@ -114,11 +143,16 @@ const NodeHandle: FC<HandleProps> = ({
|
|||||||
data-testid={`output-handle-${keyName}`}
|
data-testid={`output-handle-${keyName}`}
|
||||||
position={Position.Right}
|
position={Position.Right}
|
||||||
id={keyName}
|
id={keyName}
|
||||||
className="group -mr-[38px]"
|
className={cn("group -mr-[38px]", isBroken && "cursor-not-allowed")}
|
||||||
|
isConnectable={!isBroken}
|
||||||
>
|
>
|
||||||
<div className="pointer-events-none flex items-center">
|
<div className="pointer-events-none flex items-center">
|
||||||
{label}
|
{label}
|
||||||
<Dot isConnected={isConnected} type={schema.type} />
|
<Dot
|
||||||
|
isConnected={isConnected}
|
||||||
|
type={effectiveType}
|
||||||
|
isBroken={isBroken}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
</Handle>
|
</Handle>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
ConnectionData,
|
ConnectedEdge,
|
||||||
CustomNodeData,
|
CustomNodeData,
|
||||||
} from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
|
} from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
|
||||||
import { NodeTableInput } from "@/app/(platform)/build/components/legacy-builder/NodeTableInput";
|
import { NodeTableInput } from "@/app/(platform)/build/components/legacy-builder/NodeTableInput";
|
||||||
@@ -65,7 +65,7 @@ type NodeObjectInputTreeProps = {
|
|||||||
selfKey?: string;
|
selfKey?: string;
|
||||||
schema: BlockIORootSchema | BlockIOObjectSubSchema;
|
schema: BlockIORootSchema | BlockIOObjectSubSchema;
|
||||||
object?: { [key: string]: any };
|
object?: { [key: string]: any };
|
||||||
connections: ConnectionData;
|
connections: ConnectedEdge[];
|
||||||
handleInputClick: (key: string) => void;
|
handleInputClick: (key: string) => void;
|
||||||
handleInputChange: (key: string, value: any) => void;
|
handleInputChange: (key: string, value: any) => void;
|
||||||
errors: { [key: string]: string | undefined };
|
errors: { [key: string]: string | undefined };
|
||||||
@@ -585,7 +585,7 @@ const NodeOneOfDiscriminatorField: FC<{
|
|||||||
currentValue?: any;
|
currentValue?: any;
|
||||||
defaultValue?: any;
|
defaultValue?: any;
|
||||||
errors: { [key: string]: string | undefined };
|
errors: { [key: string]: string | undefined };
|
||||||
connections: ConnectionData;
|
connections: ConnectedEdge[];
|
||||||
handleInputChange: (key: string, value: any) => void;
|
handleInputChange: (key: string, value: any) => void;
|
||||||
handleInputClick: (key: string) => void;
|
handleInputClick: (key: string) => void;
|
||||||
className?: string;
|
className?: string;
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
import { FC, useCallback, useEffect, useState } from "react";
|
import { FC, useCallback, useEffect, useState } from "react";
|
||||||
|
|
||||||
import NodeHandle from "@/app/(platform)/build/components/legacy-builder/NodeHandle";
|
import NodeHandle from "@/app/(platform)/build/components/legacy-builder/NodeHandle";
|
||||||
import {
|
import type {
|
||||||
BlockIOTableSubSchema,
|
BlockIOTableSubSchema,
|
||||||
TableCellValue,
|
TableCellValue,
|
||||||
TableRow,
|
TableRow,
|
||||||
} from "@/lib/autogpt-server-api/types";
|
} from "@/lib/autogpt-server-api/types";
|
||||||
|
import type { ConnectedEdge } from "./CustomNode/CustomNode";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { PlusIcon, XIcon } from "@phosphor-icons/react";
|
import { PlusIcon, XIcon } from "@phosphor-icons/react";
|
||||||
import { Button } from "../../../../../components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import { Input } from "../../../../../components/atoms/Input/Input";
|
import { Input } from "@/components/atoms/Input/Input";
|
||||||
|
|
||||||
interface NodeTableInputProps {
|
interface NodeTableInputProps {
|
||||||
/** Unique identifier for the node in the builder graph */
|
/** Unique identifier for the node in the builder graph */
|
||||||
@@ -25,13 +26,7 @@ interface NodeTableInputProps {
|
|||||||
/** Validation errors mapped by field key */
|
/** Validation errors mapped by field key */
|
||||||
errors: { [key: string]: string | undefined };
|
errors: { [key: string]: string | undefined };
|
||||||
/** Graph connections between nodes in the builder */
|
/** Graph connections between nodes in the builder */
|
||||||
connections: {
|
connections: ConnectedEdge[];
|
||||||
edge_id: string;
|
|
||||||
source: string;
|
|
||||||
sourceHandle: string;
|
|
||||||
target: string;
|
|
||||||
targetHandle: string;
|
|
||||||
}[];
|
|
||||||
/** Callback when table data changes */
|
/** Callback when table data changes */
|
||||||
handleInputChange: (key: string, value: TableRow[]) => void;
|
handleInputChange: (key: string, value: TableRow[]) => void;
|
||||||
/** Callback when input field is clicked (for builder selection) */
|
/** Callback when input field is clicked (for builder selection) */
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import { useCallback } from "react";
|
import { useCallback } from "react";
|
||||||
import { Node, Edge, useReactFlow } from "@xyflow/react";
|
import { Node, Edge, useReactFlow } from "@xyflow/react";
|
||||||
import { Key, storage } from "@/services/storage/local-storage";
|
import { Key, storage } from "@/services/storage/local-storage";
|
||||||
|
import { ConnectedEdge } from "./CustomNode/CustomNode";
|
||||||
|
|
||||||
interface CopyableData {
|
interface CopyableData {
|
||||||
nodes: Node[];
|
nodes: Node[];
|
||||||
@@ -111,13 +112,15 @@ export function useCopyPaste(getNextNodeId: () => string) {
|
|||||||
(edge: Edge) =>
|
(edge: Edge) =>
|
||||||
edge.source === node.id || edge.target === node.id,
|
edge.source === node.id || edge.target === node.id,
|
||||||
)
|
)
|
||||||
.map((edge: Edge) => ({
|
.map(
|
||||||
edge_id: edge.id,
|
(edge: Edge): ConnectedEdge => ({
|
||||||
source: edge.source,
|
id: edge.id,
|
||||||
target: edge.target,
|
source: edge.source,
|
||||||
sourceHandle: edge.sourceHandle,
|
target: edge.target,
|
||||||
targetHandle: edge.targetHandle,
|
sourceHandle: edge.sourceHandle!,
|
||||||
}));
|
targetHandle: edge.targetHandle!,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...node,
|
...node,
|
||||||
|
|||||||
@@ -0,0 +1,104 @@
|
|||||||
|
import { GraphInputSchema } from "@/lib/autogpt-server-api";
|
||||||
|
import { GraphMetaLike, IncompatibilityInfo } from "./types";
|
||||||
|
|
||||||
|
// Helper type for schema properties - the generated types are too loose
|
||||||
|
type SchemaProperties = Record<string, GraphInputSchema["properties"][string]>;
|
||||||
|
type SchemaRequired = string[];
|
||||||
|
|
||||||
|
// Helper to safely extract schema properties
|
||||||
|
export function getSchemaProperties(schema: unknown): SchemaProperties {
|
||||||
|
if (
|
||||||
|
schema &&
|
||||||
|
typeof schema === "object" &&
|
||||||
|
"properties" in schema &&
|
||||||
|
typeof schema.properties === "object" &&
|
||||||
|
schema.properties !== null
|
||||||
|
) {
|
||||||
|
return schema.properties as SchemaProperties;
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getSchemaRequired(schema: unknown): SchemaRequired {
|
||||||
|
if (
|
||||||
|
schema &&
|
||||||
|
typeof schema === "object" &&
|
||||||
|
"required" in schema &&
|
||||||
|
Array.isArray(schema.required)
|
||||||
|
) {
|
||||||
|
return schema.required as SchemaRequired;
|
||||||
|
}
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates the updated agent node inputs for a sub-agent node
|
||||||
|
*/
|
||||||
|
export function createUpdatedAgentNodeInputs(
|
||||||
|
currentInputs: Record<string, unknown>,
|
||||||
|
latestSubGraphVersion: GraphMetaLike,
|
||||||
|
): Record<string, unknown> {
|
||||||
|
return {
|
||||||
|
...currentInputs,
|
||||||
|
graph_version: latestSubGraphVersion.version,
|
||||||
|
input_schema: latestSubGraphVersion.input_schema,
|
||||||
|
output_schema: latestSubGraphVersion.output_schema,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Generic edge type that works with both builders:
|
||||||
|
* - New builder uses CustomEdge with (formally) optional handles
|
||||||
|
* - Legacy builder uses ConnectedEdge type with required handles */
|
||||||
|
export type EdgeLike = {
|
||||||
|
id: string;
|
||||||
|
source: string;
|
||||||
|
target: string;
|
||||||
|
sourceHandle?: string | null;
|
||||||
|
targetHandle?: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Determines which edges are broken after an incompatible update.
|
||||||
|
* Works with both legacy ConnectedEdge and new CustomEdge.
|
||||||
|
*/
|
||||||
|
export function getBrokenEdgeIDs(
|
||||||
|
connections: EdgeLike[],
|
||||||
|
incompatibilities: IncompatibilityInfo,
|
||||||
|
nodeID: string,
|
||||||
|
): string[] {
|
||||||
|
const brokenEdgeIDs: string[] = [];
|
||||||
|
const typeMismatchInputNames = new Set(
|
||||||
|
incompatibilities.inputTypeMismatches.map((m) => m.name),
|
||||||
|
);
|
||||||
|
|
||||||
|
connections.forEach((conn) => {
|
||||||
|
// Check if this connection uses a missing input (node is target)
|
||||||
|
if (
|
||||||
|
conn.target === nodeID &&
|
||||||
|
conn.targetHandle &&
|
||||||
|
incompatibilities.missingInputs.includes(conn.targetHandle)
|
||||||
|
) {
|
||||||
|
brokenEdgeIDs.push(conn.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this connection uses an input with a type mismatch (node is target)
|
||||||
|
if (
|
||||||
|
conn.target === nodeID &&
|
||||||
|
conn.targetHandle &&
|
||||||
|
typeMismatchInputNames.has(conn.targetHandle)
|
||||||
|
) {
|
||||||
|
brokenEdgeIDs.push(conn.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this connection uses a missing output (node is source)
|
||||||
|
if (
|
||||||
|
conn.source === nodeID &&
|
||||||
|
conn.sourceHandle &&
|
||||||
|
incompatibilities.missingOutputs.includes(conn.sourceHandle)
|
||||||
|
) {
|
||||||
|
brokenEdgeIDs.push(conn.id);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return brokenEdgeIDs;
|
||||||
|
}
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
export { useSubAgentUpdate } from "./useSubAgentUpdate";
|
||||||
|
export { createUpdatedAgentNodeInputs, getBrokenEdgeIDs } from "./helpers";
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
import type { GraphMeta as LegacyGraphMeta } from "@/lib/autogpt-server-api";
|
||||||
|
import type { GraphMeta as GeneratedGraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||||
|
|
||||||
|
export type SubAgentUpdateInfo<T extends GraphMetaLike = GraphMetaLike> = {
|
||||||
|
hasUpdate: boolean;
|
||||||
|
currentVersion: number;
|
||||||
|
latestVersion: number;
|
||||||
|
latestGraph: T | null;
|
||||||
|
isCompatible: boolean;
|
||||||
|
incompatibilities: IncompatibilityInfo | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Union type for GraphMeta that works with both legacy and new builder
|
||||||
|
export type GraphMetaLike = LegacyGraphMeta | GeneratedGraphMeta;
|
||||||
|
|
||||||
|
export type IncompatibilityInfo = {
|
||||||
|
missingInputs: string[]; // Connected inputs that no longer exist
|
||||||
|
missingOutputs: string[]; // Connected outputs that no longer exist
|
||||||
|
newInputs: string[]; // Inputs that exist in new version but not in current
|
||||||
|
newOutputs: string[]; // Outputs that exist in new version but not in current
|
||||||
|
newRequiredInputs: string[]; // New required inputs not in current version or not required
|
||||||
|
inputTypeMismatches: Array<{
|
||||||
|
name: string;
|
||||||
|
oldType: string;
|
||||||
|
newType: string;
|
||||||
|
}>; // Connected inputs where the type has changed
|
||||||
|
};
|
||||||
@@ -0,0 +1,160 @@
|
|||||||
|
import { useMemo } from "react";
|
||||||
|
import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api";
|
||||||
|
import { getEffectiveType } from "@/lib/utils";
|
||||||
|
import { EdgeLike, getSchemaProperties, getSchemaRequired } from "./helpers";
|
||||||
|
import {
|
||||||
|
GraphMetaLike,
|
||||||
|
IncompatibilityInfo,
|
||||||
|
SubAgentUpdateInfo,
|
||||||
|
} from "./types";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if a newer version of a sub-agent is available and determines compatibility
|
||||||
|
*/
|
||||||
|
export function useSubAgentUpdate<T extends GraphMetaLike>(
|
||||||
|
nodeID: string,
|
||||||
|
graphID: string | undefined,
|
||||||
|
graphVersion: number | undefined,
|
||||||
|
currentInputSchema: GraphInputSchema | undefined,
|
||||||
|
currentOutputSchema: GraphOutputSchema | undefined,
|
||||||
|
connections: EdgeLike[],
|
||||||
|
availableGraphs: T[],
|
||||||
|
): SubAgentUpdateInfo<T> {
|
||||||
|
// Find the latest version of the same graph
|
||||||
|
const latestGraph = useMemo(() => {
|
||||||
|
if (!graphID) return null;
|
||||||
|
return availableGraphs.find((graph) => graph.id === graphID) || null;
|
||||||
|
}, [graphID, availableGraphs]);
|
||||||
|
|
||||||
|
// Check if there's an update available
|
||||||
|
const hasUpdate = useMemo(() => {
|
||||||
|
if (!latestGraph || graphVersion === undefined) return false;
|
||||||
|
return latestGraph.version! > graphVersion;
|
||||||
|
}, [latestGraph, graphVersion]);
|
||||||
|
|
||||||
|
// Get connected input and output handles for this specific node
|
||||||
|
const connectedHandles = useMemo(() => {
|
||||||
|
const inputHandles = new Set<string>();
|
||||||
|
const outputHandles = new Set<string>();
|
||||||
|
|
||||||
|
connections.forEach((conn) => {
|
||||||
|
// If this node is the target, the targetHandle is an input on this node
|
||||||
|
if (conn.target === nodeID && conn.targetHandle) {
|
||||||
|
inputHandles.add(conn.targetHandle);
|
||||||
|
}
|
||||||
|
// If this node is the source, the sourceHandle is an output on this node
|
||||||
|
if (conn.source === nodeID && conn.sourceHandle) {
|
||||||
|
outputHandles.add(conn.sourceHandle);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return { inputHandles, outputHandles };
|
||||||
|
}, [connections, nodeID]);
|
||||||
|
|
||||||
|
// Check schema compatibility
|
||||||
|
const compatibilityResult = useMemo((): {
|
||||||
|
isCompatible: boolean;
|
||||||
|
incompatibilities: IncompatibilityInfo | null;
|
||||||
|
} => {
|
||||||
|
if (!hasUpdate || !latestGraph) {
|
||||||
|
return { isCompatible: true, incompatibilities: null };
|
||||||
|
}
|
||||||
|
|
||||||
|
const newInputProps = getSchemaProperties(latestGraph.input_schema);
|
||||||
|
const newOutputProps = getSchemaProperties(latestGraph.output_schema);
|
||||||
|
const newRequiredInputs = getSchemaRequired(latestGraph.input_schema);
|
||||||
|
|
||||||
|
const currentInputProps = getSchemaProperties(currentInputSchema);
|
||||||
|
const currentOutputProps = getSchemaProperties(currentOutputSchema);
|
||||||
|
const currentRequiredInputs = getSchemaRequired(currentInputSchema);
|
||||||
|
|
||||||
|
const incompatibilities: IncompatibilityInfo = {
|
||||||
|
missingInputs: [],
|
||||||
|
missingOutputs: [],
|
||||||
|
newInputs: [],
|
||||||
|
newOutputs: [],
|
||||||
|
newRequiredInputs: [],
|
||||||
|
inputTypeMismatches: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check for missing connected inputs and type mismatches
|
||||||
|
connectedHandles.inputHandles.forEach((inputHandle) => {
|
||||||
|
if (!(inputHandle in newInputProps)) {
|
||||||
|
incompatibilities.missingInputs.push(inputHandle);
|
||||||
|
} else {
|
||||||
|
// Check for type mismatch on connected inputs
|
||||||
|
const currentProp = currentInputProps[inputHandle];
|
||||||
|
const newProp = newInputProps[inputHandle];
|
||||||
|
const currentType = getEffectiveType(currentProp);
|
||||||
|
const newType = getEffectiveType(newProp);
|
||||||
|
|
||||||
|
if (currentType && newType && currentType !== newType) {
|
||||||
|
incompatibilities.inputTypeMismatches.push({
|
||||||
|
name: inputHandle,
|
||||||
|
oldType: currentType,
|
||||||
|
newType: newType,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Check for missing connected outputs
|
||||||
|
connectedHandles.outputHandles.forEach((outputHandle) => {
|
||||||
|
if (!(outputHandle in newOutputProps)) {
|
||||||
|
incompatibilities.missingOutputs.push(outputHandle);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Check for new required inputs that didn't exist or weren't required before
|
||||||
|
newRequiredInputs.forEach((requiredInput) => {
|
||||||
|
const existedBefore = requiredInput in currentInputProps;
|
||||||
|
const wasRequiredBefore = currentRequiredInputs.includes(
|
||||||
|
requiredInput as string,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!existedBefore || !wasRequiredBefore) {
|
||||||
|
incompatibilities.newRequiredInputs.push(requiredInput as string);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Check for new inputs that don't exist in the current version
|
||||||
|
Object.keys(newInputProps).forEach((inputName) => {
|
||||||
|
if (!(inputName in currentInputProps)) {
|
||||||
|
incompatibilities.newInputs.push(inputName);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Check for new outputs that don't exist in the current version
|
||||||
|
Object.keys(newOutputProps).forEach((outputName) => {
|
||||||
|
if (!(outputName in currentOutputProps)) {
|
||||||
|
incompatibilities.newOutputs.push(outputName);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const hasIncompatibilities =
|
||||||
|
incompatibilities.missingInputs.length > 0 ||
|
||||||
|
incompatibilities.missingOutputs.length > 0 ||
|
||||||
|
incompatibilities.newRequiredInputs.length > 0 ||
|
||||||
|
incompatibilities.inputTypeMismatches.length > 0;
|
||||||
|
|
||||||
|
return {
|
||||||
|
isCompatible: !hasIncompatibilities,
|
||||||
|
incompatibilities: hasIncompatibilities ? incompatibilities : null,
|
||||||
|
};
|
||||||
|
}, [
|
||||||
|
hasUpdate,
|
||||||
|
latestGraph,
|
||||||
|
currentInputSchema,
|
||||||
|
currentOutputSchema,
|
||||||
|
connectedHandles,
|
||||||
|
]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
hasUpdate,
|
||||||
|
currentVersion: graphVersion || 0,
|
||||||
|
latestVersion: latestGraph?.version || 0,
|
||||||
|
latestGraph,
|
||||||
|
isCompatible: compatibilityResult.isCompatible,
|
||||||
|
incompatibilities: compatibilityResult.incompatibilities,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import { create } from "zustand";
|
import { create } from "zustand";
|
||||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||||
|
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||||
|
|
||||||
interface GraphStore {
|
interface GraphStore {
|
||||||
graphExecutionStatus: AgentExecutionStatus | undefined;
|
graphExecutionStatus: AgentExecutionStatus | undefined;
|
||||||
@@ -17,6 +18,10 @@ interface GraphStore {
|
|||||||
outputSchema: Record<string, any> | null,
|
outputSchema: Record<string, any> | null,
|
||||||
) => void;
|
) => void;
|
||||||
|
|
||||||
|
// Available graphs; used for sub-graph updates
|
||||||
|
availableSubGraphs: GraphMeta[];
|
||||||
|
setAvailableSubGraphs: (graphs: GraphMeta[]) => void;
|
||||||
|
|
||||||
hasInputs: () => boolean;
|
hasInputs: () => boolean;
|
||||||
hasCredentials: () => boolean;
|
hasCredentials: () => boolean;
|
||||||
hasOutputs: () => boolean;
|
hasOutputs: () => boolean;
|
||||||
@@ -29,6 +34,7 @@ export const useGraphStore = create<GraphStore>((set, get) => ({
|
|||||||
inputSchema: null,
|
inputSchema: null,
|
||||||
credentialsInputSchema: null,
|
credentialsInputSchema: null,
|
||||||
outputSchema: null,
|
outputSchema: null,
|
||||||
|
availableSubGraphs: [],
|
||||||
|
|
||||||
setGraphExecutionStatus: (status: AgentExecutionStatus | undefined) => {
|
setGraphExecutionStatus: (status: AgentExecutionStatus | undefined) => {
|
||||||
set({
|
set({
|
||||||
@@ -46,6 +52,8 @@ export const useGraphStore = create<GraphStore>((set, get) => ({
|
|||||||
setGraphSchemas: (inputSchema, credentialsInputSchema, outputSchema) =>
|
setGraphSchemas: (inputSchema, credentialsInputSchema, outputSchema) =>
|
||||||
set({ inputSchema, credentialsInputSchema, outputSchema }),
|
set({ inputSchema, credentialsInputSchema, outputSchema }),
|
||||||
|
|
||||||
|
setAvailableSubGraphs: (graphs) => set({ availableSubGraphs: graphs }),
|
||||||
|
|
||||||
hasOutputs: () => {
|
hasOutputs: () => {
|
||||||
const { outputSchema } = get();
|
const { outputSchema } = get();
|
||||||
return Object.keys(outputSchema?.properties ?? {}).length > 0;
|
return Object.keys(outputSchema?.properties ?? {}).length > 0;
|
||||||
|
|||||||
@@ -17,6 +17,25 @@ import {
|
|||||||
ensurePathExists,
|
ensurePathExists,
|
||||||
parseHandleIdToPath,
|
parseHandleIdToPath,
|
||||||
} from "@/components/renderers/InputRenderer/helpers";
|
} from "@/components/renderers/InputRenderer/helpers";
|
||||||
|
import { IncompatibilityInfo } from "../hooks/useSubAgentUpdate/types";
|
||||||
|
|
||||||
|
// Resolution mode data stored per node
|
||||||
|
export type NodeResolutionData = {
|
||||||
|
incompatibilities: IncompatibilityInfo;
|
||||||
|
// The NEW schema from the update (what we're updating TO)
|
||||||
|
pendingUpdate: {
|
||||||
|
input_schema: Record<string, unknown>;
|
||||||
|
output_schema: Record<string, unknown>;
|
||||||
|
};
|
||||||
|
// The OLD schema before the update (what we're updating FROM)
|
||||||
|
// Needed to merge and show removed inputs during resolution
|
||||||
|
currentSchema: {
|
||||||
|
input_schema: Record<string, unknown>;
|
||||||
|
output_schema: Record<string, unknown>;
|
||||||
|
};
|
||||||
|
// The full updated hardcoded values to apply when resolution completes
|
||||||
|
pendingHardcodedValues: Record<string, unknown>;
|
||||||
|
};
|
||||||
|
|
||||||
// Minimum movement (in pixels) required before logging position change to history
|
// Minimum movement (in pixels) required before logging position change to history
|
||||||
// Prevents spamming history with small movements when clicking on inputs inside blocks
|
// Prevents spamming history with small movements when clicking on inputs inside blocks
|
||||||
@@ -65,12 +84,32 @@ type NodeStore = {
|
|||||||
backendId: string,
|
backendId: string,
|
||||||
errors: { [key: string]: string },
|
errors: { [key: string]: string },
|
||||||
) => void;
|
) => void;
|
||||||
clearAllNodeErrors: () => void; // Add this
|
|
||||||
|
|
||||||
syncHardcodedValuesWithHandleIds: (nodeId: string) => void;
|
syncHardcodedValuesWithHandleIds: (nodeId: string) => void;
|
||||||
|
|
||||||
// Credentials optional helpers
|
|
||||||
setCredentialsOptional: (nodeId: string, optional: boolean) => void;
|
setCredentialsOptional: (nodeId: string, optional: boolean) => void;
|
||||||
|
clearAllNodeErrors: () => void;
|
||||||
|
|
||||||
|
nodesInResolutionMode: Set<string>;
|
||||||
|
brokenEdgeIDs: Map<string, Set<string>>;
|
||||||
|
nodeResolutionData: Map<string, NodeResolutionData>;
|
||||||
|
setNodeResolutionMode: (
|
||||||
|
nodeID: string,
|
||||||
|
inResolution: boolean,
|
||||||
|
resolutionData?: NodeResolutionData,
|
||||||
|
) => void;
|
||||||
|
isNodeInResolutionMode: (nodeID: string) => boolean;
|
||||||
|
getNodeResolutionData: (nodeID: string) => NodeResolutionData | undefined;
|
||||||
|
setBrokenEdgeIDs: (nodeID: string, edgeIDs: string[]) => void;
|
||||||
|
removeBrokenEdgeID: (nodeID: string, edgeID: string) => void;
|
||||||
|
isEdgeBroken: (edgeID: string) => boolean;
|
||||||
|
clearResolutionState: () => void;
|
||||||
|
|
||||||
|
isInputBroken: (nodeID: string, handleID: string) => boolean;
|
||||||
|
getInputTypeMismatch: (
|
||||||
|
nodeID: string,
|
||||||
|
handleID: string,
|
||||||
|
) => string | undefined;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const useNodeStore = create<NodeStore>((set, get) => ({
|
export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||||
@@ -374,4 +413,99 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
|||||||
|
|
||||||
useHistoryStore.getState().pushState(newState);
|
useHistoryStore.getState().pushState(newState);
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// Sub-agent resolution mode state
|
||||||
|
nodesInResolutionMode: new Set<string>(),
|
||||||
|
brokenEdgeIDs: new Map<string, Set<string>>(),
|
||||||
|
nodeResolutionData: new Map<string, NodeResolutionData>(),
|
||||||
|
|
||||||
|
setNodeResolutionMode: (
|
||||||
|
nodeID: string,
|
||||||
|
inResolution: boolean,
|
||||||
|
resolutionData?: NodeResolutionData,
|
||||||
|
) => {
|
||||||
|
set((state) => {
|
||||||
|
const newNodesSet = new Set(state.nodesInResolutionMode);
|
||||||
|
const newResolutionDataMap = new Map(state.nodeResolutionData);
|
||||||
|
const newBrokenEdgeIDs = new Map(state.brokenEdgeIDs);
|
||||||
|
|
||||||
|
if (inResolution) {
|
||||||
|
newNodesSet.add(nodeID);
|
||||||
|
if (resolutionData) {
|
||||||
|
newResolutionDataMap.set(nodeID, resolutionData);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
newNodesSet.delete(nodeID);
|
||||||
|
newResolutionDataMap.delete(nodeID);
|
||||||
|
newBrokenEdgeIDs.delete(nodeID); // Clean up broken edges when exiting resolution mode
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
nodesInResolutionMode: newNodesSet,
|
||||||
|
nodeResolutionData: newResolutionDataMap,
|
||||||
|
brokenEdgeIDs: newBrokenEdgeIDs,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
isNodeInResolutionMode: (nodeID: string) => {
|
||||||
|
return get().nodesInResolutionMode.has(nodeID);
|
||||||
|
},
|
||||||
|
|
||||||
|
getNodeResolutionData: (nodeID: string) => {
|
||||||
|
return get().nodeResolutionData.get(nodeID);
|
||||||
|
},
|
||||||
|
|
||||||
|
setBrokenEdgeIDs: (nodeID: string, edgeIDs: string[]) => {
|
||||||
|
set((state) => {
|
||||||
|
const newMap = new Map(state.brokenEdgeIDs);
|
||||||
|
newMap.set(nodeID, new Set(edgeIDs));
|
||||||
|
return { brokenEdgeIDs: newMap };
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
removeBrokenEdgeID: (nodeID: string, edgeID: string) => {
|
||||||
|
set((state) => {
|
||||||
|
const newMap = new Map(state.brokenEdgeIDs);
|
||||||
|
const nodeSet = new Set(newMap.get(nodeID) || []);
|
||||||
|
nodeSet.delete(edgeID);
|
||||||
|
newMap.set(nodeID, nodeSet);
|
||||||
|
return { brokenEdgeIDs: newMap };
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
isEdgeBroken: (edgeID: string) => {
|
||||||
|
// Check across all nodes
|
||||||
|
const brokenEdgeIDs = get().brokenEdgeIDs;
|
||||||
|
for (const edgeSet of brokenEdgeIDs.values()) {
|
||||||
|
if (edgeSet.has(edgeID)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
|
||||||
|
clearResolutionState: () => {
|
||||||
|
set({
|
||||||
|
nodesInResolutionMode: new Set<string>(),
|
||||||
|
brokenEdgeIDs: new Map<string, Set<string>>(),
|
||||||
|
nodeResolutionData: new Map<string, NodeResolutionData>(),
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
// Helper functions for input renderers
|
||||||
|
isInputBroken: (nodeID: string, handleID: string) => {
|
||||||
|
const resolutionData = get().nodeResolutionData.get(nodeID);
|
||||||
|
if (!resolutionData) return false;
|
||||||
|
return resolutionData.incompatibilities.missingInputs.includes(handleID);
|
||||||
|
},
|
||||||
|
|
||||||
|
getInputTypeMismatch: (nodeID: string, handleID: string) => {
|
||||||
|
const resolutionData = get().nodeResolutionData.get(nodeID);
|
||||||
|
if (!resolutionData) return undefined;
|
||||||
|
const mismatch = resolutionData.incompatibilities.inputTypeMismatches.find(
|
||||||
|
(m) => m.name === handleID,
|
||||||
|
);
|
||||||
|
return mismatch?.newType;
|
||||||
|
},
|
||||||
}));
|
}));
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ function ErrorPageContent() {
|
|||||||
) {
|
) {
|
||||||
window.location.href = "/login";
|
window.location.href = "/login";
|
||||||
} else {
|
} else {
|
||||||
window.location.href = "/marketplace";
|
window.document.location.reload();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ export const extendedButtonVariants = cva(
|
|||||||
),
|
),
|
||||||
},
|
},
|
||||||
size: {
|
size: {
|
||||||
small: "px-3 py-2 text-sm gap-1.5 h-[2.25rem]",
|
small: "px-3 py-2 text-sm gap-1.5 h-[2.25rem] min-w-[5.5rem]",
|
||||||
large: "px-4 py-3 text-sm gap-2 h-[3.25rem]",
|
large: "px-4 py-3 text-sm gap-2 h-[3.25rem]",
|
||||||
icon: "p-3 !min-w-0",
|
icon: "p-3 !min-w-0",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -30,8 +30,6 @@ export const FormRenderer = ({
|
|||||||
return generateUiSchemaForCustomFields(preprocessedSchema, uiSchema);
|
return generateUiSchemaForCustomFields(preprocessedSchema, uiSchema);
|
||||||
}, [preprocessedSchema, uiSchema]);
|
}, [preprocessedSchema, uiSchema]);
|
||||||
|
|
||||||
console.log("preprocessedSchema", preprocessedSchema);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={"mb-6 mt-4"}>
|
<div className={"mb-6 mt-4"}>
|
||||||
<Form
|
<Form
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ import { FieldProps, getUiOptions, getWidget } from "@rjsf/utils";
|
|||||||
import { AnyOfFieldTitle } from "./components/AnyOfFieldTitle";
|
import { AnyOfFieldTitle } from "./components/AnyOfFieldTitle";
|
||||||
import { isEmpty } from "lodash";
|
import { isEmpty } from "lodash";
|
||||||
import { useAnyOfField } from "./useAnyOfField";
|
import { useAnyOfField } from "./useAnyOfField";
|
||||||
import { getHandleId, updateUiOption } from "../../helpers";
|
import { cleanUpHandleId, getHandleId, updateUiOption } from "../../helpers";
|
||||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||||
import { ANY_OF_FLAG } from "../../constants";
|
import { ANY_OF_FLAG } from "../../constants";
|
||||||
import { findCustomFieldId } from "../../registry";
|
import { findCustomFieldId } from "../../registry";
|
||||||
|
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
export const AnyOfField = (props: FieldProps) => {
|
export const AnyOfField = (props: FieldProps) => {
|
||||||
const { registry, schema } = props;
|
const { registry, schema } = props;
|
||||||
@@ -21,6 +23,8 @@ export const AnyOfField = (props: FieldProps) => {
|
|||||||
field_id,
|
field_id,
|
||||||
} = useAnyOfField(props);
|
} = useAnyOfField(props);
|
||||||
|
|
||||||
|
const isInputBroken = useNodeStore((state) => state.isInputBroken);
|
||||||
|
|
||||||
const parentCustomFieldId = findCustomFieldId(schema);
|
const parentCustomFieldId = findCustomFieldId(schema);
|
||||||
if (parentCustomFieldId) {
|
if (parentCustomFieldId) {
|
||||||
return null;
|
return null;
|
||||||
@@ -43,6 +47,7 @@ export const AnyOfField = (props: FieldProps) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const isHandleConnected = isInputConnected(nodeId, handleId);
|
const isHandleConnected = isInputConnected(nodeId, handleId);
|
||||||
|
const isAnyOfInputBroken = isInputBroken(nodeId, cleanUpHandleId(handleId));
|
||||||
|
|
||||||
// Now anyOf can render - custom fields if the option schema matches a custom field
|
// Now anyOf can render - custom fields if the option schema matches a custom field
|
||||||
const optionCustomFieldId = optionSchema
|
const optionCustomFieldId = optionSchema
|
||||||
@@ -78,7 +83,11 @@ export const AnyOfField = (props: FieldProps) => {
|
|||||||
registry={registry}
|
registry={registry}
|
||||||
placeholder={props.placeholder}
|
placeholder={props.placeholder}
|
||||||
autocomplete={props.autocomplete}
|
autocomplete={props.autocomplete}
|
||||||
className="-ml-1 h-[22px] w-fit gap-1 px-1 pl-2 text-xs font-medium"
|
className={cn(
|
||||||
|
"-ml-1 h-[22px] w-fit gap-1 px-1 pl-2 text-xs font-medium",
|
||||||
|
isAnyOfInputBroken &&
|
||||||
|
"border-red-500 bg-red-100 text-red-600 line-through",
|
||||||
|
)}
|
||||||
autofocus={props.autofocus}
|
autofocus={props.autofocus}
|
||||||
label=""
|
label=""
|
||||||
hideLabel={true}
|
hideLabel={true}
|
||||||
@@ -93,7 +102,7 @@ export const AnyOfField = (props: FieldProps) => {
|
|||||||
selector={selector}
|
selector={selector}
|
||||||
uiSchema={updatedUiSchema}
|
uiSchema={updatedUiSchema}
|
||||||
/>
|
/>
|
||||||
{!isHandleConnected && optionsSchemaField}
|
{!isHandleConnected && !isAnyOfInputBroken && optionsSchemaField}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import { Text } from "@/components/atoms/Text/Text";
|
|||||||
import { isOptionalType } from "../../../utils/schema-utils";
|
import { isOptionalType } from "../../../utils/schema-utils";
|
||||||
import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers";
|
import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||||
|
|
||||||
interface customFieldProps extends FieldProps {
|
interface customFieldProps extends FieldProps {
|
||||||
selector: JSX.Element;
|
selector: JSX.Element;
|
||||||
@@ -51,6 +52,13 @@ export const AnyOfFieldTitle = (props: customFieldProps) => {
|
|||||||
shouldShowTypeSelector(schema) && !isArrayItem && !isHandleConnected;
|
shouldShowTypeSelector(schema) && !isArrayItem && !isHandleConnected;
|
||||||
const shoudlShowType = isHandleConnected || (isOptional && type);
|
const shoudlShowType = isHandleConnected || (isOptional && type);
|
||||||
|
|
||||||
|
const isInputBroken = useNodeStore((state) =>
|
||||||
|
state.isInputBroken(nodeId, cleanUpHandleId(uiOptions.handleId)),
|
||||||
|
);
|
||||||
|
const inputMismatch = useNodeStore((state) =>
|
||||||
|
state.getInputTypeMismatch(nodeId, cleanUpHandleId(uiOptions.handleId)),
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
<TitleFieldTemplate
|
<TitleFieldTemplate
|
||||||
@@ -62,8 +70,16 @@ export const AnyOfFieldTitle = (props: customFieldProps) => {
|
|||||||
uiSchema={uiSchema}
|
uiSchema={uiSchema}
|
||||||
/>
|
/>
|
||||||
{shoudlShowType && (
|
{shoudlShowType && (
|
||||||
<Text variant="small" className={cn("text-zinc-700", colorClass)}>
|
<Text
|
||||||
{isOptional ? `(${displayType})` : "(any)"}
|
variant="small"
|
||||||
|
className={cn(
|
||||||
|
"text-zinc-700",
|
||||||
|
isInputBroken && "line-through",
|
||||||
|
colorClass,
|
||||||
|
inputMismatch && "rounded-md bg-red-100 px-1 !text-red-500",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{isOptional ? `(${inputMismatch || displayType})` : "(any)"}
|
||||||
</Text>
|
</Text>
|
||||||
)}
|
)}
|
||||||
{shouldShowSelector && selector}
|
{shouldShowSelector && selector}
|
||||||
|
|||||||
@@ -9,8 +9,9 @@ import { Text } from "@/components/atoms/Text/Text";
|
|||||||
import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers";
|
import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers";
|
||||||
import { isAnyOfSchema } from "../../utils/schema-utils";
|
import { isAnyOfSchema } from "../../utils/schema-utils";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { isArrayItem } from "../../helpers";
|
import { cleanUpHandleId, isArrayItem } from "../../helpers";
|
||||||
import { InputNodeHandle } from "@/app/(platform)/build/components/FlowEditor/handlers/NodeHandle";
|
import { InputNodeHandle } from "@/app/(platform)/build/components/FlowEditor/handlers/NodeHandle";
|
||||||
|
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||||
|
|
||||||
export default function TitleField(props: TitleFieldProps) {
|
export default function TitleField(props: TitleFieldProps) {
|
||||||
const { id, title, required, schema, registry, uiSchema } = props;
|
const { id, title, required, schema, registry, uiSchema } = props;
|
||||||
@@ -26,6 +27,11 @@ export default function TitleField(props: TitleFieldProps) {
|
|||||||
const smallText = isArrayItemFlag || additional;
|
const smallText = isArrayItemFlag || additional;
|
||||||
|
|
||||||
const showHandle = uiOptions.showHandles ?? showHandles;
|
const showHandle = uiOptions.showHandles ?? showHandles;
|
||||||
|
|
||||||
|
const isInputBroken = useNodeStore((state) =>
|
||||||
|
state.isInputBroken(nodeId, cleanUpHandleId(uiOptions.handleId)),
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex items-center">
|
<div className="flex items-center">
|
||||||
{showHandle !== false && (
|
{showHandle !== false && (
|
||||||
@@ -34,7 +40,11 @@ export default function TitleField(props: TitleFieldProps) {
|
|||||||
<Text
|
<Text
|
||||||
variant={isArrayItemFlag ? "small" : "body"}
|
variant={isArrayItemFlag ? "small" : "body"}
|
||||||
id={id}
|
id={id}
|
||||||
className={cn("line-clamp-1", smallText && "text-sm text-zinc-700")}
|
className={cn(
|
||||||
|
"line-clamp-1",
|
||||||
|
smallText && "text-sm text-zinc-700",
|
||||||
|
isInputBroken && "text-red-500 line-through",
|
||||||
|
)}
|
||||||
>
|
>
|
||||||
{title}
|
{title}
|
||||||
</Text>
|
</Text>
|
||||||
@@ -44,7 +54,7 @@ export default function TitleField(props: TitleFieldProps) {
|
|||||||
{!isAnyOf && (
|
{!isAnyOf && (
|
||||||
<Text
|
<Text
|
||||||
variant="small"
|
variant="small"
|
||||||
className={cn("ml-2", colorClass)}
|
className={cn("ml-2", isInputBroken && "line-through", colorClass)}
|
||||||
id={description_id}
|
id={description_id}
|
||||||
>
|
>
|
||||||
({displayType})
|
({displayType})
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ export function updateUiOption<T extends Record<string, any>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const cleanUpHandleId = (handleId: string) => {
|
export const cleanUpHandleId = (handleId: string) => {
|
||||||
|
if (!handleId) return "";
|
||||||
|
|
||||||
let newHandleId = handleId;
|
let newHandleId = handleId;
|
||||||
if (handleId.includes(ANY_OF_FLAG)) {
|
if (handleId.includes(ANY_OF_FLAG)) {
|
||||||
newHandleId = newHandleId.replace(ANY_OF_FLAG, "");
|
newHandleId = newHandleId.replace(ANY_OF_FLAG, "");
|
||||||
|
|||||||
@@ -233,13 +233,14 @@ export default function useAgentGraph(
|
|||||||
title: `${block.name} ${node.id}`,
|
title: `${block.name} ${node.id}`,
|
||||||
inputSchema: block.inputSchema,
|
inputSchema: block.inputSchema,
|
||||||
outputSchema: block.outputSchema,
|
outputSchema: block.outputSchema,
|
||||||
|
isOutputStatic: block.staticOutput,
|
||||||
hardcodedValues: node.input_default,
|
hardcodedValues: node.input_default,
|
||||||
uiType: block.uiType,
|
uiType: block.uiType,
|
||||||
metadata: metadata,
|
metadata: metadata,
|
||||||
connections: graph.links
|
connections: graph.links
|
||||||
.filter((l) => [l.source_id, l.sink_id].includes(node.id))
|
.filter((l) => [l.source_id, l.sink_id].includes(node.id))
|
||||||
.map((link) => ({
|
.map((link) => ({
|
||||||
edge_id: formatEdgeID(link),
|
id: formatEdgeID(link),
|
||||||
source: link.source_id,
|
source: link.source_id,
|
||||||
sourceHandle: link.source_name,
|
sourceHandle: link.source_name,
|
||||||
target: link.sink_id,
|
target: link.sink_id,
|
||||||
|
|||||||
@@ -245,8 +245,8 @@ export type BlockIONullSubSchema = BlockIOSubSchemaMeta & {
|
|||||||
// At the time of writing, combined schemas only occur on the first nested level in a
|
// At the time of writing, combined schemas only occur on the first nested level in a
|
||||||
// block schema. It is typed this way to make the use of these objects less tedious.
|
// block schema. It is typed this way to make the use of these objects less tedious.
|
||||||
type BlockIOCombinedTypeSubSchema = BlockIOSubSchemaMeta & {
|
type BlockIOCombinedTypeSubSchema = BlockIOSubSchemaMeta & {
|
||||||
type: never;
|
type?: never;
|
||||||
const: never;
|
const?: never;
|
||||||
} & (
|
} & (
|
||||||
| {
|
| {
|
||||||
allOf: [BlockIOSimpleTypeSubSchema];
|
allOf: [BlockIOSimpleTypeSubSchema];
|
||||||
@@ -368,8 +368,8 @@ export type GraphMeta = {
|
|||||||
recommended_schedule_cron: string | null;
|
recommended_schedule_cron: string | null;
|
||||||
forked_from_id?: GraphID | null;
|
forked_from_id?: GraphID | null;
|
||||||
forked_from_version?: number | null;
|
forked_from_version?: number | null;
|
||||||
input_schema: GraphIOSchema;
|
input_schema: GraphInputSchema;
|
||||||
output_schema: GraphIOSchema;
|
output_schema: GraphOutputSchema;
|
||||||
credentials_input_schema: CredentialsInputSchema;
|
credentials_input_schema: CredentialsInputSchema;
|
||||||
} & (
|
} & (
|
||||||
| {
|
| {
|
||||||
@@ -385,19 +385,51 @@ export type GraphMeta = {
|
|||||||
export type GraphID = Brand<string, "GraphID">;
|
export type GraphID = Brand<string, "GraphID">;
|
||||||
|
|
||||||
/* Derived from backend/data/graph.py:Graph._generate_schema() */
|
/* Derived from backend/data/graph.py:Graph._generate_schema() */
|
||||||
export type GraphIOSchema = {
|
export type GraphInputSchema = {
|
||||||
type: "object";
|
type: "object";
|
||||||
properties: Record<string, GraphIOSubSchema>;
|
properties: Record<string, GraphInputSubSchema>;
|
||||||
required: (keyof BlockIORootSchema["properties"])[];
|
required: (keyof GraphInputSchema["properties"])[];
|
||||||
};
|
};
|
||||||
export type GraphIOSubSchema = Omit<
|
export type GraphInputSubSchema = GraphOutputSubSchema &
|
||||||
BlockIOSubSchemaMeta,
|
(
|
||||||
"placeholder" | "depends_on" | "hidden"
|
| { type?: never; default: any | null } // AgentInputBlock (generic Any type)
|
||||||
> & {
|
| { type: "string"; format: "short-text"; default: string | null } // AgentShortTextInputBlock
|
||||||
type: never; // bodge to avoid type checking hell; doesn't exist at runtime
|
| { type: "string"; format: "long-text"; default: string | null } // AgentLongTextInputBlock
|
||||||
default?: string;
|
| { type: "integer"; default: number | null } // AgentNumberInputBlock
|
||||||
|
| { type: "string"; format: "date"; default: string | null } // AgentDateInputBlock
|
||||||
|
| { type: "string"; format: "time"; default: string | null } // AgentTimeInputBlock
|
||||||
|
| { type: "string"; format: "file"; default: string | null } // AgentFileInputBlock
|
||||||
|
| { type: "string"; enum: string[]; default: string | null } // AgentDropdownInputBlock
|
||||||
|
| { type: "boolean"; default: boolean } // AgentToggleInputBlock
|
||||||
|
| {
|
||||||
|
// AgentTableInputBlock
|
||||||
|
type: "array";
|
||||||
|
format: "table";
|
||||||
|
items: {
|
||||||
|
type: "object";
|
||||||
|
properties: Record<string, { type: "string" }>;
|
||||||
|
};
|
||||||
|
default: Array<Record<string, string>> | null;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
// AgentGoogleDriveFileInputBlock
|
||||||
|
type: "object";
|
||||||
|
format: "google-drive-picker";
|
||||||
|
google_drive_picker_config?: GoogleDrivePickerConfig;
|
||||||
|
default: GoogleDriveFile | null;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
export type GraphOutputSchema = {
|
||||||
|
type: "object";
|
||||||
|
properties: Record<string, GraphOutputSubSchema>;
|
||||||
|
required: (keyof GraphOutputSchema["properties"])[];
|
||||||
|
};
|
||||||
|
export type GraphOutputSubSchema = {
|
||||||
|
// TODO: typed outputs based on the incoming edges?
|
||||||
|
title: string;
|
||||||
|
description?: string;
|
||||||
|
advanced: boolean;
|
||||||
secret: boolean;
|
secret: boolean;
|
||||||
metadata?: any;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export type CredentialsInputSchema = {
|
export type CredentialsInputSchema = {
|
||||||
@@ -440,8 +472,8 @@ export type GraphUpdateable = Omit<
|
|||||||
is_active?: boolean;
|
is_active?: boolean;
|
||||||
nodes: NodeCreatable[];
|
nodes: NodeCreatable[];
|
||||||
links: LinkCreatable[];
|
links: LinkCreatable[];
|
||||||
input_schema?: GraphIOSchema;
|
input_schema?: GraphInputSchema;
|
||||||
output_schema?: GraphIOSchema;
|
output_schema?: GraphOutputSchema;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type GraphCreatable = _GraphCreatableInner & {
|
export type GraphCreatable = _GraphCreatableInner & {
|
||||||
@@ -497,8 +529,8 @@ export type LibraryAgent = {
|
|||||||
name: string;
|
name: string;
|
||||||
description: string;
|
description: string;
|
||||||
instructions?: string | null;
|
instructions?: string | null;
|
||||||
input_schema: GraphIOSchema;
|
input_schema: GraphInputSchema;
|
||||||
output_schema: GraphIOSchema;
|
output_schema: GraphOutputSchema;
|
||||||
credentials_input_schema: CredentialsInputSchema;
|
credentials_input_schema: CredentialsInputSchema;
|
||||||
new_output: boolean;
|
new_output: boolean;
|
||||||
can_access_graph: boolean;
|
can_access_graph: boolean;
|
||||||
|
|||||||
@@ -6,7 +6,10 @@ import { NodeDimension } from "@/app/(platform)/build/components/legacy-builder/
|
|||||||
import {
|
import {
|
||||||
BlockIOObjectSubSchema,
|
BlockIOObjectSubSchema,
|
||||||
BlockIORootSchema,
|
BlockIORootSchema,
|
||||||
|
BlockIOSubSchema,
|
||||||
Category,
|
Category,
|
||||||
|
GraphInputSubSchema,
|
||||||
|
GraphOutputSubSchema,
|
||||||
} from "@/lib/autogpt-server-api/types";
|
} from "@/lib/autogpt-server-api/types";
|
||||||
|
|
||||||
export function cn(...inputs: ClassValue[]) {
|
export function cn(...inputs: ClassValue[]) {
|
||||||
@@ -76,8 +79,8 @@ export function getTypeBgColor(type: string | null): string {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getTypeColor(type: string | null): string {
|
export function getTypeColor(type: string | undefined): string {
|
||||||
if (type === null) return "#6b7280";
|
if (!type) return "#6b7280";
|
||||||
return (
|
return (
|
||||||
{
|
{
|
||||||
string: "#22c55e",
|
string: "#22c55e",
|
||||||
@@ -88,11 +91,59 @@ export function getTypeColor(type: string | null): string {
|
|||||||
array: "#6366f1",
|
array: "#6366f1",
|
||||||
null: "#6b7280",
|
null: "#6b7280",
|
||||||
any: "#6b7280",
|
any: "#6b7280",
|
||||||
"": "#6b7280",
|
|
||||||
}[type] || "#6b7280"
|
}[type] || "#6b7280"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts the effective type from a JSON schema, handling anyOf/oneOf/allOf wrappers.
|
||||||
|
* Returns the first non-null type found in the schema structure.
|
||||||
|
*/
|
||||||
|
export function getEffectiveType(
|
||||||
|
schema:
|
||||||
|
| BlockIOSubSchema
|
||||||
|
| GraphInputSubSchema
|
||||||
|
| GraphOutputSubSchema
|
||||||
|
| null
|
||||||
|
| undefined,
|
||||||
|
): string | undefined {
|
||||||
|
if (!schema) return undefined;
|
||||||
|
|
||||||
|
// Direct type property
|
||||||
|
if ("type" in schema && schema.type) {
|
||||||
|
return String(schema.type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle allOf - typically a single-item wrapper
|
||||||
|
if (
|
||||||
|
"allOf" in schema &&
|
||||||
|
Array.isArray(schema.allOf) &&
|
||||||
|
schema.allOf.length > 0
|
||||||
|
) {
|
||||||
|
return getEffectiveType(schema.allOf[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle anyOf - e.g. [{ type: "string" }, { type: "null" }]
|
||||||
|
if ("anyOf" in schema && Array.isArray(schema.anyOf)) {
|
||||||
|
for (const item of schema.anyOf) {
|
||||||
|
if ("type" in item && item.type !== "null") {
|
||||||
|
return String(item.type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle oneOf
|
||||||
|
if ("oneOf" in schema && Array.isArray(schema.oneOf)) {
|
||||||
|
for (const item of schema.oneOf) {
|
||||||
|
if ("type" in item && item.type !== "null") {
|
||||||
|
return String(item.type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
export function beautifyString(name: string): string {
|
export function beautifyString(name: string): string {
|
||||||
// Regular expression to identify places to split, considering acronyms
|
// Regular expression to identify places to split, considering acronyms
|
||||||
const result = name
|
const result = name
|
||||||
|
|||||||
Reference in New Issue
Block a user