Compare commits

..

4 Commits

Author SHA1 Message Date
Swifty
5ac941fe2f feat(backend): add hybrid search for store listings, docs and blocks (#11721)
This PR adds hybrid search functionality combining semantic embeddings
with traditional text search for improved store listing discovery.

### Changes 🏗️

- Add `embeddings.py` - OpenAI-based embedding generation and similarity
search
- Add `hybrid_search.py` - Combines vector similarity with text matching
for better search results
- Add `backfill_embeddings.py` - Script to generate embeddings for
existing store listings
- Update `db.py` - Integrate hybrid search into store database queries
- Update `schema.prisma` - Add embedding storage fields and indexes
- Add migrations for embedding columns and HNSW index for vector search

### Architecture Decisions 🏛️

**Fail-Fast Approach (No Silent Fallbacks)**

We explicitly chose NOT to implement graceful degradation when hybrid
search fails. Here's why:

 **Benefits:**
- Errors surface immediately → faster fixes
- Tests verify hybrid search actually works (not just fallback)
- Consistent search quality for all users
- Forces proper infrastructure setup (API keys, database)

 **Why Not Fallback:**
- Silent degradation hides production issues
- Users get inconsistent results without knowing why
- Tests can pass even when hybrid search is broken
- Reduces operational visibility

**How We Prevent Failures:**
1. Embedding generation in approval flow (db.py:1545)
2. Error logging with `logger.error` (not warning)
3. Clear error messages (ValueError explains what's wrong)
4. Comprehensive test coverage (9/9 tests passing)

If embeddings fail, it indicates a real infrastructure issue (missing
API key, OpenAI down, database issues) that needs immediate attention,
not silent degradation.

### Test Coverage 

**All tests passing (1625 total):**
- 9/9 hybrid_search tests (including fail-fast validation)
- 3/3 db search integration tests
- Full schema compatibility (public/platform schemas)
- Error handling verification

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Test hybrid search returns relevant results
  - [x] Test embedding generation for new listings
  - [x] Test backfill script on existing data
  - [x] Verify search performance with embeddings
  - [x] Test fail-fast behavior when embeddings unavailable

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] Configuration: Requires `openai_internal_api_key` in secrets

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 04:17:03 +00:00
Reinier van der Leer
b01ea3fcbd fix(backend/executor): Centralize increment_runs calls & make add_graph_execution more robust (#11764)
[OPEN-2946: \[Scheduler\] Error executing graph <graph_id> after 19.83s:
ClientNotConnectedError: Client is not connected to the query engine,
you must call `connect()` before attempting to query
data.](https://linear.app/autogpt/issue/OPEN-2946)

- Follow-up to #11375
  <sub>(broken `increment_runs` call)</sub>
- Follow-up to #11380
  <sub>(direct `get_graph_execution` call)</sub>

### Changes 🏗️

- Move `increment_runs` call from `scheduler._execute_graph` to
`executor.utils.add_graph_execution` so it can be made through
`DatabaseManager`
  - Add `increment_onboarding_runs` to `DatabaseManager`
- Remove now-redundant `increment_onboarding_runs` calls in other places
- Make `add_graph_execution` more resilient
  - Split up large try/except block
  - Fix direct `get_graph_execution` call

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - CI + a thorough review
2026-01-15 04:08:19 +00:00
Reinier van der Leer
3b09a94e3f feat(frontend/builder): Add sub-graph update UX (#11631)
[OPEN-2743: Ability to Update Sub-Agents in Graph (Without
Re-Adding)](https://linear.app/autogpt/issue/OPEN-2743/ability-to-update-sub-agents-in-graph-without-re-adding)

Updating sub-graphs is a cumbersome experience at the moment, this
should help. :)

Demo in Builder v2:


https://github.com/user-attachments/assets/df564f32-4d1d-432c-bb91-fe9065068360


https://github.com/user-attachments/assets/f169471a-1f22-46e9-a958-ddb72d3f65af


### Changes 🏗️

- Add sub-graph update banner with I/O incompatibility notification and
resolution mode
  - Red visual indicators for broken inputs/outputs and edges
  - Update bars and tooltips show compatibility details
- Sub-agent update UI with compatibility checks, incompatibility dialog,
and guided resolution workflow
- Resolution mode banner guiding users to remove incompatible
connections
- Visual controls to stage/apply updates and auto-apply when broken
connections are fixed
  
  Technical:
- Builder v1: Add `CustomNode` > `IncompatibilityDialog` +
`SubAgentUpdateBar` sub-components
- Builder v2: Add `SubAgentUpdateFeature` + `ResolutionModeBar` +
`IncompatibleUpdateDialog` + `useSubAgentUpdateState` sub-components
  - Add `useSubAgentUpdate` hook

- Related fixes in Builder v1:
  - Fix static edges not rendering as such
  - Fix edge styling not applying
- Related fixes in Builder v2:
  - Fix excess spacing for nested node input fields

Other:
- "Retry" button in error view now reloads the page instead of
navigating to `/marketplace`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - CI for existing frontend UX flows
- [x] Updating to a new sub-agent version with compatibility issues: UX
flow works
- [x] Updating to a new sub-agent version with *no* compatibility
issues: works
  - [x] Designer approves of the look

---------

Co-authored-by: abhi1992002 <abhimanyu1992002@gmail.com>
Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
2026-01-14 13:25:20 +00:00
Zamil Majdy
61efee4139 fix(frontend): Remove hardcoded bypass of billing feature flag (#11762)
## Summary

Fixes a critical security issue where the billing button in the settings
sidebar was always visible to all users, bypassing the
`ENABLE_PLATFORM_PAYMENT` feature flag.

## Changes 🏗️

- Removed hardcoded `|| true` condition in
`frontend/src/app/(platform)/profile/(user)/layout.tsx:32` that was
bypassing the feature flag check
- The billing button is now properly gated by the
`ENABLE_PLATFORM_PAYMENT` feature flag as intended

## Root Cause

The `|| true` was accidentally left in commit
3dbc03e488 (PR #11617 - OAuth API & Single
Sign-On) from December 19, 2025. It was likely added temporarily during
development/testing to always show the billing button, but was not
removed before merging.

## Test Plan

1. Verify feature flag is set to disabled in LaunchDarkly
2. Navigate to settings page (`/profile/settings`)
3. Confirm billing button is NOT visible in the sidebar
4. Enable feature flag in LaunchDarkly
5. Refresh page and confirm billing button IS now visible
6. Verify billing page (`/profile/credits`) is still accessible via
direct URL when feature flag is disabled

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan

Fixes SECRT-1791

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* The Billing link in the profile sidebar now respects the payment
feature flag configuration and will only display when payment
functionality is enabled.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-14 03:28:36 +00:00
68 changed files with 5003 additions and 988 deletions

View File

@@ -176,7 +176,7 @@ jobs:
} }
- name: Run Database Migrations - name: Run Database Migrations
run: poetry run prisma migrate dev --name updates run: poetry run prisma migrate deploy
env: env:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }} DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }} DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}

View File

@@ -11,6 +11,7 @@ on:
- ".github/workflows/platform-frontend-ci.yml" - ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**" - "autogpt_platform/frontend/**"
merge_group: merge_group:
workflow_dispatch:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }} group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }}
@@ -151,6 +152,14 @@ jobs:
run: | run: |
cp ../.env.default ../.env cp ../.env.default ../.env
- name: Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3

View File

@@ -18,3 +18,4 @@ load-tests/results/
load-tests/*.json load-tests/*.json
load-tests/*.log load-tests/*.log
load-tests/node_modules/* load-tests/node_modules/*
migrations/*/rollback*.sql

View File

@@ -1,4 +1,5 @@
import uuid import uuid
from unittest.mock import AsyncMock, patch
import orjson import orjson
import pytest import pytest
@@ -17,6 +18,17 @@ setup_test_data = setup_test_data
setup_firecrawl_test_data = setup_firecrawl_test_data setup_firecrawl_test_data = setup_firecrawl_test_data
@pytest.fixture(scope="session", autouse=True)
def mock_embedding_functions():
"""Mock embedding functions for all tests to avoid database/API dependencies."""
with patch(
"backend.api.features.store.db.ensure_embedding",
new_callable=AsyncMock,
return_value=True,
):
yield
@pytest.mark.asyncio(scope="session") @pytest.mark.asyncio(scope="session")
async def test_run_agent(setup_test_data): async def test_run_agent(setup_test_data):
"""Test that the run_agent tool successfully executes an approved agent""" """Test that the run_agent tool successfully executes an approved agent"""

View File

@@ -35,11 +35,7 @@ from backend.data.model import (
OAuth2Credentials, OAuth2Credentials,
UserIntegrations, UserIntegrations,
) )
from backend.data.onboarding import ( from backend.data.onboarding import OnboardingStep, complete_onboarding_step
OnboardingStep,
complete_onboarding_step,
increment_runs,
)
from backend.data.user import get_user_integrations from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
@@ -378,7 +374,6 @@ async def webhook_ingress_generic(
return return
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK) await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
await increment_runs(user_id)
# Execute all triggers concurrently for better performance # Execute all triggers concurrently for better performance
tasks = [] tasks = []

View File

@@ -8,7 +8,6 @@ from backend.data.execution import GraphExecutionMeta
from backend.data.graph import get_graph from backend.data.graph import get_graph
from backend.data.integrations import get_webhook from backend.data.integrations import get_webhook
from backend.data.model import CredentialsMetaInput from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_runs
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks import get_webhook_manager from backend.integrations.webhooks import get_webhook_manager
@@ -403,8 +402,6 @@ async def execute_preset(
merged_node_input = preset.inputs | inputs merged_node_input = preset.inputs | inputs
merged_credential_inputs = preset.credentials | credential_inputs merged_credential_inputs = preset.credentials | credential_inputs
await increment_runs(user_id)
return await add_graph_execution( return await add_graph_execution(
user_id=user_id, user_id=user_id,
graph_id=preset.graph_id, graph_id=preset.graph_id,

View File

@@ -1,8 +1,7 @@
import asyncio import asyncio
import logging import logging
import typing
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Literal from typing import Any, Literal
import fastapi import fastapi
import prisma.enums import prisma.enums
@@ -10,7 +9,7 @@ import prisma.errors
import prisma.models import prisma.models
import prisma.types import prisma.types
from backend.data.db import query_raw_with_schema, transaction from backend.data.db import transaction
from backend.data.graph import ( from backend.data.graph import (
GraphMeta, GraphMeta,
GraphModel, GraphModel,
@@ -30,6 +29,8 @@ from backend.util.settings import Settings
from . import exceptions as store_exceptions from . import exceptions as store_exceptions
from . import model as store_model from . import model as store_model
from .embeddings import ensure_embedding
from .hybrid_search import hybrid_search
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
settings = Settings() settings = Settings()
@@ -50,105 +51,52 @@ async def get_store_agents(
page_size: int = 20, page_size: int = 20,
) -> store_model.StoreAgentsResponse: ) -> store_model.StoreAgentsResponse:
""" """
Get PUBLIC store agents from the StoreAgent view Get PUBLIC store agents from the StoreAgent view.
Search behavior:
- With search_query: Uses hybrid search (semantic + lexical)
- Fallback: If embeddings unavailable, gracefully degrades to lexical-only
- Rationale: User-facing endpoint prioritizes availability over accuracy
Note: Admin operations (approval) use fail-fast to prevent inconsistent state.
""" """
logger.debug( logger.debug(
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}" f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
) )
search_used_hybrid = False
store_agents: list[store_model.StoreAgent] = []
agents: list[dict[str, Any]] = []
total = 0
total_pages = 0
try: try:
# If search_query is provided, use full-text search # If search_query is provided, use hybrid search (embeddings + tsvector)
if search_query: if search_query:
offset = (page - 1) * page_size # Try hybrid search combining semantic and lexical signals
# Falls back to lexical-only if OpenAI unavailable (user-facing, high SLA)
try:
agents, total = await hybrid_search(
query=search_query,
featured=featured,
creators=creators,
category=category,
sorted_by="relevance", # Use hybrid scoring for relevance
page=page,
page_size=page_size,
)
search_used_hybrid = True
except Exception as e:
# Log error but fall back to lexical search for better UX
logger.error(
f"Hybrid search failed (likely OpenAI unavailable), "
f"falling back to lexical search: {e}"
)
# search_used_hybrid remains False, will use fallback path below
# Whitelist allowed order_by columns # Convert hybrid search results (dict format) if hybrid succeeded
ALLOWED_ORDER_BY = { if search_used_hybrid:
"rating": "rating DESC, rank DESC",
"runs": "runs DESC, rank DESC",
"name": "agent_name ASC, rank ASC",
"updated_at": "updated_at DESC, rank DESC",
}
# Validate and get order clause
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
else:
order_by_clause = "updated_at DESC, rank DESC"
# Build WHERE conditions and parameters list
where_parts: list[str] = []
params: list[typing.Any] = [search_query] # $1 - search term
param_index = 2 # Start at $2 for next parameter
# Always filter for available agents
where_parts.append("is_available = true")
if featured:
where_parts.append("featured = true")
if creators and creators:
# Use ANY with array parameter
where_parts.append(f"creator_username = ANY(${param_index})")
params.append(creators)
param_index += 1
if category and category:
where_parts.append(f"${param_index} = ANY(categories)")
params.append(category)
param_index += 1
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
# Add pagination params
params.extend([page_size, offset])
limit_param = f"${param_index}"
offset_param = f"${param_index + 1}"
# Execute full-text search query with parameterized values
sql_query = f"""
SELECT
slug,
agent_name,
agent_image,
creator_username,
creator_avatar,
sub_heading,
description,
runs,
rating,
categories,
featured,
is_available,
updated_at,
ts_rank_cd(search, query) AS rank
FROM {{schema_prefix}}"StoreAgent",
plainto_tsquery('english', $1) AS query
WHERE {sql_where_clause}
AND search @@ query
ORDER BY {order_by_clause}
LIMIT {limit_param} OFFSET {offset_param}
"""
# Count query for pagination - only uses search term parameter
count_query = f"""
SELECT COUNT(*) as count
FROM {{schema_prefix}}"StoreAgent",
plainto_tsquery('english', $1) AS query
WHERE {sql_where_clause}
AND search @@ query
"""
# Execute both queries with parameters
agents = await query_raw_with_schema(sql_query, *params)
# For count, use params without pagination (last 2 params)
count_params = params[:-2]
count_result = await query_raw_with_schema(count_query, *count_params)
total = count_result[0]["count"] if count_result else 0
total_pages = (total + page_size - 1) // page_size total_pages = (total + page_size - 1) // page_size
# Convert raw results to StoreAgent models
store_agents: list[store_model.StoreAgent] = [] store_agents: list[store_model.StoreAgent] = []
for agent in agents: for agent in agents:
try: try:
@@ -167,11 +115,13 @@ async def get_store_agents(
) )
store_agents.append(store_agent) store_agents.append(store_agent)
except Exception as e: except Exception as e:
logger.error(f"Error parsing Store agent from search results: {e}") logger.error(
f"Error parsing Store agent from hybrid search results: {e}"
)
continue continue
else: if not search_used_hybrid:
# Non-search query path (original logic) # Fallback path - use basic search or no search
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True} where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
if featured: if featured:
where_clause["featured"] = featured where_clause["featured"] = featured
@@ -180,6 +130,14 @@ async def get_store_agents(
if category: if category:
where_clause["categories"] = {"has": category} where_clause["categories"] = {"has": category}
# Add basic text search if search_query provided but hybrid failed
if search_query:
where_clause["OR"] = [
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
{"description": {"contains": search_query, "mode": "insensitive"}},
]
order_by = [] order_by = []
if sorted_by == "rating": if sorted_by == "rating":
order_by.append({"rating": "desc"}) order_by.append({"rating": "desc"})
@@ -188,7 +146,7 @@ async def get_store_agents(
elif sorted_by == "name": elif sorted_by == "name":
order_by.append({"agent_name": "asc"}) order_by.append({"agent_name": "asc"})
agents = await prisma.models.StoreAgent.prisma().find_many( db_agents = await prisma.models.StoreAgent.prisma().find_many(
where=where_clause, where=where_clause,
order=order_by, order=order_by,
skip=(page - 1) * page_size, skip=(page - 1) * page_size,
@@ -199,7 +157,7 @@ async def get_store_agents(
total_pages = (total + page_size - 1) // page_size total_pages = (total + page_size - 1) // page_size
store_agents: list[store_model.StoreAgent] = [] store_agents: list[store_model.StoreAgent] = []
for agent in agents: for agent in db_agents:
try: try:
# Create the StoreAgent object safely # Create the StoreAgent object safely
store_agent = store_model.StoreAgent( store_agent = store_model.StoreAgent(
@@ -1577,7 +1535,7 @@ async def review_store_submission(
) )
# Update the AgentGraph with store listing data # Update the AgentGraph with store listing data
await prisma.models.AgentGraph.prisma().update( await prisma.models.AgentGraph.prisma(tx).update(
where={ where={
"graphVersionId": { "graphVersionId": {
"id": store_listing_version.agentGraphId, "id": store_listing_version.agentGraphId,
@@ -1592,6 +1550,23 @@ async def review_store_submission(
}, },
) )
# Generate embedding for approved listing (blocking - admin operation)
# Inside transaction: if embedding fails, entire transaction rolls back
embedding_success = await ensure_embedding(
version_id=store_listing_version_id,
name=store_listing_version.name,
description=store_listing_version.description,
sub_heading=store_listing_version.subHeading,
categories=store_listing_version.categories or [],
tx=tx,
)
if not embedding_success:
raise ValueError(
f"Failed to generate embedding for listing {store_listing_version_id}. "
"This is likely due to OpenAI API being unavailable. "
"Please try again later or contact support if the issue persists."
)
await prisma.models.StoreListing.prisma(tx).update( await prisma.models.StoreListing.prisma(tx).update(
where={"id": store_listing_version.StoreListing.id}, where={"id": store_listing_version.StoreListing.id},
data={ data={

View File

@@ -0,0 +1,568 @@
"""
Unified Content Embeddings Service
Handles generation and storage of OpenAI embeddings for all content types
(store listings, blocks, documentation, library agents) to enable semantic/hybrid search.
"""
import asyncio
import logging
import time
from typing import Any
import prisma
from prisma.enums import ContentType
from tiktoken import encoding_for_model
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
from backend.util.clients import get_openai_client
from backend.util.json import dumps
logger = logging.getLogger(__name__)
# OpenAI embedding model configuration
EMBEDDING_MODEL = "text-embedding-3-small"
# OpenAI embedding token limit (8,191 with 1 token buffer for safety)
EMBEDDING_MAX_TOKENS = 8191
def build_searchable_text(
name: str,
description: str,
sub_heading: str,
categories: list[str],
) -> str:
"""
Build searchable text from listing version fields.
Combines relevant fields into a single string for embedding.
"""
parts = []
# Name is important - include it
if name:
parts.append(name)
# Sub-heading provides context
if sub_heading:
parts.append(sub_heading)
# Description is the main content
if description:
parts.append(description)
# Categories help with semantic matching
if categories:
parts.append(" ".join(categories))
return " ".join(parts)
async def generate_embedding(text: str) -> list[float] | None:
"""
Generate embedding for text using OpenAI API.
Returns None if embedding generation fails.
Fail-fast: no retries to maintain consistency with approval flow.
"""
try:
client = get_openai_client()
if not client:
logger.error("openai_internal_api_key not set, cannot generate embedding")
return None
# Truncate text to token limit using tiktoken
# Character-based truncation is insufficient because token ratios vary by content type
enc = encoding_for_model(EMBEDDING_MODEL)
tokens = enc.encode(text)
if len(tokens) > EMBEDDING_MAX_TOKENS:
tokens = tokens[:EMBEDDING_MAX_TOKENS]
truncated_text = enc.decode(tokens)
logger.info(
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
)
else:
truncated_text = text
start_time = time.time()
response = await client.embeddings.create(
model=EMBEDDING_MODEL,
input=truncated_text,
)
latency_ms = (time.time() - start_time) * 1000
embedding = response.data[0].embedding
logger.info(
f"Generated embedding: {len(embedding)} dims, "
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
)
return embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
return None
async def store_embedding(
version_id: str,
embedding: list[float],
tx: prisma.Prisma | None = None,
) -> bool:
"""
Store embedding in the database.
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
DEPRECATED: Use ensure_embedding() instead (includes searchable_text).
"""
return await store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=version_id,
embedding=embedding,
searchable_text="", # Empty for backward compat; ensure_embedding() populates this
metadata=None,
user_id=None, # Store agents are public
tx=tx,
)
async def store_content_embedding(
content_type: ContentType,
content_id: str,
embedding: list[float],
searchable_text: str,
metadata: dict | None = None,
user_id: str | None = None,
tx: prisma.Prisma | None = None,
) -> bool:
"""
Store embedding in the unified content embeddings table.
New function for unified content embedding storage.
Uses raw SQL since Prisma doesn't natively support pgvector.
"""
try:
client = tx if tx else prisma.get_client()
# Convert embedding to PostgreSQL vector format
embedding_str = embedding_to_vector_string(embedding)
metadata_json = dumps(metadata or {})
# Upsert the embedding
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
await execute_raw_with_schema(
"""
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
)
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
ON CONFLICT ("contentType", "contentId", "userId")
DO UPDATE SET
"embedding" = $4::vector,
"searchableText" = $5,
"metadata" = $6::jsonb,
"updatedAt" = NOW()
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
""",
content_type,
content_id,
user_id,
embedding_str,
searchable_text,
metadata_json,
client=client,
set_public_search_path=True,
)
logger.info(f"Stored embedding for {content_type}:{content_id}")
return True
except Exception as e:
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
return False
async def get_embedding(version_id: str) -> dict[str, Any] | None:
"""
Retrieve embedding record for a listing version.
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
Returns dict with storeListingVersionId, embedding, timestamps or None if not found.
"""
result = await get_content_embedding(
ContentType.STORE_AGENT, version_id, user_id=None
)
if result:
# Transform to old format for backward compatibility
return {
"storeListingVersionId": result["contentId"],
"embedding": result["embedding"],
"createdAt": result["createdAt"],
"updatedAt": result["updatedAt"],
}
return None
async def get_content_embedding(
content_type: ContentType, content_id: str, user_id: str | None = None
) -> dict[str, Any] | None:
"""
Retrieve embedding record for any content type.
New function for unified content embedding retrieval.
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
"""
try:
result = await query_raw_with_schema(
"""
SELECT
"contentType",
"contentId",
"userId",
"embedding"::text as "embedding",
"searchableText",
"metadata",
"createdAt",
"updatedAt"
FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
set_public_search_path=True,
)
if result and len(result) > 0:
return result[0]
return None
except Exception as e:
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
return None
async def ensure_embedding(
version_id: str,
name: str,
description: str,
sub_heading: str,
categories: list[str],
force: bool = False,
tx: prisma.Prisma | None = None,
) -> bool:
"""
Ensure an embedding exists for the listing version.
Creates embedding if missing. Use force=True to regenerate.
Backward-compatible wrapper for store listings.
Args:
version_id: The StoreListingVersion ID
name: Agent name
description: Agent description
sub_heading: Agent sub-heading
categories: Agent categories
force: Force regeneration even if embedding exists
tx: Optional transaction client
Returns:
True if embedding exists/was created, False on failure
"""
try:
# Check if embedding already exists
if not force:
existing = await get_embedding(version_id)
if existing and existing.get("embedding"):
logger.debug(f"Embedding for version {version_id} already exists")
return True
# Build searchable text for embedding
searchable_text = build_searchable_text(
name, description, sub_heading, categories
)
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(f"Could not generate embedding for version {version_id}")
return False
# Store the embedding with metadata using new function
metadata = {
"name": name,
"subHeading": sub_heading,
"categories": categories,
}
return await store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=version_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata,
user_id=None, # Store agents are public
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
return False
async def delete_embedding(version_id: str) -> bool:
"""
Delete embedding for a listing version.
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
Note: This is usually handled automatically by CASCADE delete,
but provided for manual cleanup if needed.
"""
return await delete_content_embedding(ContentType.STORE_AGENT, version_id)
async def delete_content_embedding(
content_type: ContentType, content_id: str, user_id: str | None = None
) -> bool:
"""
Delete embedding for any content type.
New function for unified content embedding deletion.
Note: This is usually handled automatically by CASCADE delete,
but provided for manual cleanup if needed.
Args:
content_type: The type of content (STORE_AGENT, LIBRARY_AGENT, etc.)
content_id: The unique identifier for the content
user_id: Optional user ID. For public content (STORE_AGENT, BLOCK), pass None.
For user-scoped content (LIBRARY_AGENT), pass the user's ID to avoid
deleting embeddings belonging to other users.
Returns:
True if deletion succeeded, False otherwise
"""
try:
client = prisma.get_client()
await execute_raw_with_schema(
"""
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType"
AND "contentId" = $2
AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
client=client,
)
user_str = f" (user: {user_id})" if user_id else ""
logger.info(f"Deleted embedding for {content_type}:{content_id}{user_str}")
return True
except Exception as e:
logger.error(f"Failed to delete embedding for {content_type}:{content_id}: {e}")
return False
async def get_embedding_stats() -> dict[str, Any]:
"""
Get statistics about embedding coverage.
Returns counts of:
- Total approved listing versions
- Versions with embeddings
- Versions without embeddings
"""
try:
# Count approved versions
approved_result = await query_raw_with_schema(
"""
SELECT COUNT(*) as count
FROM {schema_prefix}"StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
AND "isDeleted" = false
"""
)
total_approved = approved_result[0]["count"] if approved_result else 0
# Count versions with embeddings
embedded_result = await query_raw_with_schema(
"""
SELECT COUNT(*) as count
FROM {schema_prefix}"StoreListingVersion" slv
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
WHERE slv."submissionStatus" = 'APPROVED'
AND slv."isDeleted" = false
"""
)
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
return {
"total_approved": total_approved,
"with_embeddings": with_embeddings,
"without_embeddings": total_approved - with_embeddings,
"coverage_percent": (
round(with_embeddings / total_approved * 100, 1)
if total_approved > 0
else 0
),
}
except Exception as e:
logger.error(f"Failed to get embedding stats: {e}")
return {
"total_approved": 0,
"with_embeddings": 0,
"without_embeddings": 0,
"coverage_percent": 0,
"error": str(e),
}
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
"""
Generate embeddings for approved listings that don't have them.
Args:
batch_size: Number of embeddings to generate in one call
Returns:
Dict with success/failure counts
"""
try:
# Find approved versions without embeddings
missing = await query_raw_with_schema(
"""
SELECT
slv.id,
slv.name,
slv.description,
slv."subHeading",
slv.categories
FROM {schema_prefix}"StoreListingVersion" slv
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
WHERE slv."submissionStatus" = 'APPROVED'
AND slv."isDeleted" = false
AND uce."contentId" IS NULL
LIMIT $1
""",
batch_size,
)
if not missing:
return {
"processed": 0,
"success": 0,
"failed": 0,
"message": "No missing embeddings",
}
# Process embeddings concurrently for better performance
embedding_tasks = [
ensure_embedding(
version_id=row["id"],
name=row["name"],
description=row["description"],
sub_heading=row["subHeading"],
categories=row["categories"] or [],
)
for row in missing
]
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
success = sum(1 for result in results if result is True)
failed = len(results) - success
return {
"processed": len(missing),
"success": success,
"failed": failed,
"message": f"Backfilled {success} embeddings, {failed} failed",
}
except Exception as e:
logger.error(f"Failed to backfill embeddings: {e}")
return {
"processed": 0,
"success": 0,
"failed": 0,
"error": str(e),
}
async def embed_query(query: str) -> list[float] | None:
"""
Generate embedding for a search query.
Same as generate_embedding but with clearer intent.
"""
return await generate_embedding(query)
def embedding_to_vector_string(embedding: list[float]) -> str:
"""Convert embedding list to PostgreSQL vector string format."""
return "[" + ",".join(str(x) for x in embedding) + "]"
async def ensure_content_embedding(
content_type: ContentType,
content_id: str,
searchable_text: str,
metadata: dict | None = None,
user_id: str | None = None,
force: bool = False,
tx: prisma.Prisma | None = None,
) -> bool:
"""
Ensure an embedding exists for any content type.
Generic function for creating embeddings for store agents, blocks, docs, etc.
Args:
content_type: ContentType enum value (STORE_AGENT, BLOCK, etc.)
content_id: Unique identifier for the content
searchable_text: Combined text for embedding generation
metadata: Optional metadata to store with embedding
force: Force regeneration even if embedding exists
tx: Optional transaction client
Returns:
True if embedding exists/was created, False on failure
"""
try:
# Check if embedding already exists
if not force:
existing = await get_content_embedding(content_type, content_id, user_id)
if existing and existing.get("embedding"):
logger.debug(
f"Embedding for {content_type}:{content_id} already exists"
)
return True
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(
f"Could not generate embedding for {content_type}:{content_id}"
)
return False
# Store the embedding
return await store_content_embedding(
content_type=content_type,
content_id=content_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata or {},
user_id=user_id,
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
return False

View File

@@ -0,0 +1,329 @@
"""
Integration tests for embeddings with schema handling.
These tests verify that embeddings operations work correctly across different database schemas.
"""
from unittest.mock import AsyncMock, patch
import pytest
from prisma.enums import ContentType
from backend.api.features.store import embeddings
# Schema prefix tests removed - functionality moved to db.raw_with_schema() helper
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_store_content_embedding_with_schema():
"""Test storing embeddings with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
mock_get_client.return_value = mock_client
result = await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1] * 1536,
searchable_text="test text",
metadata={"test": "data"},
user_id=None,
)
# Verify the query was called
assert mock_client.execute_raw.called
# Get the SQL query that was executed
call_args = mock_client.execute_raw.call_args
sql_query = call_args[0][0]
# Verify schema prefix is in the query
assert '"platform"."UnifiedContentEmbedding"' in sql_query
# Verify result
assert result is True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_get_content_embedding_with_schema():
"""Test retrieving embeddings with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.query_raw.return_value = [
{
"contentType": "STORE_AGENT",
"contentId": "test-id",
"userId": None,
"embedding": "[0.1, 0.2]",
"searchableText": "test",
"metadata": {},
"createdAt": "2024-01-01",
"updatedAt": "2024-01-01",
}
]
mock_get_client.return_value = mock_client
result = await embeddings.get_content_embedding(
ContentType.STORE_AGENT,
"test-id",
user_id=None,
)
# Verify the query was called
assert mock_client.query_raw.called
# Get the SQL query that was executed
call_args = mock_client.query_raw.call_args
sql_query = call_args[0][0]
# Verify schema prefix is in the query
assert '"platform"."UnifiedContentEmbedding"' in sql_query
# Verify result
assert result is not None
assert result["contentId"] == "test-id"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_delete_content_embedding_with_schema():
"""Test deleting embeddings with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
mock_get_client.return_value = mock_client
result = await embeddings.delete_content_embedding(
ContentType.STORE_AGENT,
"test-id",
)
# Verify the query was called
assert mock_client.execute_raw.called
# Get the SQL query that was executed
call_args = mock_client.execute_raw.call_args
sql_query = call_args[0][0]
# Verify schema prefix is in the query
assert '"platform"."UnifiedContentEmbedding"' in sql_query
# Verify result
assert result is True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_get_embedding_stats_with_schema():
"""Test embedding statistics with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
# Mock both query results
mock_client.query_raw.side_effect = [
[{"count": 100}], # total_approved
[{"count": 80}], # with_embeddings
]
mock_get_client.return_value = mock_client
result = await embeddings.get_embedding_stats()
# Verify both queries were called
assert mock_client.query_raw.call_count == 2
# Get both SQL queries
first_call = mock_client.query_raw.call_args_list[0]
second_call = mock_client.query_raw.call_args_list[1]
first_sql = first_call[0][0]
second_sql = second_call[0][0]
# Verify schema prefix in both queries
assert '"platform"."StoreListingVersion"' in first_sql
assert '"platform"."StoreListingVersion"' in second_sql
assert '"platform"."UnifiedContentEmbedding"' in second_sql
# Verify results
assert result["total_approved"] == 100
assert result["with_embeddings"] == 80
assert result["without_embeddings"] == 20
assert result["coverage_percent"] == 80.0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_backfill_missing_embeddings_with_schema():
"""Test backfilling embeddings with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
# Mock missing embeddings query
mock_client.query_raw.return_value = [
{
"id": "version-1",
"name": "Test Agent",
"description": "Test description",
"subHeading": "Test heading",
"categories": ["test"],
}
]
mock_get_client.return_value = mock_client
with patch(
"backend.api.features.store.embeddings.ensure_embedding"
) as mock_ensure:
mock_ensure.return_value = True
result = await embeddings.backfill_missing_embeddings(batch_size=10)
# Verify the query was called
assert mock_client.query_raw.called
# Get the SQL query
call_args = mock_client.query_raw.call_args
sql_query = call_args[0][0]
# Verify schema prefix in query
assert '"platform"."StoreListingVersion"' in sql_query
assert '"platform"."UnifiedContentEmbedding"' in sql_query
# Verify ensure_embedding was called
assert mock_ensure.called
# Verify results
assert result["processed"] == 1
assert result["success"] == 1
assert result["failed"] == 0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_ensure_content_embedding_with_schema():
"""Test ensuring embeddings exist with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch(
"backend.api.features.store.embeddings.get_content_embedding"
) as mock_get:
# Simulate no existing embedding
mock_get.return_value = None
with patch(
"backend.api.features.store.embeddings.generate_embedding"
) as mock_generate:
mock_generate.return_value = [0.1] * 1536
with patch(
"backend.api.features.store.embeddings.store_content_embedding"
) as mock_store:
mock_store.return_value = True
result = await embeddings.ensure_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
searchable_text="test text",
metadata={"test": "data"},
user_id=None,
force=False,
)
# Verify the flow
assert mock_get.called
assert mock_generate.called
assert mock_store.called
assert result is True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_backward_compatibility_store_embedding():
"""Test backward compatibility wrapper for store_embedding."""
with patch(
"backend.api.features.store.embeddings.store_content_embedding"
) as mock_store:
mock_store.return_value = True
result = await embeddings.store_embedding(
version_id="test-version-id",
embedding=[0.1] * 1536,
tx=None,
)
# Verify it calls the new function with correct parameters
assert mock_store.called
call_args = mock_store.call_args
assert call_args[1]["content_type"] == ContentType.STORE_AGENT
assert call_args[1]["content_id"] == "test-version-id"
assert call_args[1]["user_id"] is None
assert result is True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_backward_compatibility_get_embedding():
"""Test backward compatibility wrapper for get_embedding."""
with patch(
"backend.api.features.store.embeddings.get_content_embedding"
) as mock_get:
mock_get.return_value = {
"contentType": "STORE_AGENT",
"contentId": "test-version-id",
"embedding": "[0.1, 0.2]",
"createdAt": "2024-01-01",
"updatedAt": "2024-01-01",
}
result = await embeddings.get_embedding("test-version-id")
# Verify it calls the new function
assert mock_get.called
# Verify it transforms to old format
assert result is not None
assert result["storeListingVersionId"] == "test-version-id"
assert "embedding" in result
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_schema_handling_error_cases():
"""Test error handling in schema-aware operations."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.execute_raw.side_effect = Exception("Database error")
mock_get_client.return_value = mock_client
result = await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1] * 1536,
searchable_text="test",
metadata=None,
user_id=None,
)
# Should return False on error, not raise
assert result is False
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,387 @@
from unittest.mock import AsyncMock, MagicMock, patch
import prisma
import pytest
from prisma import Prisma
from prisma.enums import ContentType
from backend.api.features.store import embeddings
@pytest.fixture(autouse=True)
async def setup_prisma():
"""Setup Prisma client for tests."""
try:
Prisma()
except prisma.errors.ClientAlreadyRegisteredError:
pass
yield
@pytest.mark.asyncio(loop_scope="session")
async def test_build_searchable_text():
"""Test searchable text building from listing fields."""
result = embeddings.build_searchable_text(
name="AI Assistant",
description="A helpful AI assistant for productivity",
sub_heading="Boost your productivity",
categories=["AI", "Productivity"],
)
expected = "AI Assistant Boost your productivity A helpful AI assistant for productivity AI Productivity"
assert result == expected
@pytest.mark.asyncio(loop_scope="session")
async def test_build_searchable_text_empty_fields():
"""Test searchable text building with empty fields."""
result = embeddings.build_searchable_text(
name="", description="Test description", sub_heading="", categories=[]
)
assert result == "Test description"
@pytest.mark.asyncio(loop_scope="session")
async def test_generate_embedding_success():
"""Test successful embedding generation."""
# Mock OpenAI response
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.data = [MagicMock()]
mock_response.data[0].embedding = [0.1, 0.2, 0.3] * 512 # 1536 dimensions
# Use AsyncMock for async embeddings.create method
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
# Patch at the point of use in embeddings.py
with patch(
"backend.api.features.store.embeddings.get_openai_client"
) as mock_get_client:
mock_get_client.return_value = mock_client
result = await embeddings.generate_embedding("test text")
assert result is not None
assert len(result) == 1536
assert result[0] == 0.1
mock_client.embeddings.create.assert_called_once_with(
model="text-embedding-3-small", input="test text"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_generate_embedding_no_api_key():
"""Test embedding generation without API key."""
# Patch at the point of use in embeddings.py
with patch(
"backend.api.features.store.embeddings.get_openai_client"
) as mock_get_client:
mock_get_client.return_value = None
result = await embeddings.generate_embedding("test text")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_generate_embedding_api_error():
"""Test embedding generation with API error."""
mock_client = MagicMock()
mock_client.embeddings.create = AsyncMock(side_effect=Exception("API Error"))
# Patch at the point of use in embeddings.py
with patch(
"backend.api.features.store.embeddings.get_openai_client"
) as mock_get_client:
mock_get_client.return_value = mock_client
result = await embeddings.generate_embedding("test text")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_generate_embedding_text_truncation():
"""Test that long text is properly truncated using tiktoken."""
from tiktoken import encoding_for_model
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.data = [MagicMock()]
mock_response.data[0].embedding = [0.1] * 1536
# Use AsyncMock for async embeddings.create method
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
# Patch at the point of use in embeddings.py
with patch(
"backend.api.features.store.embeddings.get_openai_client"
) as mock_get_client:
mock_get_client.return_value = mock_client
# Create text that will exceed 8191 tokens
# Use varied characters to ensure token-heavy text: each word is ~1 token
words = [f"word{i}" for i in range(10000)]
long_text = " ".join(words) # ~10000 tokens
await embeddings.generate_embedding(long_text)
# Verify text was truncated to 8191 tokens
call_args = mock_client.embeddings.create.call_args
truncated_text = call_args.kwargs["input"]
# Count actual tokens in truncated text
enc = encoding_for_model("text-embedding-3-small")
actual_tokens = len(enc.encode(truncated_text))
# Should be at or just under 8191 tokens
assert actual_tokens <= 8191
# Should be close to the limit (not over-truncated)
assert actual_tokens >= 8100
@pytest.mark.asyncio(loop_scope="session")
async def test_store_embedding_success(mocker):
"""Test successful embedding storage."""
mock_client = mocker.AsyncMock()
mock_client.execute_raw = mocker.AsyncMock()
embedding = [0.1, 0.2, 0.3]
result = await embeddings.store_embedding(
version_id="test-version-id", embedding=embedding, tx=mock_client
)
assert result is True
# execute_raw is called twice: once for SET search_path, once for INSERT
assert mock_client.execute_raw.call_count == 2
# First call: SET search_path
first_call_args = mock_client.execute_raw.call_args_list[0][0]
assert "SET search_path" in first_call_args[0]
# Second call: INSERT query with the actual data
second_call_args = mock_client.execute_raw.call_args_list[1][0]
assert "test-version-id" in second_call_args
assert "[0.1,0.2,0.3]" in second_call_args
assert None in second_call_args # userId should be None for store agents
@pytest.mark.asyncio(loop_scope="session")
async def test_store_embedding_database_error(mocker):
"""Test embedding storage with database error."""
mock_client = mocker.AsyncMock()
mock_client.execute_raw.side_effect = Exception("Database error")
embedding = [0.1, 0.2, 0.3]
result = await embeddings.store_embedding(
version_id="test-version-id", embedding=embedding, tx=mock_client
)
assert result is False
@pytest.mark.asyncio(loop_scope="session")
async def test_get_embedding_success():
"""Test successful embedding retrieval."""
mock_result = [
{
"contentType": "STORE_AGENT",
"contentId": "test-version-id",
"userId": None,
"embedding": "[0.1,0.2,0.3]",
"searchableText": "Test text",
"metadata": {},
"createdAt": "2024-01-01T00:00:00Z",
"updatedAt": "2024-01-01T00:00:00Z",
}
]
with patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_result,
):
result = await embeddings.get_embedding("test-version-id")
assert result is not None
assert result["storeListingVersionId"] == "test-version-id"
assert result["embedding"] == "[0.1,0.2,0.3]"
@pytest.mark.asyncio(loop_scope="session")
async def test_get_embedding_not_found():
"""Test embedding retrieval when not found."""
with patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=[],
):
result = await embeddings.get_embedding("test-version-id")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
@patch("backend.api.features.store.embeddings.store_embedding")
@patch("backend.api.features.store.embeddings.get_embedding")
async def test_ensure_embedding_already_exists(mock_get, mock_store, mock_generate):
"""Test ensure_embedding when embedding already exists."""
mock_get.return_value = {"embedding": "[0.1,0.2,0.3]"}
result = await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
assert result is True
mock_generate.assert_not_called()
mock_store.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
@patch("backend.api.features.store.embeddings.store_content_embedding")
@patch("backend.api.features.store.embeddings.get_embedding")
async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
"""Test ensure_embedding creating new embedding."""
mock_get.return_value = None
mock_generate.return_value = [0.1, 0.2, 0.3]
mock_store.return_value = True
result = await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
assert result is True
mock_generate.assert_called_once_with("Test Test heading Test description test")
mock_store.assert_called_once_with(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1, 0.2, 0.3],
searchable_text="Test Test heading Test description test",
metadata={"name": "Test", "subHeading": "Test heading", "categories": ["test"]},
user_id=None,
tx=None,
)
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
@patch("backend.api.features.store.embeddings.get_embedding")
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
"""Test ensure_embedding when generation fails."""
mock_get.return_value = None
mock_generate.return_value = None
result = await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
assert result is False
@pytest.mark.asyncio(loop_scope="session")
async def test_get_embedding_stats():
"""Test embedding statistics retrieval."""
# Mock approved count query and embedded count query
mock_approved_result = [{"count": 100}]
mock_embedded_result = [{"count": 75}]
with patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
side_effect=[mock_approved_result, mock_embedded_result],
):
result = await embeddings.get_embedding_stats()
assert result["total_approved"] == 100
assert result["with_embeddings"] == 75
assert result["without_embeddings"] == 25
assert result["coverage_percent"] == 75.0
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.ensure_embedding")
async def test_backfill_missing_embeddings_success(mock_ensure):
"""Test backfill with successful embedding generation."""
# Mock missing embeddings query
mock_missing = [
{
"id": "version-1",
"name": "Agent 1",
"description": "Description 1",
"subHeading": "Heading 1",
"categories": ["AI"],
},
{
"id": "version-2",
"name": "Agent 2",
"description": "Description 2",
"subHeading": "Heading 2",
"categories": ["Productivity"],
},
]
# Mock ensure_embedding to succeed for first, fail for second
mock_ensure.side_effect = [True, False]
with patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_missing,
):
result = await embeddings.backfill_missing_embeddings(batch_size=5)
assert result["processed"] == 2
assert result["success"] == 1
assert result["failed"] == 1
assert mock_ensure.call_count == 2
@pytest.mark.asyncio(loop_scope="session")
async def test_backfill_missing_embeddings_no_missing():
"""Test backfill when no embeddings are missing."""
with patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=[],
):
result = await embeddings.backfill_missing_embeddings(batch_size=5)
assert result["processed"] == 0
assert result["success"] == 0
assert result["failed"] == 0
assert result["message"] == "No missing embeddings"
@pytest.mark.asyncio(loop_scope="session")
async def test_embedding_to_vector_string():
"""Test embedding to PostgreSQL vector string conversion."""
embedding = [0.1, 0.2, 0.3, -0.4]
result = embeddings.embedding_to_vector_string(embedding)
assert result == "[0.1,0.2,0.3,-0.4]"
@pytest.mark.asyncio(loop_scope="session")
async def test_embed_query():
"""Test embed_query function (alias for generate_embedding)."""
with patch(
"backend.api.features.store.embeddings.generate_embedding"
) as mock_generate:
mock_generate.return_value = [0.1, 0.2, 0.3]
result = await embeddings.embed_query("test query")
assert result == [0.1, 0.2, 0.3]
mock_generate.assert_called_once_with("test query")

View File

@@ -0,0 +1,393 @@
"""
Hybrid Search for Store Agents
Combines semantic (embedding) search with lexical (tsvector) search
for improved relevance in marketplace agent discovery.
"""
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Literal
from backend.api.features.store.embeddings import (
embed_query,
embedding_to_vector_string,
)
from backend.data.db import query_raw_with_schema
logger = logging.getLogger(__name__)
@dataclass
class HybridSearchWeights:
"""Weights for combining search signals."""
semantic: float = 0.30 # Embedding cosine similarity
lexical: float = 0.30 # tsvector ts_rank_cd score
category: float = 0.20 # Category match boost
recency: float = 0.10 # Newer agents ranked higher
popularity: float = 0.10 # Agent usage/runs (PageRank-like)
def __post_init__(self):
"""Validate weights are non-negative and sum to approximately 1.0."""
total = (
self.semantic
+ self.lexical
+ self.category
+ self.recency
+ self.popularity
)
if any(
w < 0
for w in [
self.semantic,
self.lexical,
self.category,
self.recency,
self.popularity,
]
):
raise ValueError("All weights must be non-negative")
if not (0.99 <= total <= 1.01):
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
DEFAULT_WEIGHTS = HybridSearchWeights()
# Minimum relevance score threshold - agents below this are filtered out
# With weights (0.30 semantic + 0.30 lexical + 0.20 category + 0.10 recency + 0.10 popularity):
# - 0.20 means at least ~60% semantic match OR strong lexical match required
# - Ensures only genuinely relevant results are returned
# - Recency/popularity alone (0.10 each) won't pass the threshold
DEFAULT_MIN_SCORE = 0.20
@dataclass
class HybridSearchResult:
"""A single search result with score breakdown."""
slug: str
agent_name: str
agent_image: str
creator_username: str
creator_avatar: str
sub_heading: str
description: str
runs: int
rating: float
categories: list[str]
featured: bool
is_available: bool
updated_at: datetime
# Score breakdown (for debugging/tuning)
combined_score: float
semantic_score: float = 0.0
lexical_score: float = 0.0
category_score: float = 0.0
recency_score: float = 0.0
popularity_score: float = 0.0
async def hybrid_search(
query: str,
featured: bool = False,
creators: list[str] | None = None,
category: str | None = None,
sorted_by: (
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
) = None,
page: int = 1,
page_size: int = 20,
weights: HybridSearchWeights | None = None,
min_score: float | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
Perform hybrid search combining semantic and lexical signals.
Args:
query: Search query string
featured: Filter for featured agents only
creators: Filter by creator usernames
category: Filter by category
sorted_by: Sort order (relevance uses hybrid scoring)
page: Page number (1-indexed)
page_size: Results per page
weights: Custom weights for search signals
min_score: Minimum relevance score threshold (0-1). Results below
this score are filtered out. Defaults to DEFAULT_MIN_SCORE.
Returns:
Tuple of (results list, total count). Returns empty list if no
results meet the minimum relevance threshold.
"""
# Validate inputs
query = query.strip()
if not query:
return [], 0 # Empty query returns no results
if page < 1:
page = 1
if page_size < 1:
page_size = 1
if page_size > 100: # Cap at reasonable limit to prevent performance issues
page_size = 100
if weights is None:
weights = DEFAULT_WEIGHTS
if min_score is None:
min_score = DEFAULT_MIN_SCORE
offset = (page - 1) * page_size
# Generate query embedding
query_embedding = await embed_query(query)
# Build WHERE clause conditions
where_parts: list[str] = ["sa.is_available = true"]
params: list[Any] = []
param_index = 1
# Add search query for lexical matching
params.append(query)
query_param = f"${param_index}"
param_index += 1
# Add lowercased query for category matching
params.append(query.lower())
query_lower_param = f"${param_index}"
param_index += 1
if featured:
where_parts.append("sa.featured = true")
if creators:
where_parts.append(f"sa.creator_username = ANY(${param_index})")
params.append(creators)
param_index += 1
if category:
where_parts.append(f"${param_index} = ANY(sa.categories)")
params.append(category)
param_index += 1
# Safe: where_parts only contains hardcoded strings with $N parameter placeholders
# No user input is concatenated directly into the SQL string
where_clause = " AND ".join(where_parts)
# Embedding is required for hybrid search - fail fast if unavailable
if query_embedding is None or not query_embedding:
# Log detailed error server-side
logger.error(
"Failed to generate query embedding. "
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
)
# Raise generic error to client
raise ValueError("Search service temporarily unavailable")
# Add embedding parameter
embedding_str = embedding_to_vector_string(query_embedding)
params.append(embedding_str)
embedding_param = f"${param_index}"
param_index += 1
# Add weight parameters for SQL calculation
params.append(weights.semantic)
weight_semantic_param = f"${param_index}"
param_index += 1
params.append(weights.lexical)
weight_lexical_param = f"${param_index}"
param_index += 1
params.append(weights.category)
weight_category_param = f"${param_index}"
param_index += 1
params.append(weights.recency)
weight_recency_param = f"${param_index}"
param_index += 1
params.append(weights.popularity)
weight_popularity_param = f"${param_index}"
param_index += 1
# Add min_score parameter
params.append(min_score)
min_score_param = f"${param_index}"
param_index += 1
# Optimized hybrid search query:
# 1. Direct join to UnifiedContentEmbedding via contentId=storeListingVersionId (no redundant JOINs)
# 2. UNION approach (deduplicates agents matching both branches)
# 3. COUNT(*) OVER() to get total count in single query
# 4. Optimized category matching with EXISTS + unnest
# 5. Pre-calculated max values for lexical and popularity normalization
# 6. Simplified recency calculation with linear decay
# 7. Logarithmic popularity scaling to prevent viral agents from dominating
sql_query = f"""
WITH candidates AS (
-- Lexical matches (uses GIN index on search column)
SELECT sa."storeListingVersionId"
FROM {{schema_prefix}}"StoreAgent" sa
WHERE {where_clause}
AND sa.search @@ plainto_tsquery('english', {query_param})
UNION
-- Semantic matches (uses HNSW index on embedding with KNN)
SELECT "storeListingVersionId"
FROM (
SELECT sa."storeListingVersionId", uce.embedding
FROM {{schema_prefix}}"StoreAgent" sa
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
WHERE {where_clause}
ORDER BY uce.embedding <=> {embedding_param}::vector
LIMIT 200
) semantic_results
),
search_scores AS (
SELECT
sa.slug,
sa.agent_name,
sa.agent_image,
sa.creator_username,
sa.creator_avatar,
sa.sub_heading,
sa.description,
sa.runs,
sa.rating,
sa.categories,
sa.featured,
sa.is_available,
sa.updated_at,
-- Semantic score: cosine similarity (1 - distance)
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
-- Lexical score: ts_rank_cd (will be normalized later)
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
-- Category match: optimized with unnest for better performance
CASE
WHEN EXISTS (
SELECT 1 FROM unnest(sa.categories) cat
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
)
THEN 1.0
ELSE 0.0
END as category_score,
-- Recency score: linear decay over 90 days (simpler than exponential)
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score,
-- Popularity raw: agent runs count (will be normalized with log scaling)
sa.runs as popularity_raw
FROM candidates c
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON c."storeListingVersionId" = sa."storeListingVersionId"
LEFT JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
),
max_lexical AS (
SELECT MAX(lexical_raw) as max_val FROM search_scores
),
max_popularity AS (
SELECT MAX(popularity_raw) as max_val FROM search_scores
),
normalized AS (
SELECT
ss.*,
-- Normalize lexical score by pre-calculated max
CASE
WHEN ml.max_val > 0
THEN ss.lexical_raw / ml.max_val
ELSE 0
END as lexical_score,
-- Normalize popularity with logarithmic scaling to prevent viral agents from dominating
-- LOG(1 + runs) / LOG(1 + max_runs) ensures score is 0-1 range
CASE
WHEN mp.max_val > 0 AND ss.popularity_raw > 0
THEN LN(1 + ss.popularity_raw) / LN(1 + mp.max_val)
ELSE 0
END as popularity_score
FROM search_scores ss
CROSS JOIN max_lexical ml
CROSS JOIN max_popularity mp
),
scored AS (
SELECT
slug,
agent_name,
agent_image,
creator_username,
creator_avatar,
sub_heading,
description,
runs,
rating,
categories,
featured,
is_available,
updated_at,
semantic_score,
lexical_score,
category_score,
recency_score,
popularity_score,
(
{weight_semantic_param} * semantic_score +
{weight_lexical_param} * lexical_score +
{weight_category_param} * category_score +
{weight_recency_param} * recency_score +
{weight_popularity_param} * popularity_score
) as combined_score
FROM normalized
),
filtered AS (
SELECT
*,
COUNT(*) OVER () as total_count
FROM scored
WHERE combined_score >= {min_score_param}
)
SELECT * FROM filtered
ORDER BY combined_score DESC
LIMIT ${param_index} OFFSET ${param_index + 1}
"""
# Add pagination params
params.extend([page_size, offset])
# Execute search query - includes total_count via window function
results = await query_raw_with_schema(
sql_query, *params, set_public_search_path=True
)
# Extract total count from first result (all rows have same count)
total = results[0]["total_count"] if results else 0
# Remove total_count from results before returning
for result in results:
result.pop("total_count", None)
# Log without sensitive query content
logger.info(f"Hybrid search: {len(results)} results, {total} total")
return results, total
async def hybrid_search_simple(
query: str,
page: int = 1,
page_size: int = 20,
) -> tuple[list[dict[str, Any]], int]:
"""
Simplified hybrid search for common use cases.
Uses default weights and no filters.
"""
return await hybrid_search(
query=query,
page=page,
page_size=page_size,
)

View File

@@ -0,0 +1,334 @@
"""
Integration tests for hybrid search with schema handling.
These tests verify that hybrid search works correctly across different database schemas.
"""
from unittest.mock import patch
import pytest
from backend.api.features.store.hybrid_search import HybridSearchWeights, hybrid_search
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_with_schema_handling():
"""Test that hybrid search correctly handles database schema prefixes."""
# Test with a mock query to ensure schema handling works
query = "test agent"
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Mock the query result
mock_query.return_value = [
{
"slug": "test/agent",
"agent_name": "Test Agent",
"agent_image": "test.png",
"creator_username": "test",
"creator_avatar": "avatar.png",
"sub_heading": "Test sub-heading",
"description": "Test description",
"runs": 10,
"rating": 4.5,
"categories": ["test"],
"featured": False,
"is_available": True,
"updated_at": "2024-01-01T00:00:00Z",
"combined_score": 0.8,
"semantic_score": 0.7,
"lexical_score": 0.6,
"category_score": 0.5,
"recency_score": 0.4,
"total_count": 1,
}
]
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * 1536 # Mock embedding
results, total = await hybrid_search(
query=query,
page=1,
page_size=20,
)
# Verify the query was called
assert mock_query.called
# Verify the SQL template uses schema_prefix placeholder
call_args = mock_query.call_args
sql_template = call_args[0][0]
assert "{schema_prefix}" in sql_template
# Verify results
assert len(results) == 1
assert total == 1
assert results[0]["slug"] == "test/agent"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_with_public_schema():
"""Test hybrid search when using public schema (no prefix needed)."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "public"
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * 1536
results, total = await hybrid_search(
query="test",
page=1,
page_size=20,
)
# Verify the mock was set up correctly
assert mock_schema.return_value == "public"
# Results should work even with empty results
assert results == []
assert total == 0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_with_custom_schema():
"""Test hybrid search when using custom schema (e.g., 'platform')."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * 1536
results, total = await hybrid_search(
query="test",
page=1,
page_size=20,
)
# Verify the mock was set up correctly
assert mock_schema.return_value == "platform"
assert results == []
assert total == 0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_without_embeddings():
"""Test hybrid search fails fast when embeddings are unavailable."""
# Patch where the function is used, not where it's defined
with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed:
# Simulate embedding failure
mock_embed.return_value = None
# Should raise ValueError with helpful message
with pytest.raises(ValueError) as exc_info:
await hybrid_search(
query="test",
page=1,
page_size=20,
)
# Verify error message is generic (doesn't leak implementation details)
assert "Search service temporarily unavailable" in str(exc_info.value)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_with_filters():
"""Test hybrid search with various filters."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * 1536
# Test with featured filter
results, total = await hybrid_search(
query="test",
featured=True,
creators=["user1", "user2"],
category="productivity",
page=1,
page_size=10,
)
# Verify filters were applied in the query
call_args = mock_query.call_args
params = call_args[0][1:] # Skip SQL template
# Should have query, query_lower, creators array, category
assert len(params) >= 4
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_weights():
"""Test hybrid search with custom weights."""
custom_weights = HybridSearchWeights(
semantic=0.5,
lexical=0.3,
category=0.1,
recency=0.1,
popularity=0.0,
)
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * 1536
results, total = await hybrid_search(
query="test",
weights=custom_weights,
page=1,
page_size=20,
)
# Verify custom weights were used in the query
call_args = mock_query.call_args
sql_template = call_args[0][0]
params = call_args[0][1:] # Get all parameters passed
# Check that SQL uses parameterized weights (not f-string interpolation)
assert "$" in sql_template # Verify parameterization is used
# Check that custom weights are in the params
assert 0.5 in params # semantic weight
assert 0.3 in params # lexical weight
assert 0.1 in params # category and recency weights
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_min_score_filtering():
"""Test hybrid search minimum score threshold."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Return results with varying scores
mock_query.return_value = [
{
"slug": "high-score/agent",
"agent_name": "High Score Agent",
"combined_score": 0.8,
"total_count": 1,
# ... other fields
}
]
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * 1536
# Test with custom min_score
results, total = await hybrid_search(
query="test",
min_score=0.5, # High threshold
page=1,
page_size=20,
)
# Verify min_score was applied in query
call_args = mock_query.call_args
sql_template = call_args[0][0]
params = call_args[0][1:] # Get all parameters
# Check that SQL uses parameterized min_score
assert "combined_score >=" in sql_template
assert "$" in sql_template # Verify parameterization
# Check that custom min_score is in the params
assert 0.5 in params
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_pagination():
"""Test hybrid search pagination."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * 1536
# Test page 2 with page_size 10
results, total = await hybrid_search(
query="test",
page=2,
page_size=10,
)
# Verify pagination parameters
call_args = mock_query.call_args
params = call_args[0]
# Last two params should be LIMIT and OFFSET
limit = params[-2]
offset = params[-1]
assert limit == 10 # page_size
assert offset == 10 # (page - 1) * page_size = (2 - 1) * 10
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_error_handling():
"""Test hybrid search error handling."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Simulate database error
mock_query.side_effect = Exception("Database connection error")
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * 1536
# Should raise exception
with pytest.raises(Exception) as exc_info:
await hybrid_search(
query="test",
page=1,
page_size=20,
)
assert "Database connection error" in str(exc_info.value)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -64,7 +64,6 @@ from backend.data.onboarding import (
complete_re_run_agent, complete_re_run_agent,
get_recommended_agents, get_recommended_agents,
get_user_onboarding, get_user_onboarding,
increment_runs,
onboarding_enabled, onboarding_enabled,
reset_user_onboarding, reset_user_onboarding,
update_user_onboarding, update_user_onboarding,
@@ -975,7 +974,6 @@ async def execute_graph(
# Record successful graph execution # Record successful graph execution
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id) record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
record_graph_operation(operation="execute", status="success") record_graph_operation(operation="execute", status="success")
await increment_runs(user_id)
await complete_re_run_agent(user_id, graph_id) await complete_re_run_agent(user_id, graph_id)
if source == "library": if source == "library":
await complete_onboarding_step( await complete_onboarding_step(

View File

@@ -1,631 +0,0 @@
import json
import shlex
import uuid
from typing import Literal, Optional
from e2b import AsyncSandbox as BaseAsyncSandbox
from pydantic import BaseModel, SecretStr
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
# Test credentials for E2B
TEST_E2B_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="e2b",
api_key=SecretStr("mock-e2b-api-key"),
title="Mock E2B API key",
expires_at=None,
)
TEST_E2B_CREDENTIALS_INPUT = {
"provider": TEST_E2B_CREDENTIALS.provider,
"id": TEST_E2B_CREDENTIALS.id,
"type": TEST_E2B_CREDENTIALS.type,
"title": TEST_E2B_CREDENTIALS.title,
}
# Test credentials for Anthropic
TEST_ANTHROPIC_CREDENTIALS = APIKeyCredentials(
id="2e568a2b-b2ea-475a-8564-9a676bf31c56",
provider="anthropic",
api_key=SecretStr("mock-anthropic-api-key"),
title="Mock Anthropic API key",
expires_at=None,
)
TEST_ANTHROPIC_CREDENTIALS_INPUT = {
"provider": TEST_ANTHROPIC_CREDENTIALS.provider,
"id": TEST_ANTHROPIC_CREDENTIALS.id,
"type": TEST_ANTHROPIC_CREDENTIALS.type,
"title": TEST_ANTHROPIC_CREDENTIALS.title,
}
class ClaudeCodeBlock(Block):
"""
Execute tasks using Claude Code (Anthropic's AI coding assistant) in an E2B sandbox.
Claude Code can create files, install tools, run commands, and perform complex
coding tasks autonomously within a secure sandbox environment.
"""
# Use base template - we'll install Claude Code ourselves for latest version
DEFAULT_TEMPLATE = "base"
class Input(BlockSchemaInput):
e2b_credentials: CredentialsMetaInput[
Literal[ProviderName.E2B], Literal["api_key"]
] = CredentialsField(
description=(
"API key for the E2B platform to create the sandbox. "
"Get one on the [e2b website](https://e2b.dev/docs)"
),
)
anthropic_credentials: CredentialsMetaInput[
Literal[ProviderName.ANTHROPIC], Literal["api_key"]
] = CredentialsField(
description=(
"API key for Anthropic to power Claude Code. "
"Get one at [Anthropic's website](https://console.anthropic.com)"
),
)
prompt: str = SchemaField(
description=(
"The task or instruction for Claude Code to execute. "
"Claude Code can create files, install packages, run commands, "
"and perform complex coding tasks."
),
placeholder="Create a hello world index.html file",
default="",
advanced=False,
)
timeout: int = SchemaField(
description=(
"Sandbox timeout in seconds. Claude Code tasks can take "
"a while, so set this appropriately for your task complexity. "
"Note: This only applies when creating a new sandbox. "
"When reconnecting to an existing sandbox via sandbox_id, "
"the original timeout is retained."
),
default=300, # 5 minutes default
advanced=True,
)
setup_commands: list[str] = SchemaField(
description=(
"Optional shell commands to run before executing Claude Code. "
"Useful for installing dependencies or setting up the environment."
),
default_factory=list,
advanced=True,
)
working_directory: str = SchemaField(
description="Working directory for Claude Code to operate in.",
default="/home/user",
advanced=True,
)
# Session/continuation support
session_id: str = SchemaField(
description=(
"Session ID to resume a previous conversation. "
"Leave empty for a new conversation. "
"Use the session_id from a previous run to continue that conversation."
),
default="",
advanced=True,
)
sandbox_id: str = SchemaField(
description=(
"Sandbox ID to reconnect to an existing sandbox. "
"Required when resuming a session (along with session_id). "
"Use the sandbox_id from a previous run where dispose_sandbox was False."
),
default="",
advanced=True,
)
conversation_history: str = SchemaField(
description=(
"Previous conversation history to continue from. "
"Use this to restore context on a fresh sandbox if the previous one timed out. "
"Pass the conversation_history output from a previous run."
),
default="",
advanced=True,
)
dispose_sandbox: bool = SchemaField(
description=(
"Whether to dispose of the sandbox immediately after execution. "
"Set to False if you want to continue the conversation later "
"(you'll need both sandbox_id and session_id from the output)."
),
default=True,
advanced=True,
)
class FileOutput(BaseModel):
"""A file extracted from the sandbox."""
path: str
relative_path: str # Path relative to working directory (for GitHub, etc.)
name: str
content: str
class Output(BlockSchemaOutput):
response: str = SchemaField(
description="The output/response from Claude Code execution"
)
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
description=(
"List of text files created/modified by Claude Code during this execution. "
"Each file has 'path', 'relative_path', 'name', and 'content' fields."
)
)
conversation_history: str = SchemaField(
description=(
"Full conversation history including this turn. "
"Pass this to conversation_history input to continue on a fresh sandbox "
"if the previous sandbox timed out."
)
)
session_id: str = SchemaField(
description=(
"Session ID for this conversation. "
"Pass this back along with sandbox_id to continue the conversation."
)
)
sandbox_id: Optional[str] = SchemaField(
description=(
"ID of the sandbox instance. "
"Pass this back along with session_id to continue the conversation. "
"This is None if dispose_sandbox was True (sandbox was disposed)."
),
default=None,
)
error: str = SchemaField(description="Error message if execution failed")
def __init__(self):
super().__init__(
id="4e34f4a5-9b89-4326-ba77-2dd6750b7194",
description=(
"Execute tasks using Claude Code in an E2B sandbox. "
"Claude Code can create files, install tools, run commands, "
"and perform complex coding tasks autonomously."
),
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.AI},
input_schema=ClaudeCodeBlock.Input,
output_schema=ClaudeCodeBlock.Output,
test_credentials={
"e2b_credentials": TEST_E2B_CREDENTIALS,
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS,
},
test_input={
"e2b_credentials": TEST_E2B_CREDENTIALS_INPUT,
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS_INPUT,
"prompt": "Create a hello world HTML file",
"timeout": 300,
"setup_commands": [],
"working_directory": "/home/user",
"session_id": "",
"sandbox_id": "",
"conversation_history": "",
"dispose_sandbox": True,
},
test_output=[
("response", "Created index.html with hello world content"),
(
"files",
[
{
"path": "/home/user/index.html",
"relative_path": "index.html",
"name": "index.html",
"content": "<html>Hello World</html>",
}
],
),
(
"conversation_history",
"User: Create a hello world HTML file\n"
"Claude: Created index.html with hello world content",
),
("session_id", str),
("sandbox_id", None), # None because dispose_sandbox=True in test_input
],
test_mock={
"execute_claude_code": lambda *args, **kwargs: (
"Created index.html with hello world content", # response
[
ClaudeCodeBlock.FileOutput(
path="/home/user/index.html",
relative_path="index.html",
name="index.html",
content="<html>Hello World</html>",
)
], # files
"User: Create a hello world HTML file\n"
"Claude: Created index.html with hello world content", # conversation_history
"test-session-id", # session_id
"sandbox_id", # sandbox_id
),
},
)
async def execute_claude_code(
self,
e2b_api_key: str,
anthropic_api_key: str,
prompt: str,
timeout: int,
setup_commands: list[str],
working_directory: str,
session_id: str,
existing_sandbox_id: str,
conversation_history: str,
dispose_sandbox: bool,
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
"""
Execute Claude Code in an E2B sandbox.
Returns:
Tuple of (response, files, conversation_history, session_id, sandbox_id)
"""
# Validate that sandbox_id is provided when resuming a session
if session_id and not existing_sandbox_id:
raise ValueError(
"sandbox_id is required when resuming a session with session_id. "
"The session state is stored in the original sandbox. "
"If the sandbox has timed out, use conversation_history instead "
"to restore context on a fresh sandbox."
)
sandbox = None
try:
# Either reconnect to existing sandbox or create a new one
if existing_sandbox_id:
# Reconnect to existing sandbox for conversation continuation
sandbox = await BaseAsyncSandbox.connect(
sandbox_id=existing_sandbox_id,
api_key=e2b_api_key,
)
else:
# Create new sandbox
sandbox = await BaseAsyncSandbox.create(
template=self.DEFAULT_TEMPLATE,
api_key=e2b_api_key,
timeout=timeout,
envs={"ANTHROPIC_API_KEY": anthropic_api_key},
)
# Install Claude Code from npm (ensures we get the latest version)
install_result = await sandbox.commands.run(
"npm install -g @anthropic-ai/claude-code@latest",
timeout=120, # 2 min timeout for install
)
if install_result.exit_code != 0:
raise Exception(
f"Failed to install Claude Code: {install_result.stderr}"
)
# Run any user-provided setup commands
for cmd in setup_commands:
setup_result = await sandbox.commands.run(cmd)
if setup_result.exit_code != 0:
raise Exception(
f"Setup command failed: {cmd}\n"
f"Exit code: {setup_result.exit_code}\n"
f"Stdout: {setup_result.stdout}\n"
f"Stderr: {setup_result.stderr}"
)
# Generate or use provided session ID
current_session_id = session_id if session_id else str(uuid.uuid4())
# Build base Claude flags
base_flags = "-p --dangerously-skip-permissions --output-format json"
# Add conversation history context if provided (for fresh sandbox continuation)
history_flag = ""
if conversation_history and not session_id:
# Inject previous conversation as context via system prompt
# Use consistent escaping via _escape_prompt helper
escaped_history = self._escape_prompt(
f"Previous conversation context: {conversation_history}"
)
history_flag = f" --append-system-prompt {escaped_history}"
# Build Claude command based on whether we're resuming or starting new
# Use shlex.quote for working_directory and session IDs to prevent injection
safe_working_dir = shlex.quote(working_directory)
if session_id:
# Resuming existing session (sandbox still alive)
safe_session_id = shlex.quote(session_id)
claude_command = (
f"cd {safe_working_dir} && "
f"echo {self._escape_prompt(prompt)} | "
f"claude --resume {safe_session_id} {base_flags}"
)
else:
# New session with specific ID
safe_current_session_id = shlex.quote(current_session_id)
claude_command = (
f"cd {safe_working_dir} && "
f"echo {self._escape_prompt(prompt)} | "
f"claude --session-id {safe_current_session_id} {base_flags}{history_flag}"
)
# Capture timestamp before running Claude Code to filter files later
# Capture timestamp 1 second in the past to avoid race condition with file creation
timestamp_result = await sandbox.commands.run(
"date -u -d '1 second ago' +%Y-%m-%dT%H:%M:%S"
)
if timestamp_result.exit_code != 0:
raise RuntimeError(
f"Failed to capture timestamp: {timestamp_result.stderr}"
)
start_timestamp = (
timestamp_result.stdout.strip() if timestamp_result.stdout else None
)
result = await sandbox.commands.run(
claude_command,
timeout=0, # No command timeout - let sandbox timeout handle it
)
# Check for command failure
if result.exit_code != 0:
error_msg = result.stderr or result.stdout or "Unknown error"
raise Exception(
f"Claude Code command failed with exit code {result.exit_code}:\n"
f"{error_msg}"
)
raw_output = result.stdout or ""
sandbox_id = sandbox.sandbox_id
# Parse JSON output to extract response and build conversation history
response = ""
new_conversation_history = conversation_history or ""
try:
# The JSON output contains the result
output_data = json.loads(raw_output)
response = output_data.get("result", raw_output)
# Build conversation history entry
turn_entry = f"User: {prompt}\nClaude: {response}"
if new_conversation_history:
new_conversation_history = (
f"{new_conversation_history}\n\n{turn_entry}"
)
else:
new_conversation_history = turn_entry
except json.JSONDecodeError:
# If not valid JSON, use raw output
response = raw_output
turn_entry = f"User: {prompt}\nClaude: {response}"
if new_conversation_history:
new_conversation_history = (
f"{new_conversation_history}\n\n{turn_entry}"
)
else:
new_conversation_history = turn_entry
# Extract files created/modified during this run
files = await self._extract_files(
sandbox, working_directory, start_timestamp
)
return (
response,
files,
new_conversation_history,
current_session_id,
sandbox_id,
)
finally:
if dispose_sandbox and sandbox:
await sandbox.kill()
async def _extract_files(
self,
sandbox: BaseAsyncSandbox,
working_directory: str,
since_timestamp: str | None = None,
) -> list["ClaudeCodeBlock.FileOutput"]:
"""
Extract text files created/modified during this Claude Code execution.
Args:
sandbox: The E2B sandbox instance
working_directory: Directory to search for files
since_timestamp: ISO timestamp - only return files modified after this time
Returns:
List of FileOutput objects with path, relative_path, name, and content
"""
files: list[ClaudeCodeBlock.FileOutput] = []
# Text file extensions we can safely read as text
text_extensions = {
".txt",
".md",
".html",
".htm",
".css",
".js",
".ts",
".jsx",
".tsx",
".json",
".xml",
".yaml",
".yml",
".toml",
".ini",
".cfg",
".conf",
".py",
".rb",
".php",
".java",
".c",
".cpp",
".h",
".hpp",
".cs",
".go",
".rs",
".swift",
".kt",
".scala",
".sh",
".bash",
".zsh",
".sql",
".graphql",
".env",
".gitignore",
".dockerfile",
"Dockerfile",
".vue",
".svelte",
".astro",
".mdx",
".rst",
".tex",
".csv",
".log",
}
try:
# List files recursively using find command
# Exclude node_modules and .git directories, but allow hidden files
# like .env and .gitignore (they're filtered by text_extensions later)
# Filter by timestamp to only get files created/modified during this run
safe_working_dir = shlex.quote(working_directory)
timestamp_filter = ""
if since_timestamp:
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
find_result = await sandbox.commands.run(
f"find {safe_working_dir} -type f "
f"{timestamp_filter}"
f"-not -path '*/node_modules/*' "
f"-not -path '*/.git/*' "
f"2>/dev/null"
)
if find_result.stdout:
for file_path in find_result.stdout.strip().split("\n"):
if not file_path:
continue
# Check if it's a text file we can read
is_text = any(
file_path.endswith(ext) for ext in text_extensions
) or file_path.endswith("Dockerfile")
if is_text:
try:
content = await sandbox.files.read(file_path)
# Handle bytes or string
if isinstance(content, bytes):
content = content.decode("utf-8", errors="replace")
# Extract filename from path
file_name = file_path.split("/")[-1]
# Calculate relative path by stripping working directory
relative_path = file_path
if file_path.startswith(working_directory):
relative_path = file_path[len(working_directory) :]
# Remove leading slash if present
if relative_path.startswith("/"):
relative_path = relative_path[1:]
files.append(
ClaudeCodeBlock.FileOutput(
path=file_path,
relative_path=relative_path,
name=file_name,
content=content,
)
)
except Exception:
# Skip files that can't be read
pass
except Exception:
# If file extraction fails, return empty results
pass
return files
def _escape_prompt(self, prompt: str) -> str:
"""Escape the prompt for safe shell execution."""
# Use single quotes and escape any single quotes in the prompt
escaped = prompt.replace("'", "'\"'\"'")
return f"'{escaped}'"
async def run(
self,
input_data: Input,
*,
e2b_credentials: APIKeyCredentials,
anthropic_credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
try:
(
response,
files,
conversation_history,
session_id,
sandbox_id,
) = await self.execute_claude_code(
e2b_api_key=e2b_credentials.api_key.get_secret_value(),
anthropic_api_key=anthropic_credentials.api_key.get_secret_value(),
prompt=input_data.prompt,
timeout=input_data.timeout,
setup_commands=input_data.setup_commands,
working_directory=input_data.working_directory,
session_id=input_data.session_id,
existing_sandbox_id=input_data.sandbox_id,
conversation_history=input_data.conversation_history,
dispose_sandbox=input_data.dispose_sandbox,
)
yield "response", response
# Always yield files (empty list if none) to match Output schema
yield "files", [f.model_dump() for f in files]
# Always yield conversation_history so user can restore context on fresh sandbox
yield "conversation_history", conversation_history
# Always yield session_id so user can continue conversation
yield "session_id", session_id
# Always yield sandbox_id (None if disposed) to match Output schema
yield "sandbox_id", sandbox_id if not input_data.dispose_sandbox else None
except Exception as e:
yield "error", str(e)

View File

@@ -38,6 +38,20 @@ POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
if POOL_TIMEOUT: if POOL_TIMEOUT:
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT) DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
# Add public schema to search_path for pgvector type access
# The vector extension is in public schema, but search_path is determined by schema parameter
# Extract the schema from DATABASE_URL or default to 'public' (matching get_database_schema())
parsed_url = urlparse(DATABASE_URL)
url_params = dict(parse_qsl(parsed_url.query))
db_schema = url_params.get("schema", "public")
# Build search_path, avoiding duplicates if db_schema is already 'public'
search_path_schemas = list(
dict.fromkeys([db_schema, "public"])
) # Preserves order, removes duplicates
search_path = ",".join(search_path_schemas)
# This allows using ::vector without schema qualification
DATABASE_URL = add_param(DATABASE_URL, "options", f"-c search_path={search_path}")
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
prisma = Prisma( prisma = Prisma(
@@ -108,21 +122,102 @@ def get_database_schema() -> str:
return query_params.get("schema", "public") return query_params.get("schema", "public")
async def query_raw_with_schema(query_template: str, *args) -> list[dict]: async def _raw_with_schema(
"""Execute raw SQL query with proper schema handling.""" query_template: str,
*args,
execute: bool = False,
client: Prisma | None = None,
set_public_search_path: bool = False,
) -> list[dict] | int:
"""Internal: Execute raw SQL with proper schema handling.
Use query_raw_with_schema() or execute_raw_with_schema() instead.
Args:
query_template: SQL query with {schema_prefix} placeholder
*args: Query parameters
execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE.
client: Optional Prisma client for transactions (only used when execute=True).
set_public_search_path: If True, sets search_path to include public schema.
Needed for pgvector types and other public schema objects.
Returns:
- list[dict] if execute=False (query results)
- int if execute=True (number of affected rows)
"""
schema = get_database_schema() schema = get_database_schema()
schema_prefix = f'"{schema}".' if schema != "public" else "" schema_prefix = f'"{schema}".' if schema != "public" else ""
formatted_query = query_template.format(schema_prefix=schema_prefix) formatted_query = query_template.format(schema_prefix=schema_prefix)
import prisma as prisma_module import prisma as prisma_module
result = await prisma_module.get_client().query_raw( db_client = client if client else prisma_module.get_client()
formatted_query, *args # type: ignore
) # Set search_path to include public schema if requested
# Prisma doesn't support the 'options' connection parameter, so we set it per-session
# This is idempotent and safe to call multiple times
if set_public_search_path:
await db_client.execute_raw(f"SET search_path = {schema}, public") # type: ignore
if execute:
result = await db_client.execute_raw(formatted_query, *args) # type: ignore
else:
result = await db_client.query_raw(formatted_query, *args) # type: ignore
return result return result
async def query_raw_with_schema(
query_template: str, *args, set_public_search_path: bool = False
) -> list[dict]:
"""Execute raw SQL SELECT query with proper schema handling.
Args:
query_template: SQL query with {schema_prefix} placeholder
*args: Query parameters
set_public_search_path: If True, sets search_path to include public schema.
Needed for pgvector types and other public schema objects.
Returns:
List of result rows as dictionaries
Example:
results = await query_raw_with_schema(
'SELECT * FROM {schema_prefix}"User" WHERE id = $1',
user_id
)
"""
return await _raw_with_schema(query_template, *args, execute=False, set_public_search_path=set_public_search_path) # type: ignore
async def execute_raw_with_schema(
query_template: str,
*args,
client: Prisma | None = None,
set_public_search_path: bool = False,
) -> int:
"""Execute raw SQL command (INSERT/UPDATE/DELETE) with proper schema handling.
Args:
query_template: SQL query with {schema_prefix} placeholder
*args: Query parameters
client: Optional Prisma client for transactions
set_public_search_path: If True, sets search_path to include public schema.
Needed for pgvector types and other public schema objects.
Returns:
Number of affected rows
Example:
await execute_raw_with_schema(
'INSERT INTO {schema_prefix}"User" (id, name) VALUES ($1, $2)',
user_id, name,
client=tx # Optional transaction client
)
"""
return await _raw_with_schema(query_template, *args, execute=True, client=client, set_public_search_path=set_public_search_path) # type: ignore
class BaseDbModel(BaseModel): class BaseDbModel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4())) id: str = Field(default_factory=lambda: str(uuid4()))

View File

@@ -1,5 +1,6 @@
import json import json
from typing import Any from typing import Any
from unittest.mock import AsyncMock, patch
from uuid import UUID from uuid import UUID
import fastapi.exceptions import fastapi.exceptions
@@ -18,6 +19,17 @@ from backend.usecases.sample import create_test_user
from backend.util.test import SpinTestServer from backend.util.test import SpinTestServer
@pytest.fixture(scope="session", autouse=True)
def mock_embedding_functions():
"""Mock embedding functions for all tests to avoid database/API dependencies."""
with patch(
"backend.api.features.store.db.ensure_embedding",
new_callable=AsyncMock,
return_value=True,
):
yield
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(loop_scope="session")
async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot): async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot):
""" """

View File

@@ -334,7 +334,7 @@ async def _get_user_timezone(user_id: str) -> str:
return get_user_timezone_or_utc(user.timezone if user else None) return get_user_timezone_or_utc(user.timezone if user else None)
async def increment_runs(user_id: str): async def increment_onboarding_runs(user_id: str):
""" """
Increment a user's run counters and trigger any onboarding milestones. Increment a user's run counters and trigger any onboarding milestones.
""" """

View File

@@ -7,6 +7,10 @@ from backend.api.features.library.db import (
list_library_agents, list_library_agents,
) )
from backend.api.features.store.db import get_store_agent_details, get_store_agents from backend.api.features.store.db import get_store_agent_details, get_store_agents
from backend.api.features.store.embeddings import (
backfill_missing_embeddings,
get_embedding_stats,
)
from backend.data import db from backend.data import db
from backend.data.analytics import ( from backend.data.analytics import (
get_accuracy_trends_and_alerts, get_accuracy_trends_and_alerts,
@@ -20,6 +24,7 @@ from backend.data.execution import (
get_execution_kv_data, get_execution_kv_data,
get_execution_outputs_by_node_exec_id, get_execution_outputs_by_node_exec_id,
get_frequently_executed_graphs, get_frequently_executed_graphs,
get_graph_execution,
get_graph_execution_meta, get_graph_execution_meta,
get_graph_executions, get_graph_executions,
get_graph_executions_count, get_graph_executions_count,
@@ -57,6 +62,7 @@ from backend.data.notifications import (
get_user_notification_oldest_message_in_batch, get_user_notification_oldest_message_in_batch,
remove_notifications_from_batch, remove_notifications_from_batch,
) )
from backend.data.onboarding import increment_onboarding_runs
from backend.data.user import ( from backend.data.user import (
get_active_user_ids_in_timerange, get_active_user_ids_in_timerange,
get_user_by_id, get_user_by_id,
@@ -140,6 +146,7 @@ class DatabaseManager(AppService):
get_child_graph_executions = _(get_child_graph_executions) get_child_graph_executions = _(get_child_graph_executions)
get_graph_executions = _(get_graph_executions) get_graph_executions = _(get_graph_executions)
get_graph_executions_count = _(get_graph_executions_count) get_graph_executions_count = _(get_graph_executions_count)
get_graph_execution = _(get_graph_execution)
get_graph_execution_meta = _(get_graph_execution_meta) get_graph_execution_meta = _(get_graph_execution_meta)
create_graph_execution = _(create_graph_execution) create_graph_execution = _(create_graph_execution)
get_node_execution = _(get_node_execution) get_node_execution = _(get_node_execution)
@@ -204,10 +211,17 @@ class DatabaseManager(AppService):
add_store_agent_to_library = _(add_store_agent_to_library) add_store_agent_to_library = _(add_store_agent_to_library)
validate_graph_execution_permissions = _(validate_graph_execution_permissions) validate_graph_execution_permissions = _(validate_graph_execution_permissions)
# Onboarding
increment_onboarding_runs = _(increment_onboarding_runs)
# Store # Store
get_store_agents = _(get_store_agents) get_store_agents = _(get_store_agents)
get_store_agent_details = _(get_store_agent_details) get_store_agent_details = _(get_store_agent_details)
# Store Embeddings
get_embedding_stats = _(get_embedding_stats)
backfill_missing_embeddings = _(backfill_missing_embeddings)
# Summary data - async # Summary data - async
get_user_execution_summary_data = _(get_user_execution_summary_data) get_user_execution_summary_data = _(get_user_execution_summary_data)
@@ -259,6 +273,10 @@ class DatabaseManagerClient(AppServiceClient):
get_store_agents = _(d.get_store_agents) get_store_agents = _(d.get_store_agents)
get_store_agent_details = _(d.get_store_agent_details) get_store_agent_details = _(d.get_store_agent_details)
# Store Embeddings
get_embedding_stats = _(d.get_embedding_stats)
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
class DatabaseManagerAsyncClient(AppServiceClient): class DatabaseManagerAsyncClient(AppServiceClient):
d = DatabaseManager d = DatabaseManager
@@ -274,6 +292,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_graph = d.get_graph get_graph = d.get_graph
get_graph_metadata = d.get_graph_metadata get_graph_metadata = d.get_graph_metadata
get_graph_settings = d.get_graph_settings get_graph_settings = d.get_graph_settings
get_graph_execution = d.get_graph_execution
get_graph_execution_meta = d.get_graph_execution_meta get_graph_execution_meta = d.get_graph_execution_meta
get_node = d.get_node get_node = d.get_node
get_node_execution = d.get_node_execution get_node_execution = d.get_node_execution
@@ -318,6 +337,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
add_store_agent_to_library = d.add_store_agent_to_library add_store_agent_to_library = d.add_store_agent_to_library
validate_graph_execution_permissions = d.validate_graph_execution_permissions validate_graph_execution_permissions = d.validate_graph_execution_permissions
# Onboarding
increment_onboarding_runs = d.increment_onboarding_runs
# Store # Store
get_store_agents = d.get_store_agents get_store_agents = d.get_store_agents
get_store_agent_details = d.get_store_agent_details get_store_agent_details = d.get_store_agent_details

View File

@@ -1,4 +1,5 @@
import logging import logging
from unittest.mock import AsyncMock, patch
import fastapi.responses import fastapi.responses
import pytest import pytest
@@ -19,6 +20,17 @@ from backend.util.test import SpinTestServer, wait_execution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@pytest.fixture(scope="session", autouse=True)
def mock_embedding_functions():
"""Mock embedding functions for all tests to avoid database/API dependencies."""
with patch(
"backend.api.features.store.db.ensure_embedding",
new_callable=AsyncMock,
return_value=True,
):
yield
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph: async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
logger.info(f"Creating graph for user {u.id}") logger.info(f"Creating graph for user {u.id}")
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id) return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)

View File

@@ -2,6 +2,7 @@ import asyncio
import logging import logging
import os import os
import threading import threading
import time
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
@@ -27,7 +28,6 @@ from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.block import BlockInput from backend.data.block import BlockInput
from backend.data.execution import GraphExecutionWithNodes from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_runs
from backend.executor import utils as execution_utils from backend.executor import utils as execution_utils
from backend.monitoring import ( from backend.monitoring import (
NotificationJobArgs, NotificationJobArgs,
@@ -37,7 +37,7 @@ from backend.monitoring import (
report_execution_accuracy_alerts, report_execution_accuracy_alerts,
report_late_executions, report_late_executions,
) )
from backend.util.clients import get_scheduler_client from backend.util.clients import get_database_manager_client, get_scheduler_client
from backend.util.cloud_storage import cleanup_expired_files_async from backend.util.cloud_storage import cleanup_expired_files_async
from backend.util.exceptions import ( from backend.util.exceptions import (
GraphNotFoundError, GraphNotFoundError,
@@ -156,7 +156,6 @@ async def _execute_graph(**kwargs):
inputs=args.input_data, inputs=args.input_data,
graph_credentials_inputs=args.input_credentials, graph_credentials_inputs=args.input_credentials,
) )
await increment_runs(args.user_id)
elapsed = asyncio.get_event_loop().time() - start_time elapsed = asyncio.get_event_loop().time() - start_time
logger.info( logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} " f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
@@ -254,6 +253,74 @@ def execution_accuracy_alerts():
return report_execution_accuracy_alerts() return report_execution_accuracy_alerts()
def ensure_embeddings_coverage():
"""
Ensure approved store agents have embeddings for hybrid search.
Processes ALL missing embeddings in batches of 10 until 100% coverage.
Missing embeddings = agents invisible in hybrid search.
Schedule: Runs every 6 hours (balanced between coverage and API costs).
- Catches agents approved between scheduled runs
- Batch size 10: gradual processing to avoid rate limits
- Manual trigger available via execute_ensure_embeddings_coverage endpoint
"""
db_client = get_database_manager_client()
stats = db_client.get_embedding_stats()
# Check for error from get_embedding_stats() first
if "error" in stats:
logger.error(
f"Failed to get embedding stats: {stats['error']} - skipping backfill"
)
return {"processed": 0, "success": 0, "failed": 0, "error": stats["error"]}
if stats["without_embeddings"] == 0:
logger.info("All approved agents have embeddings, skipping backfill")
return {"processed": 0, "success": 0, "failed": 0}
logger.info(
f"Found {stats['without_embeddings']} agents without embeddings "
f"({stats['coverage_percent']}% coverage) - processing all"
)
total_processed = 0
total_success = 0
total_failed = 0
# Process in batches until no more missing embeddings
while True:
result = db_client.backfill_missing_embeddings(batch_size=10)
total_processed += result["processed"]
total_success += result["success"]
total_failed += result["failed"]
if result["processed"] == 0:
# No more missing embeddings
break
if result["success"] == 0 and result["processed"] > 0:
# All attempts in this batch failed - stop to avoid infinite loop
logger.error(
f"All {result['processed']} embedding attempts failed - stopping backfill"
)
break
# Small delay between batches to avoid rate limits
time.sleep(1)
logger.info(
f"Embedding backfill completed: {total_success}/{total_processed} succeeded, "
f"{total_failed} failed"
)
return {
"processed": total_processed,
"success": total_success,
"failed": total_failed,
}
# Monitoring functions are now imported from monitoring module # Monitoring functions are now imported from monitoring module
@@ -475,6 +542,19 @@ class Scheduler(AppService):
jobstore=Jobstores.EXECUTION.value, jobstore=Jobstores.EXECUTION.value,
) )
# Embedding Coverage - Every 6 hours
# Ensures all approved agents have embeddings for hybrid search
# Critical: missing embeddings = agents invisible in search
self.scheduler.add_job(
ensure_embeddings_coverage,
id="ensure_embeddings_coverage",
trigger="interval",
hours=6,
replace_existing=True,
max_instances=1, # Prevent overlapping runs
jobstore=Jobstores.EXECUTION.value,
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR) self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED) self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES) self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
@@ -632,6 +712,11 @@ class Scheduler(AppService):
"""Manually trigger execution accuracy alert checking.""" """Manually trigger execution accuracy alert checking."""
return execution_accuracy_alerts() return execution_accuracy_alerts()
@expose
def execute_ensure_embeddings_coverage(self):
"""Manually trigger embedding backfill for approved store agents."""
return ensure_embeddings_coverage()
class SchedulerClient(AppServiceClient): class SchedulerClient(AppServiceClient):
@classmethod @classmethod

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError
from backend.data import execution as execution_db from backend.data import execution as execution_db
from backend.data import graph as graph_db from backend.data import graph as graph_db
from backend.data import onboarding as onboarding_db
from backend.data import user as user_db from backend.data import user as user_db
from backend.data.block import ( from backend.data.block import (
Block, Block,
@@ -31,7 +32,6 @@ from backend.data.execution import (
GraphExecutionStats, GraphExecutionStats,
GraphExecutionWithNodes, GraphExecutionWithNodes,
NodesInputMasks, NodesInputMasks,
get_graph_execution,
) )
from backend.data.graph import GraphModel, Node from backend.data.graph import GraphModel, Node
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput
@@ -809,13 +809,14 @@ async def add_graph_execution(
edb = execution_db edb = execution_db
udb = user_db udb = user_db
gdb = graph_db gdb = graph_db
odb = onboarding_db
else: else:
edb = udb = gdb = get_database_manager_async_client() edb = udb = gdb = odb = get_database_manager_async_client()
# Get or create the graph execution # Get or create the graph execution
if graph_exec_id: if graph_exec_id:
# Resume existing execution # Resume existing execution
graph_exec = await get_graph_execution( graph_exec = await edb.get_graph_execution(
user_id=user_id, user_id=user_id,
execution_id=graph_exec_id, execution_id=graph_exec_id,
include_node_executions=True, include_node_executions=True,
@@ -891,6 +892,7 @@ async def add_graph_execution(
) )
logger.info(f"Publishing execution {graph_exec.id} to execution queue") logger.info(f"Publishing execution {graph_exec.id} to execution queue")
# Publish to execution queue for executor to pick up
exec_queue = await get_async_execution_queue() exec_queue = await get_async_execution_queue()
await exec_queue.publish_message( await exec_queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY, routing_key=GRAPH_EXECUTION_ROUTING_KEY,
@@ -899,14 +901,12 @@ async def add_graph_execution(
) )
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue") logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
# Update execution status to QUEUED
graph_exec.status = ExecutionStatus.QUEUED graph_exec.status = ExecutionStatus.QUEUED
await edb.update_graph_execution_stats( await edb.update_graph_execution_stats(
graph_exec_id=graph_exec.id, graph_exec_id=graph_exec.id,
status=graph_exec.status, status=graph_exec.status,
) )
await get_async_execution_event_bus().publish(graph_exec)
return graph_exec
except BaseException as e: except BaseException as e:
err = str(e) or type(e).__name__ err = str(e) or type(e).__name__
if not graph_exec: if not graph_exec:
@@ -927,6 +927,24 @@ async def add_graph_execution(
) )
raise raise
try:
await get_async_execution_event_bus().publish(graph_exec)
logger.info(f"Published update for execution #{graph_exec.id} to event bus")
except Exception as e:
logger.error(
f"Failed to publish execution event for graph exec #{graph_exec.id}: {e}"
)
try:
await odb.increment_onboarding_runs(user_id)
logger.info(
f"Incremented user #{user_id} onboarding runs for exec #{graph_exec.id}"
)
except Exception as e:
logger.error(f"Failed to increment onboarding runs for user #{user_id}: {e}")
return graph_exec
# ============ Execution Output Helpers ============ # # ============ Execution Output Helpers ============ #

View File

@@ -10,6 +10,7 @@ from backend.util.settings import Settings
settings = Settings() settings = Settings()
if TYPE_CHECKING: if TYPE_CHECKING:
from openai import AsyncOpenAI
from supabase import AClient, Client from supabase import AClient, Client
from backend.data.execution import ( from backend.data.execution import (
@@ -139,6 +140,24 @@ async def get_async_supabase() -> "AClient":
) )
# ============ OpenAI Client ============ #
@cached(ttl_seconds=3600)
def get_openai_client() -> "AsyncOpenAI | None":
"""
Get a process-cached async OpenAI client for embeddings.
Returns None if API key is not configured.
"""
from openai import AsyncOpenAI
api_key = settings.secrets.openai_internal_api_key
if not api_key:
return None
return AsyncOpenAI(api_key=api_key)
# ============ Notification Queue Helpers ============ # # ============ Notification Queue Helpers ============ #

View File

@@ -0,0 +1,46 @@
-- CreateExtension
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
-- Create in public schema so vector type is available across all schemas
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "vector" WITH SCHEMA "public";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'vector extension not available or already exists, skipping';
END $$;
-- CreateEnum
CREATE TYPE "ContentType" AS ENUM ('STORE_AGENT', 'BLOCK', 'INTEGRATION', 'DOCUMENTATION', 'LIBRARY_AGENT');
-- CreateTable
CREATE TABLE "UnifiedContentEmbedding" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"contentType" "ContentType" NOT NULL,
"contentId" TEXT NOT NULL,
"userId" TEXT,
"embedding" public.vector(1536) NOT NULL,
"searchableText" TEXT NOT NULL,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "UnifiedContentEmbedding_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "UnifiedContentEmbedding_contentType_idx" ON "UnifiedContentEmbedding"("contentType");
-- CreateIndex
CREATE INDEX "UnifiedContentEmbedding_userId_idx" ON "UnifiedContentEmbedding"("userId");
-- CreateIndex
CREATE INDEX "UnifiedContentEmbedding_contentType_userId_idx" ON "UnifiedContentEmbedding"("contentType", "userId");
-- CreateIndex
-- NULLS NOT DISTINCT ensures only one public (NULL userId) embedding per contentType+contentId
-- Requires PostgreSQL 15+. Supabase uses PostgreSQL 15+.
CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" ON "UnifiedContentEmbedding"("contentType", "contentId", "userId") NULLS NOT DISTINCT;
-- CreateIndex
-- HNSW index for fast vector similarity search on embeddings
-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py
CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops);

View File

@@ -0,0 +1,71 @@
-- Acknowledge Supabase-managed extensions to prevent drift warnings
-- These extensions are pre-installed by Supabase in specific schemas
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
-- Create schemas (safe in both CI and Supabase)
CREATE SCHEMA IF NOT EXISTS "extensions";
-- Extensions that exist in both CI and Supabase
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pgcrypto extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'uuid-ossp extension not available, skipping';
END $$;
-- Supabase-specific extensions (skip gracefully in CI)
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pg_net extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pgjwt extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE SCHEMA IF NOT EXISTS "graphql";
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pg_graphql extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE SCHEMA IF NOT EXISTS "pgsodium";
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pgsodium extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE SCHEMA IF NOT EXISTS "vault";
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'supabase_vault extension not available, skipping';
END $$;
-- Return to platform
CREATE SCHEMA IF NOT EXISTS "platform";

View File

@@ -2,13 +2,14 @@ datasource db {
provider = "postgresql" provider = "postgresql"
url = env("DATABASE_URL") url = env("DATABASE_URL")
directUrl = env("DIRECT_URL") directUrl = env("DIRECT_URL")
extensions = [pgvector(map: "vector")]
} }
generator client { generator client {
provider = "prisma-client-py" provider = "prisma-client-py"
recursive_type_depth = -1 recursive_type_depth = -1
interface = "asyncio" interface = "asyncio"
previewFeatures = ["views", "fullTextSearch"] previewFeatures = ["views", "fullTextSearch", "postgresqlExtensions"]
partial_type_generator = "backend/data/partial_types.py" partial_type_generator = "backend/data/partial_types.py"
} }
@@ -733,7 +734,6 @@ view StoreAgent {
sub_heading String sub_heading String
description String description String
categories String[] categories String[]
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
runs Int runs Int
rating Float rating Float
versions String[] versions String[]
@@ -899,6 +899,9 @@ model StoreListingVersion {
// Reviews for this specific version // Reviews for this specific version
Reviews StoreListingReview[] Reviews StoreListingReview[]
// Note: Embeddings now stored in UnifiedContentEmbedding table
// Use contentType=STORE_AGENT and contentId=storeListingVersionId
@@unique([storeListingId, version]) @@unique([storeListingId, version])
@@index([storeListingId, submissionStatus, isAvailable]) @@index([storeListingId, submissionStatus, isAvailable])
@@index([submissionStatus]) @@index([submissionStatus])
@@ -906,6 +909,42 @@ model StoreListingVersion {
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups @@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
} }
// Content type enum for unified search across store agents, blocks, docs
// Note: BLOCK/INTEGRATION are file-based (Python classes), not DB records
// DOCUMENTATION are file-based (.md files), not DB records
// Only STORE_AGENT and LIBRARY_AGENT are stored in database
enum ContentType {
STORE_AGENT // Database: StoreListingVersion
BLOCK // File-based: Python classes in /backend/blocks/
INTEGRATION // File-based: Python classes (blocks with credentials)
DOCUMENTATION // File-based: .md/.mdx files
LIBRARY_AGENT // Database: User's personal agents
}
// Unified embeddings table for all searchable content types
// Supports both public content (userId=null) and user-specific content (userId=userID)
model UnifiedContentEmbedding {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
// Content identification
contentType ContentType
contentId String // DB ID (storeListingVersionId) or file identifier (block.id, file_path)
userId String? // NULL for public content (store, blocks, docs), userId for private content (library agents)
// Search data
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
searchableText String // Combined text for search and fallback
metadata Json @default("{}") // Content-specific metadata
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
@@index([contentType])
@@index([userId])
@@index([contentType, userId])
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
}
model StoreListingReview { model StoreListingReview {
id String @id @default(uuid()) id String @id @default(uuid())
createdAt DateTime @default(now()) createdAt DateTime @default(now())

View File

@@ -81,7 +81,6 @@ export const RunInputDialog = ({
Inputs Inputs
</Text> </Text>
</div> </div>
<div className="px-2">
<FormRenderer <FormRenderer
jsonSchema={inputSchema as RJSFSchema} jsonSchema={inputSchema as RJSFSchema}
handleChange={(v) => handleInputChange(v.formData)} handleChange={(v) => handleInputChange(v.formData)}
@@ -93,7 +92,6 @@ export const RunInputDialog = ({
}} }}
/> />
</div> </div>
</div>
)} )}
{/* Action Button */} {/* Action Button */}

View File

@@ -3,6 +3,7 @@ import { useGetV2GetSpecificBlocks } from "@/app/api/__generated__/endpoints/def
import { import {
useGetV1GetExecutionDetails, useGetV1GetExecutionDetails,
useGetV1GetSpecificGraph, useGetV1GetSpecificGraph,
useGetV1ListUserGraphs,
} from "@/app/api/__generated__/endpoints/graphs/graphs"; } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo"; import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
import { GraphModel } from "@/app/api/__generated__/models/graphModel"; import { GraphModel } from "@/app/api/__generated__/models/graphModel";
@@ -17,6 +18,7 @@ import { useReactFlow } from "@xyflow/react";
import { useControlPanelStore } from "../../../stores/controlPanelStore"; import { useControlPanelStore } from "../../../stores/controlPanelStore";
import { useHistoryStore } from "../../../stores/historyStore"; import { useHistoryStore } from "../../../stores/historyStore";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus"; import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { okData } from "@/app/api/helpers";
export const useFlow = () => { export const useFlow = () => {
const [isLocked, setIsLocked] = useState(false); const [isLocked, setIsLocked] = useState(false);
@@ -36,6 +38,9 @@ export const useFlow = () => {
const setGraphExecutionStatus = useGraphStore( const setGraphExecutionStatus = useGraphStore(
useShallow((state) => state.setGraphExecutionStatus), useShallow((state) => state.setGraphExecutionStatus),
); );
const setAvailableSubGraphs = useGraphStore(
useShallow((state) => state.setAvailableSubGraphs),
);
const updateEdgeBeads = useEdgeStore( const updateEdgeBeads = useEdgeStore(
useShallow((state) => state.updateEdgeBeads), useShallow((state) => state.updateEdgeBeads),
); );
@@ -62,6 +67,11 @@ export const useFlow = () => {
}, },
); );
// Fetch all available graphs for sub-agent update detection
const { data: availableGraphs } = useGetV1ListUserGraphs({
query: { select: okData },
});
const { data: graph, isLoading: isGraphLoading } = useGetV1GetSpecificGraph( const { data: graph, isLoading: isGraphLoading } = useGetV1GetSpecificGraph(
flowID ?? "", flowID ?? "",
flowVersion !== null ? { version: flowVersion } : {}, flowVersion !== null ? { version: flowVersion } : {},
@@ -116,10 +126,18 @@ export const useFlow = () => {
} }
}, [graph]); }, [graph]);
// Update available sub-graphs in store for sub-agent update detection
useEffect(() => {
if (availableGraphs) {
setAvailableSubGraphs(availableGraphs);
}
}, [availableGraphs, setAvailableSubGraphs]);
// adding nodes // adding nodes
useEffect(() => { useEffect(() => {
if (customNodes.length > 0) { if (customNodes.length > 0) {
useNodeStore.getState().setNodes([]); useNodeStore.getState().setNodes([]);
useNodeStore.getState().clearResolutionState();
addNodes(customNodes); addNodes(customNodes);
// Sync hardcoded values with handle IDs. // Sync hardcoded values with handle IDs.
@@ -203,6 +221,7 @@ export const useFlow = () => {
useEffect(() => { useEffect(() => {
return () => { return () => {
useNodeStore.getState().setNodes([]); useNodeStore.getState().setNodes([]);
useNodeStore.getState().clearResolutionState();
useEdgeStore.getState().setEdges([]); useEdgeStore.getState().setEdges([]);
useGraphStore.getState().reset(); useGraphStore.getState().reset();
useEdgeStore.getState().resetEdgeBeads(); useEdgeStore.getState().resetEdgeBeads();

View File

@@ -8,6 +8,7 @@ import {
getBezierPath, getBezierPath,
} from "@xyflow/react"; } from "@xyflow/react";
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore"; import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { XIcon } from "@phosphor-icons/react"; import { XIcon } from "@phosphor-icons/react";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { NodeExecutionResult } from "@/lib/autogpt-server-api"; import { NodeExecutionResult } from "@/lib/autogpt-server-api";
@@ -35,6 +36,8 @@ const CustomEdge = ({
selected, selected,
}: EdgeProps<CustomEdge>) => { }: EdgeProps<CustomEdge>) => {
const removeConnection = useEdgeStore((state) => state.removeEdge); const removeConnection = useEdgeStore((state) => state.removeEdge);
// Subscribe to the brokenEdgeIDs map and check if this edge is broken across any node
const isBroken = useNodeStore((state) => state.isEdgeBroken(id));
const [isHovered, setIsHovered] = useState(false); const [isHovered, setIsHovered] = useState(false);
const [edgePath, labelX, labelY] = getBezierPath({ const [edgePath, labelX, labelY] = getBezierPath({
@@ -50,6 +53,12 @@ const CustomEdge = ({
const beadUp = data?.beadUp ?? 0; const beadUp = data?.beadUp ?? 0;
const beadDown = data?.beadDown ?? 0; const beadDown = data?.beadDown ?? 0;
const handleRemoveEdge = () => {
removeConnection(id);
// Note: broken edge tracking is cleaned up automatically by useSubAgentUpdateState
// when it detects the edge no longer exists
};
return ( return (
<> <>
<BaseEdge <BaseEdge
@@ -57,7 +66,9 @@ const CustomEdge = ({
markerEnd={markerEnd} markerEnd={markerEnd}
className={cn( className={cn(
isStatic && "!stroke-[1.5px] [stroke-dasharray:6]", isStatic && "!stroke-[1.5px] [stroke-dasharray:6]",
selected isBroken
? "!stroke-red-500 !stroke-[2px] [stroke-dasharray:4]"
: selected
? "stroke-zinc-800" ? "stroke-zinc-800"
: "stroke-zinc-500/50 hover:stroke-zinc-500", : "stroke-zinc-500/50 hover:stroke-zinc-500",
)} )}
@@ -70,12 +81,16 @@ const CustomEdge = ({
/> />
<EdgeLabelRenderer> <EdgeLabelRenderer>
<Button <Button
onClick={() => removeConnection(id)} onClick={handleRemoveEdge}
className={cn( className={cn(
"absolute h-fit min-w-0 p-1 transition-opacity", "absolute h-fit min-w-0 p-1 transition-opacity",
isHovered ? "opacity-100" : "opacity-0", isBroken
? "bg-red-500 opacity-100 hover:bg-red-600"
: isHovered
? "opacity-100"
: "opacity-0",
)} )}
variant="secondary" variant={isBroken ? "primary" : "secondary"}
style={{ style={{
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`, transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
pointerEvents: "all", pointerEvents: "all",

View File

@@ -3,6 +3,7 @@ import { Handle, Position } from "@xyflow/react";
import { useEdgeStore } from "../../../stores/edgeStore"; import { useEdgeStore } from "../../../stores/edgeStore";
import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers"; import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { useNodeStore } from "../../../stores/nodeStore";
const InputNodeHandle = ({ const InputNodeHandle = ({
handleId, handleId,
@@ -15,6 +16,9 @@ const InputNodeHandle = ({
const isInputConnected = useEdgeStore((state) => const isInputConnected = useEdgeStore((state) =>
state.isInputConnected(nodeId ?? "", cleanedHandleId), state.isInputConnected(nodeId ?? "", cleanedHandleId),
); );
const isInputBroken = useNodeStore((state) =>
state.isInputBroken(nodeId, cleanedHandleId),
);
return ( return (
<Handle <Handle
@@ -27,7 +31,10 @@ const InputNodeHandle = ({
<CircleIcon <CircleIcon
size={16} size={16}
weight={isInputConnected ? "fill" : "duotone"} weight={isInputConnected ? "fill" : "duotone"}
className={"text-gray-400 opacity-100"} className={cn(
"text-gray-400 opacity-100",
isInputBroken && "text-red-500",
)}
/> />
</div> </div>
</Handle> </Handle>
@@ -38,14 +45,17 @@ const OutputNodeHandle = ({
field_name, field_name,
nodeId, nodeId,
hexColor, hexColor,
isBroken,
}: { }: {
field_name: string; field_name: string;
nodeId: string; nodeId: string;
hexColor: string; hexColor: string;
isBroken: boolean;
}) => { }) => {
const isOutputConnected = useEdgeStore((state) => const isOutputConnected = useEdgeStore((state) =>
state.isOutputConnected(nodeId, field_name), state.isOutputConnected(nodeId, field_name),
); );
return ( return (
<Handle <Handle
type={"source"} type={"source"}
@@ -58,7 +68,10 @@ const OutputNodeHandle = ({
size={16} size={16}
weight={"duotone"} weight={"duotone"}
color={isOutputConnected ? hexColor : "gray"} color={isOutputConnected ? hexColor : "gray"}
className={cn("text-gray-400 opacity-100")} className={cn(
"text-gray-400 opacity-100",
isBroken && "text-red-500",
)}
/> />
</div> </div>
</Handle> </Handle>

View File

@@ -20,6 +20,8 @@ import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput";
import { NodeRightClickMenu } from "./components/NodeRightClickMenu"; import { NodeRightClickMenu } from "./components/NodeRightClickMenu";
import { StickyNoteBlock } from "./components/StickyNoteBlock"; import { StickyNoteBlock } from "./components/StickyNoteBlock";
import { WebhookDisclaimer } from "./components/WebhookDisclaimer"; import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
import { SubAgentUpdateFeature } from "./components/SubAgentUpdate/SubAgentUpdateFeature";
import { useCustomNode } from "./useCustomNode";
export type CustomNodeData = { export type CustomNodeData = {
hardcodedValues: { hardcodedValues: {
@@ -45,6 +47,10 @@ export type CustomNode = XYNode<CustomNodeData, "custom">;
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo( export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
({ data, id: nodeId, selected }) => { ({ data, id: nodeId, selected }) => {
const { inputSchema, outputSchema } = useCustomNode({ data, nodeId });
const isAgent = data.uiType === BlockUIType.AGENT;
if (data.uiType === BlockUIType.NOTE) { if (data.uiType === BlockUIType.NOTE) {
return ( return (
<StickyNoteBlock data={data} selected={selected} nodeId={nodeId} /> <StickyNoteBlock data={data} selected={selected} nodeId={nodeId} />
@@ -63,16 +69,6 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
const isAyrshare = data.uiType === BlockUIType.AYRSHARE; const isAyrshare = data.uiType === BlockUIType.AYRSHARE;
const inputSchema =
data.uiType === BlockUIType.AGENT
? (data.hardcodedValues.input_schema ?? {})
: data.inputSchema;
const outputSchema =
data.uiType === BlockUIType.AGENT
? (data.hardcodedValues.output_schema ?? {})
: data.outputSchema;
const hasConfigErrors = const hasConfigErrors =
data.errors && data.errors &&
Object.values(data.errors).some( Object.values(data.errors).some(
@@ -87,12 +83,11 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
const hasErrors = hasConfigErrors || hasOutputError; const hasErrors = hasConfigErrors || hasOutputError;
// Currently all blockTypes design are similar - that's why i am using the same component for all of them
// If in future - if we need some drastic change in some blockTypes design - we can create separate components for them
const node = ( const node = (
<NodeContainer selected={selected} nodeId={nodeId} hasErrors={hasErrors}> <NodeContainer selected={selected} nodeId={nodeId} hasErrors={hasErrors}>
<div className="rounded-xlarge bg-white"> <div className="rounded-xlarge bg-white">
<NodeHeader data={data} nodeId={nodeId} /> <NodeHeader data={data} nodeId={nodeId} />
{isAgent && <SubAgentUpdateFeature nodeID={nodeId} nodeData={data} />}
{isWebhook && <WebhookDisclaimer nodeId={nodeId} />} {isWebhook && <WebhookDisclaimer nodeId={nodeId} />}
{isAyrshare && <AyrshareConnectButton />} {isAyrshare && <AyrshareConnectButton />}
<FormCreator <FormCreator

View File

@@ -0,0 +1,118 @@
import React from "react";
import { ArrowUpIcon, WarningIcon } from "@phosphor-icons/react";
import { Button } from "@/components/atoms/Button/Button";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import { cn, beautifyString } from "@/lib/utils";
import { CustomNodeData } from "../../CustomNode";
import { useSubAgentUpdateState } from "./useSubAgentUpdateState";
import { IncompatibleUpdateDialog } from "./components/IncompatibleUpdateDialog";
import { ResolutionModeBar } from "./components/ResolutionModeBar";
/**
* Inline component for the update bar that can be placed after the header.
* Use this inside the node content where you want the bar to appear.
*/
type SubAgentUpdateFeatureProps = {
nodeID: string;
nodeData: CustomNodeData;
};
export function SubAgentUpdateFeature({
nodeID,
nodeData,
}: SubAgentUpdateFeatureProps) {
const {
updateInfo,
isInResolutionMode,
handleUpdateClick,
showIncompatibilityDialog,
setShowIncompatibilityDialog,
handleConfirmIncompatibleUpdate,
} = useSubAgentUpdateState({ nodeID: nodeID, nodeData: nodeData });
const agentName = nodeData.title || "Agent";
if (!updateInfo.hasUpdate && !isInResolutionMode) {
return null;
}
return (
<>
{isInResolutionMode ? (
<ResolutionModeBar incompatibilities={updateInfo.incompatibilities} />
) : (
<SubAgentUpdateAvailableBar
currentVersion={updateInfo.currentVersion}
latestVersion={updateInfo.latestVersion}
isCompatible={updateInfo.isCompatible}
onUpdate={handleUpdateClick}
/>
)}
{/* Incompatibility dialog - rendered here since this component owns the state */}
{updateInfo.incompatibilities && (
<IncompatibleUpdateDialog
isOpen={showIncompatibilityDialog}
onClose={() => setShowIncompatibilityDialog(false)}
onConfirm={handleConfirmIncompatibleUpdate}
currentVersion={updateInfo.currentVersion}
latestVersion={updateInfo.latestVersion}
agentName={beautifyString(agentName)}
incompatibilities={updateInfo.incompatibilities}
/>
)}
</>
);
}
type SubAgentUpdateAvailableBarProps = {
currentVersion: number;
latestVersion: number;
isCompatible: boolean;
onUpdate: () => void;
};
function SubAgentUpdateAvailableBar({
currentVersion,
latestVersion,
isCompatible,
onUpdate,
}: SubAgentUpdateAvailableBarProps): React.ReactElement {
return (
<div className="flex items-center justify-between gap-2 rounded-t-xl bg-blue-50 px-3 py-2 dark:bg-blue-900/30">
<div className="flex items-center gap-2">
<ArrowUpIcon className="h-4 w-4 text-blue-600 dark:text-blue-400" />
<span className="text-sm text-blue-700 dark:text-blue-300">
Update available (v{currentVersion} v{latestVersion})
</span>
{!isCompatible && (
<Tooltip>
<TooltipTrigger asChild>
<WarningIcon className="h-4 w-4 text-amber-500" />
</TooltipTrigger>
<TooltipContent className="max-w-xs">
<p className="font-medium">Incompatible changes detected</p>
<p className="text-xs text-gray-400">
Click Update to see details
</p>
</TooltipContent>
</Tooltip>
)}
</div>
<Button
size="small"
variant={isCompatible ? "primary" : "outline"}
onClick={onUpdate}
className={cn(
"h-7 text-xs",
!isCompatible && "border-amber-500 text-amber-600 hover:bg-amber-50",
)}
>
Update
</Button>
</div>
);
}

View File

@@ -0,0 +1,274 @@
import React from "react";
import {
WarningIcon,
XCircleIcon,
PlusCircleIcon,
} from "@phosphor-icons/react";
import { Button } from "@/components/atoms/Button/Button";
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { beautifyString } from "@/lib/utils";
import { IncompatibilityInfo } from "@/app/(platform)/build/hooks/useSubAgentUpdate/types";
type IncompatibleUpdateDialogProps = {
isOpen: boolean;
onClose: () => void;
onConfirm: () => void;
currentVersion: number;
latestVersion: number;
agentName: string;
incompatibilities: IncompatibilityInfo;
};
export function IncompatibleUpdateDialog({
isOpen,
onClose,
onConfirm,
currentVersion,
latestVersion,
agentName,
incompatibilities,
}: IncompatibleUpdateDialogProps) {
const hasMissingInputs = incompatibilities.missingInputs.length > 0;
const hasMissingOutputs = incompatibilities.missingOutputs.length > 0;
const hasNewInputs = incompatibilities.newInputs.length > 0;
const hasNewOutputs = incompatibilities.newOutputs.length > 0;
const hasNewRequired = incompatibilities.newRequiredInputs.length > 0;
const hasTypeMismatches = incompatibilities.inputTypeMismatches.length > 0;
const hasInputChanges = hasMissingInputs || hasNewInputs;
const hasOutputChanges = hasMissingOutputs || hasNewOutputs;
return (
<Dialog
title={
<div className="flex items-center gap-2">
<WarningIcon className="h-5 w-5 text-amber-500" weight="fill" />
Incompatible Update
</div>
}
controlled={{
isOpen,
set: async (open) => {
if (!open) onClose();
},
}}
onClose={onClose}
styling={{ maxWidth: "32rem" }}
>
<Dialog.Content>
<div className="space-y-4">
<p className="text-sm text-gray-600 dark:text-gray-400">
Updating <strong>{beautifyString(agentName)}</strong> from v
{currentVersion} to v{latestVersion} will break some connections.
</p>
{/* Input changes - two column layout */}
{hasInputChanges && (
<TwoColumnSection
title="Input Changes"
leftIcon={
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
}
leftTitle="Removed"
leftItems={incompatibilities.missingInputs}
rightIcon={
<PlusCircleIcon
className="h-4 w-4 text-green-500"
weight="fill"
/>
}
rightTitle="Added"
rightItems={incompatibilities.newInputs}
/>
)}
{/* Output changes - two column layout */}
{hasOutputChanges && (
<TwoColumnSection
title="Output Changes"
leftIcon={
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
}
leftTitle="Removed"
leftItems={incompatibilities.missingOutputs}
rightIcon={
<PlusCircleIcon
className="h-4 w-4 text-green-500"
weight="fill"
/>
}
rightTitle="Added"
rightItems={incompatibilities.newOutputs}
/>
)}
{hasTypeMismatches && (
<SingleColumnSection
icon={
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
}
title="Type Changed"
description="These connected inputs have a different type:"
items={incompatibilities.inputTypeMismatches.map(
(m) => `${m.name} (${m.oldType}${m.newType})`,
)}
/>
)}
{hasNewRequired && (
<SingleColumnSection
icon={
<PlusCircleIcon
className="h-4 w-4 text-amber-500"
weight="fill"
/>
}
title="New Required Inputs"
description="These inputs are now required:"
items={incompatibilities.newRequiredInputs}
/>
)}
<Alert variant="warning">
<AlertDescription>
If you proceed, you&apos;ll need to remove the broken connections
before you can save or run your agent.
</AlertDescription>
</Alert>
<Dialog.Footer>
<Button variant="ghost" size="small" onClick={onClose}>
Cancel
</Button>
<Button
variant="primary"
size="small"
onClick={onConfirm}
className="border-amber-700 bg-amber-600 hover:bg-amber-700"
>
Update Anyway
</Button>
</Dialog.Footer>
</div>
</Dialog.Content>
</Dialog>
);
}
type TwoColumnSectionProps = {
title: string;
leftIcon: React.ReactNode;
leftTitle: string;
leftItems: string[];
rightIcon: React.ReactNode;
rightTitle: string;
rightItems: string[];
};
function TwoColumnSection({
title,
leftIcon,
leftTitle,
leftItems,
rightIcon,
rightTitle,
rightItems,
}: TwoColumnSectionProps) {
return (
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
<span className="font-medium">{title}</span>
<div className="mt-2 grid grid-cols-2 items-start gap-4">
{/* Left column - Breaking changes */}
<div className="min-w-0">
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
{leftIcon}
<span>{leftTitle}</span>
</div>
<ul className="mt-1.5 space-y-1">
{leftItems.length > 0 ? (
leftItems.map((item) => (
<li
key={item}
className="text-sm text-gray-700 dark:text-gray-300"
>
<code className="rounded bg-red-50 px-1 py-0.5 font-mono text-xs text-red-700 dark:bg-red-900/30 dark:text-red-300">
{item}
</code>
</li>
))
) : (
<li className="text-sm italic text-gray-400 dark:text-gray-500">
None
</li>
)}
</ul>
</div>
{/* Right column - Possible solutions */}
<div className="min-w-0">
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
{rightIcon}
<span>{rightTitle}</span>
</div>
<ul className="mt-1.5 space-y-1">
{rightItems.length > 0 ? (
rightItems.map((item) => (
<li
key={item}
className="text-sm text-gray-700 dark:text-gray-300"
>
<code className="rounded bg-green-50 px-1 py-0.5 font-mono text-xs text-green-700 dark:bg-green-900/30 dark:text-green-300">
{item}
</code>
</li>
))
) : (
<li className="text-sm italic text-gray-400 dark:text-gray-500">
None
</li>
)}
</ul>
</div>
</div>
</div>
);
}
type SingleColumnSectionProps = {
icon: React.ReactNode;
title: string;
description: string;
items: string[];
};
function SingleColumnSection({
icon,
title,
description,
items,
}: SingleColumnSectionProps) {
return (
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
<div className="flex items-center gap-2">
{icon}
<span className="font-medium">{title}</span>
</div>
<p className="mt-1 text-sm text-gray-500 dark:text-gray-400">
{description}
</p>
<ul className="mt-2 space-y-1">
{items.map((item) => (
<li
key={item}
className="ml-4 list-disc text-sm text-gray-700 dark:text-gray-300"
>
<code className="rounded bg-gray-100 px-1 py-0.5 font-mono text-xs dark:bg-gray-800">
{item}
</code>
</li>
))}
</ul>
</div>
);
}

View File

@@ -0,0 +1,107 @@
import React from "react";
import { InfoIcon, WarningIcon } from "@phosphor-icons/react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import { IncompatibilityInfo } from "@/app/(platform)/build/hooks/useSubAgentUpdate/types";
type ResolutionModeBarProps = {
incompatibilities: IncompatibilityInfo | null;
};
export function ResolutionModeBar({
incompatibilities,
}: ResolutionModeBarProps): React.ReactElement {
const renderIncompatibilities = () => {
if (!incompatibilities) return <span>No incompatibilities</span>;
const sections: React.ReactNode[] = [];
if (incompatibilities.missingInputs.length > 0) {
sections.push(
<div key="missing-inputs" className="mb-1">
<span className="font-semibold">Missing inputs: </span>
{incompatibilities.missingInputs.map((name, i) => (
<React.Fragment key={name}>
<code className="font-mono">{name}</code>
{i < incompatibilities.missingInputs.length - 1 && ", "}
</React.Fragment>
))}
</div>,
);
}
if (incompatibilities.missingOutputs.length > 0) {
sections.push(
<div key="missing-outputs" className="mb-1">
<span className="font-semibold">Missing outputs: </span>
{incompatibilities.missingOutputs.map((name, i) => (
<React.Fragment key={name}>
<code className="font-mono">{name}</code>
{i < incompatibilities.missingOutputs.length - 1 && ", "}
</React.Fragment>
))}
</div>,
);
}
if (incompatibilities.newRequiredInputs.length > 0) {
sections.push(
<div key="new-required" className="mb-1">
<span className="font-semibold">New required inputs: </span>
{incompatibilities.newRequiredInputs.map((name, i) => (
<React.Fragment key={name}>
<code className="font-mono">{name}</code>
{i < incompatibilities.newRequiredInputs.length - 1 && ", "}
</React.Fragment>
))}
</div>,
);
}
if (incompatibilities.inputTypeMismatches.length > 0) {
sections.push(
<div key="type-mismatches" className="mb-1">
<span className="font-semibold">Type changed: </span>
{incompatibilities.inputTypeMismatches.map((m, i) => (
<React.Fragment key={m.name}>
<code className="font-mono">{m.name}</code>
<span className="text-gray-400">
{" "}
({m.oldType} {m.newType})
</span>
{i < incompatibilities.inputTypeMismatches.length - 1 && ", "}
</React.Fragment>
))}
</div>,
);
}
return <>{sections}</>;
};
return (
<div className="flex items-center justify-between gap-2 rounded-t-xl bg-amber-50 px-3 py-2 dark:bg-amber-900/30">
<div className="flex items-center gap-2">
<WarningIcon className="h-4 w-4 text-amber-600 dark:text-amber-400" />
<span className="text-sm text-amber-700 dark:text-amber-300">
Remove incompatible connections
</span>
<Tooltip>
<TooltipTrigger asChild>
<InfoIcon className="h-4 w-4 cursor-help text-amber-500" />
</TooltipTrigger>
<TooltipContent className="max-w-sm">
<p className="mb-2 font-semibold">Incompatible changes:</p>
<div className="text-xs">{renderIncompatibilities()}</div>
<p className="mt-2 text-xs text-gray-400">
{(incompatibilities?.newRequiredInputs.length ?? 0) > 0
? "Replace / delete"
: "Delete"}{" "}
the red connections to continue
</p>
</TooltipContent>
</Tooltip>
</div>
</div>
);
}

View File

@@ -0,0 +1,194 @@
import { useState, useCallback, useEffect } from "react";
import { useShallow } from "zustand/react/shallow";
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
import {
useNodeStore,
NodeResolutionData,
} from "@/app/(platform)/build/stores/nodeStore";
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
import {
useSubAgentUpdate,
createUpdatedAgentNodeInputs,
getBrokenEdgeIDs,
} from "@/app/(platform)/build/hooks/useSubAgentUpdate";
import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api";
import { CustomNodeData } from "../../CustomNode";
// Stable empty set to avoid creating new references in selectors
const EMPTY_SET: Set<string> = new Set();
type UseSubAgentUpdateParams = {
nodeID: string;
nodeData: CustomNodeData;
};
export function useSubAgentUpdateState({
nodeID,
nodeData,
}: UseSubAgentUpdateParams) {
const [showIncompatibilityDialog, setShowIncompatibilityDialog] =
useState(false);
// Get store actions
const updateNodeData = useNodeStore(
useShallow((state) => state.updateNodeData),
);
const setNodeResolutionMode = useNodeStore(
useShallow((state) => state.setNodeResolutionMode),
);
const isNodeInResolutionMode = useNodeStore(
useShallow((state) => state.isNodeInResolutionMode),
);
const setBrokenEdgeIDs = useNodeStore(
useShallow((state) => state.setBrokenEdgeIDs),
);
// Get this node's broken edge IDs from the per-node map
// Use EMPTY_SET as fallback to maintain referential stability
const brokenEdgeIDs = useNodeStore(
(state) => state.brokenEdgeIDs.get(nodeID) || EMPTY_SET,
);
const getNodeResolutionData = useNodeStore(
useShallow((state) => state.getNodeResolutionData),
);
const connectedEdges = useEdgeStore(
useShallow((state) => state.getNodeEdges(nodeID)),
);
const availableSubGraphs = useGraphStore(
useShallow((state) => state.availableSubGraphs),
);
// Extract agent-specific data
const graphID = nodeData.hardcodedValues?.graph_id as string | undefined;
const graphVersion = nodeData.hardcodedValues?.graph_version as
| number
| undefined;
const currentInputSchema = nodeData.hardcodedValues?.input_schema as
| GraphInputSchema
| undefined;
const currentOutputSchema = nodeData.hardcodedValues?.output_schema as
| GraphOutputSchema
| undefined;
// Use the sub-agent update hook
const updateInfo = useSubAgentUpdate(
nodeID,
graphID,
graphVersion,
currentInputSchema,
currentOutputSchema,
connectedEdges,
availableSubGraphs,
);
const isInResolutionMode = isNodeInResolutionMode(nodeID);
// Handle update button click
const handleUpdateClick = useCallback(() => {
if (!updateInfo.hasUpdate || !updateInfo.latestGraph) return;
if (updateInfo.isCompatible) {
// Compatible update - apply directly
const newHardcodedValues = createUpdatedAgentNodeInputs(
nodeData.hardcodedValues,
updateInfo.latestGraph,
);
updateNodeData(nodeID, { hardcodedValues: newHardcodedValues });
} else {
// Incompatible update - show dialog
setShowIncompatibilityDialog(true);
}
}, [
updateInfo.hasUpdate,
updateInfo.latestGraph,
updateInfo.isCompatible,
nodeData.hardcodedValues,
updateNodeData,
nodeID,
]);
// Handle confirming an incompatible update
function handleConfirmIncompatibleUpdate() {
if (!updateInfo.latestGraph || !updateInfo.incompatibilities) return;
const latestGraph = updateInfo.latestGraph;
// Get the new schemas from the latest graph version
const newInputSchema =
(latestGraph.input_schema as Record<string, unknown>) || {};
const newOutputSchema =
(latestGraph.output_schema as Record<string, unknown>) || {};
// Create the updated hardcoded values but DON'T apply them yet
// We'll apply them when resolution is complete
const pendingHardcodedValues = createUpdatedAgentNodeInputs(
nodeData.hardcodedValues,
latestGraph,
);
// Get broken edge IDs and store them for this node
const brokenIds = getBrokenEdgeIDs(
connectedEdges,
updateInfo.incompatibilities,
nodeID,
);
setBrokenEdgeIDs(nodeID, brokenIds);
// Enter resolution mode with both old and new schemas
// DON'T apply the update yet - keep old schema so connections remain visible
const resolutionData: NodeResolutionData = {
incompatibilities: updateInfo.incompatibilities,
pendingUpdate: {
input_schema: newInputSchema,
output_schema: newOutputSchema,
},
currentSchema: {
input_schema: (currentInputSchema as Record<string, unknown>) || {},
output_schema: (currentOutputSchema as Record<string, unknown>) || {},
},
pendingHardcodedValues,
};
setNodeResolutionMode(nodeID, true, resolutionData);
setShowIncompatibilityDialog(false);
}
// Check if resolution is complete (all broken edges removed)
const resolutionData = getNodeResolutionData(nodeID);
// Auto-check resolution on edge changes
useEffect(() => {
if (!isInResolutionMode) return;
// Check if any broken edges still exist
const remainingBroken = Array.from(brokenEdgeIDs).filter((edgeId) =>
connectedEdges.some((e) => e.id === edgeId),
);
if (remainingBroken.length === 0) {
// Resolution complete - now apply the pending update
if (resolutionData?.pendingHardcodedValues) {
updateNodeData(nodeID, {
hardcodedValues: resolutionData.pendingHardcodedValues,
});
}
// setNodeResolutionMode will clean up this node's broken edges automatically
setNodeResolutionMode(nodeID, false);
}
}, [
isInResolutionMode,
brokenEdgeIDs,
connectedEdges,
resolutionData,
nodeID,
]);
return {
updateInfo,
isInResolutionMode,
resolutionData,
showIncompatibilityDialog,
setShowIncompatibilityDialog,
handleUpdateClick,
handleConfirmIncompatibleUpdate,
};
}

View File

@@ -1,4 +1,6 @@
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus"; import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { NodeResolutionData } from "@/app/(platform)/build/stores/nodeStore";
import { RJSFSchema } from "@rjsf/utils";
export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = { export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
INCOMPLETE: "ring-slate-300 bg-slate-300", INCOMPLETE: "ring-slate-300 bg-slate-300",
@@ -9,3 +11,48 @@ export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
TERMINATED: "ring-orange-300 bg-orange-300 ", TERMINATED: "ring-orange-300 bg-orange-300 ",
FAILED: "ring-red-300 bg-red-300", FAILED: "ring-red-300 bg-red-300",
}; };
/**
* Merges schemas during resolution mode to include removed inputs/outputs
* that still have connections, so users can see and delete them.
*/
export function mergeSchemaForResolution(
currentSchema: Record<string, unknown>,
newSchema: Record<string, unknown>,
resolutionData: NodeResolutionData,
type: "input" | "output",
): Record<string, unknown> {
const newProps = (newSchema.properties as RJSFSchema) || {};
const currentProps = (currentSchema.properties as RJSFSchema) || {};
const mergedProps = { ...newProps };
const incomp = resolutionData.incompatibilities;
if (type === "input") {
// Add back missing inputs that have connections
incomp.missingInputs.forEach((inputName: string) => {
if (currentProps[inputName]) {
mergedProps[inputName] = currentProps[inputName];
}
});
// Add back inputs with type mismatches (keep old type so connection works visually)
incomp.inputTypeMismatches.forEach(
(mismatch: { name: string; oldType: string; newType: string }) => {
if (currentProps[mismatch.name]) {
mergedProps[mismatch.name] = currentProps[mismatch.name];
}
},
);
} else {
// Add back missing outputs that have connections
incomp.missingOutputs.forEach((outputName: string) => {
if (currentProps[outputName]) {
mergedProps[outputName] = currentProps[outputName];
}
});
}
return {
...newSchema,
properties: mergedProps,
};
}

View File

@@ -0,0 +1,58 @@
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { CustomNodeData } from "./CustomNode";
import { BlockUIType } from "../../../types";
import { useMemo } from "react";
import { mergeSchemaForResolution } from "./helpers";
export const useCustomNode = ({
data,
nodeId,
}: {
data: CustomNodeData;
nodeId: string;
}) => {
const isInResolutionMode = useNodeStore((state) =>
state.nodesInResolutionMode.has(nodeId),
);
const resolutionData = useNodeStore((state) =>
state.nodeResolutionData.get(nodeId),
);
const isAgent = data.uiType === BlockUIType.AGENT;
const currentInputSchema = isAgent
? (data.hardcodedValues.input_schema ?? {})
: data.inputSchema;
const currentOutputSchema = isAgent
? (data.hardcodedValues.output_schema ?? {})
: data.outputSchema;
const inputSchema = useMemo(() => {
if (isAgent && isInResolutionMode && resolutionData) {
return mergeSchemaForResolution(
resolutionData.currentSchema.input_schema,
resolutionData.pendingUpdate.input_schema,
resolutionData,
"input",
);
}
return currentInputSchema;
}, [isAgent, isInResolutionMode, resolutionData, currentInputSchema]);
const outputSchema = useMemo(() => {
if (isAgent && isInResolutionMode && resolutionData) {
return mergeSchemaForResolution(
resolutionData.currentSchema.output_schema,
resolutionData.pendingUpdate.output_schema,
resolutionData,
"output",
);
}
return currentOutputSchema;
}, [isAgent, isInResolutionMode, resolutionData, currentOutputSchema]);
return {
inputSchema,
outputSchema,
};
};

View File

@@ -5,20 +5,16 @@ import { useNodeStore } from "../../../stores/nodeStore";
import { BlockUIType } from "../../types"; import { BlockUIType } from "../../types";
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer"; import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
export const FormCreator = React.memo( interface FormCreatorProps {
({
jsonSchema,
nodeId,
uiType,
showHandles = true,
className,
}: {
jsonSchema: RJSFSchema; jsonSchema: RJSFSchema;
nodeId: string; nodeId: string;
uiType: BlockUIType; uiType: BlockUIType;
showHandles?: boolean; showHandles?: boolean;
className?: string; className?: string;
}) => { }
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
({ jsonSchema, nodeId, uiType, showHandles = true, className }) => {
const updateNodeData = useNodeStore((state) => state.updateNodeData); const updateNodeData = useNodeStore((state) => state.updateNodeData);
const getHardCodedValues = useNodeStore( const getHardCodedValues = useNodeStore(

View File

@@ -14,6 +14,8 @@ import {
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore"; import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
import { getTypeDisplayInfo } from "./helpers"; import { getTypeDisplayInfo } from "./helpers";
import { BlockUIType } from "../../types"; import { BlockUIType } from "../../types";
import { cn } from "@/lib/utils";
import { useBrokenOutputs } from "./useBrokenOutputs";
export const OutputHandler = ({ export const OutputHandler = ({
outputSchema, outputSchema,
@@ -27,6 +29,9 @@ export const OutputHandler = ({
const { isOutputConnected } = useEdgeStore(); const { isOutputConnected } = useEdgeStore();
const properties = outputSchema?.properties || {}; const properties = outputSchema?.properties || {};
const [isOutputVisible, setIsOutputVisible] = useState(true); const [isOutputVisible, setIsOutputVisible] = useState(true);
const brokenOutputs = useBrokenOutputs(nodeId);
console.log("brokenOutputs", brokenOutputs);
const showHandles = uiType !== BlockUIType.OUTPUT; const showHandles = uiType !== BlockUIType.OUTPUT;
@@ -44,6 +49,7 @@ export const OutputHandler = ({
const shouldShow = isConnected || isOutputVisible; const shouldShow = isConnected || isOutputVisible;
const { displayType, colorClass, hexColor } = const { displayType, colorClass, hexColor } =
getTypeDisplayInfo(fieldSchema); getTypeDisplayInfo(fieldSchema);
const isBroken = brokenOutputs.has(fullKey);
return shouldShow ? ( return shouldShow ? (
<div key={fullKey} className="flex flex-col items-end gap-2"> <div key={fullKey} className="flex flex-col items-end gap-2">
@@ -64,15 +70,29 @@ export const OutputHandler = ({
</Tooltip> </Tooltip>
</TooltipProvider> </TooltipProvider>
)} )}
<Text variant="body" className="text-slate-700"> <Text
variant="body"
className={cn(
"text-slate-700",
isBroken && "text-red-500 line-through",
)}
>
{fieldTitle} {fieldTitle}
</Text> </Text>
<Text variant="small" as="span" className={colorClass}> <Text
variant="small"
as="span"
className={cn(
colorClass,
isBroken && "!text-red-500 line-through",
)}
>
({displayType}) ({displayType})
</Text> </Text>
{showHandles && ( {showHandles && (
<OutputNodeHandle <OutputNodeHandle
isBroken={isBroken}
field_name={fullKey} field_name={fullKey}
nodeId={nodeId} nodeId={nodeId}
hexColor={hexColor} hexColor={hexColor}

View File

@@ -0,0 +1,23 @@
import { useMemo } from "react";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
/**
* Hook to get the set of broken output names for a node in resolution mode.
*/
export function useBrokenOutputs(nodeID: string): Set<string> {
// Subscribe to the actual state values, not just methods
const isInResolution = useNodeStore((state) =>
state.nodesInResolutionMode.has(nodeID),
);
const resolutionData = useNodeStore((state) =>
state.nodeResolutionData.get(nodeID),
);
return useMemo(() => {
if (!isInResolution || !resolutionData) {
return new Set<string>();
}
return new Set(resolutionData.incompatibilities.missingOutputs);
}, [isInResolution, resolutionData]);
}

View File

@@ -25,7 +25,7 @@ export const RightSidebar = () => {
> >
<div className="mb-4"> <div className="mb-4">
<h2 className="text-lg font-semibold text-slate-800 dark:text-slate-200"> <h2 className="text-lg font-semibold text-slate-800 dark:text-slate-200">
Flow Debug Panel Graph Debug Panel
</h2> </h2>
</div> </div>
@@ -65,7 +65,7 @@ export const RightSidebar = () => {
{l.source_id}[{l.source_name}] {l.sink_id}[{l.sink_name}] {l.source_id}[{l.source_name}] {l.sink_id}[{l.sink_name}]
</div> </div>
<div className="mt-1 text-slate-500 dark:text-slate-400"> <div className="mt-1 text-slate-500 dark:text-slate-400">
edge_id: {l.id} edge.id: {l.id}
</div> </div>
</div> </div>
))} ))}

View File

@@ -12,7 +12,14 @@ import {
PopoverContent, PopoverContent,
PopoverTrigger, PopoverTrigger,
} from "@/components/__legacy__/ui/popover"; } from "@/components/__legacy__/ui/popover";
import { Block, BlockUIType, SpecialBlockID } from "@/lib/autogpt-server-api"; import {
Block,
BlockIORootSchema,
BlockUIType,
GraphInputSchema,
GraphOutputSchema,
SpecialBlockID,
} from "@/lib/autogpt-server-api";
import { MagnifyingGlassIcon, PlusIcon } from "@radix-ui/react-icons"; import { MagnifyingGlassIcon, PlusIcon } from "@radix-ui/react-icons";
import { IconToyBrick } from "@/components/__legacy__/ui/icons"; import { IconToyBrick } from "@/components/__legacy__/ui/icons";
import { getPrimaryCategoryColor } from "@/lib/utils"; import { getPrimaryCategoryColor } from "@/lib/utils";
@@ -24,8 +31,10 @@ import {
import { GraphMeta } from "@/lib/autogpt-server-api"; import { GraphMeta } from "@/lib/autogpt-server-api";
import jaro from "jaro-winkler"; import jaro from "jaro-winkler";
type _Block = Block & { type _Block = Omit<Block, "inputSchema" | "outputSchema"> & {
uiKey?: string; uiKey?: string;
inputSchema: BlockIORootSchema | GraphInputSchema;
outputSchema: BlockIORootSchema | GraphOutputSchema;
hardcodedValues?: Record<string, any>; hardcodedValues?: Record<string, any>;
_cached?: { _cached?: {
blockName: string; blockName: string;

View File

@@ -2,7 +2,7 @@ import React from "react";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { Button } from "@/components/__legacy__/ui/button"; import { Button } from "@/components/__legacy__/ui/button";
import { LogOut } from "lucide-react"; import { LogOut } from "lucide-react";
import { ClockIcon } from "@phosphor-icons/react"; import { ClockIcon, WarningIcon } from "@phosphor-icons/react";
import { IconPlay, IconSquare } from "@/components/__legacy__/ui/icons"; import { IconPlay, IconSquare } from "@/components/__legacy__/ui/icons";
interface Props { interface Props {
@@ -13,6 +13,7 @@ interface Props {
isRunning: boolean; isRunning: boolean;
isDisabled: boolean; isDisabled: boolean;
className?: string; className?: string;
resolutionModeActive?: boolean;
} }
export const BuildActionBar: React.FC<Props> = ({ export const BuildActionBar: React.FC<Props> = ({
@@ -23,9 +24,30 @@ export const BuildActionBar: React.FC<Props> = ({
isRunning, isRunning,
isDisabled, isDisabled,
className, className,
resolutionModeActive = false,
}) => { }) => {
const buttonClasses = const buttonClasses =
"flex items-center gap-2 text-sm font-medium md:text-lg"; "flex items-center gap-2 text-sm font-medium md:text-lg";
// Show resolution mode message instead of action buttons
if (resolutionModeActive) {
return (
<div
className={cn(
"flex w-fit select-none items-center justify-center p-4",
className,
)}
>
<div className="flex items-center gap-3 rounded-lg border border-amber-300 bg-amber-50 px-4 py-3 dark:border-amber-700 dark:bg-amber-900/30">
<WarningIcon className="size-5 text-amber-600 dark:text-amber-400" />
<span className="text-sm font-medium text-amber-800 dark:text-amber-200">
Remove incompatible connections to continue
</span>
</div>
</div>
);
}
return ( return (
<div <div
className={cn( className={cn(

View File

@@ -60,10 +60,16 @@ export function CustomEdge({
targetY - 5, targetY - 5,
); );
const { deleteElements } = useReactFlow<Node, CustomEdge>(); const { deleteElements } = useReactFlow<Node, CustomEdge>();
const { visualizeBeads } = useContext(BuilderContext) ?? { const builderContext = useContext(BuilderContext);
const { visualizeBeads } = builderContext ?? {
visualizeBeads: "no", visualizeBeads: "no",
}; };
// Check if this edge is broken (during resolution mode)
const isBroken =
builderContext?.resolutionMode?.active &&
builderContext?.resolutionMode?.brokenEdgeIds?.includes(id);
const onEdgeRemoveClick = () => { const onEdgeRemoveClick = () => {
deleteElements({ edges: [{ id }] }); deleteElements({ edges: [{ id }] });
}; };
@@ -171,12 +177,27 @@ export function CustomEdge({
const middle = getPointForT(0.5); const middle = getPointForT(0.5);
// Determine edge color - red for broken edges
const baseColor = data?.edgeColor ?? "#555555";
const edgeColor = isBroken ? "#ef4444" : baseColor;
// Add opacity to hex color (99 = 60% opacity, 80 = 50% opacity)
const strokeColor = isBroken
? `${edgeColor}99`
: selected
? edgeColor
: `${edgeColor}80`;
return ( return (
<> <>
<BaseEdge <BaseEdge
path={svgPath} path={svgPath}
markerEnd={markerEnd} markerEnd={markerEnd}
className={`data-sentry-unmask transition-all duration-200 ${data?.isStatic ? "[stroke-dasharray:5_3]" : "[stroke-dasharray:0]"} [stroke-width:${data?.isStatic ? 2.5 : 2}px] hover:[stroke-width:${data?.isStatic ? 3.5 : 3}px] ${selected ? `[stroke:${data?.edgeColor ?? "#555555"}]` : `[stroke:${data?.edgeColor ?? "#555555"}80] hover:[stroke:${data?.edgeColor ?? "#555555"}]`}`} style={{
stroke: strokeColor,
strokeWidth: data?.isStatic ? 2.5 : 2,
strokeDasharray: data?.isStatic ? "5 3" : undefined,
}}
className="data-sentry-unmask transition-all duration-200"
/> />
<path <path
d={svgPath} d={svgPath}

View File

@@ -18,6 +18,8 @@ import {
BlockIOSubSchema, BlockIOSubSchema,
BlockUIType, BlockUIType,
Category, Category,
GraphInputSchema,
GraphOutputSchema,
NodeExecutionResult, NodeExecutionResult,
} from "@/lib/autogpt-server-api"; } from "@/lib/autogpt-server-api";
import { import {
@@ -62,14 +64,21 @@ import { NodeGenericInputField, NodeTextBoxInput } from "../NodeInputs";
import NodeOutputs from "../NodeOutputs"; import NodeOutputs from "../NodeOutputs";
import OutputModalComponent from "../OutputModalComponent"; import OutputModalComponent from "../OutputModalComponent";
import "./customnode.css"; import "./customnode.css";
import { SubAgentUpdateBar } from "./SubAgentUpdateBar";
import { IncompatibilityDialog } from "./IncompatibilityDialog";
import {
useSubAgentUpdate,
createUpdatedAgentNodeInputs,
getBrokenEdgeIDs,
} from "../../../hooks/useSubAgentUpdate";
export type ConnectionData = Array<{ export type ConnectedEdge = {
edge_id: string; id: string;
source: string; source: string;
sourceHandle: string; sourceHandle: string;
target: string; target: string;
targetHandle: string; targetHandle: string;
}>; };
export type CustomNodeData = { export type CustomNodeData = {
blockType: string; blockType: string;
@@ -80,7 +89,7 @@ export type CustomNodeData = {
inputSchema: BlockIORootSchema; inputSchema: BlockIORootSchema;
outputSchema: BlockIORootSchema; outputSchema: BlockIORootSchema;
hardcodedValues: { [key: string]: any }; hardcodedValues: { [key: string]: any };
connections: ConnectionData; connections: ConnectedEdge[];
isOutputOpen: boolean; isOutputOpen: boolean;
status?: NodeExecutionResult["status"]; status?: NodeExecutionResult["status"];
/** executionResults contains outputs across multiple executions /** executionResults contains outputs across multiple executions
@@ -127,20 +136,199 @@ export const CustomNode = React.memo(
let subGraphID = ""; let subGraphID = "";
if (data.uiType === BlockUIType.AGENT) {
// Display the graph's schema instead AgentExecutorBlock's schema.
data.inputSchema = data.hardcodedValues?.input_schema || {};
data.outputSchema = data.hardcodedValues?.output_schema || {};
subGraphID = data.hardcodedValues?.graph_id || subGraphID;
}
if (!builderContext) { if (!builderContext) {
throw new Error( throw new Error(
"BuilderContext consumer must be inside FlowEditor component", "BuilderContext consumer must be inside FlowEditor component",
); );
} }
const { libraryAgent, setIsAnyModalOpen, getNextNodeId } = builderContext; const {
libraryAgent,
setIsAnyModalOpen,
getNextNodeId,
availableFlows,
resolutionMode,
enterResolutionMode,
} = builderContext;
// Check if this node is in resolution mode (moved up for schema merge logic)
const isInResolutionMode =
resolutionMode.active && resolutionMode.nodeId === id;
if (data.uiType === BlockUIType.AGENT) {
// Display the graph's schema instead AgentExecutorBlock's schema.
const currentInputSchema = data.hardcodedValues?.input_schema || {};
const currentOutputSchema = data.hardcodedValues?.output_schema || {};
subGraphID = data.hardcodedValues?.graph_id || subGraphID;
// During resolution mode, merge old connected inputs/outputs with new schema
if (isInResolutionMode && resolutionMode.pendingUpdate) {
const newInputSchema =
(resolutionMode.pendingUpdate.input_schema as BlockIORootSchema) ||
{};
const newOutputSchema =
(resolutionMode.pendingUpdate.output_schema as BlockIORootSchema) ||
{};
// Merge input schemas: start with new schema, add old connected inputs that are missing
const mergedInputProps = { ...newInputSchema.properties };
const incomp = resolutionMode.incompatibilities;
if (incomp && currentInputSchema.properties) {
// Add back missing inputs that have connections (so user can see/delete them)
incomp.missingInputs.forEach((inputName) => {
if (currentInputSchema.properties[inputName]) {
mergedInputProps[inputName] =
currentInputSchema.properties[inputName];
}
});
// Add back inputs with type mismatches (keep old type so connection still works visually)
incomp.inputTypeMismatches.forEach((mismatch) => {
if (currentInputSchema.properties[mismatch.name]) {
mergedInputProps[mismatch.name] =
currentInputSchema.properties[mismatch.name];
}
});
}
// Merge output schemas: start with new schema, add old connected outputs that are missing
const mergedOutputProps = { ...newOutputSchema.properties };
if (incomp && currentOutputSchema.properties) {
incomp.missingOutputs.forEach((outputName) => {
if (currentOutputSchema.properties[outputName]) {
mergedOutputProps[outputName] =
currentOutputSchema.properties[outputName];
}
});
}
data.inputSchema = {
...newInputSchema,
properties: mergedInputProps,
};
data.outputSchema = {
...newOutputSchema,
properties: mergedOutputProps,
};
} else {
data.inputSchema = currentInputSchema;
data.outputSchema = currentOutputSchema;
}
}
const setHardcodedValues = useCallback(
(values: any) => {
updateNodeData(id, { hardcodedValues: values });
},
[id, updateNodeData],
);
// Sub-agent update detection
const isAgentBlock = data.uiType === BlockUIType.AGENT;
const graphId = isAgentBlock ? data.hardcodedValues?.graph_id : undefined;
const graphVersion = isAgentBlock
? data.hardcodedValues?.graph_version
: undefined;
const subAgentUpdate = useSubAgentUpdate(
id,
graphId,
graphVersion,
isAgentBlock
? (data.hardcodedValues?.input_schema as GraphInputSchema)
: undefined,
isAgentBlock
? (data.hardcodedValues?.output_schema as GraphOutputSchema)
: undefined,
data.connections,
availableFlows,
);
const [showIncompatibilityDialog, setShowIncompatibilityDialog] =
useState(false);
// Helper to check if a handle is broken (for resolution mode)
const isInputHandleBroken = useCallback(
(handleName: string): boolean => {
if (!isInResolutionMode || !resolutionMode.incompatibilities) {
return false;
}
const incomp = resolutionMode.incompatibilities;
return (
incomp.missingInputs.includes(handleName) ||
incomp.inputTypeMismatches.some((m) => m.name === handleName)
);
},
[isInResolutionMode, resolutionMode.incompatibilities],
);
const isOutputHandleBroken = useCallback(
(handleName: string): boolean => {
if (!isInResolutionMode || !resolutionMode.incompatibilities) {
return false;
}
return resolutionMode.incompatibilities.missingOutputs.includes(
handleName,
);
},
[isInResolutionMode, resolutionMode.incompatibilities],
);
// Handle update button click
const handleUpdateClick = useCallback(() => {
if (!subAgentUpdate.latestGraph) return;
if (subAgentUpdate.isCompatible) {
// Compatible update - directly apply
const updatedValues = createUpdatedAgentNodeInputs(
data.hardcodedValues,
subAgentUpdate.latestGraph,
);
setHardcodedValues(updatedValues);
toast({
title: "Agent updated",
description: `Updated to version ${subAgentUpdate.latestVersion}`,
});
} else {
// Incompatible update - show dialog
setShowIncompatibilityDialog(true);
}
}, [subAgentUpdate, data.hardcodedValues, setHardcodedValues]);
// Handle confirm incompatible update
const handleConfirmIncompatibleUpdate = useCallback(() => {
if (!subAgentUpdate.latestGraph || !subAgentUpdate.incompatibilities) {
return;
}
// Create the updated values but DON'T apply them yet
const updatedValues = createUpdatedAgentNodeInputs(
data.hardcodedValues,
subAgentUpdate.latestGraph,
);
// Get broken edge IDs
const brokenEdgeIds = getBrokenEdgeIDs(
data.connections,
subAgentUpdate.incompatibilities,
id,
);
// Enter resolution mode with pending update (don't apply schema yet)
enterResolutionMode(
id,
subAgentUpdate.incompatibilities,
brokenEdgeIds,
updatedValues,
);
setShowIncompatibilityDialog(false);
}, [
subAgentUpdate,
data.hardcodedValues,
data.connections,
id,
enterResolutionMode,
]);
useEffect(() => { useEffect(() => {
if (data.executionResults || data.status) { if (data.executionResults || data.status) {
@@ -156,13 +344,6 @@ export const CustomNode = React.memo(
setIsAnyModalOpen?.(isModalOpen || isOutputModalOpen); setIsAnyModalOpen?.(isModalOpen || isOutputModalOpen);
}, [isModalOpen, isOutputModalOpen, data, setIsAnyModalOpen]); }, [isModalOpen, isOutputModalOpen, data, setIsAnyModalOpen]);
const setHardcodedValues = useCallback(
(values: any) => {
updateNodeData(id, { hardcodedValues: values });
},
[id, updateNodeData],
);
const handleTitleEdit = useCallback(() => { const handleTitleEdit = useCallback(() => {
setIsEditingTitle(true); setIsEditingTitle(true);
setTimeout(() => { setTimeout(() => {
@@ -255,6 +436,7 @@ export const CustomNode = React.memo(
isConnected={isOutputHandleConnected(propKey)} isConnected={isOutputHandleConnected(propKey)}
schema={fieldSchema} schema={fieldSchema}
side="right" side="right"
isBroken={isOutputHandleBroken(propKey)}
/> />
{"properties" in fieldSchema && {"properties" in fieldSchema &&
renderHandles( renderHandles(
@@ -385,6 +567,7 @@ export const CustomNode = React.memo(
isRequired={isRequired} isRequired={isRequired}
schema={propSchema} schema={propSchema}
side="left" side="left"
isBroken={isInputHandleBroken(propKey)}
/> />
) : ( ) : (
propKey !== "credentials" && propKey !== "credentials" &&
@@ -873,6 +1056,22 @@ export const CustomNode = React.memo(
<ContextMenuContent /> <ContextMenuContent />
</div> </div>
{/* Sub-agent Update Bar - shown below header */}
{isAgentBlock && (subAgentUpdate.hasUpdate || isInResolutionMode) && (
<SubAgentUpdateBar
currentVersion={subAgentUpdate.currentVersion}
latestVersion={subAgentUpdate.latestVersion}
isCompatible={subAgentUpdate.isCompatible}
incompatibilities={
isInResolutionMode
? resolutionMode.incompatibilities
: subAgentUpdate.incompatibilities
}
onUpdate={handleUpdateClick}
isInResolutionMode={isInResolutionMode}
/>
)}
{/* Body */} {/* Body */}
<div className="mx-5 my-6 rounded-b-xl"> <div className="mx-5 my-6 rounded-b-xl">
{/* Input Handles */} {/* Input Handles */}
@@ -1044,9 +1243,24 @@ export const CustomNode = React.memo(
); );
return ( return (
<>
<ContextMenu.Root> <ContextMenu.Root>
<ContextMenu.Trigger>{nodeContent()}</ContextMenu.Trigger> <ContextMenu.Trigger>{nodeContent()}</ContextMenu.Trigger>
</ContextMenu.Root> </ContextMenu.Root>
{/* Incompatibility Dialog for sub-agent updates */}
{isAgentBlock && subAgentUpdate.incompatibilities && (
<IncompatibilityDialog
isOpen={showIncompatibilityDialog}
onClose={() => setShowIncompatibilityDialog(false)}
onConfirm={handleConfirmIncompatibleUpdate}
currentVersion={subAgentUpdate.currentVersion}
latestVersion={subAgentUpdate.latestVersion}
agentName={data.blockType || "Agent"}
incompatibilities={subAgentUpdate.incompatibilities}
/>
)}
</>
); );
}, },
(prevProps, nextProps) => { (prevProps, nextProps) => {

View File

@@ -0,0 +1,244 @@
import React from "react";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/__legacy__/ui/dialog";
import { Button } from "@/components/__legacy__/ui/button";
import { AlertTriangle, XCircle, PlusCircle } from "lucide-react";
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
import { beautifyString } from "@/lib/utils";
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
interface IncompatibilityDialogProps {
isOpen: boolean;
onClose: () => void;
onConfirm: () => void;
currentVersion: number;
latestVersion: number;
agentName: string;
incompatibilities: IncompatibilityInfo;
}
export const IncompatibilityDialog: React.FC<IncompatibilityDialogProps> = ({
isOpen,
onClose,
onConfirm,
currentVersion,
latestVersion,
agentName,
incompatibilities,
}) => {
const hasMissingInputs = incompatibilities.missingInputs.length > 0;
const hasMissingOutputs = incompatibilities.missingOutputs.length > 0;
const hasNewInputs = incompatibilities.newInputs.length > 0;
const hasNewOutputs = incompatibilities.newOutputs.length > 0;
const hasNewRequired = incompatibilities.newRequiredInputs.length > 0;
const hasTypeMismatches = incompatibilities.inputTypeMismatches.length > 0;
const hasInputChanges = hasMissingInputs || hasNewInputs;
const hasOutputChanges = hasMissingOutputs || hasNewOutputs;
return (
<Dialog open={isOpen} onOpenChange={(open) => !open && onClose()}>
<DialogContent className="max-w-lg">
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<AlertTriangle className="h-5 w-5 text-amber-500" />
Incompatible Update
</DialogTitle>
<DialogDescription>
Updating <strong>{beautifyString(agentName)}</strong> from v
{currentVersion} to v{latestVersion} will break some connections.
</DialogDescription>
</DialogHeader>
<div className="space-y-4 py-2">
{/* Input changes - two column layout */}
{hasInputChanges && (
<TwoColumnSection
title="Input Changes"
leftIcon={<XCircle className="h-4 w-4 text-red-500" />}
leftTitle="Removed"
leftItems={incompatibilities.missingInputs}
rightIcon={<PlusCircle className="h-4 w-4 text-green-500" />}
rightTitle="Added"
rightItems={incompatibilities.newInputs}
/>
)}
{/* Output changes - two column layout */}
{hasOutputChanges && (
<TwoColumnSection
title="Output Changes"
leftIcon={<XCircle className="h-4 w-4 text-red-500" />}
leftTitle="Removed"
leftItems={incompatibilities.missingOutputs}
rightIcon={<PlusCircle className="h-4 w-4 text-green-500" />}
rightTitle="Added"
rightItems={incompatibilities.newOutputs}
/>
)}
{hasTypeMismatches && (
<SingleColumnSection
icon={<XCircle className="h-4 w-4 text-red-500" />}
title="Type Changed"
description="These connected inputs have a different type:"
items={incompatibilities.inputTypeMismatches.map(
(m) => `${m.name} (${m.oldType}${m.newType})`,
)}
/>
)}
{hasNewRequired && (
<SingleColumnSection
icon={<PlusCircle className="h-4 w-4 text-amber-500" />}
title="New Required Inputs"
description="These inputs are now required:"
items={incompatibilities.newRequiredInputs}
/>
)}
</div>
<Alert variant="warning">
<AlertDescription>
If you proceed, you&apos;ll need to remove the broken connections
before you can save or run your agent.
</AlertDescription>
</Alert>
<DialogFooter className="gap-2 sm:gap-0">
<Button variant="outline" onClick={onClose}>
Cancel
</Button>
<Button
variant="destructive"
onClick={onConfirm}
className="bg-amber-600 hover:bg-amber-700"
>
Update Anyway
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
);
};
interface TwoColumnSectionProps {
title: string;
leftIcon: React.ReactNode;
leftTitle: string;
leftItems: string[];
rightIcon: React.ReactNode;
rightTitle: string;
rightItems: string[];
}
const TwoColumnSection: React.FC<TwoColumnSectionProps> = ({
title,
leftIcon,
leftTitle,
leftItems,
rightIcon,
rightTitle,
rightItems,
}) => (
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
<span className="font-medium">{title}</span>
<div className="mt-2 grid grid-cols-2 items-start gap-4">
{/* Left column - Breaking changes */}
<div className="min-w-0">
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
{leftIcon}
<span>{leftTitle}</span>
</div>
<ul className="mt-1.5 space-y-1">
{leftItems.length > 0 ? (
leftItems.map((item) => (
<li
key={item}
className="text-sm text-gray-700 dark:text-gray-300"
>
<code className="rounded bg-red-50 px-1 py-0.5 font-mono text-xs text-red-700 dark:bg-red-900/30 dark:text-red-300">
{item}
</code>
</li>
))
) : (
<li className="text-sm italic text-gray-400 dark:text-gray-500">
None
</li>
)}
</ul>
</div>
{/* Right column - Possible solutions */}
<div className="min-w-0">
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
{rightIcon}
<span>{rightTitle}</span>
</div>
<ul className="mt-1.5 space-y-1">
{rightItems.length > 0 ? (
rightItems.map((item) => (
<li
key={item}
className="text-sm text-gray-700 dark:text-gray-300"
>
<code className="rounded bg-green-50 px-1 py-0.5 font-mono text-xs text-green-700 dark:bg-green-900/30 dark:text-green-300">
{item}
</code>
</li>
))
) : (
<li className="text-sm italic text-gray-400 dark:text-gray-500">
None
</li>
)}
</ul>
</div>
</div>
</div>
);
interface SingleColumnSectionProps {
icon: React.ReactNode;
title: string;
description: string;
items: string[];
}
const SingleColumnSection: React.FC<SingleColumnSectionProps> = ({
icon,
title,
description,
items,
}) => (
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
<div className="flex items-center gap-2">
{icon}
<span className="font-medium">{title}</span>
</div>
<p className="mt-1 text-sm text-gray-500 dark:text-gray-400">
{description}
</p>
<ul className="mt-2 space-y-1">
{items.map((item) => (
<li
key={item}
className="ml-4 list-disc text-sm text-gray-700 dark:text-gray-300"
>
<code className="rounded bg-gray-100 px-1 py-0.5 font-mono text-xs dark:bg-gray-800">
{item}
</code>
</li>
))}
</ul>
</div>
);
export default IncompatibilityDialog;

View File

@@ -0,0 +1,130 @@
import React from "react";
import { Button } from "@/components/__legacy__/ui/button";
import { ArrowUp, AlertTriangle, Info } from "lucide-react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
import { cn } from "@/lib/utils";
interface SubAgentUpdateBarProps {
currentVersion: number;
latestVersion: number;
isCompatible: boolean;
incompatibilities: IncompatibilityInfo | null;
onUpdate: () => void;
isInResolutionMode?: boolean;
}
export const SubAgentUpdateBar: React.FC<SubAgentUpdateBarProps> = ({
currentVersion,
latestVersion,
isCompatible,
incompatibilities,
onUpdate,
isInResolutionMode = false,
}) => {
if (isInResolutionMode) {
return <ResolutionModeBar incompatibilities={incompatibilities} />;
}
return (
<div className="flex items-center justify-between gap-2 rounded-t-lg bg-blue-50 px-3 py-2 dark:bg-blue-900/30">
<div className="flex items-center gap-2">
<ArrowUp className="h-4 w-4 text-blue-600 dark:text-blue-400" />
<span className="text-sm text-blue-700 dark:text-blue-300">
Update available (v{currentVersion} v{latestVersion})
</span>
{!isCompatible && (
<Tooltip>
<TooltipTrigger asChild>
<AlertTriangle className="h-4 w-4 text-amber-500" />
</TooltipTrigger>
<TooltipContent className="max-w-xs">
<p className="font-medium">Incompatible changes detected</p>
<p className="text-xs text-gray-400">
Click Update to see details
</p>
</TooltipContent>
</Tooltip>
)}
</div>
<Button
size="sm"
variant={isCompatible ? "default" : "outline"}
onClick={onUpdate}
className={cn(
"h-7 text-xs",
!isCompatible && "border-amber-500 text-amber-600 hover:bg-amber-50",
)}
>
Update
</Button>
</div>
);
};
interface ResolutionModeBarProps {
incompatibilities: IncompatibilityInfo | null;
}
const ResolutionModeBar: React.FC<ResolutionModeBarProps> = ({
incompatibilities,
}) => {
const formatIncompatibilities = () => {
if (!incompatibilities) return "No incompatibilities";
const items: string[] = [];
if (incompatibilities.missingInputs.length > 0) {
items.push(
`Missing inputs: ${incompatibilities.missingInputs.join(", ")}`,
);
}
if (incompatibilities.missingOutputs.length > 0) {
items.push(
`Missing outputs: ${incompatibilities.missingOutputs.join(", ")}`,
);
}
if (incompatibilities.newRequiredInputs.length > 0) {
items.push(
`New required inputs: ${incompatibilities.newRequiredInputs.join(", ")}`,
);
}
if (incompatibilities.inputTypeMismatches.length > 0) {
const mismatches = incompatibilities.inputTypeMismatches
.map((m) => `${m.name} (${m.oldType}${m.newType})`)
.join(", ");
items.push(`Type changed: ${mismatches}`);
}
return items.join("\n");
};
return (
<div className="flex items-center justify-between gap-2 rounded-t-lg bg-amber-50 px-3 py-2 dark:bg-amber-900/30">
<div className="flex items-center gap-2">
<AlertTriangle className="h-4 w-4 text-amber-600 dark:text-amber-400" />
<span className="text-sm text-amber-700 dark:text-amber-300">
Remove incompatible connections
</span>
<Tooltip>
<TooltipTrigger asChild>
<Info className="h-4 w-4 cursor-help text-amber-500" />
</TooltipTrigger>
<TooltipContent className="max-w-sm whitespace-pre-line">
<p className="font-medium">Incompatible changes:</p>
<p className="mt-1 text-xs">{formatIncompatibilities()}</p>
<p className="mt-2 text-xs text-gray-400">
Delete the red connections to continue
</p>
</TooltipContent>
</Tooltip>
</div>
</div>
);
};
export default SubAgentUpdateBar;

View File

@@ -26,15 +26,17 @@ import {
applyNodeChanges, applyNodeChanges,
} from "@xyflow/react"; } from "@xyflow/react";
import "@xyflow/react/dist/style.css"; import "@xyflow/react/dist/style.css";
import { CustomNode } from "../CustomNode/CustomNode"; import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode";
import "./flow.css"; import "./flow.css";
import { import {
BlockUIType, BlockUIType,
formatEdgeID, formatEdgeID,
GraphExecutionID, GraphExecutionID,
GraphID, GraphID,
GraphMeta,
LibraryAgent, LibraryAgent,
} from "@/lib/autogpt-server-api"; } from "@/lib/autogpt-server-api";
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
import { Key, storage } from "@/services/storage/local-storage"; import { Key, storage } from "@/services/storage/local-storage";
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils"; import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
import { history } from "../history"; import { history } from "../history";
@@ -72,12 +74,30 @@ import { FloatingSafeModeToggle } from "../../FloatingSafeModeToogle";
// It helps to prevent spamming the history with small movements especially when pressing on a input in a block // It helps to prevent spamming the history with small movements especially when pressing on a input in a block
const MINIMUM_MOVE_BEFORE_LOG = 50; const MINIMUM_MOVE_BEFORE_LOG = 50;
export type ResolutionModeState = {
active: boolean;
nodeId: string | null;
incompatibilities: IncompatibilityInfo | null;
brokenEdgeIds: string[];
pendingUpdate: Record<string, unknown> | null; // The hardcoded values to apply after resolution
};
type BuilderContextType = { type BuilderContextType = {
libraryAgent: LibraryAgent | null; libraryAgent: LibraryAgent | null;
visualizeBeads: "no" | "static" | "animate"; visualizeBeads: "no" | "static" | "animate";
setIsAnyModalOpen: (isOpen: boolean) => void; setIsAnyModalOpen: (isOpen: boolean) => void;
getNextNodeId: () => string; getNextNodeId: () => string;
getNodeTitle: (nodeID: string) => string | null; getNodeTitle: (nodeID: string) => string | null;
availableFlows: GraphMeta[];
resolutionMode: ResolutionModeState;
enterResolutionMode: (
nodeId: string,
incompatibilities: IncompatibilityInfo,
brokenEdgeIds: string[],
pendingUpdate: Record<string, unknown>,
) => void;
exitResolutionMode: () => void;
applyPendingUpdate: () => void;
}; };
export type NodeDimension = { export type NodeDimension = {
@@ -172,6 +192,92 @@ const FlowEditor: React.FC<{
// It stores the dimension of all nodes with position as well // It stores the dimension of all nodes with position as well
const [nodeDimensions, setNodeDimensions] = useState<NodeDimension>({}); const [nodeDimensions, setNodeDimensions] = useState<NodeDimension>({});
// Resolution mode state for sub-agent incompatible updates
const [resolutionMode, setResolutionMode] = useState<ResolutionModeState>({
active: false,
nodeId: null,
incompatibilities: null,
brokenEdgeIds: [],
pendingUpdate: null,
});
const enterResolutionMode = useCallback(
(
nodeId: string,
incompatibilities: IncompatibilityInfo,
brokenEdgeIds: string[],
pendingUpdate: Record<string, unknown>,
) => {
setResolutionMode({
active: true,
nodeId,
incompatibilities,
brokenEdgeIds,
pendingUpdate,
});
},
[],
);
const exitResolutionMode = useCallback(() => {
setResolutionMode({
active: false,
nodeId: null,
incompatibilities: null,
brokenEdgeIds: [],
pendingUpdate: null,
});
}, []);
// Apply pending update after resolution mode completes
const applyPendingUpdate = useCallback(() => {
if (!resolutionMode.nodeId || !resolutionMode.pendingUpdate) return;
const node = nodes.find((n) => n.id === resolutionMode.nodeId);
if (node) {
const pendingUpdate = resolutionMode.pendingUpdate as {
[key: string]: any;
};
setNodes((nds) =>
nds.map((n) =>
n.id === resolutionMode.nodeId
? { ...n, data: { ...n.data, hardcodedValues: pendingUpdate } }
: n,
),
);
}
exitResolutionMode();
toast({
title: "Update complete",
description: "Agent has been updated to the new version.",
});
}, [resolutionMode, nodes, setNodes, exitResolutionMode, toast]);
// Check if all broken edges have been removed and auto-apply pending update
useEffect(() => {
if (!resolutionMode.active || resolutionMode.brokenEdgeIds.length === 0) {
return;
}
const currentEdgeIds = new Set(edges.map((e) => e.id));
const remainingBrokenEdges = resolutionMode.brokenEdgeIds.filter((id) =>
currentEdgeIds.has(id),
);
if (remainingBrokenEdges.length === 0) {
// All broken edges have been removed, apply pending update
applyPendingUpdate();
} else if (
remainingBrokenEdges.length !== resolutionMode.brokenEdgeIds.length
) {
// Update the list of broken edges
setResolutionMode((prev) => ({
...prev,
brokenEdgeIds: remainingBrokenEdges,
}));
}
}, [edges, resolutionMode, applyPendingUpdate]);
// Set page title with or without graph name // Set page title with or without graph name
useEffect(() => { useEffect(() => {
document.title = savedAgent document.title = savedAgent
@@ -431,17 +537,19 @@ const FlowEditor: React.FC<{
...node.data.connections.filter( ...node.data.connections.filter(
(conn) => (conn) =>
!removedEdges.some( !removedEdges.some(
(removedEdge) => removedEdge.id === conn.edge_id, (removedEdge) => removedEdge.id === conn.id,
), ),
), ),
// Add node connections for added edges // Add node connections for added edges
...addedEdges.map((addedEdge) => ({ ...addedEdges.map(
edge_id: addedEdge.item.id, (addedEdge): ConnectedEdge => ({
id: addedEdge.item.id,
source: addedEdge.item.source, source: addedEdge.item.source,
target: addedEdge.item.target, target: addedEdge.item.target,
sourceHandle: addedEdge.item.sourceHandle!, sourceHandle: addedEdge.item.sourceHandle!,
targetHandle: addedEdge.item.targetHandle!, targetHandle: addedEdge.item.targetHandle!,
})), }),
),
], ],
}, },
})); }));
@@ -467,13 +575,15 @@ const FlowEditor: React.FC<{
data: { data: {
...node.data, ...node.data,
connections: [ connections: [
...replaceEdges.map((replaceEdge) => ({ ...replaceEdges.map(
edge_id: replaceEdge.item.id, (replaceEdge): ConnectedEdge => ({
id: replaceEdge.item.id,
source: replaceEdge.item.source, source: replaceEdge.item.source,
target: replaceEdge.item.target, target: replaceEdge.item.target,
sourceHandle: replaceEdge.item.sourceHandle!, sourceHandle: replaceEdge.item.sourceHandle!,
targetHandle: replaceEdge.item.targetHandle!, targetHandle: replaceEdge.item.targetHandle!,
})), }),
),
], ],
}, },
})), })),
@@ -890,8 +1000,23 @@ const FlowEditor: React.FC<{
setIsAnyModalOpen, setIsAnyModalOpen,
getNextNodeId, getNextNodeId,
getNodeTitle, getNodeTitle,
availableFlows,
resolutionMode,
enterResolutionMode,
exitResolutionMode,
applyPendingUpdate,
}), }),
[libraryAgent, visualizeBeads, getNextNodeId, getNodeTitle], [
libraryAgent,
visualizeBeads,
getNextNodeId,
getNodeTitle,
availableFlows,
resolutionMode,
enterResolutionMode,
applyPendingUpdate,
exitResolutionMode,
],
); );
return ( return (
@@ -991,6 +1116,7 @@ const FlowEditor: React.FC<{
onClickScheduleButton={handleScheduleButton} onClickScheduleButton={handleScheduleButton}
isDisabled={!savedAgent} isDisabled={!savedAgent}
isRunning={isRunning} isRunning={isRunning}
resolutionModeActive={resolutionMode.active}
/> />
) : ( ) : (
<Alert className="absolute bottom-4 left-1/2 z-20 w-auto -translate-x-1/2 select-none"> <Alert className="absolute bottom-4 left-1/2 z-20 w-auto -translate-x-1/2 select-none">

View File

@@ -1,6 +1,11 @@
import { BlockIOSubSchema } from "@/lib/autogpt-server-api/types"; import { BlockIOSubSchema } from "@/lib/autogpt-server-api/types";
import { cn } from "@/lib/utils"; import {
import { beautifyString, getTypeBgColor, getTypeTextColor } from "@/lib/utils"; cn,
beautifyString,
getTypeBgColor,
getTypeTextColor,
getEffectiveType,
} from "@/lib/utils";
import { FC, memo, useCallback } from "react"; import { FC, memo, useCallback } from "react";
import { Handle, Position } from "@xyflow/react"; import { Handle, Position } from "@xyflow/react";
import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip"; import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip";
@@ -13,6 +18,7 @@ type HandleProps = {
side: "left" | "right"; side: "left" | "right";
title?: string; title?: string;
className?: string; className?: string;
isBroken?: boolean;
}; };
// Move the constant out of the component to avoid re-creation on every render. // Move the constant out of the component to avoid re-creation on every render.
@@ -27,18 +33,23 @@ const TYPE_NAME: Record<string, string> = {
}; };
// Extract and memoize the Dot component so that it doesn't re-render unnecessarily. // Extract and memoize the Dot component so that it doesn't re-render unnecessarily.
const Dot: FC<{ isConnected: boolean; type?: string }> = memo( const Dot: FC<{ isConnected: boolean; type?: string; isBroken?: boolean }> =
({ isConnected, type }) => { memo(({ isConnected, type, isBroken }) => {
const color = isConnected const color = isBroken
? "border-red-500 bg-red-100 dark:bg-red-900/30"
: isConnected
? getTypeBgColor(type || "any") ? getTypeBgColor(type || "any")
: "border-gray-300 dark:border-gray-600"; : "border-gray-300 dark:border-gray-600";
return ( return (
<div <div
className={`${color} m-1 h-4 w-4 rounded-full border-2 bg-white transition-colors duration-100 group-hover:bg-gray-300 dark:bg-slate-800 dark:group-hover:bg-gray-700`} className={cn(
"m-1 h-4 w-4 rounded-full border-2 bg-white transition-colors duration-100 group-hover:bg-gray-300 dark:bg-slate-800 dark:group-hover:bg-gray-700",
color,
isBroken && "opacity-50",
)}
/> />
); );
}, });
);
Dot.displayName = "Dot"; Dot.displayName = "Dot";
const NodeHandle: FC<HandleProps> = ({ const NodeHandle: FC<HandleProps> = ({
@@ -49,24 +60,34 @@ const NodeHandle: FC<HandleProps> = ({
side, side,
title, title,
className, className,
isBroken = false,
}) => { }) => {
const typeClass = `text-sm ${getTypeTextColor(schema.type || "any")} ${ // Extract effective type from schema (handles anyOf/oneOf/allOf wrappers)
const effectiveType = getEffectiveType(schema);
const typeClass = `text-sm ${getTypeTextColor(effectiveType || "any")} ${
side === "left" ? "text-left" : "text-right" side === "left" ? "text-left" : "text-right"
}`; }`;
const label = ( const label = (
<div className="flex flex-grow flex-row"> <div className={cn("flex flex-grow flex-row", isBroken && "opacity-50")}>
<span <span
className={cn( className={cn(
"data-sentry-unmask text-m green flex items-end pr-2 text-gray-900 dark:text-gray-100", "data-sentry-unmask text-m green flex items-end pr-2 text-gray-900 dark:text-gray-100",
className, className,
isBroken && "text-red-500 line-through",
)} )}
> >
{title || schema.title || beautifyString(keyName.toLowerCase())} {title || schema.title || beautifyString(keyName.toLowerCase())}
{isRequired ? "*" : ""} {isRequired ? "*" : ""}
</span> </span>
<span className={`${typeClass} data-sentry-unmask flex items-end`}> <span
({TYPE_NAME[schema.type as keyof typeof TYPE_NAME] || "any"}) className={cn(
`${typeClass} data-sentry-unmask flex items-end`,
isBroken && "text-red-400",
)}
>
({TYPE_NAME[effectiveType as keyof typeof TYPE_NAME] || "any"})
</span> </span>
</div> </div>
); );
@@ -84,7 +105,7 @@ const NodeHandle: FC<HandleProps> = ({
return ( return (
<div <div
key={keyName} key={keyName}
className="handle-container" className={cn("handle-container", isBroken && "pointer-events-none")}
onContextMenu={handleContextMenu} onContextMenu={handleContextMenu}
> >
<Handle <Handle
@@ -92,10 +113,15 @@ const NodeHandle: FC<HandleProps> = ({
data-testid={`input-handle-${keyName}`} data-testid={`input-handle-${keyName}`}
position={Position.Left} position={Position.Left}
id={keyName} id={keyName}
className="group -ml-[38px]" className={cn("group -ml-[38px]", isBroken && "cursor-not-allowed")}
isConnectable={!isBroken}
> >
<div className="pointer-events-none flex items-center"> <div className="pointer-events-none flex items-center">
<Dot isConnected={isConnected} type={schema.type} /> <Dot
isConnected={isConnected}
type={effectiveType}
isBroken={isBroken}
/>
{label} {label}
</div> </div>
</Handle> </Handle>
@@ -106,7 +132,10 @@ const NodeHandle: FC<HandleProps> = ({
return ( return (
<div <div
key={keyName} key={keyName}
className="handle-container justify-end" className={cn(
"handle-container justify-end",
isBroken && "pointer-events-none",
)}
onContextMenu={handleContextMenu} onContextMenu={handleContextMenu}
> >
<Handle <Handle
@@ -114,11 +143,16 @@ const NodeHandle: FC<HandleProps> = ({
data-testid={`output-handle-${keyName}`} data-testid={`output-handle-${keyName}`}
position={Position.Right} position={Position.Right}
id={keyName} id={keyName}
className="group -mr-[38px]" className={cn("group -mr-[38px]", isBroken && "cursor-not-allowed")}
isConnectable={!isBroken}
> >
<div className="pointer-events-none flex items-center"> <div className="pointer-events-none flex items-center">
{label} {label}
<Dot isConnected={isConnected} type={schema.type} /> <Dot
isConnected={isConnected}
type={effectiveType}
isBroken={isBroken}
/>
</div> </div>
</Handle> </Handle>
</div> </div>

View File

@@ -1,5 +1,5 @@
import { import {
ConnectionData, ConnectedEdge,
CustomNodeData, CustomNodeData,
} from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode"; } from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
import { NodeTableInput } from "@/app/(platform)/build/components/legacy-builder/NodeTableInput"; import { NodeTableInput } from "@/app/(platform)/build/components/legacy-builder/NodeTableInput";
@@ -65,7 +65,7 @@ type NodeObjectInputTreeProps = {
selfKey?: string; selfKey?: string;
schema: BlockIORootSchema | BlockIOObjectSubSchema; schema: BlockIORootSchema | BlockIOObjectSubSchema;
object?: { [key: string]: any }; object?: { [key: string]: any };
connections: ConnectionData; connections: ConnectedEdge[];
handleInputClick: (key: string) => void; handleInputClick: (key: string) => void;
handleInputChange: (key: string, value: any) => void; handleInputChange: (key: string, value: any) => void;
errors: { [key: string]: string | undefined }; errors: { [key: string]: string | undefined };
@@ -585,7 +585,7 @@ const NodeOneOfDiscriminatorField: FC<{
currentValue?: any; currentValue?: any;
defaultValue?: any; defaultValue?: any;
errors: { [key: string]: string | undefined }; errors: { [key: string]: string | undefined };
connections: ConnectionData; connections: ConnectedEdge[];
handleInputChange: (key: string, value: any) => void; handleInputChange: (key: string, value: any) => void;
handleInputClick: (key: string) => void; handleInputClick: (key: string) => void;
className?: string; className?: string;

View File

@@ -1,15 +1,16 @@
import { FC, useCallback, useEffect, useState } from "react"; import { FC, useCallback, useEffect, useState } from "react";
import NodeHandle from "@/app/(platform)/build/components/legacy-builder/NodeHandle"; import NodeHandle from "@/app/(platform)/build/components/legacy-builder/NodeHandle";
import { import type {
BlockIOTableSubSchema, BlockIOTableSubSchema,
TableCellValue, TableCellValue,
TableRow, TableRow,
} from "@/lib/autogpt-server-api/types"; } from "@/lib/autogpt-server-api/types";
import type { ConnectedEdge } from "./CustomNode/CustomNode";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { PlusIcon, XIcon } from "@phosphor-icons/react"; import { PlusIcon, XIcon } from "@phosphor-icons/react";
import { Button } from "../../../../../components/atoms/Button/Button"; import { Button } from "@/components/atoms/Button/Button";
import { Input } from "../../../../../components/atoms/Input/Input"; import { Input } from "@/components/atoms/Input/Input";
interface NodeTableInputProps { interface NodeTableInputProps {
/** Unique identifier for the node in the builder graph */ /** Unique identifier for the node in the builder graph */
@@ -25,13 +26,7 @@ interface NodeTableInputProps {
/** Validation errors mapped by field key */ /** Validation errors mapped by field key */
errors: { [key: string]: string | undefined }; errors: { [key: string]: string | undefined };
/** Graph connections between nodes in the builder */ /** Graph connections between nodes in the builder */
connections: { connections: ConnectedEdge[];
edge_id: string;
source: string;
sourceHandle: string;
target: string;
targetHandle: string;
}[];
/** Callback when table data changes */ /** Callback when table data changes */
handleInputChange: (key: string, value: TableRow[]) => void; handleInputChange: (key: string, value: TableRow[]) => void;
/** Callback when input field is clicked (for builder selection) */ /** Callback when input field is clicked (for builder selection) */

View File

@@ -1,6 +1,7 @@
import { useCallback } from "react"; import { useCallback } from "react";
import { Node, Edge, useReactFlow } from "@xyflow/react"; import { Node, Edge, useReactFlow } from "@xyflow/react";
import { Key, storage } from "@/services/storage/local-storage"; import { Key, storage } from "@/services/storage/local-storage";
import { ConnectedEdge } from "./CustomNode/CustomNode";
interface CopyableData { interface CopyableData {
nodes: Node[]; nodes: Node[];
@@ -111,13 +112,15 @@ export function useCopyPaste(getNextNodeId: () => string) {
(edge: Edge) => (edge: Edge) =>
edge.source === node.id || edge.target === node.id, edge.source === node.id || edge.target === node.id,
) )
.map((edge: Edge) => ({ .map(
edge_id: edge.id, (edge: Edge): ConnectedEdge => ({
id: edge.id,
source: edge.source, source: edge.source,
target: edge.target, target: edge.target,
sourceHandle: edge.sourceHandle, sourceHandle: edge.sourceHandle!,
targetHandle: edge.targetHandle, targetHandle: edge.targetHandle!,
})); }),
);
return { return {
...node, ...node,

View File

@@ -0,0 +1,104 @@
import { GraphInputSchema } from "@/lib/autogpt-server-api";
import { GraphMetaLike, IncompatibilityInfo } from "./types";
// Helper type for schema properties - the generated types are too loose
type SchemaProperties = Record<string, GraphInputSchema["properties"][string]>;
type SchemaRequired = string[];
// Helper to safely extract schema properties
export function getSchemaProperties(schema: unknown): SchemaProperties {
if (
schema &&
typeof schema === "object" &&
"properties" in schema &&
typeof schema.properties === "object" &&
schema.properties !== null
) {
return schema.properties as SchemaProperties;
}
return {};
}
export function getSchemaRequired(schema: unknown): SchemaRequired {
if (
schema &&
typeof schema === "object" &&
"required" in schema &&
Array.isArray(schema.required)
) {
return schema.required as SchemaRequired;
}
return [];
}
/**
* Creates the updated agent node inputs for a sub-agent node
*/
export function createUpdatedAgentNodeInputs(
currentInputs: Record<string, unknown>,
latestSubGraphVersion: GraphMetaLike,
): Record<string, unknown> {
return {
...currentInputs,
graph_version: latestSubGraphVersion.version,
input_schema: latestSubGraphVersion.input_schema,
output_schema: latestSubGraphVersion.output_schema,
};
}
/** Generic edge type that works with both builders:
* - New builder uses CustomEdge with (formally) optional handles
* - Legacy builder uses ConnectedEdge type with required handles */
export type EdgeLike = {
id: string;
source: string;
target: string;
sourceHandle?: string | null;
targetHandle?: string | null;
};
/**
* Determines which edges are broken after an incompatible update.
* Works with both legacy ConnectedEdge and new CustomEdge.
*/
export function getBrokenEdgeIDs(
connections: EdgeLike[],
incompatibilities: IncompatibilityInfo,
nodeID: string,
): string[] {
const brokenEdgeIDs: string[] = [];
const typeMismatchInputNames = new Set(
incompatibilities.inputTypeMismatches.map((m) => m.name),
);
connections.forEach((conn) => {
// Check if this connection uses a missing input (node is target)
if (
conn.target === nodeID &&
conn.targetHandle &&
incompatibilities.missingInputs.includes(conn.targetHandle)
) {
brokenEdgeIDs.push(conn.id);
}
// Check if this connection uses an input with a type mismatch (node is target)
if (
conn.target === nodeID &&
conn.targetHandle &&
typeMismatchInputNames.has(conn.targetHandle)
) {
brokenEdgeIDs.push(conn.id);
}
// Check if this connection uses a missing output (node is source)
if (
conn.source === nodeID &&
conn.sourceHandle &&
incompatibilities.missingOutputs.includes(conn.sourceHandle)
) {
brokenEdgeIDs.push(conn.id);
}
});
return brokenEdgeIDs;
}

View File

@@ -0,0 +1,2 @@
export { useSubAgentUpdate } from "./useSubAgentUpdate";
export { createUpdatedAgentNodeInputs, getBrokenEdgeIDs } from "./helpers";

View File

@@ -0,0 +1,27 @@
import type { GraphMeta as LegacyGraphMeta } from "@/lib/autogpt-server-api";
import type { GraphMeta as GeneratedGraphMeta } from "@/app/api/__generated__/models/graphMeta";
export type SubAgentUpdateInfo<T extends GraphMetaLike = GraphMetaLike> = {
hasUpdate: boolean;
currentVersion: number;
latestVersion: number;
latestGraph: T | null;
isCompatible: boolean;
incompatibilities: IncompatibilityInfo | null;
};
// Union type for GraphMeta that works with both legacy and new builder
export type GraphMetaLike = LegacyGraphMeta | GeneratedGraphMeta;
export type IncompatibilityInfo = {
missingInputs: string[]; // Connected inputs that no longer exist
missingOutputs: string[]; // Connected outputs that no longer exist
newInputs: string[]; // Inputs that exist in new version but not in current
newOutputs: string[]; // Outputs that exist in new version but not in current
newRequiredInputs: string[]; // New required inputs not in current version or not required
inputTypeMismatches: Array<{
name: string;
oldType: string;
newType: string;
}>; // Connected inputs where the type has changed
};

View File

@@ -0,0 +1,160 @@
import { useMemo } from "react";
import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api";
import { getEffectiveType } from "@/lib/utils";
import { EdgeLike, getSchemaProperties, getSchemaRequired } from "./helpers";
import {
GraphMetaLike,
IncompatibilityInfo,
SubAgentUpdateInfo,
} from "./types";
/**
* Checks if a newer version of a sub-agent is available and determines compatibility
*/
export function useSubAgentUpdate<T extends GraphMetaLike>(
nodeID: string,
graphID: string | undefined,
graphVersion: number | undefined,
currentInputSchema: GraphInputSchema | undefined,
currentOutputSchema: GraphOutputSchema | undefined,
connections: EdgeLike[],
availableGraphs: T[],
): SubAgentUpdateInfo<T> {
// Find the latest version of the same graph
const latestGraph = useMemo(() => {
if (!graphID) return null;
return availableGraphs.find((graph) => graph.id === graphID) || null;
}, [graphID, availableGraphs]);
// Check if there's an update available
const hasUpdate = useMemo(() => {
if (!latestGraph || graphVersion === undefined) return false;
return latestGraph.version! > graphVersion;
}, [latestGraph, graphVersion]);
// Get connected input and output handles for this specific node
const connectedHandles = useMemo(() => {
const inputHandles = new Set<string>();
const outputHandles = new Set<string>();
connections.forEach((conn) => {
// If this node is the target, the targetHandle is an input on this node
if (conn.target === nodeID && conn.targetHandle) {
inputHandles.add(conn.targetHandle);
}
// If this node is the source, the sourceHandle is an output on this node
if (conn.source === nodeID && conn.sourceHandle) {
outputHandles.add(conn.sourceHandle);
}
});
return { inputHandles, outputHandles };
}, [connections, nodeID]);
// Check schema compatibility
const compatibilityResult = useMemo((): {
isCompatible: boolean;
incompatibilities: IncompatibilityInfo | null;
} => {
if (!hasUpdate || !latestGraph) {
return { isCompatible: true, incompatibilities: null };
}
const newInputProps = getSchemaProperties(latestGraph.input_schema);
const newOutputProps = getSchemaProperties(latestGraph.output_schema);
const newRequiredInputs = getSchemaRequired(latestGraph.input_schema);
const currentInputProps = getSchemaProperties(currentInputSchema);
const currentOutputProps = getSchemaProperties(currentOutputSchema);
const currentRequiredInputs = getSchemaRequired(currentInputSchema);
const incompatibilities: IncompatibilityInfo = {
missingInputs: [],
missingOutputs: [],
newInputs: [],
newOutputs: [],
newRequiredInputs: [],
inputTypeMismatches: [],
};
// Check for missing connected inputs and type mismatches
connectedHandles.inputHandles.forEach((inputHandle) => {
if (!(inputHandle in newInputProps)) {
incompatibilities.missingInputs.push(inputHandle);
} else {
// Check for type mismatch on connected inputs
const currentProp = currentInputProps[inputHandle];
const newProp = newInputProps[inputHandle];
const currentType = getEffectiveType(currentProp);
const newType = getEffectiveType(newProp);
if (currentType && newType && currentType !== newType) {
incompatibilities.inputTypeMismatches.push({
name: inputHandle,
oldType: currentType,
newType: newType,
});
}
}
});
// Check for missing connected outputs
connectedHandles.outputHandles.forEach((outputHandle) => {
if (!(outputHandle in newOutputProps)) {
incompatibilities.missingOutputs.push(outputHandle);
}
});
// Check for new required inputs that didn't exist or weren't required before
newRequiredInputs.forEach((requiredInput) => {
const existedBefore = requiredInput in currentInputProps;
const wasRequiredBefore = currentRequiredInputs.includes(
requiredInput as string,
);
if (!existedBefore || !wasRequiredBefore) {
incompatibilities.newRequiredInputs.push(requiredInput as string);
}
});
// Check for new inputs that don't exist in the current version
Object.keys(newInputProps).forEach((inputName) => {
if (!(inputName in currentInputProps)) {
incompatibilities.newInputs.push(inputName);
}
});
// Check for new outputs that don't exist in the current version
Object.keys(newOutputProps).forEach((outputName) => {
if (!(outputName in currentOutputProps)) {
incompatibilities.newOutputs.push(outputName);
}
});
const hasIncompatibilities =
incompatibilities.missingInputs.length > 0 ||
incompatibilities.missingOutputs.length > 0 ||
incompatibilities.newRequiredInputs.length > 0 ||
incompatibilities.inputTypeMismatches.length > 0;
return {
isCompatible: !hasIncompatibilities,
incompatibilities: hasIncompatibilities ? incompatibilities : null,
};
}, [
hasUpdate,
latestGraph,
currentInputSchema,
currentOutputSchema,
connectedHandles,
]);
return {
hasUpdate,
currentVersion: graphVersion || 0,
latestVersion: latestGraph?.version || 0,
latestGraph,
isCompatible: compatibilityResult.isCompatible,
incompatibilities: compatibilityResult.incompatibilities,
};
}

View File

@@ -1,5 +1,6 @@
import { create } from "zustand"; import { create } from "zustand";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus"; import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
interface GraphStore { interface GraphStore {
graphExecutionStatus: AgentExecutionStatus | undefined; graphExecutionStatus: AgentExecutionStatus | undefined;
@@ -17,6 +18,10 @@ interface GraphStore {
outputSchema: Record<string, any> | null, outputSchema: Record<string, any> | null,
) => void; ) => void;
// Available graphs; used for sub-graph updates
availableSubGraphs: GraphMeta[];
setAvailableSubGraphs: (graphs: GraphMeta[]) => void;
hasInputs: () => boolean; hasInputs: () => boolean;
hasCredentials: () => boolean; hasCredentials: () => boolean;
hasOutputs: () => boolean; hasOutputs: () => boolean;
@@ -29,6 +34,7 @@ export const useGraphStore = create<GraphStore>((set, get) => ({
inputSchema: null, inputSchema: null,
credentialsInputSchema: null, credentialsInputSchema: null,
outputSchema: null, outputSchema: null,
availableSubGraphs: [],
setGraphExecutionStatus: (status: AgentExecutionStatus | undefined) => { setGraphExecutionStatus: (status: AgentExecutionStatus | undefined) => {
set({ set({
@@ -46,6 +52,8 @@ export const useGraphStore = create<GraphStore>((set, get) => ({
setGraphSchemas: (inputSchema, credentialsInputSchema, outputSchema) => setGraphSchemas: (inputSchema, credentialsInputSchema, outputSchema) =>
set({ inputSchema, credentialsInputSchema, outputSchema }), set({ inputSchema, credentialsInputSchema, outputSchema }),
setAvailableSubGraphs: (graphs) => set({ availableSubGraphs: graphs }),
hasOutputs: () => { hasOutputs: () => {
const { outputSchema } = get(); const { outputSchema } = get();
return Object.keys(outputSchema?.properties ?? {}).length > 0; return Object.keys(outputSchema?.properties ?? {}).length > 0;

View File

@@ -17,6 +17,25 @@ import {
ensurePathExists, ensurePathExists,
parseHandleIdToPath, parseHandleIdToPath,
} from "@/components/renderers/InputRenderer/helpers"; } from "@/components/renderers/InputRenderer/helpers";
import { IncompatibilityInfo } from "../hooks/useSubAgentUpdate/types";
// Resolution mode data stored per node
export type NodeResolutionData = {
incompatibilities: IncompatibilityInfo;
// The NEW schema from the update (what we're updating TO)
pendingUpdate: {
input_schema: Record<string, unknown>;
output_schema: Record<string, unknown>;
};
// The OLD schema before the update (what we're updating FROM)
// Needed to merge and show removed inputs during resolution
currentSchema: {
input_schema: Record<string, unknown>;
output_schema: Record<string, unknown>;
};
// The full updated hardcoded values to apply when resolution completes
pendingHardcodedValues: Record<string, unknown>;
};
// Minimum movement (in pixels) required before logging position change to history // Minimum movement (in pixels) required before logging position change to history
// Prevents spamming history with small movements when clicking on inputs inside blocks // Prevents spamming history with small movements when clicking on inputs inside blocks
@@ -65,12 +84,32 @@ type NodeStore = {
backendId: string, backendId: string,
errors: { [key: string]: string }, errors: { [key: string]: string },
) => void; ) => void;
clearAllNodeErrors: () => void; // Add this
syncHardcodedValuesWithHandleIds: (nodeId: string) => void; syncHardcodedValuesWithHandleIds: (nodeId: string) => void;
// Credentials optional helpers
setCredentialsOptional: (nodeId: string, optional: boolean) => void; setCredentialsOptional: (nodeId: string, optional: boolean) => void;
clearAllNodeErrors: () => void;
nodesInResolutionMode: Set<string>;
brokenEdgeIDs: Map<string, Set<string>>;
nodeResolutionData: Map<string, NodeResolutionData>;
setNodeResolutionMode: (
nodeID: string,
inResolution: boolean,
resolutionData?: NodeResolutionData,
) => void;
isNodeInResolutionMode: (nodeID: string) => boolean;
getNodeResolutionData: (nodeID: string) => NodeResolutionData | undefined;
setBrokenEdgeIDs: (nodeID: string, edgeIDs: string[]) => void;
removeBrokenEdgeID: (nodeID: string, edgeID: string) => void;
isEdgeBroken: (edgeID: string) => boolean;
clearResolutionState: () => void;
isInputBroken: (nodeID: string, handleID: string) => boolean;
getInputTypeMismatch: (
nodeID: string,
handleID: string,
) => string | undefined;
}; };
export const useNodeStore = create<NodeStore>((set, get) => ({ export const useNodeStore = create<NodeStore>((set, get) => ({
@@ -374,4 +413,99 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
useHistoryStore.getState().pushState(newState); useHistoryStore.getState().pushState(newState);
}, },
// Sub-agent resolution mode state
nodesInResolutionMode: new Set<string>(),
brokenEdgeIDs: new Map<string, Set<string>>(),
nodeResolutionData: new Map<string, NodeResolutionData>(),
setNodeResolutionMode: (
nodeID: string,
inResolution: boolean,
resolutionData?: NodeResolutionData,
) => {
set((state) => {
const newNodesSet = new Set(state.nodesInResolutionMode);
const newResolutionDataMap = new Map(state.nodeResolutionData);
const newBrokenEdgeIDs = new Map(state.brokenEdgeIDs);
if (inResolution) {
newNodesSet.add(nodeID);
if (resolutionData) {
newResolutionDataMap.set(nodeID, resolutionData);
}
} else {
newNodesSet.delete(nodeID);
newResolutionDataMap.delete(nodeID);
newBrokenEdgeIDs.delete(nodeID); // Clean up broken edges when exiting resolution mode
}
return {
nodesInResolutionMode: newNodesSet,
nodeResolutionData: newResolutionDataMap,
brokenEdgeIDs: newBrokenEdgeIDs,
};
});
},
isNodeInResolutionMode: (nodeID: string) => {
return get().nodesInResolutionMode.has(nodeID);
},
getNodeResolutionData: (nodeID: string) => {
return get().nodeResolutionData.get(nodeID);
},
setBrokenEdgeIDs: (nodeID: string, edgeIDs: string[]) => {
set((state) => {
const newMap = new Map(state.brokenEdgeIDs);
newMap.set(nodeID, new Set(edgeIDs));
return { brokenEdgeIDs: newMap };
});
},
removeBrokenEdgeID: (nodeID: string, edgeID: string) => {
set((state) => {
const newMap = new Map(state.brokenEdgeIDs);
const nodeSet = new Set(newMap.get(nodeID) || []);
nodeSet.delete(edgeID);
newMap.set(nodeID, nodeSet);
return { brokenEdgeIDs: newMap };
});
},
isEdgeBroken: (edgeID: string) => {
// Check across all nodes
const brokenEdgeIDs = get().brokenEdgeIDs;
for (const edgeSet of brokenEdgeIDs.values()) {
if (edgeSet.has(edgeID)) {
return true;
}
}
return false;
},
clearResolutionState: () => {
set({
nodesInResolutionMode: new Set<string>(),
brokenEdgeIDs: new Map<string, Set<string>>(),
nodeResolutionData: new Map<string, NodeResolutionData>(),
});
},
// Helper functions for input renderers
isInputBroken: (nodeID: string, handleID: string) => {
const resolutionData = get().nodeResolutionData.get(nodeID);
if (!resolutionData) return false;
return resolutionData.incompatibilities.missingInputs.includes(handleID);
},
getInputTypeMismatch: (nodeID: string, handleID: string) => {
const resolutionData = get().nodeResolutionData.get(nodeID);
if (!resolutionData) return undefined;
const mismatch = resolutionData.incompatibilities.inputTypeMismatches.find(
(m) => m.name === handleID,
);
return mismatch?.newType;
},
})); }));

View File

@@ -18,7 +18,7 @@ function ErrorPageContent() {
) { ) {
window.location.href = "/login"; window.location.href = "/login";
} else { } else {
window.location.href = "/marketplace"; window.document.location.reload();
} }
} }

View File

@@ -29,7 +29,7 @@ export default function Layout({ children }: { children: React.ReactNode }) {
href: "/profile/dashboard", href: "/profile/dashboard",
icon: <StorefrontIcon className="size-5" />, icon: <StorefrontIcon className="size-5" />,
}, },
...(isPaymentEnabled || true ...(isPaymentEnabled
? [ ? [
{ {
text: "Billing", text: "Billing",

View File

@@ -32,7 +32,7 @@ export const extendedButtonVariants = cva(
), ),
}, },
size: { size: {
small: "px-3 py-2 text-sm gap-1.5 h-[2.25rem]", small: "px-3 py-2 text-sm gap-1.5 h-[2.25rem] min-w-[5.5rem]",
large: "px-4 py-3 text-sm gap-2 h-[3.25rem]", large: "px-4 py-3 text-sm gap-2 h-[3.25rem]",
icon: "p-3 !min-w-0", icon: "p-3 !min-w-0",
}, },

View File

@@ -30,8 +30,6 @@ export const FormRenderer = ({
return generateUiSchemaForCustomFields(preprocessedSchema, uiSchema); return generateUiSchemaForCustomFields(preprocessedSchema, uiSchema);
}, [preprocessedSchema, uiSchema]); }, [preprocessedSchema, uiSchema]);
console.log("preprocessedSchema", preprocessedSchema);
return ( return (
<div className={"mb-6 mt-4"}> <div className={"mb-6 mt-4"}>
<Form <Form

View File

@@ -2,10 +2,12 @@ import { FieldProps, getUiOptions, getWidget } from "@rjsf/utils";
import { AnyOfFieldTitle } from "./components/AnyOfFieldTitle"; import { AnyOfFieldTitle } from "./components/AnyOfFieldTitle";
import { isEmpty } from "lodash"; import { isEmpty } from "lodash";
import { useAnyOfField } from "./useAnyOfField"; import { useAnyOfField } from "./useAnyOfField";
import { getHandleId, updateUiOption } from "../../helpers"; import { cleanUpHandleId, getHandleId, updateUiOption } from "../../helpers";
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore"; import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
import { ANY_OF_FLAG } from "../../constants"; import { ANY_OF_FLAG } from "../../constants";
import { findCustomFieldId } from "../../registry"; import { findCustomFieldId } from "../../registry";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { cn } from "@/lib/utils";
export const AnyOfField = (props: FieldProps) => { export const AnyOfField = (props: FieldProps) => {
const { registry, schema } = props; const { registry, schema } = props;
@@ -21,6 +23,8 @@ export const AnyOfField = (props: FieldProps) => {
field_id, field_id,
} = useAnyOfField(props); } = useAnyOfField(props);
const isInputBroken = useNodeStore((state) => state.isInputBroken);
const parentCustomFieldId = findCustomFieldId(schema); const parentCustomFieldId = findCustomFieldId(schema);
if (parentCustomFieldId) { if (parentCustomFieldId) {
return null; return null;
@@ -43,6 +47,7 @@ export const AnyOfField = (props: FieldProps) => {
}); });
const isHandleConnected = isInputConnected(nodeId, handleId); const isHandleConnected = isInputConnected(nodeId, handleId);
const isAnyOfInputBroken = isInputBroken(nodeId, cleanUpHandleId(handleId));
// Now anyOf can render - custom fields if the option schema matches a custom field // Now anyOf can render - custom fields if the option schema matches a custom field
const optionCustomFieldId = optionSchema const optionCustomFieldId = optionSchema
@@ -78,7 +83,11 @@ export const AnyOfField = (props: FieldProps) => {
registry={registry} registry={registry}
placeholder={props.placeholder} placeholder={props.placeholder}
autocomplete={props.autocomplete} autocomplete={props.autocomplete}
className="-ml-1 h-[22px] w-fit gap-1 px-1 pl-2 text-xs font-medium" className={cn(
"-ml-1 h-[22px] w-fit gap-1 px-1 pl-2 text-xs font-medium",
isAnyOfInputBroken &&
"border-red-500 bg-red-100 text-red-600 line-through",
)}
autofocus={props.autofocus} autofocus={props.autofocus}
label="" label=""
hideLabel={true} hideLabel={true}
@@ -93,7 +102,7 @@ export const AnyOfField = (props: FieldProps) => {
selector={selector} selector={selector}
uiSchema={updatedUiSchema} uiSchema={updatedUiSchema}
/> />
{!isHandleConnected && optionsSchemaField} {!isHandleConnected && !isAnyOfInputBroken && optionsSchemaField}
</div> </div>
); );
}; };

View File

@@ -13,6 +13,7 @@ import { Text } from "@/components/atoms/Text/Text";
import { isOptionalType } from "../../../utils/schema-utils"; import { isOptionalType } from "../../../utils/schema-utils";
import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers"; import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
interface customFieldProps extends FieldProps { interface customFieldProps extends FieldProps {
selector: JSX.Element; selector: JSX.Element;
@@ -51,6 +52,13 @@ export const AnyOfFieldTitle = (props: customFieldProps) => {
shouldShowTypeSelector(schema) && !isArrayItem && !isHandleConnected; shouldShowTypeSelector(schema) && !isArrayItem && !isHandleConnected;
const shoudlShowType = isHandleConnected || (isOptional && type); const shoudlShowType = isHandleConnected || (isOptional && type);
const isInputBroken = useNodeStore((state) =>
state.isInputBroken(nodeId, cleanUpHandleId(uiOptions.handleId)),
);
const inputMismatch = useNodeStore((state) =>
state.getInputTypeMismatch(nodeId, cleanUpHandleId(uiOptions.handleId)),
);
return ( return (
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<TitleFieldTemplate <TitleFieldTemplate
@@ -62,8 +70,16 @@ export const AnyOfFieldTitle = (props: customFieldProps) => {
uiSchema={uiSchema} uiSchema={uiSchema}
/> />
{shoudlShowType && ( {shoudlShowType && (
<Text variant="small" className={cn("text-zinc-700", colorClass)}> <Text
{isOptional ? `(${displayType})` : "(any)"} variant="small"
className={cn(
"text-zinc-700",
isInputBroken && "line-through",
colorClass,
inputMismatch && "rounded-md bg-red-100 px-1 !text-red-500",
)}
>
{isOptional ? `(${inputMismatch || displayType})` : "(any)"}
</Text> </Text>
)} )}
{shouldShowSelector && selector} {shouldShowSelector && selector}

View File

@@ -9,8 +9,9 @@ import { Text } from "@/components/atoms/Text/Text";
import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers"; import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers";
import { isAnyOfSchema } from "../../utils/schema-utils"; import { isAnyOfSchema } from "../../utils/schema-utils";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { isArrayItem } from "../../helpers"; import { cleanUpHandleId, isArrayItem } from "../../helpers";
import { InputNodeHandle } from "@/app/(platform)/build/components/FlowEditor/handlers/NodeHandle"; import { InputNodeHandle } from "@/app/(platform)/build/components/FlowEditor/handlers/NodeHandle";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
export default function TitleField(props: TitleFieldProps) { export default function TitleField(props: TitleFieldProps) {
const { id, title, required, schema, registry, uiSchema } = props; const { id, title, required, schema, registry, uiSchema } = props;
@@ -26,6 +27,11 @@ export default function TitleField(props: TitleFieldProps) {
const smallText = isArrayItemFlag || additional; const smallText = isArrayItemFlag || additional;
const showHandle = uiOptions.showHandles ?? showHandles; const showHandle = uiOptions.showHandles ?? showHandles;
const isInputBroken = useNodeStore((state) =>
state.isInputBroken(nodeId, cleanUpHandleId(uiOptions.handleId)),
);
return ( return (
<div className="flex items-center"> <div className="flex items-center">
{showHandle !== false && ( {showHandle !== false && (
@@ -34,7 +40,11 @@ export default function TitleField(props: TitleFieldProps) {
<Text <Text
variant={isArrayItemFlag ? "small" : "body"} variant={isArrayItemFlag ? "small" : "body"}
id={id} id={id}
className={cn("line-clamp-1", smallText && "text-sm text-zinc-700")} className={cn(
"line-clamp-1",
smallText && "text-sm text-zinc-700",
isInputBroken && "text-red-500 line-through",
)}
> >
{title} {title}
</Text> </Text>
@@ -44,7 +54,7 @@ export default function TitleField(props: TitleFieldProps) {
{!isAnyOf && ( {!isAnyOf && (
<Text <Text
variant="small" variant="small"
className={cn("ml-2", colorClass)} className={cn("ml-2", isInputBroken && "line-through", colorClass)}
id={description_id} id={description_id}
> >
({displayType}) ({displayType})

View File

@@ -30,6 +30,8 @@ export function updateUiOption<T extends Record<string, any>>(
} }
export const cleanUpHandleId = (handleId: string) => { export const cleanUpHandleId = (handleId: string) => {
if (!handleId) return "";
let newHandleId = handleId; let newHandleId = handleId;
if (handleId.includes(ANY_OF_FLAG)) { if (handleId.includes(ANY_OF_FLAG)) {
newHandleId = newHandleId.replace(ANY_OF_FLAG, ""); newHandleId = newHandleId.replace(ANY_OF_FLAG, "");

View File

@@ -233,13 +233,14 @@ export default function useAgentGraph(
title: `${block.name} ${node.id}`, title: `${block.name} ${node.id}`,
inputSchema: block.inputSchema, inputSchema: block.inputSchema,
outputSchema: block.outputSchema, outputSchema: block.outputSchema,
isOutputStatic: block.staticOutput,
hardcodedValues: node.input_default, hardcodedValues: node.input_default,
uiType: block.uiType, uiType: block.uiType,
metadata: metadata, metadata: metadata,
connections: graph.links connections: graph.links
.filter((l) => [l.source_id, l.sink_id].includes(node.id)) .filter((l) => [l.source_id, l.sink_id].includes(node.id))
.map((link) => ({ .map((link) => ({
edge_id: formatEdgeID(link), id: formatEdgeID(link),
source: link.source_id, source: link.source_id,
sourceHandle: link.source_name, sourceHandle: link.source_name,
target: link.sink_id, target: link.sink_id,

View File

@@ -245,8 +245,8 @@ export type BlockIONullSubSchema = BlockIOSubSchemaMeta & {
// At the time of writing, combined schemas only occur on the first nested level in a // At the time of writing, combined schemas only occur on the first nested level in a
// block schema. It is typed this way to make the use of these objects less tedious. // block schema. It is typed this way to make the use of these objects less tedious.
type BlockIOCombinedTypeSubSchema = BlockIOSubSchemaMeta & { type BlockIOCombinedTypeSubSchema = BlockIOSubSchemaMeta & {
type: never; type?: never;
const: never; const?: never;
} & ( } & (
| { | {
allOf: [BlockIOSimpleTypeSubSchema]; allOf: [BlockIOSimpleTypeSubSchema];
@@ -368,8 +368,8 @@ export type GraphMeta = {
recommended_schedule_cron: string | null; recommended_schedule_cron: string | null;
forked_from_id?: GraphID | null; forked_from_id?: GraphID | null;
forked_from_version?: number | null; forked_from_version?: number | null;
input_schema: GraphIOSchema; input_schema: GraphInputSchema;
output_schema: GraphIOSchema; output_schema: GraphOutputSchema;
credentials_input_schema: CredentialsInputSchema; credentials_input_schema: CredentialsInputSchema;
} & ( } & (
| { | {
@@ -385,19 +385,51 @@ export type GraphMeta = {
export type GraphID = Brand<string, "GraphID">; export type GraphID = Brand<string, "GraphID">;
/* Derived from backend/data/graph.py:Graph._generate_schema() */ /* Derived from backend/data/graph.py:Graph._generate_schema() */
export type GraphIOSchema = { export type GraphInputSchema = {
type: "object"; type: "object";
properties: Record<string, GraphIOSubSchema>; properties: Record<string, GraphInputSubSchema>;
required: (keyof BlockIORootSchema["properties"])[]; required: (keyof GraphInputSchema["properties"])[];
}; };
export type GraphIOSubSchema = Omit< export type GraphInputSubSchema = GraphOutputSubSchema &
BlockIOSubSchemaMeta, (
"placeholder" | "depends_on" | "hidden" | { type?: never; default: any | null } // AgentInputBlock (generic Any type)
> & { | { type: "string"; format: "short-text"; default: string | null } // AgentShortTextInputBlock
type: never; // bodge to avoid type checking hell; doesn't exist at runtime | { type: "string"; format: "long-text"; default: string | null } // AgentLongTextInputBlock
default?: string; | { type: "integer"; default: number | null } // AgentNumberInputBlock
| { type: "string"; format: "date"; default: string | null } // AgentDateInputBlock
| { type: "string"; format: "time"; default: string | null } // AgentTimeInputBlock
| { type: "string"; format: "file"; default: string | null } // AgentFileInputBlock
| { type: "string"; enum: string[]; default: string | null } // AgentDropdownInputBlock
| { type: "boolean"; default: boolean } // AgentToggleInputBlock
| {
// AgentTableInputBlock
type: "array";
format: "table";
items: {
type: "object";
properties: Record<string, { type: "string" }>;
};
default: Array<Record<string, string>> | null;
}
| {
// AgentGoogleDriveFileInputBlock
type: "object";
format: "google-drive-picker";
google_drive_picker_config?: GoogleDrivePickerConfig;
default: GoogleDriveFile | null;
}
);
export type GraphOutputSchema = {
type: "object";
properties: Record<string, GraphOutputSubSchema>;
required: (keyof GraphOutputSchema["properties"])[];
};
export type GraphOutputSubSchema = {
// TODO: typed outputs based on the incoming edges?
title: string;
description?: string;
advanced: boolean;
secret: boolean; secret: boolean;
metadata?: any;
}; };
export type CredentialsInputSchema = { export type CredentialsInputSchema = {
@@ -440,8 +472,8 @@ export type GraphUpdateable = Omit<
is_active?: boolean; is_active?: boolean;
nodes: NodeCreatable[]; nodes: NodeCreatable[];
links: LinkCreatable[]; links: LinkCreatable[];
input_schema?: GraphIOSchema; input_schema?: GraphInputSchema;
output_schema?: GraphIOSchema; output_schema?: GraphOutputSchema;
}; };
export type GraphCreatable = _GraphCreatableInner & { export type GraphCreatable = _GraphCreatableInner & {
@@ -497,8 +529,8 @@ export type LibraryAgent = {
name: string; name: string;
description: string; description: string;
instructions?: string | null; instructions?: string | null;
input_schema: GraphIOSchema; input_schema: GraphInputSchema;
output_schema: GraphIOSchema; output_schema: GraphOutputSchema;
credentials_input_schema: CredentialsInputSchema; credentials_input_schema: CredentialsInputSchema;
new_output: boolean; new_output: boolean;
can_access_graph: boolean; can_access_graph: boolean;

View File

@@ -6,7 +6,10 @@ import { NodeDimension } from "@/app/(platform)/build/components/legacy-builder/
import { import {
BlockIOObjectSubSchema, BlockIOObjectSubSchema,
BlockIORootSchema, BlockIORootSchema,
BlockIOSubSchema,
Category, Category,
GraphInputSubSchema,
GraphOutputSubSchema,
} from "@/lib/autogpt-server-api/types"; } from "@/lib/autogpt-server-api/types";
export function cn(...inputs: ClassValue[]) { export function cn(...inputs: ClassValue[]) {
@@ -76,8 +79,8 @@ export function getTypeBgColor(type: string | null): string {
); );
} }
export function getTypeColor(type: string | null): string { export function getTypeColor(type: string | undefined): string {
if (type === null) return "#6b7280"; if (!type) return "#6b7280";
return ( return (
{ {
string: "#22c55e", string: "#22c55e",
@@ -88,11 +91,59 @@ export function getTypeColor(type: string | null): string {
array: "#6366f1", array: "#6366f1",
null: "#6b7280", null: "#6b7280",
any: "#6b7280", any: "#6b7280",
"": "#6b7280",
}[type] || "#6b7280" }[type] || "#6b7280"
); );
} }
/**
* Extracts the effective type from a JSON schema, handling anyOf/oneOf/allOf wrappers.
* Returns the first non-null type found in the schema structure.
*/
export function getEffectiveType(
schema:
| BlockIOSubSchema
| GraphInputSubSchema
| GraphOutputSubSchema
| null
| undefined,
): string | undefined {
if (!schema) return undefined;
// Direct type property
if ("type" in schema && schema.type) {
return String(schema.type);
}
// Handle allOf - typically a single-item wrapper
if (
"allOf" in schema &&
Array.isArray(schema.allOf) &&
schema.allOf.length > 0
) {
return getEffectiveType(schema.allOf[0]);
}
// Handle anyOf - e.g. [{ type: "string" }, { type: "null" }]
if ("anyOf" in schema && Array.isArray(schema.anyOf)) {
for (const item of schema.anyOf) {
if ("type" in item && item.type !== "null") {
return String(item.type);
}
}
}
// Handle oneOf
if ("oneOf" in schema && Array.isArray(schema.oneOf)) {
for (const item of schema.oneOf) {
if ("type" in item && item.type !== "null") {
return String(item.type);
}
}
}
return undefined;
}
export function beautifyString(name: string): string { export function beautifyString(name: string): string {
// Regular expression to identify places to split, considering acronyms // Regular expression to identify places to split, considering acronyms
const result = name const result = name