From b01ea3fcbd11366f49455be6d11759e8a700bf92 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Thu, 15 Jan 2026 05:08:19 +0100 Subject: [PATCH 001/103] fix(backend/executor): Centralize `increment_runs` calls & make `add_graph_execution` more robust (#11764) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [OPEN-2946: \[Scheduler\] Error executing graph after 19.83s: ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.](https://linear.app/autogpt/issue/OPEN-2946) - Follow-up to #11375 (broken `increment_runs` call) - Follow-up to #11380 (direct `get_graph_execution` call) ### Changes 🏗️ - Move `increment_runs` call from `scheduler._execute_graph` to `executor.utils.add_graph_execution` so it can be made through `DatabaseManager` - Add `increment_onboarding_runs` to `DatabaseManager` - Remove now-redundant `increment_onboarding_runs` calls in other places - Make `add_graph_execution` more resilient - Split up large try/except block - Fix direct `get_graph_execution` call ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - CI + a thorough review --- .../api/features/integrations/router.py | 7 +---- .../api/features/library/routes/presets.py | 3 -- .../backend/backend/api/features/v1.py | 2 -- .../backend/backend/data/onboarding.py | 2 +- .../backend/backend/executor/database.py | 10 +++++++ .../backend/backend/executor/scheduler.py | 2 -- .../backend/backend/executor/utils.py | 30 +++++++++++++++---- 7 files changed, 36 insertions(+), 20 deletions(-) diff --git a/autogpt_platform/backend/backend/api/features/integrations/router.py b/autogpt_platform/backend/backend/api/features/integrations/router.py index f5dd8c092b..36585b14b5 100644 --- a/autogpt_platform/backend/backend/api/features/integrations/router.py +++ b/autogpt_platform/backend/backend/api/features/integrations/router.py @@ -35,11 +35,7 @@ from backend.data.model import ( OAuth2Credentials, UserIntegrations, ) -from backend.data.onboarding import ( - OnboardingStep, - complete_onboarding_step, - increment_runs, -) +from backend.data.onboarding import OnboardingStep, complete_onboarding_step from backend.data.user import get_user_integrations from backend.executor.utils import add_graph_execution from backend.integrations.ayrshare import AyrshareClient, SocialPlatform @@ -378,7 +374,6 @@ async def webhook_ingress_generic( return await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK) - await increment_runs(user_id) # Execute all triggers concurrently for better performance tasks = [] diff --git a/autogpt_platform/backend/backend/api/features/library/routes/presets.py b/autogpt_platform/backend/backend/api/features/library/routes/presets.py index cd4c04e0f2..98f2cb5f15 100644 --- a/autogpt_platform/backend/backend/api/features/library/routes/presets.py +++ b/autogpt_platform/backend/backend/api/features/library/routes/presets.py @@ -8,7 +8,6 @@ from backend.data.execution import GraphExecutionMeta from backend.data.graph import get_graph from backend.data.integrations import get_webhook from backend.data.model import CredentialsMetaInput -from backend.data.onboarding import increment_runs from backend.executor.utils import add_graph_execution, make_node_credentials_input_map from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.webhooks import get_webhook_manager @@ -403,8 +402,6 @@ async def execute_preset( merged_node_input = preset.inputs | inputs merged_credential_inputs = preset.credentials | credential_inputs - await increment_runs(user_id) - return await add_graph_execution( user_id=user_id, graph_id=preset.graph_id, diff --git a/autogpt_platform/backend/backend/api/features/v1.py b/autogpt_platform/backend/backend/api/features/v1.py index 9b05b4755f..661e8ff7f2 100644 --- a/autogpt_platform/backend/backend/api/features/v1.py +++ b/autogpt_platform/backend/backend/api/features/v1.py @@ -64,7 +64,6 @@ from backend.data.onboarding import ( complete_re_run_agent, get_recommended_agents, get_user_onboarding, - increment_runs, onboarding_enabled, reset_user_onboarding, update_user_onboarding, @@ -975,7 +974,6 @@ async def execute_graph( # Record successful graph execution record_graph_execution(graph_id=graph_id, status="success", user_id=user_id) record_graph_operation(operation="execute", status="success") - await increment_runs(user_id) await complete_re_run_agent(user_id, graph_id) if source == "library": await complete_onboarding_step( diff --git a/autogpt_platform/backend/backend/data/onboarding.py b/autogpt_platform/backend/backend/data/onboarding.py index cc63b89afd..6a842d1022 100644 --- a/autogpt_platform/backend/backend/data/onboarding.py +++ b/autogpt_platform/backend/backend/data/onboarding.py @@ -334,7 +334,7 @@ async def _get_user_timezone(user_id: str) -> str: return get_user_timezone_or_utc(user.timezone if user else None) -async def increment_runs(user_id: str): +async def increment_onboarding_runs(user_id: str): """ Increment a user's run counters and trigger any onboarding milestones. """ diff --git a/autogpt_platform/backend/backend/executor/database.py b/autogpt_platform/backend/backend/executor/database.py index af68bf526d..9848948bff 100644 --- a/autogpt_platform/backend/backend/executor/database.py +++ b/autogpt_platform/backend/backend/executor/database.py @@ -20,6 +20,7 @@ from backend.data.execution import ( get_execution_kv_data, get_execution_outputs_by_node_exec_id, get_frequently_executed_graphs, + get_graph_execution, get_graph_execution_meta, get_graph_executions, get_graph_executions_count, @@ -57,6 +58,7 @@ from backend.data.notifications import ( get_user_notification_oldest_message_in_batch, remove_notifications_from_batch, ) +from backend.data.onboarding import increment_onboarding_runs from backend.data.user import ( get_active_user_ids_in_timerange, get_user_by_id, @@ -140,6 +142,7 @@ class DatabaseManager(AppService): get_child_graph_executions = _(get_child_graph_executions) get_graph_executions = _(get_graph_executions) get_graph_executions_count = _(get_graph_executions_count) + get_graph_execution = _(get_graph_execution) get_graph_execution_meta = _(get_graph_execution_meta) create_graph_execution = _(create_graph_execution) get_node_execution = _(get_node_execution) @@ -204,6 +207,9 @@ class DatabaseManager(AppService): add_store_agent_to_library = _(add_store_agent_to_library) validate_graph_execution_permissions = _(validate_graph_execution_permissions) + # Onboarding + increment_onboarding_runs = _(increment_onboarding_runs) + # Store get_store_agents = _(get_store_agents) get_store_agent_details = _(get_store_agent_details) @@ -274,6 +280,7 @@ class DatabaseManagerAsyncClient(AppServiceClient): get_graph = d.get_graph get_graph_metadata = d.get_graph_metadata get_graph_settings = d.get_graph_settings + get_graph_execution = d.get_graph_execution get_graph_execution_meta = d.get_graph_execution_meta get_node = d.get_node get_node_execution = d.get_node_execution @@ -318,6 +325,9 @@ class DatabaseManagerAsyncClient(AppServiceClient): add_store_agent_to_library = d.add_store_agent_to_library validate_graph_execution_permissions = d.validate_graph_execution_permissions + # Onboarding + increment_onboarding_runs = d.increment_onboarding_runs + # Store get_store_agents = d.get_store_agents get_store_agent_details = d.get_store_agent_details diff --git a/autogpt_platform/backend/backend/executor/scheduler.py b/autogpt_platform/backend/backend/executor/scheduler.py index 06c50bf82e..963c901fd6 100644 --- a/autogpt_platform/backend/backend/executor/scheduler.py +++ b/autogpt_platform/backend/backend/executor/scheduler.py @@ -27,7 +27,6 @@ from backend.data.auth.oauth import cleanup_expired_oauth_tokens from backend.data.block import BlockInput from backend.data.execution import GraphExecutionWithNodes from backend.data.model import CredentialsMetaInput -from backend.data.onboarding import increment_runs from backend.executor import utils as execution_utils from backend.monitoring import ( NotificationJobArgs, @@ -156,7 +155,6 @@ async def _execute_graph(**kwargs): inputs=args.input_data, graph_credentials_inputs=args.input_credentials, ) - await increment_runs(args.user_id) elapsed = asyncio.get_event_loop().time() - start_time logger.info( f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} " diff --git a/autogpt_platform/backend/backend/executor/utils.py b/autogpt_platform/backend/backend/executor/utils.py index 1fb2b9404f..25f0389e99 100644 --- a/autogpt_platform/backend/backend/executor/utils.py +++ b/autogpt_platform/backend/backend/executor/utils.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError from backend.data import execution as execution_db from backend.data import graph as graph_db +from backend.data import onboarding as onboarding_db from backend.data import user as user_db from backend.data.block import ( Block, @@ -31,7 +32,6 @@ from backend.data.execution import ( GraphExecutionStats, GraphExecutionWithNodes, NodesInputMasks, - get_graph_execution, ) from backend.data.graph import GraphModel, Node from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput @@ -809,13 +809,14 @@ async def add_graph_execution( edb = execution_db udb = user_db gdb = graph_db + odb = onboarding_db else: - edb = udb = gdb = get_database_manager_async_client() + edb = udb = gdb = odb = get_database_manager_async_client() # Get or create the graph execution if graph_exec_id: # Resume existing execution - graph_exec = await get_graph_execution( + graph_exec = await edb.get_graph_execution( user_id=user_id, execution_id=graph_exec_id, include_node_executions=True, @@ -891,6 +892,7 @@ async def add_graph_execution( ) logger.info(f"Publishing execution {graph_exec.id} to execution queue") + # Publish to execution queue for executor to pick up exec_queue = await get_async_execution_queue() await exec_queue.publish_message( routing_key=GRAPH_EXECUTION_ROUTING_KEY, @@ -899,14 +901,12 @@ async def add_graph_execution( ) logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue") + # Update execution status to QUEUED graph_exec.status = ExecutionStatus.QUEUED await edb.update_graph_execution_stats( graph_exec_id=graph_exec.id, status=graph_exec.status, ) - await get_async_execution_event_bus().publish(graph_exec) - - return graph_exec except BaseException as e: err = str(e) or type(e).__name__ if not graph_exec: @@ -927,6 +927,24 @@ async def add_graph_execution( ) raise + try: + await get_async_execution_event_bus().publish(graph_exec) + logger.info(f"Published update for execution #{graph_exec.id} to event bus") + except Exception as e: + logger.error( + f"Failed to publish execution event for graph exec #{graph_exec.id}: {e}" + ) + + try: + await odb.increment_onboarding_runs(user_id) + logger.info( + f"Incremented user #{user_id} onboarding runs for exec #{graph_exec.id}" + ) + except Exception as e: + logger.error(f"Failed to increment onboarding runs for user #{user_id}: {e}") + + return graph_exec + # ============ Execution Output Helpers ============ # From 5ac941fe2ff09ef5b871b6ca36f17b7a0087d3f5 Mon Sep 17 00:00:00 2001 From: Swifty Date: Thu, 15 Jan 2026 05:17:03 +0100 Subject: [PATCH 002/103] feat(backend): add hybrid search for store listings, docs and blocks (#11721) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds hybrid search functionality combining semantic embeddings with traditional text search for improved store listing discovery. ### Changes 🏗️ - Add `embeddings.py` - OpenAI-based embedding generation and similarity search - Add `hybrid_search.py` - Combines vector similarity with text matching for better search results - Add `backfill_embeddings.py` - Script to generate embeddings for existing store listings - Update `db.py` - Integrate hybrid search into store database queries - Update `schema.prisma` - Add embedding storage fields and indexes - Add migrations for embedding columns and HNSW index for vector search ### Architecture Decisions 🏛️ **Fail-Fast Approach (No Silent Fallbacks)** We explicitly chose NOT to implement graceful degradation when hybrid search fails. Here's why: ✅ **Benefits:** - Errors surface immediately → faster fixes - Tests verify hybrid search actually works (not just fallback) - Consistent search quality for all users - Forces proper infrastructure setup (API keys, database) ❌ **Why Not Fallback:** - Silent degradation hides production issues - Users get inconsistent results without knowing why - Tests can pass even when hybrid search is broken - Reduces operational visibility **How We Prevent Failures:** 1. Embedding generation in approval flow (db.py:1545) 2. Error logging with `logger.error` (not warning) 3. Clear error messages (ValueError explains what's wrong) 4. Comprehensive test coverage (9/9 tests passing) If embeddings fail, it indicates a real infrastructure issue (missing API key, OpenAI down, database issues) that needs immediate attention, not silent degradation. ### Test Coverage ✅ **All tests passing (1625 total):** - 9/9 hybrid_search tests (including fail-fast validation) - 3/3 db search integration tests - Full schema compatibility (public/platform schemas) - Error handling verification ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Test hybrid search returns relevant results - [x] Test embedding generation for new listings - [x] Test backfill script on existing data - [x] Verify search performance with embeddings - [x] Test fail-fast behavior when embeddings unavailable #### For configuration changes: - [x] `.env.default` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes - [x] Configuration: Requires `openai_internal_api_key` in secrets --------- Co-authored-by: Zamil Majdy Co-authored-by: Claude Opus 4.5 --- .github/workflows/platform-backend-ci.yml | 2 +- .github/workflows/platform-frontend-ci.yml | 9 + autogpt_platform/backend/.gitignore | 1 + .../api/features/chat/tools/run_agent_test.py | 12 + .../backend/backend/api/features/store/db.py | 213 +++---- .../backend/api/features/store/embeddings.py | 568 ++++++++++++++++++ .../features/store/embeddings_schema_test.py | 329 ++++++++++ .../api/features/store/embeddings_test.py | 387 ++++++++++++ .../api/features/store/hybrid_search.py | 393 ++++++++++++ .../api/features/store/hybrid_search_test.py | 334 ++++++++++ autogpt_platform/backend/backend/data/db.py | 105 +++- .../backend/backend/data/graph_test.py | 12 + .../backend/backend/executor/database.py | 12 + .../backend/backend/executor/manager_test.py | 12 + .../backend/backend/executor/scheduler.py | 89 ++- .../backend/backend/util/clients.py | 19 + .../migration.sql | 46 ++ .../migration.sql | 71 +++ autogpt_platform/backend/schema.prisma | 119 ++-- 19 files changed, 2567 insertions(+), 166 deletions(-) create mode 100644 autogpt_platform/backend/backend/api/features/store/embeddings.py create mode 100644 autogpt_platform/backend/backend/api/features/store/embeddings_schema_test.py create mode 100644 autogpt_platform/backend/backend/api/features/store/embeddings_test.py create mode 100644 autogpt_platform/backend/backend/api/features/store/hybrid_search.py create mode 100644 autogpt_platform/backend/backend/api/features/store/hybrid_search_test.py create mode 100644 autogpt_platform/backend/migrations/20260109181714_add_docs_embedding/migration.sql create mode 100644 autogpt_platform/backend/migrations/20260112173500_add_supabase_extensions_to_platform_schema/migration.sql diff --git a/.github/workflows/platform-backend-ci.yml b/.github/workflows/platform-backend-ci.yml index da5ab83c1c..f66cce8a37 100644 --- a/.github/workflows/platform-backend-ci.yml +++ b/.github/workflows/platform-backend-ci.yml @@ -176,7 +176,7 @@ jobs: } - name: Run Database Migrations - run: poetry run prisma migrate dev --name updates + run: poetry run prisma migrate deploy env: DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }} DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }} diff --git a/.github/workflows/platform-frontend-ci.yml b/.github/workflows/platform-frontend-ci.yml index 2154fe1385..d0edc7327d 100644 --- a/.github/workflows/platform-frontend-ci.yml +++ b/.github/workflows/platform-frontend-ci.yml @@ -11,6 +11,7 @@ on: - ".github/workflows/platform-frontend-ci.yml" - "autogpt_platform/frontend/**" merge_group: + workflow_dispatch: concurrency: group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }} @@ -151,6 +152,14 @@ jobs: run: | cp ../.env.default ../.env + - name: Copy backend .env and set OpenAI API key + run: | + cp ../backend/.env.default ../backend/.env + echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env + env: + # Used by E2E test data script to generate embeddings for approved store agents + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 diff --git a/autogpt_platform/backend/.gitignore b/autogpt_platform/backend/.gitignore index 95b59cf676..9224c07d9e 100644 --- a/autogpt_platform/backend/.gitignore +++ b/autogpt_platform/backend/.gitignore @@ -18,3 +18,4 @@ load-tests/results/ load-tests/*.json load-tests/*.log load-tests/node_modules/* +migrations/*/rollback*.sql diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py index ebad1a0050..cd4e7d04ba 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py @@ -1,4 +1,5 @@ import uuid +from unittest.mock import AsyncMock, patch import orjson import pytest @@ -17,6 +18,17 @@ setup_test_data = setup_test_data setup_firecrawl_test_data = setup_firecrawl_test_data +@pytest.fixture(scope="session", autouse=True) +def mock_embedding_functions(): + """Mock embedding functions for all tests to avoid database/API dependencies.""" + with patch( + "backend.api.features.store.db.ensure_embedding", + new_callable=AsyncMock, + return_value=True, + ): + yield + + @pytest.mark.asyncio(scope="session") async def test_run_agent(setup_test_data): """Test that the run_agent tool successfully executes an approved agent""" diff --git a/autogpt_platform/backend/backend/api/features/store/db.py b/autogpt_platform/backend/backend/api/features/store/db.py index 8e4310ee02..e6aa3853f6 100644 --- a/autogpt_platform/backend/backend/api/features/store/db.py +++ b/autogpt_platform/backend/backend/api/features/store/db.py @@ -1,8 +1,7 @@ import asyncio import logging -import typing from datetime import datetime, timezone -from typing import Literal +from typing import Any, Literal import fastapi import prisma.enums @@ -10,7 +9,7 @@ import prisma.errors import prisma.models import prisma.types -from backend.data.db import query_raw_with_schema, transaction +from backend.data.db import transaction from backend.data.graph import ( GraphMeta, GraphModel, @@ -30,6 +29,8 @@ from backend.util.settings import Settings from . import exceptions as store_exceptions from . import model as store_model +from .embeddings import ensure_embedding +from .hybrid_search import hybrid_search logger = logging.getLogger(__name__) settings = Settings() @@ -50,128 +51,77 @@ async def get_store_agents( page_size: int = 20, ) -> store_model.StoreAgentsResponse: """ - Get PUBLIC store agents from the StoreAgent view + Get PUBLIC store agents from the StoreAgent view. + + Search behavior: + - With search_query: Uses hybrid search (semantic + lexical) + - Fallback: If embeddings unavailable, gracefully degrades to lexical-only + - Rationale: User-facing endpoint prioritizes availability over accuracy + + Note: Admin operations (approval) use fail-fast to prevent inconsistent state. """ logger.debug( f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}" ) + search_used_hybrid = False + store_agents: list[store_model.StoreAgent] = [] + agents: list[dict[str, Any]] = [] + total = 0 + total_pages = 0 + try: - # If search_query is provided, use full-text search + # If search_query is provided, use hybrid search (embeddings + tsvector) if search_query: - offset = (page - 1) * page_size + # Try hybrid search combining semantic and lexical signals + # Falls back to lexical-only if OpenAI unavailable (user-facing, high SLA) + try: + agents, total = await hybrid_search( + query=search_query, + featured=featured, + creators=creators, + category=category, + sorted_by="relevance", # Use hybrid scoring for relevance + page=page, + page_size=page_size, + ) + search_used_hybrid = True + except Exception as e: + # Log error but fall back to lexical search for better UX + logger.error( + f"Hybrid search failed (likely OpenAI unavailable), " + f"falling back to lexical search: {e}" + ) + # search_used_hybrid remains False, will use fallback path below - # Whitelist allowed order_by columns - ALLOWED_ORDER_BY = { - "rating": "rating DESC, rank DESC", - "runs": "runs DESC, rank DESC", - "name": "agent_name ASC, rank ASC", - "updated_at": "updated_at DESC, rank DESC", - } + # Convert hybrid search results (dict format) if hybrid succeeded + if search_used_hybrid: + total_pages = (total + page_size - 1) // page_size + store_agents: list[store_model.StoreAgent] = [] + for agent in agents: + try: + store_agent = store_model.StoreAgent( + slug=agent["slug"], + agent_name=agent["agent_name"], + agent_image=( + agent["agent_image"][0] if agent["agent_image"] else "" + ), + creator=agent["creator_username"] or "Needs Profile", + creator_avatar=agent["creator_avatar"] or "", + sub_heading=agent["sub_heading"], + description=agent["description"], + runs=agent["runs"], + rating=agent["rating"], + ) + store_agents.append(store_agent) + except Exception as e: + logger.error( + f"Error parsing Store agent from hybrid search results: {e}" + ) + continue - # Validate and get order clause - if sorted_by and sorted_by in ALLOWED_ORDER_BY: - order_by_clause = ALLOWED_ORDER_BY[sorted_by] - else: - order_by_clause = "updated_at DESC, rank DESC" - - # Build WHERE conditions and parameters list - where_parts: list[str] = [] - params: list[typing.Any] = [search_query] # $1 - search term - param_index = 2 # Start at $2 for next parameter - - # Always filter for available agents - where_parts.append("is_available = true") - - if featured: - where_parts.append("featured = true") - - if creators and creators: - # Use ANY with array parameter - where_parts.append(f"creator_username = ANY(${param_index})") - params.append(creators) - param_index += 1 - - if category and category: - where_parts.append(f"${param_index} = ANY(categories)") - params.append(category) - param_index += 1 - - sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1" - - # Add pagination params - params.extend([page_size, offset]) - limit_param = f"${param_index}" - offset_param = f"${param_index + 1}" - - # Execute full-text search query with parameterized values - sql_query = f""" - SELECT - slug, - agent_name, - agent_image, - creator_username, - creator_avatar, - sub_heading, - description, - runs, - rating, - categories, - featured, - is_available, - updated_at, - ts_rank_cd(search, query) AS rank - FROM {{schema_prefix}}"StoreAgent", - plainto_tsquery('english', $1) AS query - WHERE {sql_where_clause} - AND search @@ query - ORDER BY {order_by_clause} - LIMIT {limit_param} OFFSET {offset_param} - """ - - # Count query for pagination - only uses search term parameter - count_query = f""" - SELECT COUNT(*) as count - FROM {{schema_prefix}}"StoreAgent", - plainto_tsquery('english', $1) AS query - WHERE {sql_where_clause} - AND search @@ query - """ - - # Execute both queries with parameters - agents = await query_raw_with_schema(sql_query, *params) - - # For count, use params without pagination (last 2 params) - count_params = params[:-2] - count_result = await query_raw_with_schema(count_query, *count_params) - - total = count_result[0]["count"] if count_result else 0 - total_pages = (total + page_size - 1) // page_size - - # Convert raw results to StoreAgent models - store_agents: list[store_model.StoreAgent] = [] - for agent in agents: - try: - store_agent = store_model.StoreAgent( - slug=agent["slug"], - agent_name=agent["agent_name"], - agent_image=( - agent["agent_image"][0] if agent["agent_image"] else "" - ), - creator=agent["creator_username"] or "Needs Profile", - creator_avatar=agent["creator_avatar"] or "", - sub_heading=agent["sub_heading"], - description=agent["description"], - runs=agent["runs"], - rating=agent["rating"], - ) - store_agents.append(store_agent) - except Exception as e: - logger.error(f"Error parsing Store agent from search results: {e}") - continue - - else: - # Non-search query path (original logic) + if not search_used_hybrid: + # Fallback path - use basic search or no search where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True} if featured: where_clause["featured"] = featured @@ -180,6 +130,14 @@ async def get_store_agents( if category: where_clause["categories"] = {"has": category} + # Add basic text search if search_query provided but hybrid failed + if search_query: + where_clause["OR"] = [ + {"agent_name": {"contains": search_query, "mode": "insensitive"}}, + {"sub_heading": {"contains": search_query, "mode": "insensitive"}}, + {"description": {"contains": search_query, "mode": "insensitive"}}, + ] + order_by = [] if sorted_by == "rating": order_by.append({"rating": "desc"}) @@ -188,7 +146,7 @@ async def get_store_agents( elif sorted_by == "name": order_by.append({"agent_name": "asc"}) - agents = await prisma.models.StoreAgent.prisma().find_many( + db_agents = await prisma.models.StoreAgent.prisma().find_many( where=where_clause, order=order_by, skip=(page - 1) * page_size, @@ -199,7 +157,7 @@ async def get_store_agents( total_pages = (total + page_size - 1) // page_size store_agents: list[store_model.StoreAgent] = [] - for agent in agents: + for agent in db_agents: try: # Create the StoreAgent object safely store_agent = store_model.StoreAgent( @@ -1577,7 +1535,7 @@ async def review_store_submission( ) # Update the AgentGraph with store listing data - await prisma.models.AgentGraph.prisma().update( + await prisma.models.AgentGraph.prisma(tx).update( where={ "graphVersionId": { "id": store_listing_version.agentGraphId, @@ -1592,6 +1550,23 @@ async def review_store_submission( }, ) + # Generate embedding for approved listing (blocking - admin operation) + # Inside transaction: if embedding fails, entire transaction rolls back + embedding_success = await ensure_embedding( + version_id=store_listing_version_id, + name=store_listing_version.name, + description=store_listing_version.description, + sub_heading=store_listing_version.subHeading, + categories=store_listing_version.categories or [], + tx=tx, + ) + if not embedding_success: + raise ValueError( + f"Failed to generate embedding for listing {store_listing_version_id}. " + "This is likely due to OpenAI API being unavailable. " + "Please try again later or contact support if the issue persists." + ) + await prisma.models.StoreListing.prisma(tx).update( where={"id": store_listing_version.StoreListing.id}, data={ diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings.py b/autogpt_platform/backend/backend/api/features/store/embeddings.py new file mode 100644 index 0000000000..70f4360c0c --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/store/embeddings.py @@ -0,0 +1,568 @@ +""" +Unified Content Embeddings Service + +Handles generation and storage of OpenAI embeddings for all content types +(store listings, blocks, documentation, library agents) to enable semantic/hybrid search. +""" + +import asyncio +import logging +import time +from typing import Any + +import prisma +from prisma.enums import ContentType +from tiktoken import encoding_for_model + +from backend.data.db import execute_raw_with_schema, query_raw_with_schema +from backend.util.clients import get_openai_client +from backend.util.json import dumps + +logger = logging.getLogger(__name__) + + +# OpenAI embedding model configuration +EMBEDDING_MODEL = "text-embedding-3-small" +# OpenAI embedding token limit (8,191 with 1 token buffer for safety) +EMBEDDING_MAX_TOKENS = 8191 + + +def build_searchable_text( + name: str, + description: str, + sub_heading: str, + categories: list[str], +) -> str: + """ + Build searchable text from listing version fields. + + Combines relevant fields into a single string for embedding. + """ + parts = [] + + # Name is important - include it + if name: + parts.append(name) + + # Sub-heading provides context + if sub_heading: + parts.append(sub_heading) + + # Description is the main content + if description: + parts.append(description) + + # Categories help with semantic matching + if categories: + parts.append(" ".join(categories)) + + return " ".join(parts) + + +async def generate_embedding(text: str) -> list[float] | None: + """ + Generate embedding for text using OpenAI API. + + Returns None if embedding generation fails. + Fail-fast: no retries to maintain consistency with approval flow. + """ + try: + client = get_openai_client() + if not client: + logger.error("openai_internal_api_key not set, cannot generate embedding") + return None + + # Truncate text to token limit using tiktoken + # Character-based truncation is insufficient because token ratios vary by content type + enc = encoding_for_model(EMBEDDING_MODEL) + tokens = enc.encode(text) + if len(tokens) > EMBEDDING_MAX_TOKENS: + tokens = tokens[:EMBEDDING_MAX_TOKENS] + truncated_text = enc.decode(tokens) + logger.info( + f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens" + ) + else: + truncated_text = text + + start_time = time.time() + response = await client.embeddings.create( + model=EMBEDDING_MODEL, + input=truncated_text, + ) + latency_ms = (time.time() - start_time) * 1000 + + embedding = response.data[0].embedding + logger.info( + f"Generated embedding: {len(embedding)} dims, " + f"{len(tokens)} tokens, {latency_ms:.0f}ms" + ) + return embedding + + except Exception as e: + logger.error(f"Failed to generate embedding: {e}") + return None + + +async def store_embedding( + version_id: str, + embedding: list[float], + tx: prisma.Prisma | None = None, +) -> bool: + """ + Store embedding in the database. + + BACKWARD COMPATIBILITY: Maintained for existing store listing usage. + DEPRECATED: Use ensure_embedding() instead (includes searchable_text). + """ + return await store_content_embedding( + content_type=ContentType.STORE_AGENT, + content_id=version_id, + embedding=embedding, + searchable_text="", # Empty for backward compat; ensure_embedding() populates this + metadata=None, + user_id=None, # Store agents are public + tx=tx, + ) + + +async def store_content_embedding( + content_type: ContentType, + content_id: str, + embedding: list[float], + searchable_text: str, + metadata: dict | None = None, + user_id: str | None = None, + tx: prisma.Prisma | None = None, +) -> bool: + """ + Store embedding in the unified content embeddings table. + + New function for unified content embedding storage. + Uses raw SQL since Prisma doesn't natively support pgvector. + """ + try: + client = tx if tx else prisma.get_client() + + # Convert embedding to PostgreSQL vector format + embedding_str = embedding_to_vector_string(embedding) + metadata_json = dumps(metadata or {}) + + # Upsert the embedding + # WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT + await execute_raw_with_schema( + """ + INSERT INTO {schema_prefix}"UnifiedContentEmbedding" ( + "id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt" + ) + VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW()) + ON CONFLICT ("contentType", "contentId", "userId") + DO UPDATE SET + "embedding" = $4::vector, + "searchableText" = $5, + "metadata" = $6::jsonb, + "updatedAt" = NOW() + WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType" + AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2 + AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL)) + """, + content_type, + content_id, + user_id, + embedding_str, + searchable_text, + metadata_json, + client=client, + set_public_search_path=True, + ) + + logger.info(f"Stored embedding for {content_type}:{content_id}") + return True + + except Exception as e: + logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}") + return False + + +async def get_embedding(version_id: str) -> dict[str, Any] | None: + """ + Retrieve embedding record for a listing version. + + BACKWARD COMPATIBILITY: Maintained for existing store listing usage. + Returns dict with storeListingVersionId, embedding, timestamps or None if not found. + """ + result = await get_content_embedding( + ContentType.STORE_AGENT, version_id, user_id=None + ) + if result: + # Transform to old format for backward compatibility + return { + "storeListingVersionId": result["contentId"], + "embedding": result["embedding"], + "createdAt": result["createdAt"], + "updatedAt": result["updatedAt"], + } + return None + + +async def get_content_embedding( + content_type: ContentType, content_id: str, user_id: str | None = None +) -> dict[str, Any] | None: + """ + Retrieve embedding record for any content type. + + New function for unified content embedding retrieval. + Returns dict with contentType, contentId, embedding, timestamps or None if not found. + """ + try: + result = await query_raw_with_schema( + """ + SELECT + "contentType", + "contentId", + "userId", + "embedding"::text as "embedding", + "searchableText", + "metadata", + "createdAt", + "updatedAt" + FROM {schema_prefix}"UnifiedContentEmbedding" + WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL)) + """, + content_type, + content_id, + user_id, + set_public_search_path=True, + ) + + if result and len(result) > 0: + return result[0] + return None + + except Exception as e: + logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}") + return None + + +async def ensure_embedding( + version_id: str, + name: str, + description: str, + sub_heading: str, + categories: list[str], + force: bool = False, + tx: prisma.Prisma | None = None, +) -> bool: + """ + Ensure an embedding exists for the listing version. + + Creates embedding if missing. Use force=True to regenerate. + Backward-compatible wrapper for store listings. + + Args: + version_id: The StoreListingVersion ID + name: Agent name + description: Agent description + sub_heading: Agent sub-heading + categories: Agent categories + force: Force regeneration even if embedding exists + tx: Optional transaction client + + Returns: + True if embedding exists/was created, False on failure + """ + try: + # Check if embedding already exists + if not force: + existing = await get_embedding(version_id) + if existing and existing.get("embedding"): + logger.debug(f"Embedding for version {version_id} already exists") + return True + + # Build searchable text for embedding + searchable_text = build_searchable_text( + name, description, sub_heading, categories + ) + + # Generate new embedding + embedding = await generate_embedding(searchable_text) + if embedding is None: + logger.warning(f"Could not generate embedding for version {version_id}") + return False + + # Store the embedding with metadata using new function + metadata = { + "name": name, + "subHeading": sub_heading, + "categories": categories, + } + return await store_content_embedding( + content_type=ContentType.STORE_AGENT, + content_id=version_id, + embedding=embedding, + searchable_text=searchable_text, + metadata=metadata, + user_id=None, # Store agents are public + tx=tx, + ) + + except Exception as e: + logger.error(f"Failed to ensure embedding for version {version_id}: {e}") + return False + + +async def delete_embedding(version_id: str) -> bool: + """ + Delete embedding for a listing version. + + BACKWARD COMPATIBILITY: Maintained for existing store listing usage. + Note: This is usually handled automatically by CASCADE delete, + but provided for manual cleanup if needed. + """ + return await delete_content_embedding(ContentType.STORE_AGENT, version_id) + + +async def delete_content_embedding( + content_type: ContentType, content_id: str, user_id: str | None = None +) -> bool: + """ + Delete embedding for any content type. + + New function for unified content embedding deletion. + Note: This is usually handled automatically by CASCADE delete, + but provided for manual cleanup if needed. + + Args: + content_type: The type of content (STORE_AGENT, LIBRARY_AGENT, etc.) + content_id: The unique identifier for the content + user_id: Optional user ID. For public content (STORE_AGENT, BLOCK), pass None. + For user-scoped content (LIBRARY_AGENT), pass the user's ID to avoid + deleting embeddings belonging to other users. + + Returns: + True if deletion succeeded, False otherwise + """ + try: + client = prisma.get_client() + + await execute_raw_with_schema( + """ + DELETE FROM {schema_prefix}"UnifiedContentEmbedding" + WHERE "contentType" = $1::{schema_prefix}"ContentType" + AND "contentId" = $2 + AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL)) + """, + content_type, + content_id, + user_id, + client=client, + ) + + user_str = f" (user: {user_id})" if user_id else "" + logger.info(f"Deleted embedding for {content_type}:{content_id}{user_str}") + return True + + except Exception as e: + logger.error(f"Failed to delete embedding for {content_type}:{content_id}: {e}") + return False + + +async def get_embedding_stats() -> dict[str, Any]: + """ + Get statistics about embedding coverage. + + Returns counts of: + - Total approved listing versions + - Versions with embeddings + - Versions without embeddings + """ + try: + # Count approved versions + approved_result = await query_raw_with_schema( + """ + SELECT COUNT(*) as count + FROM {schema_prefix}"StoreListingVersion" + WHERE "submissionStatus" = 'APPROVED' + AND "isDeleted" = false + """ + ) + total_approved = approved_result[0]["count"] if approved_result else 0 + + # Count versions with embeddings + embedded_result = await query_raw_with_schema( + """ + SELECT COUNT(*) as count + FROM {schema_prefix}"StoreListingVersion" slv + JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType" + WHERE slv."submissionStatus" = 'APPROVED' + AND slv."isDeleted" = false + """ + ) + with_embeddings = embedded_result[0]["count"] if embedded_result else 0 + + return { + "total_approved": total_approved, + "with_embeddings": with_embeddings, + "without_embeddings": total_approved - with_embeddings, + "coverage_percent": ( + round(with_embeddings / total_approved * 100, 1) + if total_approved > 0 + else 0 + ), + } + + except Exception as e: + logger.error(f"Failed to get embedding stats: {e}") + return { + "total_approved": 0, + "with_embeddings": 0, + "without_embeddings": 0, + "coverage_percent": 0, + "error": str(e), + } + + +async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]: + """ + Generate embeddings for approved listings that don't have them. + + Args: + batch_size: Number of embeddings to generate in one call + + Returns: + Dict with success/failure counts + """ + try: + # Find approved versions without embeddings + missing = await query_raw_with_schema( + """ + SELECT + slv.id, + slv.name, + slv.description, + slv."subHeading", + slv.categories + FROM {schema_prefix}"StoreListingVersion" slv + LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce + ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType" + WHERE slv."submissionStatus" = 'APPROVED' + AND slv."isDeleted" = false + AND uce."contentId" IS NULL + LIMIT $1 + """, + batch_size, + ) + + if not missing: + return { + "processed": 0, + "success": 0, + "failed": 0, + "message": "No missing embeddings", + } + + # Process embeddings concurrently for better performance + embedding_tasks = [ + ensure_embedding( + version_id=row["id"], + name=row["name"], + description=row["description"], + sub_heading=row["subHeading"], + categories=row["categories"] or [], + ) + for row in missing + ] + + results = await asyncio.gather(*embedding_tasks, return_exceptions=True) + + success = sum(1 for result in results if result is True) + failed = len(results) - success + + return { + "processed": len(missing), + "success": success, + "failed": failed, + "message": f"Backfilled {success} embeddings, {failed} failed", + } + + except Exception as e: + logger.error(f"Failed to backfill embeddings: {e}") + return { + "processed": 0, + "success": 0, + "failed": 0, + "error": str(e), + } + + +async def embed_query(query: str) -> list[float] | None: + """ + Generate embedding for a search query. + + Same as generate_embedding but with clearer intent. + """ + return await generate_embedding(query) + + +def embedding_to_vector_string(embedding: list[float]) -> str: + """Convert embedding list to PostgreSQL vector string format.""" + return "[" + ",".join(str(x) for x in embedding) + "]" + + +async def ensure_content_embedding( + content_type: ContentType, + content_id: str, + searchable_text: str, + metadata: dict | None = None, + user_id: str | None = None, + force: bool = False, + tx: prisma.Prisma | None = None, +) -> bool: + """ + Ensure an embedding exists for any content type. + + Generic function for creating embeddings for store agents, blocks, docs, etc. + + Args: + content_type: ContentType enum value (STORE_AGENT, BLOCK, etc.) + content_id: Unique identifier for the content + searchable_text: Combined text for embedding generation + metadata: Optional metadata to store with embedding + force: Force regeneration even if embedding exists + tx: Optional transaction client + + Returns: + True if embedding exists/was created, False on failure + """ + try: + # Check if embedding already exists + if not force: + existing = await get_content_embedding(content_type, content_id, user_id) + if existing and existing.get("embedding"): + logger.debug( + f"Embedding for {content_type}:{content_id} already exists" + ) + return True + + # Generate new embedding + embedding = await generate_embedding(searchable_text) + if embedding is None: + logger.warning( + f"Could not generate embedding for {content_type}:{content_id}" + ) + return False + + # Store the embedding + return await store_content_embedding( + content_type=content_type, + content_id=content_id, + embedding=embedding, + searchable_text=searchable_text, + metadata=metadata or {}, + user_id=user_id, + tx=tx, + ) + + except Exception as e: + logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}") + return False diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings_schema_test.py b/autogpt_platform/backend/backend/api/features/store/embeddings_schema_test.py new file mode 100644 index 0000000000..441fd9961a --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/store/embeddings_schema_test.py @@ -0,0 +1,329 @@ +""" +Integration tests for embeddings with schema handling. + +These tests verify that embeddings operations work correctly across different database schemas. +""" + +from unittest.mock import AsyncMock, patch + +import pytest +from prisma.enums import ContentType + +from backend.api.features.store import embeddings + +# Schema prefix tests removed - functionality moved to db.raw_with_schema() helper + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_store_content_embedding_with_schema(): + """Test storing embeddings with proper schema handling.""" + with patch("backend.data.db.get_database_schema") as mock_schema: + mock_schema.return_value = "platform" + + with patch("prisma.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_get_client.return_value = mock_client + + result = await embeddings.store_content_embedding( + content_type=ContentType.STORE_AGENT, + content_id="test-id", + embedding=[0.1] * 1536, + searchable_text="test text", + metadata={"test": "data"}, + user_id=None, + ) + + # Verify the query was called + assert mock_client.execute_raw.called + + # Get the SQL query that was executed + call_args = mock_client.execute_raw.call_args + sql_query = call_args[0][0] + + # Verify schema prefix is in the query + assert '"platform"."UnifiedContentEmbedding"' in sql_query + + # Verify result + assert result is True + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_get_content_embedding_with_schema(): + """Test retrieving embeddings with proper schema handling.""" + with patch("backend.data.db.get_database_schema") as mock_schema: + mock_schema.return_value = "platform" + + with patch("prisma.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.query_raw.return_value = [ + { + "contentType": "STORE_AGENT", + "contentId": "test-id", + "userId": None, + "embedding": "[0.1, 0.2]", + "searchableText": "test", + "metadata": {}, + "createdAt": "2024-01-01", + "updatedAt": "2024-01-01", + } + ] + mock_get_client.return_value = mock_client + + result = await embeddings.get_content_embedding( + ContentType.STORE_AGENT, + "test-id", + user_id=None, + ) + + # Verify the query was called + assert mock_client.query_raw.called + + # Get the SQL query that was executed + call_args = mock_client.query_raw.call_args + sql_query = call_args[0][0] + + # Verify schema prefix is in the query + assert '"platform"."UnifiedContentEmbedding"' in sql_query + + # Verify result + assert result is not None + assert result["contentId"] == "test-id" + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_delete_content_embedding_with_schema(): + """Test deleting embeddings with proper schema handling.""" + with patch("backend.data.db.get_database_schema") as mock_schema: + mock_schema.return_value = "platform" + + with patch("prisma.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_get_client.return_value = mock_client + + result = await embeddings.delete_content_embedding( + ContentType.STORE_AGENT, + "test-id", + ) + + # Verify the query was called + assert mock_client.execute_raw.called + + # Get the SQL query that was executed + call_args = mock_client.execute_raw.call_args + sql_query = call_args[0][0] + + # Verify schema prefix is in the query + assert '"platform"."UnifiedContentEmbedding"' in sql_query + + # Verify result + assert result is True + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_get_embedding_stats_with_schema(): + """Test embedding statistics with proper schema handling.""" + with patch("backend.data.db.get_database_schema") as mock_schema: + mock_schema.return_value = "platform" + + with patch("prisma.get_client") as mock_get_client: + mock_client = AsyncMock() + # Mock both query results + mock_client.query_raw.side_effect = [ + [{"count": 100}], # total_approved + [{"count": 80}], # with_embeddings + ] + mock_get_client.return_value = mock_client + + result = await embeddings.get_embedding_stats() + + # Verify both queries were called + assert mock_client.query_raw.call_count == 2 + + # Get both SQL queries + first_call = mock_client.query_raw.call_args_list[0] + second_call = mock_client.query_raw.call_args_list[1] + + first_sql = first_call[0][0] + second_sql = second_call[0][0] + + # Verify schema prefix in both queries + assert '"platform"."StoreListingVersion"' in first_sql + assert '"platform"."StoreListingVersion"' in second_sql + assert '"platform"."UnifiedContentEmbedding"' in second_sql + + # Verify results + assert result["total_approved"] == 100 + assert result["with_embeddings"] == 80 + assert result["without_embeddings"] == 20 + assert result["coverage_percent"] == 80.0 + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_backfill_missing_embeddings_with_schema(): + """Test backfilling embeddings with proper schema handling.""" + with patch("backend.data.db.get_database_schema") as mock_schema: + mock_schema.return_value = "platform" + + with patch("prisma.get_client") as mock_get_client: + mock_client = AsyncMock() + # Mock missing embeddings query + mock_client.query_raw.return_value = [ + { + "id": "version-1", + "name": "Test Agent", + "description": "Test description", + "subHeading": "Test heading", + "categories": ["test"], + } + ] + mock_get_client.return_value = mock_client + + with patch( + "backend.api.features.store.embeddings.ensure_embedding" + ) as mock_ensure: + mock_ensure.return_value = True + + result = await embeddings.backfill_missing_embeddings(batch_size=10) + + # Verify the query was called + assert mock_client.query_raw.called + + # Get the SQL query + call_args = mock_client.query_raw.call_args + sql_query = call_args[0][0] + + # Verify schema prefix in query + assert '"platform"."StoreListingVersion"' in sql_query + assert '"platform"."UnifiedContentEmbedding"' in sql_query + + # Verify ensure_embedding was called + assert mock_ensure.called + + # Verify results + assert result["processed"] == 1 + assert result["success"] == 1 + assert result["failed"] == 0 + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_ensure_content_embedding_with_schema(): + """Test ensuring embeddings exist with proper schema handling.""" + with patch("backend.data.db.get_database_schema") as mock_schema: + mock_schema.return_value = "platform" + + with patch( + "backend.api.features.store.embeddings.get_content_embedding" + ) as mock_get: + # Simulate no existing embedding + mock_get.return_value = None + + with patch( + "backend.api.features.store.embeddings.generate_embedding" + ) as mock_generate: + mock_generate.return_value = [0.1] * 1536 + + with patch( + "backend.api.features.store.embeddings.store_content_embedding" + ) as mock_store: + mock_store.return_value = True + + result = await embeddings.ensure_content_embedding( + content_type=ContentType.STORE_AGENT, + content_id="test-id", + searchable_text="test text", + metadata={"test": "data"}, + user_id=None, + force=False, + ) + + # Verify the flow + assert mock_get.called + assert mock_generate.called + assert mock_store.called + assert result is True + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_backward_compatibility_store_embedding(): + """Test backward compatibility wrapper for store_embedding.""" + with patch( + "backend.api.features.store.embeddings.store_content_embedding" + ) as mock_store: + mock_store.return_value = True + + result = await embeddings.store_embedding( + version_id="test-version-id", + embedding=[0.1] * 1536, + tx=None, + ) + + # Verify it calls the new function with correct parameters + assert mock_store.called + call_args = mock_store.call_args + + assert call_args[1]["content_type"] == ContentType.STORE_AGENT + assert call_args[1]["content_id"] == "test-version-id" + assert call_args[1]["user_id"] is None + assert result is True + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_backward_compatibility_get_embedding(): + """Test backward compatibility wrapper for get_embedding.""" + with patch( + "backend.api.features.store.embeddings.get_content_embedding" + ) as mock_get: + mock_get.return_value = { + "contentType": "STORE_AGENT", + "contentId": "test-version-id", + "embedding": "[0.1, 0.2]", + "createdAt": "2024-01-01", + "updatedAt": "2024-01-01", + } + + result = await embeddings.get_embedding("test-version-id") + + # Verify it calls the new function + assert mock_get.called + + # Verify it transforms to old format + assert result is not None + assert result["storeListingVersionId"] == "test-version-id" + assert "embedding" in result + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_schema_handling_error_cases(): + """Test error handling in schema-aware operations.""" + with patch("backend.data.db.get_database_schema") as mock_schema: + mock_schema.return_value = "platform" + + with patch("prisma.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.execute_raw.side_effect = Exception("Database error") + mock_get_client.return_value = mock_client + + result = await embeddings.store_content_embedding( + content_type=ContentType.STORE_AGENT, + content_id="test-id", + embedding=[0.1] * 1536, + searchable_text="test", + metadata=None, + user_id=None, + ) + + # Should return False on error, not raise + assert result is False + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings_test.py b/autogpt_platform/backend/backend/api/features/store/embeddings_test.py new file mode 100644 index 0000000000..98329abb19 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/store/embeddings_test.py @@ -0,0 +1,387 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import prisma +import pytest +from prisma import Prisma +from prisma.enums import ContentType + +from backend.api.features.store import embeddings + + +@pytest.fixture(autouse=True) +async def setup_prisma(): + """Setup Prisma client for tests.""" + try: + Prisma() + except prisma.errors.ClientAlreadyRegisteredError: + pass + yield + + +@pytest.mark.asyncio(loop_scope="session") +async def test_build_searchable_text(): + """Test searchable text building from listing fields.""" + result = embeddings.build_searchable_text( + name="AI Assistant", + description="A helpful AI assistant for productivity", + sub_heading="Boost your productivity", + categories=["AI", "Productivity"], + ) + + expected = "AI Assistant Boost your productivity A helpful AI assistant for productivity AI Productivity" + assert result == expected + + +@pytest.mark.asyncio(loop_scope="session") +async def test_build_searchable_text_empty_fields(): + """Test searchable text building with empty fields.""" + result = embeddings.build_searchable_text( + name="", description="Test description", sub_heading="", categories=[] + ) + + assert result == "Test description" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_generate_embedding_success(): + """Test successful embedding generation.""" + # Mock OpenAI response + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.data = [MagicMock()] + mock_response.data[0].embedding = [0.1, 0.2, 0.3] * 512 # 1536 dimensions + + # Use AsyncMock for async embeddings.create method + mock_client.embeddings.create = AsyncMock(return_value=mock_response) + + # Patch at the point of use in embeddings.py + with patch( + "backend.api.features.store.embeddings.get_openai_client" + ) as mock_get_client: + mock_get_client.return_value = mock_client + + result = await embeddings.generate_embedding("test text") + + assert result is not None + assert len(result) == 1536 + assert result[0] == 0.1 + + mock_client.embeddings.create.assert_called_once_with( + model="text-embedding-3-small", input="test text" + ) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_generate_embedding_no_api_key(): + """Test embedding generation without API key.""" + # Patch at the point of use in embeddings.py + with patch( + "backend.api.features.store.embeddings.get_openai_client" + ) as mock_get_client: + mock_get_client.return_value = None + + result = await embeddings.generate_embedding("test text") + + assert result is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_generate_embedding_api_error(): + """Test embedding generation with API error.""" + mock_client = MagicMock() + mock_client.embeddings.create = AsyncMock(side_effect=Exception("API Error")) + + # Patch at the point of use in embeddings.py + with patch( + "backend.api.features.store.embeddings.get_openai_client" + ) as mock_get_client: + mock_get_client.return_value = mock_client + + result = await embeddings.generate_embedding("test text") + + assert result is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_generate_embedding_text_truncation(): + """Test that long text is properly truncated using tiktoken.""" + from tiktoken import encoding_for_model + + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.data = [MagicMock()] + mock_response.data[0].embedding = [0.1] * 1536 + + # Use AsyncMock for async embeddings.create method + mock_client.embeddings.create = AsyncMock(return_value=mock_response) + + # Patch at the point of use in embeddings.py + with patch( + "backend.api.features.store.embeddings.get_openai_client" + ) as mock_get_client: + mock_get_client.return_value = mock_client + + # Create text that will exceed 8191 tokens + # Use varied characters to ensure token-heavy text: each word is ~1 token + words = [f"word{i}" for i in range(10000)] + long_text = " ".join(words) # ~10000 tokens + + await embeddings.generate_embedding(long_text) + + # Verify text was truncated to 8191 tokens + call_args = mock_client.embeddings.create.call_args + truncated_text = call_args.kwargs["input"] + + # Count actual tokens in truncated text + enc = encoding_for_model("text-embedding-3-small") + actual_tokens = len(enc.encode(truncated_text)) + + # Should be at or just under 8191 tokens + assert actual_tokens <= 8191 + # Should be close to the limit (not over-truncated) + assert actual_tokens >= 8100 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_store_embedding_success(mocker): + """Test successful embedding storage.""" + mock_client = mocker.AsyncMock() + mock_client.execute_raw = mocker.AsyncMock() + + embedding = [0.1, 0.2, 0.3] + + result = await embeddings.store_embedding( + version_id="test-version-id", embedding=embedding, tx=mock_client + ) + + assert result is True + # execute_raw is called twice: once for SET search_path, once for INSERT + assert mock_client.execute_raw.call_count == 2 + + # First call: SET search_path + first_call_args = mock_client.execute_raw.call_args_list[0][0] + assert "SET search_path" in first_call_args[0] + + # Second call: INSERT query with the actual data + second_call_args = mock_client.execute_raw.call_args_list[1][0] + assert "test-version-id" in second_call_args + assert "[0.1,0.2,0.3]" in second_call_args + assert None in second_call_args # userId should be None for store agents + + +@pytest.mark.asyncio(loop_scope="session") +async def test_store_embedding_database_error(mocker): + """Test embedding storage with database error.""" + mock_client = mocker.AsyncMock() + mock_client.execute_raw.side_effect = Exception("Database error") + + embedding = [0.1, 0.2, 0.3] + + result = await embeddings.store_embedding( + version_id="test-version-id", embedding=embedding, tx=mock_client + ) + + assert result is False + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_embedding_success(): + """Test successful embedding retrieval.""" + mock_result = [ + { + "contentType": "STORE_AGENT", + "contentId": "test-version-id", + "userId": None, + "embedding": "[0.1,0.2,0.3]", + "searchableText": "Test text", + "metadata": {}, + "createdAt": "2024-01-01T00:00:00Z", + "updatedAt": "2024-01-01T00:00:00Z", + } + ] + + with patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + return_value=mock_result, + ): + result = await embeddings.get_embedding("test-version-id") + + assert result is not None + assert result["storeListingVersionId"] == "test-version-id" + assert result["embedding"] == "[0.1,0.2,0.3]" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_embedding_not_found(): + """Test embedding retrieval when not found.""" + with patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + return_value=[], + ): + result = await embeddings.get_embedding("test-version-id") + + assert result is None + + +@pytest.mark.asyncio(loop_scope="session") +@patch("backend.api.features.store.embeddings.generate_embedding") +@patch("backend.api.features.store.embeddings.store_embedding") +@patch("backend.api.features.store.embeddings.get_embedding") +async def test_ensure_embedding_already_exists(mock_get, mock_store, mock_generate): + """Test ensure_embedding when embedding already exists.""" + mock_get.return_value = {"embedding": "[0.1,0.2,0.3]"} + + result = await embeddings.ensure_embedding( + version_id="test-id", + name="Test", + description="Test description", + sub_heading="Test heading", + categories=["test"], + ) + + assert result is True + mock_generate.assert_not_called() + mock_store.assert_not_called() + + +@pytest.mark.asyncio(loop_scope="session") +@patch("backend.api.features.store.embeddings.generate_embedding") +@patch("backend.api.features.store.embeddings.store_content_embedding") +@patch("backend.api.features.store.embeddings.get_embedding") +async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate): + """Test ensure_embedding creating new embedding.""" + mock_get.return_value = None + mock_generate.return_value = [0.1, 0.2, 0.3] + mock_store.return_value = True + + result = await embeddings.ensure_embedding( + version_id="test-id", + name="Test", + description="Test description", + sub_heading="Test heading", + categories=["test"], + ) + + assert result is True + mock_generate.assert_called_once_with("Test Test heading Test description test") + mock_store.assert_called_once_with( + content_type=ContentType.STORE_AGENT, + content_id="test-id", + embedding=[0.1, 0.2, 0.3], + searchable_text="Test Test heading Test description test", + metadata={"name": "Test", "subHeading": "Test heading", "categories": ["test"]}, + user_id=None, + tx=None, + ) + + +@pytest.mark.asyncio(loop_scope="session") +@patch("backend.api.features.store.embeddings.generate_embedding") +@patch("backend.api.features.store.embeddings.get_embedding") +async def test_ensure_embedding_generation_fails(mock_get, mock_generate): + """Test ensure_embedding when generation fails.""" + mock_get.return_value = None + mock_generate.return_value = None + + result = await embeddings.ensure_embedding( + version_id="test-id", + name="Test", + description="Test description", + sub_heading="Test heading", + categories=["test"], + ) + + assert result is False + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_embedding_stats(): + """Test embedding statistics retrieval.""" + # Mock approved count query and embedded count query + mock_approved_result = [{"count": 100}] + mock_embedded_result = [{"count": 75}] + + with patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + side_effect=[mock_approved_result, mock_embedded_result], + ): + result = await embeddings.get_embedding_stats() + + assert result["total_approved"] == 100 + assert result["with_embeddings"] == 75 + assert result["without_embeddings"] == 25 + assert result["coverage_percent"] == 75.0 + + +@pytest.mark.asyncio(loop_scope="session") +@patch("backend.api.features.store.embeddings.ensure_embedding") +async def test_backfill_missing_embeddings_success(mock_ensure): + """Test backfill with successful embedding generation.""" + # Mock missing embeddings query + mock_missing = [ + { + "id": "version-1", + "name": "Agent 1", + "description": "Description 1", + "subHeading": "Heading 1", + "categories": ["AI"], + }, + { + "id": "version-2", + "name": "Agent 2", + "description": "Description 2", + "subHeading": "Heading 2", + "categories": ["Productivity"], + }, + ] + + # Mock ensure_embedding to succeed for first, fail for second + mock_ensure.side_effect = [True, False] + + with patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + return_value=mock_missing, + ): + result = await embeddings.backfill_missing_embeddings(batch_size=5) + + assert result["processed"] == 2 + assert result["success"] == 1 + assert result["failed"] == 1 + assert mock_ensure.call_count == 2 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_backfill_missing_embeddings_no_missing(): + """Test backfill when no embeddings are missing.""" + with patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + return_value=[], + ): + result = await embeddings.backfill_missing_embeddings(batch_size=5) + + assert result["processed"] == 0 + assert result["success"] == 0 + assert result["failed"] == 0 + assert result["message"] == "No missing embeddings" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_embedding_to_vector_string(): + """Test embedding to PostgreSQL vector string conversion.""" + embedding = [0.1, 0.2, 0.3, -0.4] + result = embeddings.embedding_to_vector_string(embedding) + assert result == "[0.1,0.2,0.3,-0.4]" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_embed_query(): + """Test embed_query function (alias for generate_embedding).""" + with patch( + "backend.api.features.store.embeddings.generate_embedding" + ) as mock_generate: + mock_generate.return_value = [0.1, 0.2, 0.3] + + result = await embeddings.embed_query("test query") + + assert result == [0.1, 0.2, 0.3] + mock_generate.assert_called_once_with("test query") diff --git a/autogpt_platform/backend/backend/api/features/store/hybrid_search.py b/autogpt_platform/backend/backend/api/features/store/hybrid_search.py new file mode 100644 index 0000000000..fbbbe62cb3 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/store/hybrid_search.py @@ -0,0 +1,393 @@ +""" +Hybrid Search for Store Agents + +Combines semantic (embedding) search with lexical (tsvector) search +for improved relevance in marketplace agent discovery. +""" + +import logging +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Literal + +from backend.api.features.store.embeddings import ( + embed_query, + embedding_to_vector_string, +) +from backend.data.db import query_raw_with_schema + +logger = logging.getLogger(__name__) + + +@dataclass +class HybridSearchWeights: + """Weights for combining search signals.""" + + semantic: float = 0.30 # Embedding cosine similarity + lexical: float = 0.30 # tsvector ts_rank_cd score + category: float = 0.20 # Category match boost + recency: float = 0.10 # Newer agents ranked higher + popularity: float = 0.10 # Agent usage/runs (PageRank-like) + + def __post_init__(self): + """Validate weights are non-negative and sum to approximately 1.0.""" + total = ( + self.semantic + + self.lexical + + self.category + + self.recency + + self.popularity + ) + + if any( + w < 0 + for w in [ + self.semantic, + self.lexical, + self.category, + self.recency, + self.popularity, + ] + ): + raise ValueError("All weights must be non-negative") + + if not (0.99 <= total <= 1.01): + raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}") + + +DEFAULT_WEIGHTS = HybridSearchWeights() + +# Minimum relevance score threshold - agents below this are filtered out +# With weights (0.30 semantic + 0.30 lexical + 0.20 category + 0.10 recency + 0.10 popularity): +# - 0.20 means at least ~60% semantic match OR strong lexical match required +# - Ensures only genuinely relevant results are returned +# - Recency/popularity alone (0.10 each) won't pass the threshold +DEFAULT_MIN_SCORE = 0.20 + + +@dataclass +class HybridSearchResult: + """A single search result with score breakdown.""" + + slug: str + agent_name: str + agent_image: str + creator_username: str + creator_avatar: str + sub_heading: str + description: str + runs: int + rating: float + categories: list[str] + featured: bool + is_available: bool + updated_at: datetime + + # Score breakdown (for debugging/tuning) + combined_score: float + semantic_score: float = 0.0 + lexical_score: float = 0.0 + category_score: float = 0.0 + recency_score: float = 0.0 + popularity_score: float = 0.0 + + +async def hybrid_search( + query: str, + featured: bool = False, + creators: list[str] | None = None, + category: str | None = None, + sorted_by: ( + Literal["relevance", "rating", "runs", "name", "updated_at"] | None + ) = None, + page: int = 1, + page_size: int = 20, + weights: HybridSearchWeights | None = None, + min_score: float | None = None, +) -> tuple[list[dict[str, Any]], int]: + """ + Perform hybrid search combining semantic and lexical signals. + + Args: + query: Search query string + featured: Filter for featured agents only + creators: Filter by creator usernames + category: Filter by category + sorted_by: Sort order (relevance uses hybrid scoring) + page: Page number (1-indexed) + page_size: Results per page + weights: Custom weights for search signals + min_score: Minimum relevance score threshold (0-1). Results below + this score are filtered out. Defaults to DEFAULT_MIN_SCORE. + + Returns: + Tuple of (results list, total count). Returns empty list if no + results meet the minimum relevance threshold. + """ + # Validate inputs + query = query.strip() + if not query: + return [], 0 # Empty query returns no results + + if page < 1: + page = 1 + if page_size < 1: + page_size = 1 + if page_size > 100: # Cap at reasonable limit to prevent performance issues + page_size = 100 + + if weights is None: + weights = DEFAULT_WEIGHTS + if min_score is None: + min_score = DEFAULT_MIN_SCORE + + offset = (page - 1) * page_size + + # Generate query embedding + query_embedding = await embed_query(query) + + # Build WHERE clause conditions + where_parts: list[str] = ["sa.is_available = true"] + params: list[Any] = [] + param_index = 1 + + # Add search query for lexical matching + params.append(query) + query_param = f"${param_index}" + param_index += 1 + + # Add lowercased query for category matching + params.append(query.lower()) + query_lower_param = f"${param_index}" + param_index += 1 + + if featured: + where_parts.append("sa.featured = true") + + if creators: + where_parts.append(f"sa.creator_username = ANY(${param_index})") + params.append(creators) + param_index += 1 + + if category: + where_parts.append(f"${param_index} = ANY(sa.categories)") + params.append(category) + param_index += 1 + + # Safe: where_parts only contains hardcoded strings with $N parameter placeholders + # No user input is concatenated directly into the SQL string + where_clause = " AND ".join(where_parts) + + # Embedding is required for hybrid search - fail fast if unavailable + if query_embedding is None or not query_embedding: + # Log detailed error server-side + logger.error( + "Failed to generate query embedding. " + "Check that openai_internal_api_key is configured and OpenAI API is accessible." + ) + # Raise generic error to client + raise ValueError("Search service temporarily unavailable") + + # Add embedding parameter + embedding_str = embedding_to_vector_string(query_embedding) + params.append(embedding_str) + embedding_param = f"${param_index}" + param_index += 1 + + # Add weight parameters for SQL calculation + params.append(weights.semantic) + weight_semantic_param = f"${param_index}" + param_index += 1 + + params.append(weights.lexical) + weight_lexical_param = f"${param_index}" + param_index += 1 + + params.append(weights.category) + weight_category_param = f"${param_index}" + param_index += 1 + + params.append(weights.recency) + weight_recency_param = f"${param_index}" + param_index += 1 + + params.append(weights.popularity) + weight_popularity_param = f"${param_index}" + param_index += 1 + + # Add min_score parameter + params.append(min_score) + min_score_param = f"${param_index}" + param_index += 1 + + # Optimized hybrid search query: + # 1. Direct join to UnifiedContentEmbedding via contentId=storeListingVersionId (no redundant JOINs) + # 2. UNION approach (deduplicates agents matching both branches) + # 3. COUNT(*) OVER() to get total count in single query + # 4. Optimized category matching with EXISTS + unnest + # 5. Pre-calculated max values for lexical and popularity normalization + # 6. Simplified recency calculation with linear decay + # 7. Logarithmic popularity scaling to prevent viral agents from dominating + sql_query = f""" + WITH candidates AS ( + -- Lexical matches (uses GIN index on search column) + SELECT sa."storeListingVersionId" + FROM {{schema_prefix}}"StoreAgent" sa + WHERE {where_clause} + AND sa.search @@ plainto_tsquery('english', {query_param}) + + UNION + + -- Semantic matches (uses HNSW index on embedding with KNN) + SELECT "storeListingVersionId" + FROM ( + SELECT sa."storeListingVersionId", uce.embedding + FROM {{schema_prefix}}"StoreAgent" sa + INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce + ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType" + WHERE {where_clause} + ORDER BY uce.embedding <=> {embedding_param}::vector + LIMIT 200 + ) semantic_results + ), + search_scores AS ( + SELECT + sa.slug, + sa.agent_name, + sa.agent_image, + sa.creator_username, + sa.creator_avatar, + sa.sub_heading, + sa.description, + sa.runs, + sa.rating, + sa.categories, + sa.featured, + sa.is_available, + sa.updated_at, + -- Semantic score: cosine similarity (1 - distance) + COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score, + -- Lexical score: ts_rank_cd (will be normalized later) + COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw, + -- Category match: optimized with unnest for better performance + CASE + WHEN EXISTS ( + SELECT 1 FROM unnest(sa.categories) cat + WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%' + ) + THEN 1.0 + ELSE 0.0 + END as category_score, + -- Recency score: linear decay over 90 days (simpler than exponential) + GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score, + -- Popularity raw: agent runs count (will be normalized with log scaling) + sa.runs as popularity_raw + FROM candidates c + INNER JOIN {{schema_prefix}}"StoreAgent" sa + ON c."storeListingVersionId" = sa."storeListingVersionId" + LEFT JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce + ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType" + ), + max_lexical AS ( + SELECT MAX(lexical_raw) as max_val FROM search_scores + ), + max_popularity AS ( + SELECT MAX(popularity_raw) as max_val FROM search_scores + ), + normalized AS ( + SELECT + ss.*, + -- Normalize lexical score by pre-calculated max + CASE + WHEN ml.max_val > 0 + THEN ss.lexical_raw / ml.max_val + ELSE 0 + END as lexical_score, + -- Normalize popularity with logarithmic scaling to prevent viral agents from dominating + -- LOG(1 + runs) / LOG(1 + max_runs) ensures score is 0-1 range + CASE + WHEN mp.max_val > 0 AND ss.popularity_raw > 0 + THEN LN(1 + ss.popularity_raw) / LN(1 + mp.max_val) + ELSE 0 + END as popularity_score + FROM search_scores ss + CROSS JOIN max_lexical ml + CROSS JOIN max_popularity mp + ), + scored AS ( + SELECT + slug, + agent_name, + agent_image, + creator_username, + creator_avatar, + sub_heading, + description, + runs, + rating, + categories, + featured, + is_available, + updated_at, + semantic_score, + lexical_score, + category_score, + recency_score, + popularity_score, + ( + {weight_semantic_param} * semantic_score + + {weight_lexical_param} * lexical_score + + {weight_category_param} * category_score + + {weight_recency_param} * recency_score + + {weight_popularity_param} * popularity_score + ) as combined_score + FROM normalized + ), + filtered AS ( + SELECT + *, + COUNT(*) OVER () as total_count + FROM scored + WHERE combined_score >= {min_score_param} + ) + SELECT * FROM filtered + ORDER BY combined_score DESC + LIMIT ${param_index} OFFSET ${param_index + 1} + """ + + # Add pagination params + params.extend([page_size, offset]) + + # Execute search query - includes total_count via window function + results = await query_raw_with_schema( + sql_query, *params, set_public_search_path=True + ) + + # Extract total count from first result (all rows have same count) + total = results[0]["total_count"] if results else 0 + + # Remove total_count from results before returning + for result in results: + result.pop("total_count", None) + + # Log without sensitive query content + logger.info(f"Hybrid search: {len(results)} results, {total} total") + + return results, total + + +async def hybrid_search_simple( + query: str, + page: int = 1, + page_size: int = 20, +) -> tuple[list[dict[str, Any]], int]: + """ + Simplified hybrid search for common use cases. + + Uses default weights and no filters. + """ + return await hybrid_search( + query=query, + page=page, + page_size=page_size, + ) diff --git a/autogpt_platform/backend/backend/api/features/store/hybrid_search_test.py b/autogpt_platform/backend/backend/api/features/store/hybrid_search_test.py new file mode 100644 index 0000000000..6a5cd7ad6d --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/store/hybrid_search_test.py @@ -0,0 +1,334 @@ +""" +Integration tests for hybrid search with schema handling. + +These tests verify that hybrid search works correctly across different database schemas. +""" + +from unittest.mock import patch + +import pytest + +from backend.api.features.store.hybrid_search import HybridSearchWeights, hybrid_search + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_hybrid_search_with_schema_handling(): + """Test that hybrid search correctly handles database schema prefixes.""" + # Test with a mock query to ensure schema handling works + query = "test agent" + + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + # Mock the query result + mock_query.return_value = [ + { + "slug": "test/agent", + "agent_name": "Test Agent", + "agent_image": "test.png", + "creator_username": "test", + "creator_avatar": "avatar.png", + "sub_heading": "Test sub-heading", + "description": "Test description", + "runs": 10, + "rating": 4.5, + "categories": ["test"], + "featured": False, + "is_available": True, + "updated_at": "2024-01-01T00:00:00Z", + "combined_score": 0.8, + "semantic_score": 0.7, + "lexical_score": 0.6, + "category_score": 0.5, + "recency_score": 0.4, + "total_count": 1, + } + ] + + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_embed.return_value = [0.1] * 1536 # Mock embedding + + results, total = await hybrid_search( + query=query, + page=1, + page_size=20, + ) + + # Verify the query was called + assert mock_query.called + # Verify the SQL template uses schema_prefix placeholder + call_args = mock_query.call_args + sql_template = call_args[0][0] + assert "{schema_prefix}" in sql_template + + # Verify results + assert len(results) == 1 + assert total == 1 + assert results[0]["slug"] == "test/agent" + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_hybrid_search_with_public_schema(): + """Test hybrid search when using public schema (no prefix needed).""" + with patch("backend.data.db.get_database_schema") as mock_schema: + mock_schema.return_value = "public" + + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + mock_query.return_value = [] + + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_embed.return_value = [0.1] * 1536 + + results, total = await hybrid_search( + query="test", + page=1, + page_size=20, + ) + + # Verify the mock was set up correctly + assert mock_schema.return_value == "public" + + # Results should work even with empty results + assert results == [] + assert total == 0 + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_hybrid_search_with_custom_schema(): + """Test hybrid search when using custom schema (e.g., 'platform').""" + with patch("backend.data.db.get_database_schema") as mock_schema: + mock_schema.return_value = "platform" + + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + mock_query.return_value = [] + + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_embed.return_value = [0.1] * 1536 + + results, total = await hybrid_search( + query="test", + page=1, + page_size=20, + ) + + # Verify the mock was set up correctly + assert mock_schema.return_value == "platform" + + assert results == [] + assert total == 0 + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_hybrid_search_without_embeddings(): + """Test hybrid search fails fast when embeddings are unavailable.""" + # Patch where the function is used, not where it's defined + with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed: + # Simulate embedding failure + mock_embed.return_value = None + + # Should raise ValueError with helpful message + with pytest.raises(ValueError) as exc_info: + await hybrid_search( + query="test", + page=1, + page_size=20, + ) + + # Verify error message is generic (doesn't leak implementation details) + assert "Search service temporarily unavailable" in str(exc_info.value) + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_hybrid_search_with_filters(): + """Test hybrid search with various filters.""" + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + mock_query.return_value = [] + + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_embed.return_value = [0.1] * 1536 + + # Test with featured filter + results, total = await hybrid_search( + query="test", + featured=True, + creators=["user1", "user2"], + category="productivity", + page=1, + page_size=10, + ) + + # Verify filters were applied in the query + call_args = mock_query.call_args + params = call_args[0][1:] # Skip SQL template + + # Should have query, query_lower, creators array, category + assert len(params) >= 4 + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_hybrid_search_weights(): + """Test hybrid search with custom weights.""" + custom_weights = HybridSearchWeights( + semantic=0.5, + lexical=0.3, + category=0.1, + recency=0.1, + popularity=0.0, + ) + + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + mock_query.return_value = [] + + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_embed.return_value = [0.1] * 1536 + + results, total = await hybrid_search( + query="test", + weights=custom_weights, + page=1, + page_size=20, + ) + + # Verify custom weights were used in the query + call_args = mock_query.call_args + sql_template = call_args[0][0] + params = call_args[0][1:] # Get all parameters passed + + # Check that SQL uses parameterized weights (not f-string interpolation) + assert "$" in sql_template # Verify parameterization is used + + # Check that custom weights are in the params + assert 0.5 in params # semantic weight + assert 0.3 in params # lexical weight + assert 0.1 in params # category and recency weights + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_hybrid_search_min_score_filtering(): + """Test hybrid search minimum score threshold.""" + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + # Return results with varying scores + mock_query.return_value = [ + { + "slug": "high-score/agent", + "agent_name": "High Score Agent", + "combined_score": 0.8, + "total_count": 1, + # ... other fields + } + ] + + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_embed.return_value = [0.1] * 1536 + + # Test with custom min_score + results, total = await hybrid_search( + query="test", + min_score=0.5, # High threshold + page=1, + page_size=20, + ) + + # Verify min_score was applied in query + call_args = mock_query.call_args + sql_template = call_args[0][0] + params = call_args[0][1:] # Get all parameters + + # Check that SQL uses parameterized min_score + assert "combined_score >=" in sql_template + assert "$" in sql_template # Verify parameterization + + # Check that custom min_score is in the params + assert 0.5 in params + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_hybrid_search_pagination(): + """Test hybrid search pagination.""" + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + mock_query.return_value = [] + + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_embed.return_value = [0.1] * 1536 + + # Test page 2 with page_size 10 + results, total = await hybrid_search( + query="test", + page=2, + page_size=10, + ) + + # Verify pagination parameters + call_args = mock_query.call_args + params = call_args[0] + + # Last two params should be LIMIT and OFFSET + limit = params[-2] + offset = params[-1] + + assert limit == 10 # page_size + assert offset == 10 # (page - 1) * page_size = (2 - 1) * 10 + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_hybrid_search_error_handling(): + """Test hybrid search error handling.""" + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + # Simulate database error + mock_query.side_effect = Exception("Database connection error") + + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_embed.return_value = [0.1] * 1536 + + # Should raise exception + with pytest.raises(Exception) as exc_info: + await hybrid_search( + query="test", + page=1, + page_size=20, + ) + + assert "Database connection error" in str(exc_info.value) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/autogpt_platform/backend/backend/data/db.py b/autogpt_platform/backend/backend/data/db.py index 31a27e9163..ab39881ed5 100644 --- a/autogpt_platform/backend/backend/data/db.py +++ b/autogpt_platform/backend/backend/data/db.py @@ -38,6 +38,20 @@ POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT") if POOL_TIMEOUT: DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT) +# Add public schema to search_path for pgvector type access +# The vector extension is in public schema, but search_path is determined by schema parameter +# Extract the schema from DATABASE_URL or default to 'public' (matching get_database_schema()) +parsed_url = urlparse(DATABASE_URL) +url_params = dict(parse_qsl(parsed_url.query)) +db_schema = url_params.get("schema", "public") +# Build search_path, avoiding duplicates if db_schema is already 'public' +search_path_schemas = list( + dict.fromkeys([db_schema, "public"]) +) # Preserves order, removes duplicates +search_path = ",".join(search_path_schemas) +# This allows using ::vector without schema qualification +DATABASE_URL = add_param(DATABASE_URL, "options", f"-c search_path={search_path}") + HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None prisma = Prisma( @@ -108,21 +122,102 @@ def get_database_schema() -> str: return query_params.get("schema", "public") -async def query_raw_with_schema(query_template: str, *args) -> list[dict]: - """Execute raw SQL query with proper schema handling.""" +async def _raw_with_schema( + query_template: str, + *args, + execute: bool = False, + client: Prisma | None = None, + set_public_search_path: bool = False, +) -> list[dict] | int: + """Internal: Execute raw SQL with proper schema handling. + + Use query_raw_with_schema() or execute_raw_with_schema() instead. + + Args: + query_template: SQL query with {schema_prefix} placeholder + *args: Query parameters + execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE. + client: Optional Prisma client for transactions (only used when execute=True). + set_public_search_path: If True, sets search_path to include public schema. + Needed for pgvector types and other public schema objects. + + Returns: + - list[dict] if execute=False (query results) + - int if execute=True (number of affected rows) + """ schema = get_database_schema() schema_prefix = f'"{schema}".' if schema != "public" else "" formatted_query = query_template.format(schema_prefix=schema_prefix) import prisma as prisma_module - result = await prisma_module.get_client().query_raw( - formatted_query, *args # type: ignore - ) + db_client = client if client else prisma_module.get_client() + + # Set search_path to include public schema if requested + # Prisma doesn't support the 'options' connection parameter, so we set it per-session + # This is idempotent and safe to call multiple times + if set_public_search_path: + await db_client.execute_raw(f"SET search_path = {schema}, public") # type: ignore + + if execute: + result = await db_client.execute_raw(formatted_query, *args) # type: ignore + else: + result = await db_client.query_raw(formatted_query, *args) # type: ignore return result +async def query_raw_with_schema( + query_template: str, *args, set_public_search_path: bool = False +) -> list[dict]: + """Execute raw SQL SELECT query with proper schema handling. + + Args: + query_template: SQL query with {schema_prefix} placeholder + *args: Query parameters + set_public_search_path: If True, sets search_path to include public schema. + Needed for pgvector types and other public schema objects. + + Returns: + List of result rows as dictionaries + + Example: + results = await query_raw_with_schema( + 'SELECT * FROM {schema_prefix}"User" WHERE id = $1', + user_id + ) + """ + return await _raw_with_schema(query_template, *args, execute=False, set_public_search_path=set_public_search_path) # type: ignore + + +async def execute_raw_with_schema( + query_template: str, + *args, + client: Prisma | None = None, + set_public_search_path: bool = False, +) -> int: + """Execute raw SQL command (INSERT/UPDATE/DELETE) with proper schema handling. + + Args: + query_template: SQL query with {schema_prefix} placeholder + *args: Query parameters + client: Optional Prisma client for transactions + set_public_search_path: If True, sets search_path to include public schema. + Needed for pgvector types and other public schema objects. + + Returns: + Number of affected rows + + Example: + await execute_raw_with_schema( + 'INSERT INTO {schema_prefix}"User" (id, name) VALUES ($1, $2)', + user_id, name, + client=tx # Optional transaction client + ) + """ + return await _raw_with_schema(query_template, *args, execute=True, client=client, set_public_search_path=set_public_search_path) # type: ignore + + class BaseDbModel(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) diff --git a/autogpt_platform/backend/backend/data/graph_test.py b/autogpt_platform/backend/backend/data/graph_test.py index eea7277eb9..8b7eadb887 100644 --- a/autogpt_platform/backend/backend/data/graph_test.py +++ b/autogpt_platform/backend/backend/data/graph_test.py @@ -1,5 +1,6 @@ import json from typing import Any +from unittest.mock import AsyncMock, patch from uuid import UUID import fastapi.exceptions @@ -18,6 +19,17 @@ from backend.usecases.sample import create_test_user from backend.util.test import SpinTestServer +@pytest.fixture(scope="session", autouse=True) +def mock_embedding_functions(): + """Mock embedding functions for all tests to avoid database/API dependencies.""" + with patch( + "backend.api.features.store.db.ensure_embedding", + new_callable=AsyncMock, + return_value=True, + ): + yield + + @pytest.mark.asyncio(loop_scope="session") async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot): """ diff --git a/autogpt_platform/backend/backend/executor/database.py b/autogpt_platform/backend/backend/executor/database.py index 9848948bff..f10b285450 100644 --- a/autogpt_platform/backend/backend/executor/database.py +++ b/autogpt_platform/backend/backend/executor/database.py @@ -7,6 +7,10 @@ from backend.api.features.library.db import ( list_library_agents, ) from backend.api.features.store.db import get_store_agent_details, get_store_agents +from backend.api.features.store.embeddings import ( + backfill_missing_embeddings, + get_embedding_stats, +) from backend.data import db from backend.data.analytics import ( get_accuracy_trends_and_alerts, @@ -214,6 +218,10 @@ class DatabaseManager(AppService): get_store_agents = _(get_store_agents) get_store_agent_details = _(get_store_agent_details) + # Store Embeddings + get_embedding_stats = _(get_embedding_stats) + backfill_missing_embeddings = _(backfill_missing_embeddings) + # Summary data - async get_user_execution_summary_data = _(get_user_execution_summary_data) @@ -265,6 +273,10 @@ class DatabaseManagerClient(AppServiceClient): get_store_agents = _(d.get_store_agents) get_store_agent_details = _(d.get_store_agent_details) + # Store Embeddings + get_embedding_stats = _(d.get_embedding_stats) + backfill_missing_embeddings = _(d.backfill_missing_embeddings) + class DatabaseManagerAsyncClient(AppServiceClient): d = DatabaseManager diff --git a/autogpt_platform/backend/backend/executor/manager_test.py b/autogpt_platform/backend/backend/executor/manager_test.py index bdfdb5d724..69deba4b00 100644 --- a/autogpt_platform/backend/backend/executor/manager_test.py +++ b/autogpt_platform/backend/backend/executor/manager_test.py @@ -1,4 +1,5 @@ import logging +from unittest.mock import AsyncMock, patch import fastapi.responses import pytest @@ -19,6 +20,17 @@ from backend.util.test import SpinTestServer, wait_execution logger = logging.getLogger(__name__) +@pytest.fixture(scope="session", autouse=True) +def mock_embedding_functions(): + """Mock embedding functions for all tests to avoid database/API dependencies.""" + with patch( + "backend.api.features.store.db.ensure_embedding", + new_callable=AsyncMock, + return_value=True, + ): + yield + + async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph: logger.info(f"Creating graph for user {u.id}") return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id) diff --git a/autogpt_platform/backend/backend/executor/scheduler.py b/autogpt_platform/backend/backend/executor/scheduler.py index 963c901fd6..3845c04ab6 100644 --- a/autogpt_platform/backend/backend/executor/scheduler.py +++ b/autogpt_platform/backend/backend/executor/scheduler.py @@ -2,6 +2,7 @@ import asyncio import logging import os import threading +import time import uuid from enum import Enum from typing import Optional @@ -36,7 +37,7 @@ from backend.monitoring import ( report_execution_accuracy_alerts, report_late_executions, ) -from backend.util.clients import get_scheduler_client +from backend.util.clients import get_database_manager_client, get_scheduler_client from backend.util.cloud_storage import cleanup_expired_files_async from backend.util.exceptions import ( GraphNotFoundError, @@ -252,6 +253,74 @@ def execution_accuracy_alerts(): return report_execution_accuracy_alerts() +def ensure_embeddings_coverage(): + """ + Ensure approved store agents have embeddings for hybrid search. + + Processes ALL missing embeddings in batches of 10 until 100% coverage. + Missing embeddings = agents invisible in hybrid search. + + Schedule: Runs every 6 hours (balanced between coverage and API costs). + - Catches agents approved between scheduled runs + - Batch size 10: gradual processing to avoid rate limits + - Manual trigger available via execute_ensure_embeddings_coverage endpoint + """ + db_client = get_database_manager_client() + stats = db_client.get_embedding_stats() + + # Check for error from get_embedding_stats() first + if "error" in stats: + logger.error( + f"Failed to get embedding stats: {stats['error']} - skipping backfill" + ) + return {"processed": 0, "success": 0, "failed": 0, "error": stats["error"]} + + if stats["without_embeddings"] == 0: + logger.info("All approved agents have embeddings, skipping backfill") + return {"processed": 0, "success": 0, "failed": 0} + + logger.info( + f"Found {stats['without_embeddings']} agents without embeddings " + f"({stats['coverage_percent']}% coverage) - processing all" + ) + + total_processed = 0 + total_success = 0 + total_failed = 0 + + # Process in batches until no more missing embeddings + while True: + result = db_client.backfill_missing_embeddings(batch_size=10) + + total_processed += result["processed"] + total_success += result["success"] + total_failed += result["failed"] + + if result["processed"] == 0: + # No more missing embeddings + break + + if result["success"] == 0 and result["processed"] > 0: + # All attempts in this batch failed - stop to avoid infinite loop + logger.error( + f"All {result['processed']} embedding attempts failed - stopping backfill" + ) + break + + # Small delay between batches to avoid rate limits + time.sleep(1) + + logger.info( + f"Embedding backfill completed: {total_success}/{total_processed} succeeded, " + f"{total_failed} failed" + ) + return { + "processed": total_processed, + "success": total_success, + "failed": total_failed, + } + + # Monitoring functions are now imported from monitoring module @@ -473,6 +542,19 @@ class Scheduler(AppService): jobstore=Jobstores.EXECUTION.value, ) + # Embedding Coverage - Every 6 hours + # Ensures all approved agents have embeddings for hybrid search + # Critical: missing embeddings = agents invisible in search + self.scheduler.add_job( + ensure_embeddings_coverage, + id="ensure_embeddings_coverage", + trigger="interval", + hours=6, + replace_existing=True, + max_instances=1, # Prevent overlapping runs + jobstore=Jobstores.EXECUTION.value, + ) + self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR) self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED) self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES) @@ -630,6 +712,11 @@ class Scheduler(AppService): """Manually trigger execution accuracy alert checking.""" return execution_accuracy_alerts() + @expose + def execute_ensure_embeddings_coverage(self): + """Manually trigger embedding backfill for approved store agents.""" + return ensure_embeddings_coverage() + class SchedulerClient(AppServiceClient): @classmethod diff --git a/autogpt_platform/backend/backend/util/clients.py b/autogpt_platform/backend/backend/util/clients.py index bea8a5e2f9..570e9fa3de 100644 --- a/autogpt_platform/backend/backend/util/clients.py +++ b/autogpt_platform/backend/backend/util/clients.py @@ -10,6 +10,7 @@ from backend.util.settings import Settings settings = Settings() if TYPE_CHECKING: + from openai import AsyncOpenAI from supabase import AClient, Client from backend.data.execution import ( @@ -139,6 +140,24 @@ async def get_async_supabase() -> "AClient": ) +# ============ OpenAI Client ============ # + + +@cached(ttl_seconds=3600) +def get_openai_client() -> "AsyncOpenAI | None": + """ + Get a process-cached async OpenAI client for embeddings. + + Returns None if API key is not configured. + """ + from openai import AsyncOpenAI + + api_key = settings.secrets.openai_internal_api_key + if not api_key: + return None + return AsyncOpenAI(api_key=api_key) + + # ============ Notification Queue Helpers ============ # diff --git a/autogpt_platform/backend/migrations/20260109181714_add_docs_embedding/migration.sql b/autogpt_platform/backend/migrations/20260109181714_add_docs_embedding/migration.sql new file mode 100644 index 0000000000..9c4bcff5e1 --- /dev/null +++ b/autogpt_platform/backend/migrations/20260109181714_add_docs_embedding/migration.sql @@ -0,0 +1,46 @@ +-- CreateExtension +-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first +-- Create in public schema so vector type is available across all schemas +DO $$ +BEGIN + CREATE EXTENSION IF NOT EXISTS "vector" WITH SCHEMA "public"; +EXCEPTION WHEN OTHERS THEN + RAISE NOTICE 'vector extension not available or already exists, skipping'; +END $$; + +-- CreateEnum +CREATE TYPE "ContentType" AS ENUM ('STORE_AGENT', 'BLOCK', 'INTEGRATION', 'DOCUMENTATION', 'LIBRARY_AGENT'); + +-- CreateTable +CREATE TABLE "UnifiedContentEmbedding" ( + "id" TEXT NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + "contentType" "ContentType" NOT NULL, + "contentId" TEXT NOT NULL, + "userId" TEXT, + "embedding" public.vector(1536) NOT NULL, + "searchableText" TEXT NOT NULL, + "metadata" JSONB NOT NULL DEFAULT '{}', + + CONSTRAINT "UnifiedContentEmbedding_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE INDEX "UnifiedContentEmbedding_contentType_idx" ON "UnifiedContentEmbedding"("contentType"); + +-- CreateIndex +CREATE INDEX "UnifiedContentEmbedding_userId_idx" ON "UnifiedContentEmbedding"("userId"); + +-- CreateIndex +CREATE INDEX "UnifiedContentEmbedding_contentType_userId_idx" ON "UnifiedContentEmbedding"("contentType", "userId"); + +-- CreateIndex +-- NULLS NOT DISTINCT ensures only one public (NULL userId) embedding per contentType+contentId +-- Requires PostgreSQL 15+. Supabase uses PostgreSQL 15+. +CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" ON "UnifiedContentEmbedding"("contentType", "contentId", "userId") NULLS NOT DISTINCT; + +-- CreateIndex +-- HNSW index for fast vector similarity search on embeddings +-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py +CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops); diff --git a/autogpt_platform/backend/migrations/20260112173500_add_supabase_extensions_to_platform_schema/migration.sql b/autogpt_platform/backend/migrations/20260112173500_add_supabase_extensions_to_platform_schema/migration.sql new file mode 100644 index 0000000000..ca91bc5cab --- /dev/null +++ b/autogpt_platform/backend/migrations/20260112173500_add_supabase_extensions_to_platform_schema/migration.sql @@ -0,0 +1,71 @@ +-- Acknowledge Supabase-managed extensions to prevent drift warnings +-- These extensions are pre-installed by Supabase in specific schemas +-- This migration ensures they exist where available (Supabase) or skips gracefully (CI) + +-- Create schemas (safe in both CI and Supabase) +CREATE SCHEMA IF NOT EXISTS "extensions"; + +-- Extensions that exist in both CI and Supabase +DO $$ +BEGIN + CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions"; +EXCEPTION WHEN OTHERS THEN + RAISE NOTICE 'pgcrypto extension not available, skipping'; +END $$; + +DO $$ +BEGIN + CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions"; +EXCEPTION WHEN OTHERS THEN + RAISE NOTICE 'uuid-ossp extension not available, skipping'; +END $$; + +-- Supabase-specific extensions (skip gracefully in CI) +DO $$ +BEGIN + CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions"; +EXCEPTION WHEN OTHERS THEN + RAISE NOTICE 'pg_stat_statements extension not available, skipping'; +END $$; + +DO $$ +BEGIN + CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions"; +EXCEPTION WHEN OTHERS THEN + RAISE NOTICE 'pg_net extension not available, skipping'; +END $$; + +DO $$ +BEGIN + CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions"; +EXCEPTION WHEN OTHERS THEN + RAISE NOTICE 'pgjwt extension not available, skipping'; +END $$; + +DO $$ +BEGIN + CREATE SCHEMA IF NOT EXISTS "graphql"; + CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql"; +EXCEPTION WHEN OTHERS THEN + RAISE NOTICE 'pg_graphql extension not available, skipping'; +END $$; + +DO $$ +BEGIN + CREATE SCHEMA IF NOT EXISTS "pgsodium"; + CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium"; +EXCEPTION WHEN OTHERS THEN + RAISE NOTICE 'pgsodium extension not available, skipping'; +END $$; + +DO $$ +BEGIN + CREATE SCHEMA IF NOT EXISTS "vault"; + CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault"; +EXCEPTION WHEN OTHERS THEN + RAISE NOTICE 'supabase_vault extension not available, skipping'; +END $$; + + +-- Return to platform +CREATE SCHEMA IF NOT EXISTS "platform"; \ No newline at end of file diff --git a/autogpt_platform/backend/schema.prisma b/autogpt_platform/backend/schema.prisma index 2f6c109c03..0efca1bf3d 100644 --- a/autogpt_platform/backend/schema.prisma +++ b/autogpt_platform/backend/schema.prisma @@ -1,14 +1,15 @@ datasource db { - provider = "postgresql" - url = env("DATABASE_URL") - directUrl = env("DIRECT_URL") + provider = "postgresql" + url = env("DATABASE_URL") + directUrl = env("DIRECT_URL") + extensions = [pgvector(map: "vector")] } generator client { provider = "prisma-client-py" recursive_type_depth = -1 interface = "asyncio" - previewFeatures = ["views", "fullTextSearch"] + previewFeatures = ["views", "fullTextSearch", "postgresqlExtensions"] partial_type_generator = "backend/data/partial_types.py" } @@ -127,8 +128,8 @@ model BuilderSearchHistory { updatedAt DateTime @default(now()) @updatedAt searchQuery String - filter String[] @default([]) - byCreator String[] @default([]) + filter String[] @default([]) + byCreator String[] @default([]) userId String User User @relation(fields: [userId], references: [id], onDelete: Cascade) @@ -721,26 +722,25 @@ view StoreAgent { storeListingVersionId String updated_at DateTime - slug String - agent_name String - agent_video String? - agent_output_demo String? - agent_image String[] + slug String + agent_name String + agent_video String? + agent_output_demo String? + agent_image String[] - featured Boolean @default(false) - creator_username String? - creator_avatar String? - sub_heading String - description String - categories String[] - search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) - runs Int - rating Float - versions String[] - agentGraphVersions String[] - agentGraphId String - is_available Boolean @default(true) - useForOnboarding Boolean @default(false) + featured Boolean @default(false) + creator_username String? + creator_avatar String? + sub_heading String + description String + categories String[] + runs Int + rating Float + versions String[] + agentGraphVersions String[] + agentGraphId String + is_available Boolean @default(true) + useForOnboarding Boolean @default(false) // Materialized views used (refreshed every 15 minutes via pg_cron): // - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId @@ -856,14 +856,14 @@ model StoreListingVersion { AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version]) // Content fields - name String - subHeading String - videoUrl String? - agentOutputDemoUrl String? - imageUrls String[] - description String - instructions String? - categories String[] + name String + subHeading String + videoUrl String? + agentOutputDemoUrl String? + imageUrls String[] + description String + instructions String? + categories String[] isFeatured Boolean @default(false) @@ -899,6 +899,9 @@ model StoreListingVersion { // Reviews for this specific version Reviews StoreListingReview[] + // Note: Embeddings now stored in UnifiedContentEmbedding table + // Use contentType=STORE_AGENT and contentId=storeListingVersionId + @@unique([storeListingId, version]) @@index([storeListingId, submissionStatus, isAvailable]) @@index([submissionStatus]) @@ -906,6 +909,42 @@ model StoreListingVersion { @@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups } +// Content type enum for unified search across store agents, blocks, docs +// Note: BLOCK/INTEGRATION are file-based (Python classes), not DB records +// DOCUMENTATION are file-based (.md files), not DB records +// Only STORE_AGENT and LIBRARY_AGENT are stored in database +enum ContentType { + STORE_AGENT // Database: StoreListingVersion + BLOCK // File-based: Python classes in /backend/blocks/ + INTEGRATION // File-based: Python classes (blocks with credentials) + DOCUMENTATION // File-based: .md/.mdx files + LIBRARY_AGENT // Database: User's personal agents +} + +// Unified embeddings table for all searchable content types +// Supports both public content (userId=null) and user-specific content (userId=userID) +model UnifiedContentEmbedding { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + // Content identification + contentType ContentType + contentId String // DB ID (storeListingVersionId) or file identifier (block.id, file_path) + userId String? // NULL for public content (store, blocks, docs), userId for private content (library agents) + + // Search data + embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema) + searchableText String // Combined text for search and fallback + metadata Json @default("{}") // Content-specific metadata + + @@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key") + @@index([contentType]) + @@index([userId]) + @@index([contentType, userId]) + @@index([embedding], map: "UnifiedContentEmbedding_embedding_idx") +} + model StoreListingReview { id String @id @default(uuid()) createdAt DateTime @default(now()) @@ -998,16 +1037,16 @@ model OAuthApplication { updatedAt DateTime @updatedAt // Application metadata - name String - description String? - logoUrl String? // URL to app logo stored in GCS - clientId String @unique - clientSecret String // Hashed with Scrypt (same as API keys) - clientSecretSalt String // Salt for Scrypt hashing + name String + description String? + logoUrl String? // URL to app logo stored in GCS + clientId String @unique + clientSecret String // Hashed with Scrypt (same as API keys) + clientSecretSalt String // Salt for Scrypt hashing // OAuth configuration redirectUris String[] // Allowed callback URLs - grantTypes String[] @default(["authorization_code", "refresh_token"]) + grantTypes String[] @default(["authorization_code", "refresh_token"]) scopes APIKeyPermission[] // Which permissions the app can request // Application management From 631f1bd50ad6f342b00e0b862f3e62c28e51dc81 Mon Sep 17 00:00:00 2001 From: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com> Date: Thu, 15 Jan 2026 13:17:27 +0530 Subject: [PATCH 003/103] feat(frontend): add interactive tutorial for the new builder interface (#11458) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Changes 🏗️ This PR adds a comprehensive interactive tutorial for the new Builder UI to help users learn how to create agents. Key changes include: - Added a tutorial button to the canvas controls that launches a step-by-step guide - Created a Shepherd.js-based tutorial with multiple steps covering: - Adding blocks from the Block Menu - Understanding input and output handles - Configuring block values - Connecting blocks together - Saving and running agents - Added data-id attributes to key UI elements for tutorial targeting - Implemented tutorial state management with a new tutorialStore - Added helper functions for tutorial navigation and block manipulation - Created CSS styles for tutorial tooltips and highlights - Integrated with the Run Input dialog to support tutorial flow - Added prefetching of tutorial blocks for better performance https://github.com/user-attachments/assets/3db964b3-855c-4fcc-aa5f-6cd74ab33d7d ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Complete the tutorial from start to finish - [x] Test tutorial on different screen sizes - [x] Verify all tutorial steps work correctly - [x] Ensure tutorial can be canceled and restarted - [x] Check that tutorial doesn't interfere with normal builder functionality --- .../BuilderActions/BuilderActions.tsx | 5 +- .../components/AgentOutputs/AgentOutputs.tsx | 1 + .../components/RunGraph/RunGraph.tsx | 1 + .../components/RunGraph/useRunGraph.ts | 28 +- .../RunInputDialog/RunInputDialog.tsx | 54 +++- .../ScheduleGraph/ScheduleGraph.tsx | 1 + .../Flow/components/CustomControl.tsx | 56 +++- .../FlowEditor/handlers/NodeHandle.tsx | 2 + .../CustomNode/components/NodeContainer.tsx | 1 + .../components/NodeOutput/NodeOutput.tsx | 5 +- .../FlowEditor/nodes/FormCreator.tsx | 5 +- .../FlowEditor/nodes/OutputHandler.tsx | 6 +- .../FlowEditor/tutorial/constants.ts | 129 ++++++++ .../FlowEditor/tutorial/helpers/blocks.ts | 89 ++++++ .../FlowEditor/tutorial/helpers/canvas.ts | 83 ++++++ .../tutorial/helpers/connections.ts | 19 ++ .../FlowEditor/tutorial/helpers/dom.ts | 180 ++++++++++++ .../FlowEditor/tutorial/helpers/highlights.ts | 56 ++++ .../FlowEditor/tutorial/helpers/index.ts | 66 +++++ .../FlowEditor/tutorial/helpers/menu.ts | 25 ++ .../FlowEditor/tutorial/helpers/save.ts | 31 ++ .../FlowEditor/tutorial/helpers/state.ts | 49 ++++ .../components/FlowEditor/tutorial/icons.ts | 7 + .../components/FlowEditor/tutorial/index.ts | 81 +++++ .../FlowEditor/tutorial/steps/block-basics.ts | 114 ++++++++ .../FlowEditor/tutorial/steps/block-menu.ts | 198 +++++++++++++ .../FlowEditor/tutorial/steps/completion.ts | 51 ++++ .../tutorial/steps/configure-calculator.ts | 197 +++++++++++++ .../FlowEditor/tutorial/steps/connections.ts | 276 ++++++++++++++++++ .../FlowEditor/tutorial/steps/index.ts | 22 ++ .../FlowEditor/tutorial/steps/run.ts | 97 ++++++ .../FlowEditor/tutorial/steps/save.ts | 71 +++++ .../tutorial/steps/second-calculator.ts | 272 +++++++++++++++++ .../FlowEditor/tutorial/steps/welcome.ts | 33 +++ .../components/FlowEditor/tutorial/styles.ts | 101 +++++++ .../FlowEditor/tutorial/tutorial.css | 149 ++++++++++ .../NewControlPanel/NewBlockMenu/Block.tsx | 6 + .../NewBlockMenu/BlockMenu/BlockMenu.tsx | 13 +- .../BlockMenuSearch/BlockMenuSearch.tsx | 5 +- .../BlockMenuSearchBar/BlockMenuSearchBar.tsx | 1 + .../BlockMenuSidebar/BlockMenuSidebar.tsx | 3 + .../NewControlPanel/NewBlockMenu/MenuItem.tsx | 3 + .../NewSaveControl/NewSaveControl.tsx | 13 +- .../NewControlPanel/UndoRedoButtons.tsx | 14 +- .../legacy-builder/BuildActionBar.tsx | 1 + .../components/legacy-builder/tutorial.ts | 12 +- .../build/stores/controlPanelStore.ts | 12 + .../app/(platform)/build/stores/nodeStore.ts | 2 + .../(platform)/build/stores/tutorialStore.ts | 32 ++ .../molecules/TallyPoup/TallyPopup.tsx | 7 +- .../renderers/InputRenderer/FormRenderer.tsx | 2 +- .../base/standard/TitleField.tsx | 4 +- 52 files changed, 2649 insertions(+), 42 deletions(-) create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/constants.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/helpers/blocks.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/helpers/canvas.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/helpers/connections.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/helpers/dom.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/helpers/highlights.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/helpers/index.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/helpers/menu.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/helpers/save.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/helpers/state.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/icons.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/index.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/steps/block-basics.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/steps/block-menu.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/steps/completion.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/steps/configure-calculator.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/steps/connections.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/steps/index.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/steps/run.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/steps/save.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/steps/second-calculator.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/steps/welcome.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/styles.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/tutorial/tutorial.css create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/stores/tutorialStore.ts diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/BuilderActions.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/BuilderActions.tsx index 64eb624621..86e4a3eb9c 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/BuilderActions.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/BuilderActions.tsx @@ -10,7 +10,10 @@ export const BuilderActions = memo(() => { flowID: parseAsString, }); return ( -
+
diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/AgentOutputs/AgentOutputs.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/AgentOutputs/AgentOutputs.tsx index de56bb46b8..20493b2ca0 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/AgentOutputs/AgentOutputs.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/AgentOutputs/AgentOutputs.tsx @@ -79,6 +79,7 @@ export const AgentOutputs = ({ flowID }: { flowID: string | null }) => { diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/tutorial.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/tutorial.ts index 7adcfd9c1e..418d18782a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/tutorial.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/tutorial.ts @@ -328,16 +328,16 @@ export const startTutorial = ( title: "Press Run", text: "Start your first flow by pressing the Run button!", attachTo: { - element: '[data-testid="primary-action-run-agent"]', + element: '[data-tutorial-id="primary-action-run-agent"]', on: "top", }, advanceOn: { - selector: '[data-testid="primary-action-run-agent"]', + selector: '[data-tutorial-id="primary-action-run-agent"]', event: "click", }, buttons: [], beforeShowPromise: () => - waitForElement('[data-testid="primary-action-run-agent"]'), + waitForElement('[data-tutorial-id="primary-action-run-agent"]'), when: { hide: () => { setTimeout(() => { @@ -508,16 +508,16 @@ export const startTutorial = ( title: "Press Run Again", text: "Now, press the Run button again to execute the flow with the new Calculator Block added!", attachTo: { - element: '[data-testid="primary-action-run-agent"]', + element: '[data-tutorial-id="primary-action-run-agent"]', on: "top", }, advanceOn: { - selector: '[data-testid="primary-action-run-agent"]', + selector: '[data-tutorial-id="primary-action-run-agent"]', event: "click", }, buttons: [], beforeShowPromise: () => - waitForElement('[data-testid="primary-action-run-agent"]'), + waitForElement('[data-tutorial-id="primary-action-run-agent"]'), when: { hide: () => { setTimeout(() => { diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/controlPanelStore.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/controlPanelStore.ts index 6847d769bd..5dcb11c121 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/stores/controlPanelStore.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/controlPanelStore.ts @@ -3,20 +3,32 @@ import { create } from "zustand"; type ControlPanelStore = { blockMenuOpen: boolean; saveControlOpen: boolean; + forceOpenBlockMenu: boolean; + forceOpenSave: boolean; + setBlockMenuOpen: (open: boolean) => void; setSaveControlOpen: (open: boolean) => void; + setForceOpenBlockMenu: (force: boolean) => void; + setForceOpenSave: (force: boolean) => void; + reset: () => void; }; export const useControlPanelStore = create((set) => ({ blockMenuOpen: false, saveControlOpen: false, + forceOpenBlockMenu: false, + forceOpenSave: false, + setForceOpenBlockMenu: (force) => set({ forceOpenBlockMenu: force }), + setForceOpenSave: (force) => set({ forceOpenSave: force }), setBlockMenuOpen: (open) => set({ blockMenuOpen: open }), setSaveControlOpen: (open) => set({ saveControlOpen: open }), reset: () => set({ blockMenuOpen: false, saveControlOpen: false, + forceOpenBlockMenu: false, + forceOpenSave: false, }), })); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts index 7f9deaa993..cb41da9463 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts @@ -47,6 +47,7 @@ const dragStartPositions: Record = {}; type NodeStore = { nodes: CustomNode[]; nodeCounter: number; + setNodeCounter: (nodeCounter: number) => void; nodeAdvancedStates: Record; setNodes: (nodes: CustomNode[]) => void; onNodesChange: (changes: NodeChange[]) => void; @@ -116,6 +117,7 @@ export const useNodeStore = create((set, get) => ({ nodes: [], setNodes: (nodes) => set({ nodes }), nodeCounter: 0, + setNodeCounter: (nodeCounter) => set({ nodeCounter }), nodeAdvancedStates: {}, incrementNodeCounter: () => set((state) => ({ diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/tutorialStore.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/tutorialStore.ts new file mode 100644 index 0000000000..581dda44c9 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/tutorialStore.ts @@ -0,0 +1,32 @@ +import { create } from "zustand"; + +type TutorialStore = { + isTutorialRunning: boolean; + setIsTutorialRunning: (isTutorialRunning: boolean) => void; + + currentStep: number; + setCurrentStep: (currentStep: number) => void; + + // Force open the run input dialog from the tutorial + forceOpenRunInputDialog: boolean; + setForceOpenRunInputDialog: (forceOpen: boolean) => void; + + // Track input values filled in the dialog + tutorialInputValues: Record; + setTutorialInputValues: (values: Record) => void; +}; + +export const useTutorialStore = create((set) => ({ + isTutorialRunning: false, + setIsTutorialRunning: (isTutorialRunning) => set({ isTutorialRunning }), + + currentStep: 0, + setCurrentStep: (currentStep) => set({ currentStep }), + + forceOpenRunInputDialog: false, + setForceOpenRunInputDialog: (forceOpen) => + set({ forceOpenRunInputDialog: forceOpen }), + + tutorialInputValues: {}, + setTutorialInputValues: (values) => set({ tutorialInputValues: values }), +})); diff --git a/autogpt_platform/frontend/src/components/molecules/TallyPoup/TallyPopup.tsx b/autogpt_platform/frontend/src/components/molecules/TallyPoup/TallyPopup.tsx index d8b7a6027d..be00982084 100644 --- a/autogpt_platform/frontend/src/components/molecules/TallyPoup/TallyPopup.tsx +++ b/autogpt_platform/frontend/src/components/molecules/TallyPoup/TallyPopup.tsx @@ -3,9 +3,14 @@ import React from "react"; import { useTallyPopup } from "./useTallyPopup"; import { Button } from "@/components/atoms/Button/Button"; +import { usePathname, useSearchParams } from "next/navigation"; export function TallyPopupSimple() { const { state, handlers } = useTallyPopup(); + const searchParams = useSearchParams(); + const pathname = usePathname(); + const isNewBuilder = + pathname.includes("build") && searchParams.get("view") === "new"; if (state.isFormVisible) { return null; @@ -13,7 +18,7 @@ export function TallyPopupSimple() { return (
- {state.showTutorial && ( + {state.showTutorial && !isNewBuilder && ( diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/APIKeyCredentialsModal/useAPIKeyCredentialsModal.ts b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/APIKeyCredentialsModal/useAPIKeyCredentialsModal.ts index 391633bed5..72599a2e79 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/APIKeyCredentialsModal/useAPIKeyCredentialsModal.ts +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/APIKeyCredentialsModal/useAPIKeyCredentialsModal.ts @@ -1,11 +1,11 @@ -import { z } from "zod"; -import { useForm, type UseFormReturn } from "react-hook-form"; -import { zodResolver } from "@hookform/resolvers/zod"; import useCredentials from "@/hooks/useCredentials"; import { BlockIOCredentialsSubSchema, CredentialsMetaInput, } from "@/lib/autogpt-server-api/types"; +import { zodResolver } from "@hookform/resolvers/zod"; +import { useForm, type UseFormReturn } from "react-hook-form"; +import { z } from "zod"; export type APIKeyFormValues = { apiKey: string; @@ -40,12 +40,24 @@ export function useAPIKeyCredentialsModal({ expiresAt: z.string().optional(), }); + function getDefaultExpirationDate(): string { + const tomorrow = new Date(); + tomorrow.setDate(tomorrow.getDate() + 1); + tomorrow.setHours(0, 0, 0, 0); + const year = tomorrow.getFullYear(); + const month = String(tomorrow.getMonth() + 1).padStart(2, "0"); + const day = String(tomorrow.getDate()).padStart(2, "0"); + const hours = String(tomorrow.getHours()).padStart(2, "0"); + const minutes = String(tomorrow.getMinutes()).padStart(2, "0"); + return `${year}-${month}-${day}T${hours}:${minutes}`; + } + const form = useForm({ resolver: zodResolver(formSchema), defaultValues: { apiKey: "", title: "", - expiresAt: "", + expiresAt: getDefaultExpirationDate(), }, }); diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialRow/CredentialRow.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialRow/CredentialRow.tsx index 2d0358aacb..dc69c34d93 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialRow/CredentialRow.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialRow/CredentialRow.tsx @@ -7,7 +7,8 @@ import { DropdownMenuTrigger, } from "@/components/molecules/DropdownMenu/DropdownMenu"; import { cn } from "@/lib/utils"; -import { CaretDown, DotsThreeVertical } from "@phosphor-icons/react"; +import { CaretDownIcon, DotsThreeVertical } from "@phosphor-icons/react"; +import { useEffect, useRef, useState } from "react"; import { fallbackIcon, getCredentialDisplayName, @@ -26,7 +27,7 @@ type CredentialRowProps = { provider: string; displayName: string; onSelect: () => void; - onDelete: () => void; + onDelete?: () => void; readOnly?: boolean; showCaret?: boolean; asSelectTrigger?: boolean; @@ -47,11 +48,32 @@ export function CredentialRow({ }: CredentialRowProps) { const ProviderIcon = providerIcons[provider] || fallbackIcon; const isNodeVariant = variant === "node"; + const containerRef = useRef(null); + const [showMaskedKey, setShowMaskedKey] = useState(true); + + useEffect(() => { + const container = containerRef.current; + if (!container) return; + + const resizeObserver = new ResizeObserver((entries) => { + for (const entry of entries) { + const width = entry.contentRect.width; + setShowMaskedKey(width >= 360); + } + }); + + resizeObserver.observe(container); + + return () => { + resizeObserver.disconnect(); + }; + }, []); return (
{getCredentialDisplayName(credential, displayName)} - {!(asSelectTrigger && isNodeVariant) && ( + {!(asSelectTrigger && isNodeVariant) && showMaskedKey && ( {"*".repeat(MASKED_KEY_LENGTH)} )}
- {showCaret && !asSelectTrigger && ( - + {(showCaret || (asSelectTrigger && !readOnly)) && ( + )} - {!readOnly && !showCaret && !asSelectTrigger && ( + {!readOnly && !showCaret && !asSelectTrigger && onDelete && ( + + )} + + {hasSystemCredentials && ( + + + +
+ System credentials +
+
+ +
+ {showTitle && ( +
+ + {displayName} credentials + {isOptional && ( + + (optional) + + )} + + {schema.description && ( + + )} +
+ )} + {credentialsInAccordion.length > 0 && ( + + )} + {isSystemProvider && ( + + )} +
+
+
+
+ )} + + {!showUserCredentialsOutsideAccordion && !isSystemProvider && ( + + )} + + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsFlatView/CredentialsFlatView.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsFlatView/CredentialsFlatView.tsx new file mode 100644 index 0000000000..4d220a5359 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsFlatView/CredentialsFlatView.tsx @@ -0,0 +1,134 @@ +import { Button } from "@/components/atoms/Button/Button"; +import { Text } from "@/components/atoms/Text/Text"; +import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip"; +import { + BlockIOCredentialsSubSchema, + CredentialsMetaInput, +} from "@/lib/autogpt-server-api/types"; +import { ExclamationTriangleIcon } from "@radix-ui/react-icons"; +import { CredentialRow } from "../CredentialRow/CredentialRow"; +import { CredentialsSelect } from "../CredentialsSelect/CredentialsSelect"; + +type Credential = { + id: string; + title?: string; + username?: string; + type: string; + provider: string; +}; + +type Props = { + schema: BlockIOCredentialsSubSchema; + provider: string; + displayName: string; + credentials: Credential[]; + selectedCredential?: CredentialsMetaInput; + actionButtonText: string; + isOptional: boolean; + showTitle: boolean; + readOnly: boolean; + variant: "default" | "node"; + onSelectCredential: (credentialId: string) => void; + onClearCredential: () => void; + onAddCredential: () => void; +}; + +export function CredentialsFlatView({ + schema, + provider, + displayName, + credentials, + selectedCredential, + actionButtonText, + isOptional, + showTitle, + readOnly, + variant, + onSelectCredential, + onClearCredential, + onAddCredential, +}: Props) { + const hasCredentials = credentials.length > 0; + + return ( + <> + {showTitle && ( +
+ + + {displayName} credentials + {isOptional && ( + + (optional) + + )} + {!isOptional && !selectedCredential && ( + + + required + + )} + + + {schema.description && ( + + )} +
+ )} + + {hasCredentials ? ( + <> + {(credentials.length > 1 || isOptional) && !readOnly ? ( + + ) : ( +
+ {credentials.map((credential) => ( + onSelectCredential(credential.id)} + readOnly={readOnly} + /> + ))} +
+ )} + {!readOnly && ( + + )} + + ) : ( + !readOnly && ( + + ) + )} + + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsSelect/CredentialsSelect.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsSelect/CredentialsSelect.tsx index 6e1ec2afb1..18e772dd00 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsSelect/CredentialsSelect.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsSelect/CredentialsSelect.tsx @@ -1,14 +1,4 @@ -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/__legacy__/ui/select"; -import { Text } from "@/components/atoms/Text/Text"; -import { CredentialsMetaInput } from "@/lib/autogpt-server-api/types"; -import { cn } from "@/lib/utils"; -import { useEffect } from "react"; +import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput"; import { getCredentialDisplayName } from "../../helpers"; import { CredentialRow } from "../CredentialRow/CredentialRow"; @@ -42,76 +32,77 @@ export function CredentialsSelect({ allowNone = true, variant = "default", }: Props) { - // Auto-select first credential if none is selected (only if allowNone is false) - useEffect(() => { - if (!allowNone && !selectedCredentials && credentials.length > 0) { - onSelectCredential(credentials[0].id); - } - }, [allowNone, selectedCredentials, credentials, onSelectCredential]); - - const handleValueChange = (value: string) => { + function handleValueChange(e: React.ChangeEvent) { + const value = e.target.value; if (value === "__none__") { onClearCredential?.(); } else { onSelectCredential(value); } - }; + } + + const selectedCredential = selectedCredentials + ? credentials.find((c) => c.id === selectedCredentials.id) + : null; + + const displayCredential = selectedCredential + ? { + id: selectedCredential.id, + title: selectedCredential.title, + username: selectedCredential.username, + type: selectedCredential.type, + provider: selectedCredential.provider, + } + : allowNone + ? { + id: "__none__", + title: "None (skip this credential)", + type: "none", + provider: provider, + } + : { + id: "__placeholder__", + title: "Select credential", + type: "placeholder", + provider: provider, + }; return (
- - {selectedCredentials ? ( - - {}} - onDelete={() => {}} - readOnly={readOnly} - asSelectTrigger={true} - variant={variant} - /> - + {allowNone ? ( + ) : ( - - )} - - - {allowNone && ( - -
- - None (skip this credential) - -
-
+ )} {credentials.map((credential) => ( - -
- - {getCredentialDisplayName(credential, displayName)} - -
-
+ ))} -
- + +
+ {}} + onDelete={() => {}} + readOnly={readOnly} + asSelectTrigger={true} + variant={variant} + /> +
+
); } diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/helpers.ts index 4cca825747..ef965d5382 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/helpers.ts @@ -99,4 +99,30 @@ export function getCredentialDisplayName( } export const OAUTH_TIMEOUT_MS = 5 * 60 * 1000; -export const MASKED_KEY_LENGTH = 30; +export const MASKED_KEY_LENGTH = 15; + +export function isSystemCredential(credential: { + title?: string | null; + is_system?: boolean; +}): boolean { + if (credential.is_system === true) return true; + if (!credential.title) return false; + const titleLower = credential.title.toLowerCase(); + return ( + titleLower.includes("system") || + titleLower.startsWith("use credits for") || + titleLower.includes("use credits") + ); +} + +export function filterSystemCredentials< + T extends { title?: string; is_system?: boolean }, +>(credentials: T[]): T[] { + return credentials.filter((cred) => !isSystemCredential(cred)); +} + +export function getSystemCredentials< + T extends { title?: string; is_system?: boolean }, +>(credentials: T[]): T[] { + return credentials.filter((cred) => isSystemCredential(cred)); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/useCredentialsInput.ts b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/useCredentialsInput.ts index c780ffeffc..8876ddcba9 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/useCredentialsInput.ts +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/useCredentialsInput.ts @@ -6,9 +6,11 @@ import { CredentialsMetaInput, } from "@/lib/autogpt-server-api/types"; import { useQueryClient } from "@tanstack/react-query"; -import { useEffect, useMemo, useState } from "react"; +import { useEffect, useRef, useState } from "react"; import { + filterSystemCredentials, getActionButtonText, + getSystemCredentials, OAUTH_TIMEOUT_MS, OAuthPopupResultMessage, } from "./helpers"; @@ -54,6 +56,7 @@ export function useCredentialsInput({ const api = useBackendAPI(); const queryClient = useQueryClient(); const credentials = useCredentials(schema, siblingInputs); + const hasAttemptedAutoSelect = useRef(false); const deleteCredentialsMutation = useDeleteV1DeleteCredentials({ mutation: { @@ -82,38 +85,51 @@ export function useCredentialsInput({ useEffect(() => { if (readOnly) return; if (!credentials || !("savedCredentials" in credentials)) return; + const availableCreds = credentials.savedCredentials; if ( selectedCredential && - !credentials.savedCredentials.some((c) => c.id === selectedCredential.id) + !availableCreds.some((c) => c.id === selectedCredential.id) ) { onSelectCredential(undefined); + // Reset auto-selection flag so it can run again after unsetting invalid credential + hasAttemptedAutoSelect.current = false; } }, [credentials, selectedCredential, onSelectCredential, readOnly]); - // The available credential, if there is only one - const singleCredential = useMemo(() => { - if (!credentials || !("savedCredentials" in credentials)) { - return null; - } - - return credentials.savedCredentials.length === 1 - ? credentials.savedCredentials[0] - : null; - }, [credentials]); - - // Auto-select the one available credential (only if not optional) + // Auto-select the first available credential on initial mount + // Once a user has made a selection, we don't override it useEffect(() => { if (readOnly) return; - if (isOptional) return; // Don't auto-select when credential is optional - if (singleCredential && !selectedCredential) { - onSelectCredential(singleCredential); + if (!credentials || !("savedCredentials" in credentials)) return; + + // If already selected, don't auto-select + if (selectedCredential?.id) return; + + // Only attempt auto-selection once + if (hasAttemptedAutoSelect.current) return; + hasAttemptedAutoSelect.current = true; + + // If optional, don't auto-select (user can choose "None") + if (isOptional) return; + + const savedCreds = credentials.savedCredentials; + + // Auto-select the first credential if any are available + if (savedCreds.length > 0) { + const cred = savedCreds[0]; + onSelectCredential({ + id: cred.id, + type: cred.type, + provider: credentials.provider, + title: (cred as any).title, + }); } }, [ - singleCredential, - selectedCredential, - onSelectCredential, + credentials, + selectedCredential?.id, readOnly, isOptional, + onSelectCredential, ]); if ( @@ -135,8 +151,13 @@ export function useCredentialsInput({ supportsHostScoped, savedCredentials, oAuthCallback, + isSystemProvider, } = credentials; + // Split credentials into user and system + const userCredentials = filterSystemCredentials(savedCredentials); + const systemCredentials = getSystemCredentials(savedCredentials); + async function handleOAuthLogin() { setOAuthError(null); const { login_url, state_token } = await api.oAuthLogin( @@ -291,7 +312,10 @@ export function useCredentialsInput({ supportsOAuth2, supportsUserPassword, supportsHostScoped, - credentialsToShow: savedCredentials, + isSystemProvider, + userCredentials, + systemCredentials, + allCredentials: savedCredentials, selectedCredential, oAuthError, isAPICredentialsModalOpen, @@ -306,7 +330,7 @@ export function useCredentialsInput({ supportsApiKey, supportsUserPassword, supportsHostScoped, - savedCredentials.length > 0, + userCredentials.length > 0, ), setAPICredentialsModalOpen, setUserPasswordCredentialsModalOpen, diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/RunAgentModal.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/RunAgentModal.tsx index e53f31a349..cd0c666be6 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/RunAgentModal.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/RunAgentModal.tsx @@ -12,7 +12,7 @@ import { TooltipTrigger, } from "@/components/atoms/Tooltip/BaseTooltip"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; -import { useState } from "react"; +import { useEffect, useRef, useState } from "react"; import { ScheduleAgentModal } from "../ScheduleAgentModal/ScheduleAgentModal"; import { ModalHeader } from "./components/ModalHeader/ModalHeader"; import { ModalRunSection } from "./components/ModalRunSection/ModalRunSection"; @@ -82,6 +82,8 @@ export function RunAgentModal({ }); const [isScheduleModalOpen, setIsScheduleModalOpen] = useState(false); + const [hasOverflow, setHasOverflow] = useState(false); + const contentRef = useRef(null); const hasAnySetupFields = Object.keys(agentInputFields || {}).length > 0 || @@ -89,6 +91,43 @@ export function RunAgentModal({ const isTriggerRunType = defaultRunType.includes("trigger"); + useEffect(() => { + if (!isOpen) return; + + function checkOverflow() { + if (!contentRef.current) return; + const scrollableParent = contentRef.current + .closest("[data-dialog-content]") + ?.querySelector('[class*="overflow-y-auto"]'); + if (scrollableParent) { + setHasOverflow( + scrollableParent.scrollHeight > scrollableParent.clientHeight, + ); + } + } + + const timeoutId = setTimeout(checkOverflow, 100); + const resizeObserver = new ResizeObserver(checkOverflow); + if (contentRef.current) { + const scrollableParent = contentRef.current + .closest("[data-dialog-content]") + ?.querySelector('[class*="overflow-y-auto"]'); + if (scrollableParent) { + resizeObserver.observe(scrollableParent); + } + } + + return () => { + clearTimeout(timeoutId); + resizeObserver.disconnect(); + }; + }, [ + isOpen, + hasAnySetupFields, + agentInputFields, + agentCredentialsInputFields, + ]); + function handleInputChange(key: string, value: string) { setInputValues((prev) => ({ ...prev, @@ -134,91 +173,97 @@ export function RunAgentModal({ > {triggerSlot} - {/* Header */} - +
+
+ {/* Header */} + - {/* Content */} - {hasAnySetupFields ? ( -
- - - + {/* Content */} + {hasAnySetupFields ? ( +
+ + + +
+ ) : null}
- ) : null} - -
- {isTriggerRunType ? null : !allRequiredInputsAreSet ? ( - - - - - - - - -

- Please set up all required inputs and credentials before - scheduling -

-
-
-
- ) : ( - - )} - +
+ {isTriggerRunType ? null : !allRequiredInputsAreSet ? ( + + + + + + + + +

+ Please set up all required inputs and credentials + before scheduling +

+
+
+
+ ) : ( + + )} + +
+ -
- -
+ +
diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/CredentialsGroupedView/CredentialsGroupedView.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/CredentialsGroupedView/CredentialsGroupedView.tsx new file mode 100644 index 0000000000..05b2966af7 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/CredentialsGroupedView/CredentialsGroupedView.tsx @@ -0,0 +1,181 @@ +import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput"; +import { + Accordion, + AccordionContent, + AccordionItem, + AccordionTrigger, +} from "@/components/molecules/Accordion/Accordion"; +import { CredentialsProvidersContext } from "@/providers/agent-credentials/credentials-provider"; +import { SlidersHorizontal } from "@phosphor-icons/react"; +import { useContext, useEffect, useMemo, useRef } from "react"; +import { useRunAgentModalContext } from "../../context"; +import { + areSystemCredentialProvidersLoading, + CredentialField, + findSavedCredentialByProviderAndType, + hasMissingRequiredSystemCredentials, + splitCredentialFieldsBySystem, +} from "../helpers"; + +type Props = { + credentialFields: CredentialField[]; + requiredCredentials: Set; +}; + +export function CredentialsGroupedView({ + credentialFields, + requiredCredentials, +}: Props) { + const allProviders = useContext(CredentialsProvidersContext); + const { inputCredentials, setInputCredentialsValue, inputValues } = + useRunAgentModalContext(); + + const { userCredentialFields, systemCredentialFields } = useMemo( + () => + splitCredentialFieldsBySystem( + credentialFields, + allProviders, + inputCredentials, + ), + [credentialFields, allProviders, inputCredentials], + ); + + const hasSystemCredentials = systemCredentialFields.length > 0; + const hasUserCredentials = userCredentialFields.length > 0; + const hasAttemptedAutoSelect = useRef(false); + + const isLoadingProviders = useMemo( + () => + areSystemCredentialProvidersLoading(systemCredentialFields, allProviders), + [systemCredentialFields, allProviders], + ); + + const hasMissingSystemCredentials = useMemo(() => { + if (isLoadingProviders) return false; + return hasMissingRequiredSystemCredentials( + systemCredentialFields, + requiredCredentials, + inputCredentials, + allProviders, + ); + }, [ + isLoadingProviders, + systemCredentialFields, + requiredCredentials, + inputCredentials, + allProviders, + ]); + + useEffect(() => { + if (hasAttemptedAutoSelect.current) return; + if (!hasSystemCredentials) return; + if (isLoadingProviders) return; + + for (const [key, schema] of systemCredentialFields) { + const alreadySelected = inputCredentials?.[key]; + const isRequired = requiredCredentials.has(key); + if (alreadySelected || !isRequired) continue; + + const providerNames = schema.credentials_provider || []; + const credentialTypes = schema.credentials_types || []; + const requiredScopes = schema.credentials_scopes; + const savedCredential = findSavedCredentialByProviderAndType( + providerNames, + credentialTypes, + requiredScopes, + allProviders, + ); + + if (savedCredential) { + setInputCredentialsValue(key, { + id: savedCredential.id, + provider: savedCredential.provider, + type: savedCredential.type, + title: (savedCredential as { title?: string }).title, + }); + } + } + + hasAttemptedAutoSelect.current = true; + }, [ + allProviders, + hasSystemCredentials, + systemCredentialFields, + requiredCredentials, + inputCredentials, + setInputCredentialsValue, + isLoadingProviders, + ]); + + return ( +
+ {hasUserCredentials && ( + <> + {userCredentialFields.map( + ([key, inputSubSchema]: CredentialField) => { + const selectedCred = inputCredentials?.[key]; + + return ( + { + setInputCredentialsValue(key, value); + }} + siblingInputs={inputValues} + isOptional={!requiredCredentials.has(key)} + /> + ); + }, + )} + + )} + + {hasSystemCredentials && ( + + + +
+ System credentials + {hasMissingSystemCredentials && ( + (missing) + )} +
+
+ +
+ {systemCredentialFields.map( + ([key, inputSubSchema]: CredentialField) => { + const selectedCred = inputCredentials?.[key]; + + return ( + { + setInputCredentialsValue(key, value); + }} + siblingInputs={inputValues} + isOptional={!requiredCredentials.has(key)} + /> + ); + }, + )} +
+
+
+
+ )} +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/ModalRunSection.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/ModalRunSection.tsx index aba4caee7a..7660de7c15 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/ModalRunSection.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/ModalRunSection.tsx @@ -1,8 +1,9 @@ -import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs"; import { Input } from "@/components/atoms/Input/Input"; import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip"; +import { useMemo } from "react"; import { RunAgentInputs } from "../../../RunAgentInputs/RunAgentInputs"; import { useRunAgentModalContext } from "../../context"; +import { CredentialsGroupedView } from "../CredentialsGroupedView/CredentialsGroupedView"; import { ModalSection } from "../ModalSection/ModalSection"; import { WebhookTriggerBanner } from "../WebhookTriggerBanner/WebhookTriggerBanner"; @@ -17,15 +18,16 @@ export function ModalRunSection() { inputValues, setInputValue, agentInputFields, - inputCredentials, - setInputCredentialsValue, agentCredentialsInputFields, } = useRunAgentModalContext(); const inputFields = Object.entries(agentInputFields || {}); - const credentialFields = Object.entries(agentCredentialsInputFields || {}); - // Get the list of required credentials from the schema + const credentialFields = useMemo(() => { + if (!agentCredentialsInputFields) return []; + return Object.entries(agentCredentialsInputFields); + }, [agentCredentialsInputFields]); + const requiredCredentials = new Set( (agent.credentials_input_schema?.required as string[]) || [], ); @@ -97,24 +99,10 @@ export function ModalRunSection() { title="Task Credentials" subtitle="These are the credentials the agent will use to perform this task" > -
- {Object.entries(agentCredentialsInputFields || {}).map( - ([key, inputSubSchema]) => ( - - setInputCredentialsValue(key, value) - } - siblingInputs={inputValues} - isOptional={!requiredCredentials.has(key)} - /> - ), - )} -
+ ) : null}
diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/helpers.ts new file mode 100644 index 0000000000..61267f733d --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/helpers.ts @@ -0,0 +1,210 @@ +import { CredentialsProvidersContextType } from "@/providers/agent-credentials/credentials-provider"; +import { getSystemCredentials } from "../../CredentialsInputs/helpers"; + +export type CredentialField = [string, any]; + +type SavedCredential = { + id: string; + provider: string; + type: string; + title?: string | null; +}; + +function hasRequiredScopes( + credential: { scopes?: string[]; type: string }, + requiredScopes?: string[], +) { + if (credential.type !== "oauth2") return true; + if (!requiredScopes || requiredScopes.length === 0) return true; + const grantedScopes = new Set(credential.scopes || []); + for (const scope of requiredScopes) { + if (!grantedScopes.has(scope)) return false; + } + return true; +} + +export function splitCredentialFieldsBySystem( + credentialFields: CredentialField[], + allProviders: CredentialsProvidersContextType | null, + inputCredentials?: Record, +) { + if (!allProviders || credentialFields.length === 0) { + return { + userCredentialFields: [] as CredentialField[], + systemCredentialFields: [] as CredentialField[], + }; + } + + const userFields: CredentialField[] = []; + const systemFields: CredentialField[] = []; + + for (const [key, schema] of credentialFields) { + const providerNames = schema.credentials_provider || []; + const isSystemField = providerNames.some((providerName: string) => { + const providerData = allProviders[providerName]; + return providerData?.isSystemProvider === true; + }); + + if (isSystemField) { + systemFields.push([key, schema]); + } else { + userFields.push([key, schema]); + } + } + + const sortByUnsetFirst = (a: CredentialField, b: CredentialField) => { + const aIsSet = Boolean(inputCredentials?.[a[0]]); + const bIsSet = Boolean(inputCredentials?.[b[0]]); + + if (aIsSet === bIsSet) return 0; + return aIsSet ? 1 : -1; + }; + + return { + userCredentialFields: userFields.sort(sortByUnsetFirst), + systemCredentialFields: systemFields.sort(sortByUnsetFirst), + }; +} + +export function areSystemCredentialProvidersLoading( + systemCredentialFields: CredentialField[], + allProviders: CredentialsProvidersContextType | null, +): boolean { + if (!systemCredentialFields.length) return false; + if (allProviders === null) return true; + + for (const [_, schema] of systemCredentialFields) { + const providerNames = schema.credentials_provider || []; + const hasAllProviders = providerNames.every( + (providerName: string) => allProviders?.[providerName] !== undefined, + ); + if (!hasAllProviders) return true; + } + + return false; +} + +export function hasMissingRequiredSystemCredentials( + systemCredentialFields: CredentialField[], + requiredCredentials: Set, + inputCredentials?: Record, + allProviders?: CredentialsProvidersContextType | null, +) { + if (systemCredentialFields.length === 0) return false; + if (allProviders === null) return false; + + return systemCredentialFields.some(([key, schema]) => { + if (!requiredCredentials.has(key)) return false; + if (inputCredentials?.[key]) return false; + + const providerNames = schema.credentials_provider || []; + const credentialTypes = schema.credentials_types || []; + const requiredScopes = schema.credentials_scopes; + + return !hasAvailableSystemCredential( + providerNames, + credentialTypes, + requiredScopes, + allProviders, + ); + }); +} + +function hasAvailableSystemCredential( + providerNames: string[], + credentialTypes: string[], + requiredScopes: string[] | undefined, + allProviders: CredentialsProvidersContextType | null | undefined, +) { + if (!allProviders) return false; + + for (const providerName of providerNames) { + const providerData = allProviders[providerName]; + if (!providerData) continue; + + const systemCredentials = getSystemCredentials( + providerData.savedCredentials ?? [], + ); + + for (const credential of systemCredentials) { + const typeMatches = + credentialTypes.length === 0 || + credentialTypes.includes(credential.type); + const scopesMatch = hasRequiredScopes(credential, requiredScopes); + + if (!typeMatches) continue; + if (!scopesMatch) continue; + + return true; + } + + const allCredentials = providerData.savedCredentials ?? []; + for (const credential of allCredentials) { + const typeMatches = + credentialTypes.length === 0 || + credentialTypes.includes(credential.type); + const scopesMatch = hasRequiredScopes(credential, requiredScopes); + + if (!typeMatches) continue; + if (!scopesMatch) continue; + + return true; + } + } + + return false; +} + +export function findSavedCredentialByProviderAndType( + providerNames: string[], + credentialTypes: string[], + requiredScopes: string[] | undefined, + allProviders: CredentialsProvidersContextType | null, +): SavedCredential | undefined { + for (const providerName of providerNames) { + const providerData = allProviders?.[providerName]; + if (!providerData) continue; + + const systemCredentials = getSystemCredentials( + providerData.savedCredentials ?? [], + ); + + const matchingCredentials: SavedCredential[] = []; + + for (const credential of systemCredentials) { + const typeMatches = + credentialTypes.length === 0 || + credentialTypes.includes(credential.type); + const scopesMatch = hasRequiredScopes(credential, requiredScopes); + + if (!typeMatches) continue; + if (!scopesMatch) continue; + + matchingCredentials.push(credential as SavedCredential); + } + + if (matchingCredentials.length === 0) { + const allCredentials = providerData.savedCredentials ?? []; + for (const credential of allCredentials) { + const typeMatches = + credentialTypes.length === 0 || + credentialTypes.includes(credential.type); + const scopesMatch = hasRequiredScopes(credential, requiredScopes); + + if (!typeMatches) continue; + if (!scopesMatch) continue; + + matchingCredentials.push(credential as SavedCredential); + } + } + + if (matchingCredentials.length === 1) { + return matchingCredentials[0]; + } + if (matchingCredentials.length > 1) { + return undefined; + } + } + + return undefined; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/useAgentRunModal.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/useAgentRunModal.tsx index eb32083004..3aafd4be50 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/useAgentRunModal.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/useAgentRunModal.tsx @@ -11,9 +11,18 @@ import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; import { LibraryAgentPreset } from "@/app/api/__generated__/models/libraryAgentPreset"; import { useToast } from "@/components/molecules/Toast/use-toast"; import { isEmpty } from "@/lib/utils"; +import { CredentialsProvidersContext } from "@/providers/agent-credentials/credentials-provider"; import { analytics } from "@/services/analytics"; import { useQueryClient } from "@tanstack/react-query"; -import { useCallback, useEffect, useMemo, useState } from "react"; +import { + useCallback, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from "react"; +import { getSystemCredentials } from "../CredentialsInputs/helpers"; import { showExecutionErrorToast } from "./errorHelpers"; export type RunVariant = @@ -42,8 +51,10 @@ export function useAgentRunModal( const [inputCredentials, setInputCredentials] = useState>( callbacks?.initialInputCredentials || {}, ); + const [presetName, setPresetName] = useState(""); const [presetDescription, setPresetDescription] = useState(""); + const hasInitializedSystemCreds = useRef(false); // Determine the default run type based on agent capabilities const defaultRunType: RunVariant = agent.trigger_setup_info @@ -58,6 +69,91 @@ export function useAgentRunModal( setInputCredentials(callbacks?.initialInputCredentials || {}); }, [callbacks?.initialInputValues, callbacks?.initialInputCredentials]); + const allProviders = useContext(CredentialsProvidersContext); + + // Initialize credentials with default system credentials + useEffect(() => { + if (!allProviders || !agent.credentials_input_schema?.properties) return; + if (callbacks?.initialInputCredentials) { + hasInitializedSystemCreds.current = true; + return; + } + if (hasInitializedSystemCreds.current) return; + + const properties = agent.credentials_input_schema.properties as Record< + string, + any + >; + + setInputCredentials((currentCreds) => { + const credsToAdd: Record = {}; + + for (const [key, schema] of Object.entries(properties)) { + if (currentCreds[key]) continue; + + const providerNames = schema.credentials_provider || []; + const supportedTypes = schema.credentials_types || []; + const requiredScopes = schema.credentials_scopes; + + for (const providerName of providerNames) { + const providerData = allProviders[providerName]; + if (!providerData) continue; + + const systemCreds = getSystemCredentials( + providerData.savedCredentials ?? [], + ); + const matchingSystemCreds = systemCreds.filter((cred) => { + if (!supportedTypes.includes(cred.type)) return false; + + if ( + cred.type === "oauth2" && + requiredScopes && + requiredScopes.length > 0 + ) { + const grantedScopes = new Set(cred.scopes || []); + const hasAllRequiredScopes = requiredScopes.every( + (scope: string) => grantedScopes.has(scope), + ); + if (!hasAllRequiredScopes) return false; + } + + return true; + }); + + if (matchingSystemCreds.length === 1) { + const systemCred = matchingSystemCreds[0]; + credsToAdd[key] = { + id: systemCred.id, + type: systemCred.type, + provider: providerName, + title: systemCred.title, + }; + break; + } + } + } + + if (Object.keys(credsToAdd).length > 0) { + hasInitializedSystemCreds.current = true; + return { + ...currentCreds, + ...credsToAdd, + }; + } + + return currentCreds; + }); + }, [ + allProviders, + agent.credentials_input_schema, + callbacks?.initialInputCredentials, + ]); + + // Reset initialization flag when modal closes/opens or agent changes + useEffect(() => { + hasInitializedSystemCreds.current = false; + }, [isOpen, agent.graph_id]); + // API mutations const executeGraphMutation = usePostV1ExecuteGraphAgent({ mutation: { @@ -66,7 +162,6 @@ export function useAgentRunModal( toast({ title: "Agent execution started", }); - // Invalidate runs list for this graph queryClient.invalidateQueries({ queryKey: getGetV1ListGraphExecutionsQueryKey(agent.graph_id), }); @@ -163,14 +258,10 @@ export function useAgentRunModal( }, [agentInputSchema.required, inputValues]); const [allCredentialsAreSet, missingCredentials] = useMemo(() => { - // Only check required credentials from schema, not all properties - // Credentials marked as optional in node metadata won't be in the required array const requiredCredentials = new Set( (agent.credentials_input_schema?.required as string[]) || [], ); - // Check if required credentials have valid id (not just key existence) - // A credential is valid only if it has an id field set const missing = [...requiredCredentials].filter((key) => { const cred = inputCredentials[key]; return !cred || !cred.id; @@ -184,7 +275,6 @@ export function useAgentRunModal( [agentCredentialsInputFields], ); - // Final readiness flag combining inputs + credentials when credentials are shown const allRequiredInputsAreSet = useMemo( () => allRequiredInputsAreSetRaw && @@ -223,7 +313,6 @@ export function useAgentRunModal( defaultRunType === "automatic-trigger" || defaultRunType === "manual-trigger" ) { - // Setup trigger if (!presetName.trim()) { toast({ title: "⚠️ Trigger name required", @@ -244,9 +333,6 @@ export function useAgentRunModal( }, }); } else { - // Manual execution - // Filter out incomplete credentials (optional ones not selected) - // Only send credentials that have a valid id field const validCredentials = Object.fromEntries( Object.entries(inputCredentials).filter(([_, cred]) => cred && cred.id), ); @@ -280,41 +366,24 @@ export function useAgentRunModal( }, [agentInputFields]); return { - // UI state isOpen, setIsOpen, - - // Run mode defaultRunType: defaultRunType as RunVariant, - - // Form: regular inputs inputValues, setInputValues, - - // Form: credentials inputCredentials, setInputCredentials, - - // Preset/trigger labels presetName, presetDescription, setPresetName, setPresetDescription, - - // Validation/readiness allRequiredInputsAreSet, missingInputs, - - // Schemas for rendering agentInputFields, agentCredentialsInputFields, hasInputFields, - - // Async states isExecuting: executeGraphMutation.isPending, isSettingUpTrigger: setupTriggerMutation.isPending, - - // Actions handleRun, }; } diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/AgentSettingsButton.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/AgentSettingsButton.tsx index 11dcbd943f..95fdf826a2 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/AgentSettingsButton.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/AgentSettingsButton.tsx @@ -1,37 +1,17 @@ import { Button } from "@/components/atoms/Button/Button"; +import { Text } from "@/components/atoms/Text/Text"; import { GearIcon } from "@phosphor-icons/react"; -import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; -import { useAgentSafeMode } from "@/hooks/useAgentSafeMode"; - -interface Props { - agent: LibraryAgent; - onSelectSettings: () => void; - selected?: boolean; -} - -export function AgentSettingsButton({ - agent, - onSelectSettings, - selected, -}: Props) { - const { hasHITLBlocks } = useAgentSafeMode(agent); - - if (!hasHITLBlocks) { - return null; - } +export function AgentSettingsButton() { return ( ); } diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/EmptySchedules.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/EmptySchedules.tsx index 97492d8a59..4c781b2896 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/EmptySchedules.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/EmptySchedules.tsx @@ -1,3 +1,5 @@ +"use client"; + import { Text } from "@/components/atoms/Text/Text"; export function EmptySchedules() { diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/EmptyTemplates.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/EmptyTemplates.tsx index c33abe69ad..364b762167 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/EmptyTemplates.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/EmptyTemplates.tsx @@ -1,3 +1,5 @@ +"use client"; + import { Text } from "@/components/atoms/Text/Text"; export function EmptyTemplates() { diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/EmptyTriggers.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/EmptyTriggers.tsx index 0d9dc47fff..06d09ff9a0 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/EmptyTriggers.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/EmptyTriggers.tsx @@ -1,3 +1,5 @@ +"use client"; + import { Text } from "@/components/atoms/Text/Text"; export function EmptyTriggers() { diff --git a/autogpt_platform/frontend/src/components/contextual/MarketplaceBanners/MarketplaceBanners.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/MarketplaceBanners.tsx similarity index 97% rename from autogpt_platform/frontend/src/components/contextual/MarketplaceBanners/MarketplaceBanners.tsx rename to autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/MarketplaceBanners.tsx index 4f826f6e85..00edcc721f 100644 --- a/autogpt_platform/frontend/src/components/contextual/MarketplaceBanners/MarketplaceBanners.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/MarketplaceBanners.tsx @@ -3,7 +3,7 @@ import { Button } from "@/components/atoms/Button/Button"; import { Text } from "@/components/atoms/Text/Text"; -interface MarketplaceBannersProps { +interface Props { hasUpdate?: boolean; latestVersion?: number; hasUnpublishedChanges?: boolean; @@ -21,7 +21,7 @@ export function MarketplaceBanners({ isUpdating, onUpdate, onPublish, -}: MarketplaceBannersProps) { +}: Props) { const renderUpdateBanner = () => { if (hasUpdate && latestVersion) { return ( diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/SectionWrap.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/SectionWrap.tsx index 75571dd856..f88d91bb0d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/SectionWrap.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/SectionWrap.tsx @@ -1,3 +1,5 @@ +"use client"; + import { cn } from "@/lib/utils"; type Props = { diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/LoadingSelectedContent.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/LoadingSelectedContent.tsx index dc2bb7cac2..bc5548afd0 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/LoadingSelectedContent.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/LoadingSelectedContent.tsx @@ -1,22 +1,16 @@ +import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; import { Skeleton } from "@/components/__legacy__/ui/skeleton"; import { cn } from "@/lib/utils"; -import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; import { AGENT_LIBRARY_SECTION_PADDING_X } from "../../helpers"; import { SelectedViewLayout } from "./SelectedViewLayout"; interface Props { agent: LibraryAgent; - onSelectSettings?: () => void; - selectedSettings?: boolean; } export function LoadingSelectedContent(props: Props) { return ( - +
diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedRunView/SelectedRunView.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedRunView/SelectedRunView.tsx index c66f0e9245..05da986583 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedRunView/SelectedRunView.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedRunView/SelectedRunView.tsx @@ -33,8 +33,6 @@ interface Props { onSelectRun?: (id: string) => void; onClearSelectedRun?: () => void; banner?: React.ReactNode; - onSelectSettings?: () => void; - selectedSettings?: boolean; } export function SelectedRunView({ @@ -43,8 +41,6 @@ export function SelectedRunView({ onSelectRun, onClearSelectedRun, banner, - onSelectSettings, - selectedSettings, }: Props) { const { run, preset, isLoading, responseError, httpError } = useSelectedRunView(agent.graph_id, runId); @@ -84,12 +80,7 @@ export function SelectedRunView({ return (
- +
diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedScheduleView/SelectedScheduleView.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedScheduleView/SelectedScheduleView.tsx index 445394c44a..e0a81dba5f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedScheduleView/SelectedScheduleView.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedScheduleView/SelectedScheduleView.tsx @@ -21,8 +21,6 @@ interface Props { scheduleId: string; onClearSelectedRun?: () => void; banner?: React.ReactNode; - onSelectSettings?: () => void; - selectedSettings?: boolean; } export function SelectedScheduleView({ @@ -30,8 +28,6 @@ export function SelectedScheduleView({ scheduleId, onClearSelectedRun, banner, - onSelectSettings, - selectedSettings, }: Props) { const { schedule, isLoading, error } = useSelectedScheduleView( agent.graph_id, @@ -76,12 +72,7 @@ export function SelectedScheduleView({ return (
- +
{}}> +
Agent Settings
-
- {!hasHITLBlocks ? ( -
- - This agent doesn't have any human-in-the-loop blocks, so - there are no settings to configure. - -
- ) : ( +
+ {hasHITLBlocks ? (
@@ -59,6 +52,12 @@ export function SelectedSettingsView({ agent, onClearSelectedRun }: Props) { />
+ ) : ( +
+ + This agent doesn't have any configurable settings. + +
)}
diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedTemplateView/SelectedTemplateView.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedTemplateView/SelectedTemplateView.tsx index b5ecb7ae5c..d0c49c2a93 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedTemplateView/SelectedTemplateView.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedTemplateView/SelectedTemplateView.tsx @@ -8,7 +8,7 @@ import { getAgentCredentialsFields, getAgentInputFields, } from "../../modals/AgentInputsReadOnly/helpers"; -import { CredentialsInput } from "../../modals/CredentialsInputs/CredentialsInputs"; +import { CredentialsInput } from "../../modals/CredentialsInputs/CredentialsInput"; import { RunAgentInputs } from "../../modals/RunAgentInputs/RunAgentInputs"; import { LoadingSelectedContent } from "../LoadingSelectedContent"; import { RunDetailCard } from "../RunDetailCard/RunDetailCard"; diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedTriggerView/SelectedTriggerView.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedTriggerView/SelectedTriggerView.tsx index f92c91112e..0d0cdc95cc 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedTriggerView/SelectedTriggerView.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedTriggerView/SelectedTriggerView.tsx @@ -7,7 +7,7 @@ import { getAgentCredentialsFields, getAgentInputFields, } from "../../modals/AgentInputsReadOnly/helpers"; -import { CredentialsInput } from "../../modals/CredentialsInputs/CredentialsInputs"; +import { CredentialsInput } from "../../modals/CredentialsInputs/CredentialsInput"; import { RunAgentInputs } from "../../modals/RunAgentInputs/RunAgentInputs"; import { LoadingSelectedContent } from "../LoadingSelectedContent"; import { RunDetailCard } from "../RunDetailCard/RunDetailCard"; diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedViewLayout.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedViewLayout.tsx index 8a4e46a606..fe824604df 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedViewLayout.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedViewLayout.tsx @@ -1,7 +1,7 @@ -import { Breadcrumbs } from "@/components/molecules/Breadcrumbs/Breadcrumbs"; -import { AgentSettingsButton } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/other/AgentSettingsButton"; import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; +import { Breadcrumbs } from "@/components/molecules/Breadcrumbs/Breadcrumbs"; import { AGENT_LIBRARY_SECTION_PADDING_X } from "../../helpers"; +import { AgentSettingsModal } from "../modals/AgentSettingsModal/AgentSettingsModal"; import { SectionWrap } from "../other/SectionWrap"; interface Props { @@ -9,8 +9,6 @@ interface Props { children: React.ReactNode; banner?: React.ReactNode; additionalBreadcrumb?: { name: string; link?: string }; - onSelectSettings?: () => void; - selectedSettings?: boolean; } export function SelectedViewLayout(props: Props) { @@ -19,8 +17,8 @@ export function SelectedViewLayout(props: Props) {
- {props.banner &&
{props.banner}
} -
+ {props.banner} +
- {props.agent && props.onSelectSettings && ( -
- -
- )} +
+ +
diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view.tsx index 5f57032618..1b155543f1 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view.tsx @@ -12,7 +12,7 @@ import { } from "@/lib/autogpt-server-api"; import { useBackendAPI } from "@/lib/autogpt-server-api/context"; -import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs"; +import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput"; import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentInputs/RunAgentInputs"; import { ScheduleTaskDialog } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/cron-scheduler-dialog"; import ActionButtonGroup from "@/components/__legacy__/action-button-group"; diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/page.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/page.tsx index 9ada590dd8..147c0aef45 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/page.tsx @@ -1,14 +1,7 @@ "use client"; -import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; import { NewAgentLibraryView } from "./components/NewAgentLibraryView/NewAgentLibraryView"; -import { OldAgentLibraryView } from "./components/OldAgentLibraryView/OldAgentLibraryView"; export default function AgentLibraryPage() { - const isNewLibraryPageEnabled = useGetFlag(Flag.NEW_AGENT_RUNS); - return isNewLibraryPageEnabled ? ( - - ) : ( - - ); + return ; } diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index e601be6626..6f9a87216b 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -2870,6 +2870,28 @@ } } }, + "/api/integrations/providers/system": { + "get": { + "tags": ["v1", "integrations"], + "summary": "List System Providers", + "description": "Get a list of providers that have platform credits (system credentials) available.\n\nThese providers can be used without the user providing their own API keys.", + "operationId": "getV1ListSystemProviders", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { "type": "string" }, + "type": "array", + "title": "Response Getv1Listsystemproviders" + } + } + } + } + } + } + }, "/api/integrations/webhooks/{webhook_id}/ping": { "post": { "tags": ["v1", "integrations"], diff --git a/autogpt_platform/frontend/src/components/atoms/Button/Button.tsx b/autogpt_platform/frontend/src/components/atoms/Button/Button.tsx index de1dec2d25..ab7b90e098 100644 --- a/autogpt_platform/frontend/src/components/atoms/Button/Button.tsx +++ b/autogpt_platform/frontend/src/components/atoms/Button/Button.tsx @@ -20,6 +20,7 @@ export function Button(props: ButtonProps) { rightIcon, children, as = "button", + asChild: _asChild, // Destructure to prevent passing to DOM ...restProps } = props; diff --git a/autogpt_platform/frontend/src/components/contextual/GoogleDrivePicker/GoogleDrivePicker.tsx b/autogpt_platform/frontend/src/components/contextual/GoogleDrivePicker/GoogleDrivePicker.tsx index e0a43b8c77..4decd2dbdb 100644 --- a/autogpt_platform/frontend/src/components/contextual/GoogleDrivePicker/GoogleDrivePicker.tsx +++ b/autogpt_platform/frontend/src/components/contextual/GoogleDrivePicker/GoogleDrivePicker.tsx @@ -1,6 +1,6 @@ "use client"; -import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs"; +import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput"; import { Button } from "@/components/atoms/Button/Button"; import { CircleNotchIcon, FolderOpenIcon } from "@phosphor-icons/react"; import { diff --git a/autogpt_platform/frontend/src/components/molecules/Accordion/Accordion.stories.tsx b/autogpt_platform/frontend/src/components/molecules/Accordion/Accordion.stories.tsx new file mode 100644 index 0000000000..d0fce53e0e --- /dev/null +++ b/autogpt_platform/frontend/src/components/molecules/Accordion/Accordion.stories.tsx @@ -0,0 +1,203 @@ +import type { Meta } from "@storybook/nextjs"; +import { + Accordion, + AccordionContent, + AccordionItem, + AccordionTrigger, +} from "./Accordion"; + +const meta: Meta = { + title: "Molecules/Accordion", + component: Accordion, + parameters: { + layout: "centered", + docs: { + description: { + component: ` +## Accordion Component + +A vertically stacked set of interactive headings that each reveal an associated section of content. + +### ✨ Features + +- **Built on Radix UI** - Uses @radix-ui/react-accordion for accessibility and functionality +- **Single or multiple** - Supports single or multiple items open at once +- **Smooth animations** - Built-in expand/collapse animations +- **Accessible** - Full keyboard navigation and screen reader support +- **Customizable** - Style with Tailwind CSS classes + +### 🎯 Usage + +\`\`\`tsx + + + Is it accessible? + + Yes. It adheres to the WAI-ARIA design pattern. + + + +\`\`\` + +### Props + +**Accordion**: +- **type**: "single" | "multiple" - Whether one or multiple items can be open +- **collapsible**: boolean - When type is "single", allows closing all items +- **defaultValue**: string | string[] - Default open item(s) +- **value**: string | string[] - Controlled open item(s) +- **onValueChange**: (value) => void - Callback when value changes + +**AccordionItem**: +- **value**: string - Unique identifier for the item +- **disabled**: boolean - Whether the item is disabled + +**AccordionTrigger**: +- Standard button props + +**AccordionContent**: +- Standard div props + `, + }, + }, + }, + tags: ["autodocs"], + argTypes: { + type: { + control: "radio", + options: ["single", "multiple"], + description: "Whether one or multiple items can be open at the same time", + table: { + defaultValue: { summary: "single" }, + }, + }, + collapsible: { + control: "boolean", + description: + 'When type is "single", allows closing content when clicking on open trigger', + table: { + defaultValue: { summary: "false" }, + }, + }, + }, +}; + +export default meta; + +export function Default() { + return ( + + + Is it accessible? + + Yes. It adheres to the WAI-ARIA design pattern. + + + + Is it styled? + + Yes. It comes with default styles that match your design system. + + + + Is it animated? + + Yes. It's animated by default with smooth expand/collapse + transitions. + + + + ); +} + +export function Multiple() { + return ( + + + First section + + Multiple items can be open at the same time when type is set to + "multiple". + + + + Second section + + Try opening this one while the first is still open. + + + + Third section + + All three can be open simultaneously. + + + + ); +} + +export function DefaultOpen() { + return ( + + + Closed by default + This item starts closed. + + + Open by default + + This item starts open because defaultValue is set to + "item-2". + + + + Also closed + This item also starts closed. + + + ); +} + +export function WithDisabledItem() { + return ( + + + Available item + This item can be toggled. + + + Disabled item + + This content cannot be accessed because the item is disabled. + + + + Another available item + This item can also be toggled. + + + ); +} + +export function CustomStyled() { + return ( + + + + Custom styled trigger + + + You can customize the styling using className props. + + + + + Blue themed + + + Each item can have different styles. + + + + ); +} diff --git a/autogpt_platform/frontend/src/components/molecules/Accordion/Accordion.tsx b/autogpt_platform/frontend/src/components/molecules/Accordion/Accordion.tsx new file mode 100644 index 0000000000..b071fc1d37 --- /dev/null +++ b/autogpt_platform/frontend/src/components/molecules/Accordion/Accordion.tsx @@ -0,0 +1,8 @@ +"use client"; + +export { + Accordion, + AccordionContent, + AccordionItem, + AccordionTrigger, +} from "@/components/ui/accordion"; diff --git a/autogpt_platform/frontend/src/components/molecules/Dialog/components/DrawerWrap.tsx b/autogpt_platform/frontend/src/components/molecules/Dialog/components/DrawerWrap.tsx index d00817bf59..3bfa321538 100644 --- a/autogpt_platform/frontend/src/components/molecules/Dialog/components/DrawerWrap.tsx +++ b/autogpt_platform/frontend/src/components/molecules/Dialog/components/DrawerWrap.tsx @@ -22,6 +22,9 @@ export function DrawerWrap({ handleClose, isForceOpen, }: Props) { + const accessibleTitle = title ?? "Dialog"; + const hasVisibleTitle = Boolean(title); + const closeBtn = ( + {typeof headerTitle === "string" ? ( + + {headerTitle} + + ) : ( + headerTitle + )} +
+
+ {showSessionInfo && sessionId && ( + <> + {showNewChatButton && ( + + )} + + )} + {headerActions} +
+
+ + )} + + {/* Main Content */} +
+ {/* Loading State - show when explicitly loading/creating OR when we don't have a session yet and no error */} + {(isLoading || isCreating || (!sessionId && !error)) && ( + + )} + + {/* Error State */} + {error && !isLoading && ( + + )} + + {/* Session Content */} + {sessionId && !isLoading && !error && ( + + )} +
+ + {/* Sessions Drawer */} + setIsSessionsDrawerOpen(false)} + onSelectSession={handleSelectSession} + currentSessionId={sessionId} + /> +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/AgentCarouselMessage/AgentCarouselMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentCarouselMessage/AgentCarouselMessage.tsx similarity index 65% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/AgentCarouselMessage/AgentCarouselMessage.tsx rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentCarouselMessage/AgentCarouselMessage.tsx index 125f834e05..582b24de5e 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/AgentCarouselMessage/AgentCarouselMessage.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentCarouselMessage/AgentCarouselMessage.tsx @@ -1,15 +1,16 @@ -import React from "react"; -import { Text } from "@/components/atoms/Text/Text"; import { Button } from "@/components/atoms/Button/Button"; import { Card } from "@/components/atoms/Card/Card"; -import { List, Robot, ArrowRight } from "@phosphor-icons/react"; +import { Text } from "@/components/atoms/Text/Text"; import { cn } from "@/lib/utils"; +import { ArrowRight, List, Robot } from "@phosphor-icons/react"; +import Image from "next/image"; export interface Agent { id: string; name: string; description: string; version?: number; + image_url?: string; } export interface AgentCarouselMessageProps { @@ -30,7 +31,7 @@ export function AgentCarouselMessage({ return (
@@ -40,13 +41,10 @@ export function AgentCarouselMessage({
- + Found {displayCount} {displayCount === 1 ? "Agent" : "Agents"} - + Select an agent to view details or run it
@@ -57,40 +55,49 @@ export function AgentCarouselMessage({ {agents.map((agent) => (
-
- +
+ {agent.image_url ? ( + {`${agent.name} + ) : ( +
+ +
+ )}
{agent.name} {agent.version && ( - + v{agent.version} )}
- + {agent.description} {onSelectAgent && (
{totalCount && totalCount > agents.length && ( - + Showing {agents.length} of {totalCount} results )} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/AgentInputsSetup.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/AgentInputsSetup.tsx new file mode 100644 index 0000000000..3ef71eca09 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/AgentInputsSetup.tsx @@ -0,0 +1,246 @@ +"use client"; + +import { Button } from "@/components/atoms/Button/Button"; +import { Card } from "@/components/atoms/Card/Card"; +import { Text } from "@/components/atoms/Text/Text"; +import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput"; +import { RunAgentInputs } from "@/components/contextual/RunAgentInputs/RunAgentInputs"; + +import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; +import { + BlockIOCredentialsSubSchema, + BlockIOSubSchema, +} from "@/lib/autogpt-server-api/types"; +import { cn, isEmpty } from "@/lib/utils"; +import { PlayIcon, WarningIcon } from "@phosphor-icons/react"; +import { useMemo } from "react"; +import { useAgentInputsSetup } from "./useAgentInputsSetup"; + +type LibraryAgentInputSchemaProperties = LibraryAgent["input_schema"] extends { + properties: infer P; +} + ? P extends Record + ? P + : Record + : Record; + +type LibraryAgentCredentialsInputSchemaProperties = + LibraryAgent["credentials_input_schema"] extends { + properties: infer P; + } + ? P extends Record + ? P + : Record + : Record; + +interface Props { + agentName?: string; + inputSchema: LibraryAgentInputSchemaProperties | Record; + credentialsSchema?: + | LibraryAgentCredentialsInputSchemaProperties + | Record; + message: string; + requiredFields?: string[]; + onRun: ( + inputs: Record, + credentials: Record, + ) => void; + onCancel?: () => void; + className?: string; +} + +export function AgentInputsSetup({ + agentName, + inputSchema, + credentialsSchema, + message, + requiredFields, + onRun, + onCancel, + className, +}: Props) { + const { inputValues, setInputValue, credentialsValues, setCredentialsValue } = + useAgentInputsSetup(); + + const inputSchemaObj = useMemo(() => { + if (!inputSchema) return { properties: {}, required: [] }; + if ("properties" in inputSchema && "type" in inputSchema) { + return inputSchema as { + properties: Record; + required?: string[]; + }; + } + return { properties: inputSchema as Record, required: [] }; + }, [inputSchema]); + + const credentialsSchemaObj = useMemo(() => { + if (!credentialsSchema) return { properties: {}, required: [] }; + if ("properties" in credentialsSchema && "type" in credentialsSchema) { + return credentialsSchema as { + properties: Record; + required?: string[]; + }; + } + return { + properties: credentialsSchema as Record, + required: [], + }; + }, [credentialsSchema]); + + const agentInputFields = useMemo(() => { + const properties = inputSchemaObj.properties || {}; + return Object.fromEntries( + Object.entries(properties).filter( + ([_, subSchema]: [string, any]) => !subSchema.hidden, + ), + ); + }, [inputSchemaObj]); + + const agentCredentialsInputFields = useMemo(() => { + return credentialsSchemaObj.properties || {}; + }, [credentialsSchemaObj]); + + const inputFields = Object.entries(agentInputFields); + const credentialFields = Object.entries(agentCredentialsInputFields); + + const defaultsFromSchema = useMemo(() => { + const defaults: Record = {}; + Object.entries(agentInputFields).forEach(([key, schema]) => { + if ("default" in schema && schema.default !== undefined) { + defaults[key] = schema.default; + } + }); + return defaults; + }, [agentInputFields]); + + const defaultsFromCredentialsSchema = useMemo(() => { + const defaults: Record = {}; + Object.entries(agentCredentialsInputFields).forEach(([key, schema]) => { + if ("default" in schema && schema.default !== undefined) { + defaults[key] = schema.default; + } + }); + return defaults; + }, [agentCredentialsInputFields]); + + const mergedInputValues = useMemo(() => { + return { ...defaultsFromSchema, ...inputValues }; + }, [defaultsFromSchema, inputValues]); + + const mergedCredentialsValues = useMemo(() => { + return { ...defaultsFromCredentialsSchema, ...credentialsValues }; + }, [defaultsFromCredentialsSchema, credentialsValues]); + + const allRequiredInputsAreSet = useMemo(() => { + const requiredInputs = new Set( + requiredFields || (inputSchemaObj.required as string[]) || [], + ); + const nonEmptyInputs = new Set( + Object.keys(mergedInputValues).filter( + (k) => !isEmpty(mergedInputValues[k]), + ), + ); + const missing = [...requiredInputs].filter( + (input) => !nonEmptyInputs.has(input), + ); + return missing.length === 0; + }, [inputSchemaObj.required, mergedInputValues, requiredFields]); + + const allCredentialsAreSet = useMemo(() => { + const requiredCredentials = new Set( + (credentialsSchemaObj.required as string[]) || [], + ); + if (requiredCredentials.size === 0) { + return true; + } + const missing = [...requiredCredentials].filter((key) => { + const cred = mergedCredentialsValues[key]; + return !cred || !cred.id; + }); + return missing.length === 0; + }, [credentialsSchemaObj.required, mergedCredentialsValues]); + + const canRun = allRequiredInputsAreSet && allCredentialsAreSet; + + function handleRun() { + if (canRun) { + onRun(mergedInputValues, mergedCredentialsValues); + } + } + + return ( + +
+
+ +
+
+ + {agentName ? `Configure ${agentName}` : "Agent Configuration"} + + + {message} + + + {inputFields.length > 0 && ( +
+ {inputFields.map(([key, inputSubSchema]) => ( + setInputValue(key, value)} + /> + ))} +
+ )} + + {credentialFields.length > 0 && ( +
+ {credentialFields.map(([key, schema]) => { + const requiredCredentials = new Set( + (credentialsSchemaObj.required as string[]) || [], + ); + return ( + + setCredentialsValue(key, value) + } + siblingInputs={mergedInputValues} + isOptional={!requiredCredentials.has(key)} + /> + ); + })} +
+ )} + +
+ + {onCancel && ( + + )} +
+
+
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/useAgentInputsSetup.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/useAgentInputsSetup.ts new file mode 100644 index 0000000000..e36a3f3c5d --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/useAgentInputsSetup.ts @@ -0,0 +1,38 @@ +import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types"; +import { useState } from "react"; + +export function useAgentInputsSetup() { + const [inputValues, setInputValues] = useState>({}); + const [credentialsValues, setCredentialsValues] = useState< + Record + >({}); + + function setInputValue(key: string, value: any) { + setInputValues((prev) => ({ + ...prev, + [key]: value, + })); + } + + function setCredentialsValue(key: string, value?: CredentialsMetaInput) { + if (value) { + setCredentialsValues((prev) => ({ + ...prev, + [key]: value, + })); + } else { + setCredentialsValues((prev) => { + const next = { ...prev }; + delete next[key]; + return next; + }); + } + } + + return { + inputValues, + setInputValue, + credentialsValues, + setCredentialsValue, + }; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/AuthPromptWidget/AuthPromptWidget.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AuthPromptWidget/AuthPromptWidget.tsx similarity index 84% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/AuthPromptWidget/AuthPromptWidget.tsx rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AuthPromptWidget/AuthPromptWidget.tsx index 885a06e92a..33f02e660f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/AuthPromptWidget/AuthPromptWidget.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AuthPromptWidget/AuthPromptWidget.tsx @@ -1,10 +1,9 @@ "use client"; -import React from "react"; -import { useRouter } from "next/navigation"; import { Button } from "@/components/atoms/Button/Button"; -import { SignInIcon, UserPlusIcon, ShieldIcon } from "@phosphor-icons/react"; import { cn } from "@/lib/utils"; +import { ShieldIcon, SignInIcon, UserPlusIcon } from "@phosphor-icons/react"; +import { useRouter } from "next/navigation"; export interface AuthPromptWidgetProps { message: string; @@ -54,8 +53,8 @@ export function AuthPromptWidget({ return (
-

+

Authentication Required

-

+

Sign in to set up and manage agents

-
-

- {message} -

+
+

{message}

{agentInfo && ( -
+

Ready to set up:{" "} {agentInfo.name} @@ -114,7 +111,7 @@ export function AuthPromptWidget({

-
+
Your chat session will be preserved after signing in
diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/ChatContainer.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/ChatContainer.tsx new file mode 100644 index 0000000000..6f7a0e8f51 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/ChatContainer.tsx @@ -0,0 +1,88 @@ +import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; +import { cn } from "@/lib/utils"; +import { useCallback } from "react"; +import { usePageContext } from "../../usePageContext"; +import { ChatInput } from "../ChatInput/ChatInput"; +import { MessageList } from "../MessageList/MessageList"; +import { QuickActionsWelcome } from "../QuickActionsWelcome/QuickActionsWelcome"; +import { useChatContainer } from "./useChatContainer"; + +export interface ChatContainerProps { + sessionId: string | null; + initialMessages: SessionDetailResponse["messages"]; + className?: string; +} + +export function ChatContainer({ + sessionId, + initialMessages, + className, +}: ChatContainerProps) { + const { messages, streamingChunks, isStreaming, sendMessage } = + useChatContainer({ + sessionId, + initialMessages, + }); + const { capturePageContext } = usePageContext(); + + // Wrap sendMessage to automatically capture page context + const sendMessageWithContext = useCallback( + async (content: string, isUserMessage: boolean = true) => { + const context = capturePageContext(); + await sendMessage(content, isUserMessage, context); + }, + [sendMessage, capturePageContext], + ); + + const quickActions = [ + "Find agents for social media management", + "Show me agents for content creation", + "Help me automate my business", + "What can you help me with?", + ]; + + return ( +
+ {/* Messages or Welcome Screen */} +
+ {messages.length === 0 ? ( + + ) : ( + + )} +
+ + {/* Input - Always visible */} +
+ +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/createStreamEventDispatcher.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/createStreamEventDispatcher.ts similarity index 95% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/createStreamEventDispatcher.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/createStreamEventDispatcher.ts index b8421c3386..844f126d49 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/createStreamEventDispatcher.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/createStreamEventDispatcher.ts @@ -1,14 +1,14 @@ import { toast } from "sonner"; -import type { StreamChunk } from "@/app/(platform)/chat/useChatStream"; +import { StreamChunk } from "../../useChatStream"; import type { HandlerDependencies } from "./useChatContainer.handlers"; import { + handleError, + handleLoginNeeded, + handleStreamEnd, handleTextChunk, handleTextEnded, handleToolCallStart, handleToolResponse, - handleLoginNeeded, - handleStreamEnd, - handleError, } from "./useChatContainer.handlers"; export function createStreamEventDispatcher( diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/helpers.ts similarity index 68% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/helpers.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/helpers.ts index 3a94dab1ea..cd05563369 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/helpers.ts @@ -1,5 +1,24 @@ -import type { ChatMessageData } from "@/app/(platform)/chat/components/ChatMessage/useChatMessage"; import type { ToolResult } from "@/types/chat"; +import type { ChatMessageData } from "../ChatMessage/useChatMessage"; + +export function removePageContext(content: string): string { + // Remove "Page URL: ..." pattern at start of line (case insensitive, handles various formats) + let cleaned = content.replace(/^\s*Page URL:\s*[^\n\r]*/gim, ""); + + // Find "User Message:" marker at start of line to preserve the actual user message + const userMessageMatch = cleaned.match(/^\s*User Message:\s*([\s\S]*)$/im); + if (userMessageMatch) { + // If we found "User Message:", extract everything after it + cleaned = userMessageMatch[1]; + } else { + // If no "User Message:" marker, remove "Page Content:" and everything after it at start of line + cleaned = cleaned.replace(/^\s*Page Content:[\s\S]*$/gim, ""); + } + + // Clean up extra whitespace and newlines + cleaned = cleaned.replace(/\n\s*\n\s*\n+/g, "\n\n").trim(); + return cleaned; +} export function createUserMessage(content: string): ChatMessageData { return { @@ -63,6 +82,7 @@ export function isAgentArray(value: unknown): value is Array<{ name: string; description: string; version?: number; + image_url?: string; }> { if (!Array.isArray(value)) { return false; @@ -77,7 +97,8 @@ export function isAgentArray(value: unknown): value is Array<{ typeof item.name === "string" && "description" in item && typeof item.description === "string" && - (!("version" in item) || typeof item.version === "number"), + (!("version" in item) || typeof item.version === "number") && + (!("image_url" in item) || typeof item.image_url === "string"), ); } @@ -232,6 +253,7 @@ export function isSetupInfo(value: unknown): value is { export function extractCredentialsNeeded( parsedResult: Record, + toolName: string = "run_agent", ): ChatMessageData | null { try { const setupInfo = parsedResult?.setup_info as @@ -244,7 +266,7 @@ export function extractCredentialsNeeded( | Record> | undefined; if (missingCreds && Object.keys(missingCreds).length > 0) { - const agentName = (setupInfo?.agent_name as string) || "this agent"; + const agentName = (setupInfo?.agent_name as string) || "this block"; const credentials = Object.values(missingCreds).map((credInfo) => ({ provider: (credInfo.provider as string) || "unknown", providerName: @@ -264,7 +286,7 @@ export function extractCredentialsNeeded( })); return { type: "credentials_needed", - toolName: "run_agent", + toolName, credentials, message: `To run ${agentName}, you need to add ${credentials.length === 1 ? "credentials" : `${credentials.length} credentials`}.`, agentName, @@ -277,3 +299,92 @@ export function extractCredentialsNeeded( return null; } } + +export function extractInputsNeeded( + parsedResult: Record, + toolName: string = "run_agent", +): ChatMessageData | null { + try { + const setupInfo = parsedResult?.setup_info as + | Record + | undefined; + const requirements = setupInfo?.requirements as + | Record + | undefined; + const inputs = requirements?.inputs as + | Array> + | undefined; + const credentials = requirements?.credentials as + | Array> + | undefined; + + if (!inputs || inputs.length === 0) { + return null; + } + + const agentName = (setupInfo?.agent_name as string) || "this agent"; + const agentId = parsedResult?.graph_id as string | undefined; + const graphVersion = parsedResult?.graph_version as number | undefined; + + const properties: Record = {}; + const requiredProps: string[] = []; + inputs.forEach((input) => { + const name = input.name as string; + if (name) { + properties[name] = { + title: input.name as string, + description: (input.description as string) || "", + type: (input.type as string) || "string", + default: input.default, + enum: input.options, + format: input.format, + }; + if ((input.required as boolean) === true) { + requiredProps.push(name); + } + } + }); + + const inputSchema: Record = { + type: "object", + properties, + }; + if (requiredProps.length > 0) { + inputSchema.required = requiredProps; + } + + const credentialsSchema: Record = {}; + if (credentials && credentials.length > 0) { + credentials.forEach((cred) => { + const id = cred.id as string; + if (id) { + credentialsSchema[id] = { + type: "object", + properties: {}, + credentials_provider: [cred.provider as string], + credentials_types: [(cred.type as string) || "api_key"], + credentials_scopes: cred.scopes as string[] | undefined, + }; + } + }); + } + + return { + type: "inputs_needed", + toolName, + agentName, + agentId, + graphVersion, + inputSchema, + credentialsSchema: + Object.keys(credentialsSchema).length > 0 + ? credentialsSchema + : undefined, + message: `Please provide the required inputs to run ${agentName}.`, + timestamp: new Date(), + }; + } catch (err) { + console.error("Failed to extract inputs from setup info:", err); + return null; + } +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.handlers.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.handlers.ts similarity index 85% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.handlers.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.handlers.ts index fdbecb5d61..064b847064 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.handlers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.handlers.ts @@ -1,13 +1,18 @@ -import type { Dispatch, SetStateAction, MutableRefObject } from "react"; -import type { StreamChunk } from "@/app/(platform)/chat/useChatStream"; -import type { ChatMessageData } from "@/app/(platform)/chat/components/ChatMessage/useChatMessage"; -import { parseToolResponse, extractCredentialsNeeded } from "./helpers"; +import type { Dispatch, MutableRefObject, SetStateAction } from "react"; +import { StreamChunk } from "../../useChatStream"; +import type { ChatMessageData } from "../ChatMessage/useChatMessage"; +import { + extractCredentialsNeeded, + extractInputsNeeded, + parseToolResponse, +} from "./helpers"; export interface HandlerDependencies { setHasTextChunks: Dispatch>; setStreamingChunks: Dispatch>; streamingChunksRef: MutableRefObject; setMessages: Dispatch>; + setIsStreamingInitiated: Dispatch>; sessionId: string; } @@ -100,11 +105,18 @@ export function handleToolResponse( parsedResult = null; } if ( - chunk.tool_name === "run_agent" && + (chunk.tool_name === "run_agent" || chunk.tool_name === "run_block") && chunk.success && parsedResult?.type === "setup_requirements" ) { - const credentialsMessage = extractCredentialsNeeded(parsedResult); + const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name); + if (inputsMessage) { + deps.setMessages((prev) => [...prev, inputsMessage]); + } + const credentialsMessage = extractCredentialsNeeded( + parsedResult, + chunk.tool_name, + ); if (credentialsMessage) { deps.setMessages((prev) => [...prev, credentialsMessage]); } @@ -197,10 +209,15 @@ export function handleStreamEnd( deps.setStreamingChunks([]); deps.streamingChunksRef.current = []; deps.setHasTextChunks(false); + deps.setIsStreamingInitiated(false); console.log("[Stream End] Stream complete, messages in local state"); } -export function handleError(chunk: StreamChunk, _deps: HandlerDependencies) { +export function handleError(chunk: StreamChunk, deps: HandlerDependencies) { const errorMessage = chunk.message || chunk.content || "An error occurred"; console.error("Stream error:", errorMessage); + deps.setIsStreamingInitiated(false); + deps.setHasTextChunks(false); + deps.setStreamingChunks([]); + deps.streamingChunksRef.current = []; } diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.ts new file mode 100644 index 0000000000..8e7dee7718 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.ts @@ -0,0 +1,206 @@ +import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; +import { useCallback, useMemo, useRef, useState } from "react"; +import { toast } from "sonner"; +import { useChatStream } from "../../useChatStream"; +import type { ChatMessageData } from "../ChatMessage/useChatMessage"; +import { createStreamEventDispatcher } from "./createStreamEventDispatcher"; +import { + createUserMessage, + filterAuthMessages, + isToolCallArray, + isValidMessage, + parseToolResponse, + removePageContext, +} from "./helpers"; + +interface Args { + sessionId: string | null; + initialMessages: SessionDetailResponse["messages"]; +} + +export function useChatContainer({ sessionId, initialMessages }: Args) { + const [messages, setMessages] = useState([]); + const [streamingChunks, setStreamingChunks] = useState([]); + const [hasTextChunks, setHasTextChunks] = useState(false); + const [isStreamingInitiated, setIsStreamingInitiated] = useState(false); + const streamingChunksRef = useRef([]); + const { error, sendMessage: sendStreamMessage } = useChatStream(); + const isStreaming = isStreamingInitiated || hasTextChunks; + + const allMessages = useMemo(() => { + const processedInitialMessages: ChatMessageData[] = []; + // Map to track tool calls by their ID so we can look up tool names for tool responses + const toolCallMap = new Map(); + + for (const msg of initialMessages) { + if (!isValidMessage(msg)) { + console.warn("Invalid message structure from backend:", msg); + continue; + } + + let content = String(msg.content || ""); + const role = String(msg.role || "assistant").toLowerCase(); + const toolCalls = msg.tool_calls; + const timestamp = msg.timestamp + ? new Date(msg.timestamp as string) + : undefined; + + // Remove page context from user messages when loading existing sessions + if (role === "user") { + content = removePageContext(content); + // Skip user messages that become empty after removing page context + if (!content.trim()) { + continue; + } + processedInitialMessages.push({ + type: "message", + role: "user", + content, + timestamp, + }); + continue; + } + + // Handle assistant messages first (before tool messages) to build tool call map + if (role === "assistant") { + // Strip tags from content + content = content + .replace(/[\s\S]*?<\/thinking>/gi, "") + .trim(); + + // If assistant has tool calls, create tool_call messages for each + if (toolCalls && isToolCallArray(toolCalls) && toolCalls.length > 0) { + for (const toolCall of toolCalls) { + const toolName = toolCall.function.name; + const toolId = toolCall.id; + // Store tool name for later lookup + toolCallMap.set(toolId, toolName); + + try { + const args = JSON.parse(toolCall.function.arguments || "{}"); + processedInitialMessages.push({ + type: "tool_call", + toolId, + toolName, + arguments: args, + timestamp, + }); + } catch (err) { + console.warn("Failed to parse tool call arguments:", err); + processedInitialMessages.push({ + type: "tool_call", + toolId, + toolName, + arguments: {}, + timestamp, + }); + } + } + // Only add assistant message if there's content after stripping thinking tags + if (content.trim()) { + processedInitialMessages.push({ + type: "message", + role: "assistant", + content, + timestamp, + }); + } + } else if (content.trim()) { + // Assistant message without tool calls, but with content + processedInitialMessages.push({ + type: "message", + role: "assistant", + content, + timestamp, + }); + } + continue; + } + + // Handle tool messages - look up tool name from tool call map + if (role === "tool") { + const toolCallId = (msg.tool_call_id as string) || ""; + const toolName = toolCallMap.get(toolCallId) || "unknown"; + const toolResponse = parseToolResponse( + content, + toolCallId, + toolName, + timestamp, + ); + if (toolResponse) { + processedInitialMessages.push(toolResponse); + } + continue; + } + + // Handle other message types (system, etc.) + if (content.trim()) { + processedInitialMessages.push({ + type: "message", + role: role as "user" | "assistant" | "system", + content, + timestamp, + }); + } + } + + return [...processedInitialMessages, ...messages]; + }, [initialMessages, messages]); + + const sendMessage = useCallback( + async function sendMessage( + content: string, + isUserMessage: boolean = true, + context?: { url: string; content: string }, + ) { + if (!sessionId) { + console.error("Cannot send message: no session ID"); + return; + } + if (isUserMessage) { + const userMessage = createUserMessage(content); + setMessages((prev) => [...filterAuthMessages(prev), userMessage]); + } else { + setMessages((prev) => filterAuthMessages(prev)); + } + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + setIsStreamingInitiated(true); + const dispatcher = createStreamEventDispatcher({ + setHasTextChunks, + setStreamingChunks, + streamingChunksRef, + setMessages, + sessionId, + setIsStreamingInitiated, + }); + try { + await sendStreamMessage( + sessionId, + content, + dispatcher, + isUserMessage, + context, + ); + } catch (err) { + console.error("Failed to send message:", err); + setIsStreamingInitiated(false); + const errorMessage = + err instanceof Error ? err.message : "Failed to send message"; + toast.error("Failed to send message", { + description: errorMessage, + }); + } + }, + [sessionId, sendStreamMessage], + ); + + return { + messages: allMessages, + streamingChunks, + isStreaming, + error, + sendMessage, + }; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx new file mode 100644 index 0000000000..4b9da57286 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx @@ -0,0 +1,149 @@ +import { Text } from "@/components/atoms/Text/Text"; +import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput"; +import type { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api"; +import { cn } from "@/lib/utils"; +import { CheckIcon, RobotIcon, WarningIcon } from "@phosphor-icons/react"; +import { useEffect, useRef } from "react"; +import { useChatCredentialsSetup } from "./useChatCredentialsSetup"; + +export interface CredentialInfo { + provider: string; + providerName: string; + credentialType: "api_key" | "oauth2" | "user_password" | "host_scoped"; + title: string; + scopes?: string[]; +} + +interface Props { + credentials: CredentialInfo[]; + agentName?: string; + message: string; + onAllCredentialsComplete: () => void; + onCancel: () => void; + className?: string; +} + +function createSchemaFromCredentialInfo( + credential: CredentialInfo, +): BlockIOCredentialsSubSchema { + return { + type: "object", + properties: {}, + credentials_provider: [credential.provider], + credentials_types: [credential.credentialType], + credentials_scopes: credential.scopes, + discriminator: undefined, + discriminator_mapping: undefined, + discriminator_values: undefined, + }; +} + +export function ChatCredentialsSetup({ + credentials, + agentName: _agentName, + message, + onAllCredentialsComplete, + onCancel: _onCancel, +}: Props) { + const { selectedCredentials, isAllComplete, handleCredentialSelect } = + useChatCredentialsSetup(credentials); + + // Track if we've already called completion to prevent double calls + const hasCalledCompleteRef = useRef(false); + + // Reset the completion flag when credentials change (new credential setup flow) + useEffect( + function resetCompletionFlag() { + hasCalledCompleteRef.current = false; + }, + [credentials], + ); + + // Auto-call completion when all credentials are configured + useEffect( + function autoCompleteWhenReady() { + if (isAllComplete && !hasCalledCompleteRef.current) { + hasCalledCompleteRef.current = true; + onAllCredentialsComplete(); + } + }, + [isAllComplete, onAllCredentialsComplete], + ); + + return ( +
+
+
+
+ +
+
+ +
+
+
+
+
+ + Credentials Required + + + {message} + +
+ +
+ {credentials.map((cred, index) => { + const schema = createSchemaFromCredentialInfo(cred); + const isSelected = !!selectedCredentials[cred.provider]; + + return ( +
+
+ {isSelected ? ( + + ) : ( + + )} + + {cred.providerName} + +
+ + + handleCredentialSelect(cred.provider, credMeta) + } + /> +
+ ); + })} +
+
+
+
+
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatCredentialsSetup/useChatCredentialsSetup.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatCredentialsSetup/useChatCredentialsSetup.ts similarity index 100% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatCredentialsSetup/useChatCredentialsSetup.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatCredentialsSetup/useChatCredentialsSetup.ts diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatErrorState/ChatErrorState.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatErrorState/ChatErrorState.tsx similarity index 100% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatErrorState/ChatErrorState.tsx rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatErrorState/ChatErrorState.tsx diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/ChatInput.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/ChatInput.tsx new file mode 100644 index 0000000000..3101174a11 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/ChatInput.tsx @@ -0,0 +1,64 @@ +import { Input } from "@/components/atoms/Input/Input"; +import { cn } from "@/lib/utils"; +import { ArrowUpIcon } from "@phosphor-icons/react"; +import { useChatInput } from "./useChatInput"; + +export interface ChatInputProps { + onSend: (message: string) => void; + disabled?: boolean; + placeholder?: string; + className?: string; +} + +export function ChatInput({ + onSend, + disabled = false, + placeholder = "Type your message...", + className, +}: ChatInputProps) { + const inputId = "chat-input"; + const { value, setValue, handleKeyDown, handleSend } = useChatInput({ + onSend, + disabled, + maxRows: 5, + inputId, + }); + + return ( +
+ setValue(e.target.value)} + onKeyDown={handleKeyDown} + placeholder={placeholder} + disabled={disabled} + rows={1} + wrapperClassName="mb-0 relative" + className="pr-12" + /> + + Press Enter to send, Shift+Enter for new line + + + +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/useChatInput.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/useChatInput.ts similarity index 65% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/useChatInput.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/useChatInput.ts index 2efae95483..08cf565daa 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/useChatInput.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/useChatInput.ts @@ -1,21 +1,22 @@ -import { KeyboardEvent, useCallback, useState, useRef, useEffect } from "react"; +import { KeyboardEvent, useCallback, useEffect, useState } from "react"; interface UseChatInputArgs { onSend: (message: string) => void; disabled?: boolean; maxRows?: number; + inputId?: string; } export function useChatInput({ onSend, disabled = false, maxRows = 5, + inputId = "chat-input", }: UseChatInputArgs) { const [value, setValue] = useState(""); - const textareaRef = useRef(null); useEffect(() => { - const textarea = textareaRef.current; + const textarea = document.getElementById(inputId) as HTMLTextAreaElement; if (!textarea) return; textarea.style.height = "auto"; const lineHeight = parseInt( @@ -27,23 +28,25 @@ export function useChatInput({ textarea.style.height = `${newHeight}px`; textarea.style.overflowY = textarea.scrollHeight > maxHeight ? "auto" : "hidden"; - }, [value, maxRows]); + }, [value, maxRows, inputId]); const handleSend = useCallback(() => { if (disabled || !value.trim()) return; onSend(value.trim()); setValue(""); - if (textareaRef.current) { - textareaRef.current.style.height = "auto"; + const textarea = document.getElementById(inputId) as HTMLTextAreaElement; + if (textarea) { + textarea.style.height = "auto"; } - }, [value, onSend, disabled]); + }, [value, onSend, disabled, inputId]); const handleKeyDown = useCallback( - (event: KeyboardEvent) => { + (event: KeyboardEvent) => { if (event.key === "Enter" && !event.shiftKey) { event.preventDefault(); handleSend(); } + // Shift+Enter allows default behavior (new line) - no need to handle explicitly }, [handleSend], ); @@ -53,6 +56,5 @@ export function useChatInput({ setValue, handleKeyDown, handleSend, - textareaRef, }; } diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatLoadingState/ChatLoadingState.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatLoadingState/ChatLoadingState.tsx new file mode 100644 index 0000000000..c0cdb33c50 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatLoadingState/ChatLoadingState.tsx @@ -0,0 +1,19 @@ +import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; +import { cn } from "@/lib/utils"; + +export interface ChatLoadingStateProps { + message?: string; + className?: string; +} + +export function ChatLoadingState({ className }: ChatLoadingStateProps) { + return ( +
+
+ +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/ChatMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/ChatMessage.tsx new file mode 100644 index 0000000000..69a1ab63fb --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/ChatMessage.tsx @@ -0,0 +1,341 @@ +"use client"; + +import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/store"; +import Avatar, { + AvatarFallback, + AvatarImage, +} from "@/components/atoms/Avatar/Avatar"; +import { Button } from "@/components/atoms/Button/Button"; +import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; +import { cn } from "@/lib/utils"; +import { + ArrowClockwise, + CheckCircleIcon, + CheckIcon, + CopyIcon, + RobotIcon, +} from "@phosphor-icons/react"; +import { useRouter } from "next/navigation"; +import { useCallback, useState } from "react"; +import { getToolActionPhrase } from "../../helpers"; +import { AgentCarouselMessage } from "../AgentCarouselMessage/AgentCarouselMessage"; +import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget"; +import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup"; +import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage"; +import { MarkdownContent } from "../MarkdownContent/MarkdownContent"; +import { MessageBubble } from "../MessageBubble/MessageBubble"; +import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage"; +import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage"; +import { ToolResponseMessage } from "../ToolResponseMessage/ToolResponseMessage"; +import { useChatMessage, type ChatMessageData } from "./useChatMessage"; +export interface ChatMessageProps { + message: ChatMessageData; + className?: string; + onDismissLogin?: () => void; + onDismissCredentials?: () => void; + onSendMessage?: (content: string, isUserMessage?: boolean) => void; + agentOutput?: ChatMessageData; +} + +export function ChatMessage({ + message, + className, + onDismissCredentials, + onSendMessage, + agentOutput, +}: ChatMessageProps) { + const { user } = useSupabase(); + const router = useRouter(); + const [copied, setCopied] = useState(false); + const { + isUser, + isToolCall, + isToolResponse, + isLoginNeeded, + isCredentialsNeeded, + } = useChatMessage(message); + + const { data: profile } = useGetV2GetUserProfile({ + query: { + select: (res) => (res.status === 200 ? res.data : null), + enabled: isUser && !!user, + queryKey: ["/api/store/profile", user?.id], + }, + }); + + const handleAllCredentialsComplete = useCallback( + function handleAllCredentialsComplete() { + // Send a user message that explicitly asks to retry the setup + // This ensures the LLM calls get_required_setup_info again and proceeds with execution + if (onSendMessage) { + onSendMessage( + "I've configured the required credentials. Please check if everything is ready and proceed with setting up the agent.", + ); + } + // Optionally dismiss the credentials prompt + if (onDismissCredentials) { + onDismissCredentials(); + } + }, + [onSendMessage, onDismissCredentials], + ); + + function handleCancelCredentials() { + // Dismiss the credentials prompt + if (onDismissCredentials) { + onDismissCredentials(); + } + } + + const handleCopy = useCallback(async () => { + if (message.type !== "message") return; + + try { + await navigator.clipboard.writeText(message.content); + setCopied(true); + setTimeout(() => setCopied(false), 2000); + } catch (error) { + console.error("Failed to copy:", error); + } + }, [message]); + + const handleTryAgain = useCallback(() => { + if (message.type !== "message" || !onSendMessage) return; + onSendMessage(message.content, message.role === "user"); + }, [message, onSendMessage]); + + const handleViewExecution = useCallback(() => { + if (message.type === "execution_started" && message.libraryAgentLink) { + router.push(message.libraryAgentLink); + } + }, [message, router]); + + // Render credentials needed messages + if (isCredentialsNeeded && message.type === "credentials_needed") { + return ( + + ); + } + + // Render login needed messages + if (isLoginNeeded && message.type === "login_needed") { + // If user is already logged in, show success message instead of auth prompt + if (user) { + return ( +
+
+
+
+
+ +
+
+

+ Successfully Authenticated +

+

+ You're now signed in and ready to continue +

+
+
+
+
+
+ ); + } + + // Show auth prompt if not logged in + return ( +
+ +
+ ); + } + + // Render tool call messages + if (isToolCall && message.type === "tool_call") { + return ( +
+ +
+ ); + } + + // Render no_results messages - use dedicated component, not ToolResponseMessage + if (message.type === "no_results") { + return ( +
+ +
+ ); + } + + // Render agent_carousel messages - use dedicated component, not ToolResponseMessage + if (message.type === "agent_carousel") { + return ( +
+ +
+ ); + } + + // Render execution_started messages - use dedicated component, not ToolResponseMessage + if (message.type === "execution_started") { + return ( +
+ +
+ ); + } + + // Render tool response messages (but skip agent_output if it's being rendered inside assistant message) + if (isToolResponse && message.type === "tool_response") { + // Check if this is an agent_output that should be rendered inside assistant message + if (message.result) { + let parsedResult: Record | null = null; + try { + parsedResult = + typeof message.result === "string" + ? JSON.parse(message.result) + : (message.result as Record); + } catch { + parsedResult = null; + } + if (parsedResult?.type === "agent_output") { + // Skip rendering - this will be rendered inside the assistant message + return null; + } + } + + return ( +
+ +
+ ); + } + + // Render regular chat messages + if (message.type === "message") { + return ( +
+
+ {!isUser && ( +
+
+ +
+
+ )} + +
+ + + {agentOutput && + agentOutput.type === "tool_response" && + !isUser && ( +
+ +
+ )} +
+
+ {isUser && onSendMessage && ( + + )} + +
+
+ + {isUser && ( +
+ + + + {profile?.username?.charAt(0)?.toUpperCase() || "U"} + + +
+ )} +
+
+ ); + } + + // Fallback for unknown message types + return null; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatMessage/useChatMessage.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/useChatMessage.ts similarity index 88% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatMessage/useChatMessage.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/useChatMessage.ts index ae4f48f35b..9a597d4b26 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatMessage/useChatMessage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/useChatMessage.ts @@ -1,5 +1,5 @@ -import { formatDistanceToNow } from "date-fns"; import type { ToolArguments, ToolResult } from "@/types/chat"; +import { formatDistanceToNow } from "date-fns"; export type ChatMessageData = | { @@ -65,6 +65,7 @@ export type ChatMessageData = name: string; description: string; version?: number; + image_url?: string; }>; totalCount?: number; timestamp?: string | Date; @@ -77,6 +78,17 @@ export type ChatMessageData = message?: string; libraryAgentLink?: string; timestamp?: string | Date; + } + | { + type: "inputs_needed"; + toolName: string; + agentName?: string; + agentId?: string; + graphVersion?: number; + inputSchema: Record; + credentialsSchema?: Record; + message: string; + timestamp?: string | Date; }; export function useChatMessage(message: ChatMessageData) { @@ -96,5 +108,6 @@ export function useChatMessage(message: ChatMessageData) { isNoResults: message.type === "no_results", isAgentCarousel: message.type === "agent_carousel", isExecutionStarted: message.type === "execution_started", + isInputsNeeded: message.type === "inputs_needed", }; } diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ExecutionStartedMessage/ExecutionStartedMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ExecutionStartedMessage/ExecutionStartedMessage.tsx similarity index 67% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ExecutionStartedMessage/ExecutionStartedMessage.tsx rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ExecutionStartedMessage/ExecutionStartedMessage.tsx index 77c7f8fe9b..1ac3b440e0 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ExecutionStartedMessage/ExecutionStartedMessage.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ExecutionStartedMessage/ExecutionStartedMessage.tsx @@ -1,8 +1,7 @@ -import React from "react"; -import { Text } from "@/components/atoms/Text/Text"; import { Button } from "@/components/atoms/Button/Button"; -import { CheckCircle, Play, ArrowSquareOut } from "@phosphor-icons/react"; +import { Text } from "@/components/atoms/Text/Text"; import { cn } from "@/lib/utils"; +import { ArrowSquareOut, CheckCircle, Play } from "@phosphor-icons/react"; export interface ExecutionStartedMessageProps { executionId: string; @@ -22,7 +21,7 @@ export function ExecutionStartedMessage({ return (
@@ -32,48 +31,33 @@ export function ExecutionStartedMessage({
- + Execution Started - + {message}
{/* Details */} -
+
{agentName && (
- + Agent: - + {agentName}
)}
- + Execution ID: - + {executionId.slice(0, 16)}...
@@ -94,7 +78,7 @@ export function ExecutionStartedMessage({
)} -
+
Your agent is now running. You can monitor its progress in the monitor diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/MarkdownContent/MarkdownContent.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MarkdownContent/MarkdownContent.tsx similarity index 80% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/MarkdownContent/MarkdownContent.tsx rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MarkdownContent/MarkdownContent.tsx index e82e29c438..51a0794090 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/MarkdownContent/MarkdownContent.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MarkdownContent/MarkdownContent.tsx @@ -1,9 +1,9 @@ "use client"; +import { cn } from "@/lib/utils"; import React from "react"; import ReactMarkdown from "react-markdown"; import remarkGfm from "remark-gfm"; -import { cn } from "@/lib/utils"; interface MarkdownContentProps { content: string; @@ -41,7 +41,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) { if (isInline) { return ( {children} @@ -49,17 +49,14 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) { ); } return ( - + {children} ); }, pre: ({ children, ...props }) => (
               {children}
@@ -70,7 +67,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
               href={href}
               target="_blank"
               rel="noopener noreferrer"
-              className="text-purple-600 underline decoration-1 underline-offset-2 hover:text-purple-700 dark:text-purple-400 dark:hover:text-purple-300"
+              className="text-purple-600 underline decoration-1 underline-offset-2 hover:text-purple-700"
               {...props}
             >
               {children}
@@ -126,7 +123,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
               return (
                 
@@ -136,57 +133,42 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
           },
           blockquote: ({ children, ...props }) => (
             
{children}
), h1: ({ children, ...props }) => ( -

+

{children}

), h2: ({ children, ...props }) => ( -

+

{children}

), h3: ({ children, ...props }) => (

{children}

), h4: ({ children, ...props }) => ( -

+

{children}

), h5: ({ children, ...props }) => ( -
+
{children}
), h6: ({ children, ...props }) => ( -
+
{children}
), @@ -196,15 +178,12 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {

), hr: ({ ...props }) => ( -
+
), table: ({ children, ...props }) => (
{children} @@ -213,7 +192,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) { ), th: ({ children, ...props }) => (
{children} @@ -221,7 +200,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) { ), td: ({ children, ...props }) => ( {children} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageBubble/MessageBubble.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageBubble/MessageBubble.tsx new file mode 100644 index 0000000000..98b50f3d28 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageBubble/MessageBubble.tsx @@ -0,0 +1,56 @@ +import { cn } from "@/lib/utils"; +import { ReactNode } from "react"; + +export interface MessageBubbleProps { + children: ReactNode; + variant: "user" | "assistant"; + className?: string; +} + +export function MessageBubble({ + children, + variant, + className, +}: MessageBubbleProps) { + const userTheme = { + bg: "bg-slate-900", + border: "border-slate-800", + gradient: "from-slate-900/30 via-slate-800/20 to-transparent", + text: "text-slate-50", + }; + + const assistantTheme = { + bg: "bg-slate-50/20", + border: "border-slate-100", + gradient: "from-slate-200/20 via-slate-300/10 to-transparent", + text: "text-slate-900", + }; + + const theme = variant === "user" ? userTheme : assistantTheme; + + return ( +
+ {/* Gradient flare background */} +
+
+ {children} +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageList/MessageList.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageList/MessageList.tsx new file mode 100644 index 0000000000..22b51c0a92 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageList/MessageList.tsx @@ -0,0 +1,121 @@ +"use client"; + +import { cn } from "@/lib/utils"; +import { ChatMessage } from "../ChatMessage/ChatMessage"; +import type { ChatMessageData } from "../ChatMessage/useChatMessage"; +import { StreamingMessage } from "../StreamingMessage/StreamingMessage"; +import { ThinkingMessage } from "../ThinkingMessage/ThinkingMessage"; +import { useMessageList } from "./useMessageList"; + +export interface MessageListProps { + messages: ChatMessageData[]; + streamingChunks?: string[]; + isStreaming?: boolean; + className?: string; + onStreamComplete?: () => void; + onSendMessage?: (content: string) => void; +} + +export function MessageList({ + messages, + streamingChunks = [], + isStreaming = false, + className, + onStreamComplete, + onSendMessage, +}: MessageListProps) { + const { messagesEndRef, messagesContainerRef } = useMessageList({ + messageCount: messages.length, + isStreaming, + }); + + return ( +
+
+ {/* Render all persisted messages */} + {messages.map((message, index) => { + // Check if current message is an agent_output tool_response + // and if previous message is an assistant message + let agentOutput: ChatMessageData | undefined; + + if (message.type === "tool_response" && message.result) { + let parsedResult: Record | null = null; + try { + parsedResult = + typeof message.result === "string" + ? JSON.parse(message.result) + : (message.result as Record); + } catch { + parsedResult = null; + } + if (parsedResult?.type === "agent_output") { + const prevMessage = messages[index - 1]; + if ( + prevMessage && + prevMessage.type === "message" && + prevMessage.role === "assistant" + ) { + // This agent output will be rendered inside the previous assistant message + // Skip rendering this message separately + return null; + } + } + } + + // Check if next message is an agent_output tool_response to include in current assistant message + if (message.type === "message" && message.role === "assistant") { + const nextMessage = messages[index + 1]; + if ( + nextMessage && + nextMessage.type === "tool_response" && + nextMessage.result + ) { + let parsedResult: Record | null = null; + try { + parsedResult = + typeof nextMessage.result === "string" + ? JSON.parse(nextMessage.result) + : (nextMessage.result as Record); + } catch { + parsedResult = null; + } + if (parsedResult?.type === "agent_output") { + agentOutput = nextMessage; + } + } + } + + return ( + + ); + })} + + {/* Render thinking message when streaming but no chunks yet */} + {isStreaming && streamingChunks.length === 0 && } + + {/* Render streaming message if active */} + {isStreaming && streamingChunks.length > 0 && ( + + )} + + {/* Invisible div to scroll to */} +
+
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/MessageList/useMessageList.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageList/useMessageList.ts similarity index 100% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/MessageList/useMessageList.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageList/useMessageList.ts diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/NoResultsMessage/NoResultsMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/NoResultsMessage/NoResultsMessage.tsx similarity index 72% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/NoResultsMessage/NoResultsMessage.tsx rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/NoResultsMessage/NoResultsMessage.tsx index 38eac24bce..b6adc8b93c 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/NoResultsMessage/NoResultsMessage.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/NoResultsMessage/NoResultsMessage.tsx @@ -1,7 +1,6 @@ -import React from "react"; import { Text } from "@/components/atoms/Text/Text"; -import { MagnifyingGlass, X } from "@phosphor-icons/react"; import { cn } from "@/lib/utils"; +import { MagnifyingGlass, X } from "@phosphor-icons/react"; export interface NoResultsMessageProps { message: string; @@ -17,26 +16,26 @@ export function NoResultsMessage({ return (
{/* Icon */}
-
+
-
+
{/* Content */}
- + No Results Found - + {message}
@@ -44,17 +43,14 @@ export function NoResultsMessage({ {/* Suggestions */} {suggestions.length > 0 && (
- + Try these suggestions: -
    +
      {suggestions.map((suggestion, index) => (
    • {suggestion} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/QuickActionsWelcome/QuickActionsWelcome.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/QuickActionsWelcome/QuickActionsWelcome.tsx new file mode 100644 index 0000000000..dd76fd9fb6 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/QuickActionsWelcome/QuickActionsWelcome.tsx @@ -0,0 +1,94 @@ +"use client"; + +import { Text } from "@/components/atoms/Text/Text"; +import { cn } from "@/lib/utils"; + +export interface QuickActionsWelcomeProps { + title: string; + description: string; + actions: string[]; + onActionClick: (action: string) => void; + disabled?: boolean; + className?: string; +} + +export function QuickActionsWelcome({ + title, + description, + actions, + onActionClick, + disabled = false, + className, +}: QuickActionsWelcomeProps) { + return ( +
      +
      +
      + + {title} + + + {description} + +
      +
      + {actions.map((action) => { + // Use slate theme for all cards + const theme = { + bg: "bg-slate-50/10", + border: "border-slate-100", + hoverBg: "hover:bg-slate-50/20", + hoverBorder: "hover:border-slate-200", + gradient: "from-slate-200/20 via-slate-300/10 to-transparent", + text: "text-slate-900", + hoverText: "group-hover:text-slate-900", + }; + + return ( + + ); + })} +
      +
      +
      + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/SessionsDrawer/SessionsDrawer.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/SessionsDrawer/SessionsDrawer.tsx new file mode 100644 index 0000000000..74aa709a46 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/SessionsDrawer/SessionsDrawer.tsx @@ -0,0 +1,136 @@ +"use client"; + +import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat"; +import { Text } from "@/components/atoms/Text/Text"; +import { scrollbarStyles } from "@/components/styles/scrollbars"; +import { cn } from "@/lib/utils"; +import { X } from "@phosphor-icons/react"; +import { formatDistanceToNow } from "date-fns"; +import { Drawer } from "vaul"; + +interface SessionsDrawerProps { + isOpen: boolean; + onClose: () => void; + onSelectSession: (sessionId: string) => void; + currentSessionId?: string | null; +} + +export function SessionsDrawer({ + isOpen, + onClose, + onSelectSession, + currentSessionId, +}: SessionsDrawerProps) { + const { data, isLoading } = useGetV2ListSessions( + { limit: 100 }, + { + query: { + enabled: isOpen, + }, + }, + ); + + const sessions = + data?.status === 200 + ? data.data.sessions.filter((session) => { + // Filter out sessions without messages (sessions that were never updated) + // If updated_at equals created_at, the session was created but never had messages + return session.updated_at !== session.created_at; + }) + : []; + + function handleSelectSession(sessionId: string) { + onSelectSession(sessionId); + onClose(); + } + + return ( + !open && onClose()} + direction="right" + > + + + +
      +
      + + Chat Sessions + + +
      +
      + +
      + {isLoading ? ( +
      + + Loading sessions... + +
      + ) : sessions.length === 0 ? ( +
      + + No sessions found + +
      + ) : ( +
      + {sessions.map((session) => { + const isActive = session.id === currentSessionId; + const updatedAt = session.updated_at + ? formatDistanceToNow(new Date(session.updated_at), { + addSuffix: true, + }) + : ""; + + return ( + + ); + })} +
      + )} +
      +
      +
      +
      + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/StreamingMessage/StreamingMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/StreamingMessage/StreamingMessage.tsx new file mode 100644 index 0000000000..2a6e3d5822 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/StreamingMessage/StreamingMessage.tsx @@ -0,0 +1,42 @@ +import { cn } from "@/lib/utils"; +import { RobotIcon } from "@phosphor-icons/react"; +import { MarkdownContent } from "../MarkdownContent/MarkdownContent"; +import { MessageBubble } from "../MessageBubble/MessageBubble"; +import { useStreamingMessage } from "./useStreamingMessage"; + +export interface StreamingMessageProps { + chunks: string[]; + className?: string; + onComplete?: () => void; +} + +export function StreamingMessage({ + chunks, + className, + onComplete, +}: StreamingMessageProps) { + const { displayText } = useStreamingMessage({ chunks, onComplete }); + + return ( +
      +
      +
      +
      + +
      +
      + +
      + + + +
      +
      +
      + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/StreamingMessage/useStreamingMessage.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/StreamingMessage/useStreamingMessage.ts similarity index 100% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/StreamingMessage/useStreamingMessage.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/StreamingMessage/useStreamingMessage.ts diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ThinkingMessage/ThinkingMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ThinkingMessage/ThinkingMessage.tsx new file mode 100644 index 0000000000..d8adddf416 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ThinkingMessage/ThinkingMessage.tsx @@ -0,0 +1,70 @@ +import { cn } from "@/lib/utils"; +import { RobotIcon } from "@phosphor-icons/react"; +import { useEffect, useRef, useState } from "react"; +import { MessageBubble } from "../MessageBubble/MessageBubble"; + +export interface ThinkingMessageProps { + className?: string; +} + +export function ThinkingMessage({ className }: ThinkingMessageProps) { + const [showSlowLoader, setShowSlowLoader] = useState(false); + const timerRef = useRef(null); + + useEffect(() => { + if (timerRef.current === null) { + timerRef.current = setTimeout(() => { + setShowSlowLoader(true); + }, 8000); + } + + return () => { + if (timerRef.current) { + clearTimeout(timerRef.current); + timerRef.current = null; + } + }; + }, []); + + return ( +
      +
      +
      +
      + +
      +
      + +
      + +
      + {showSlowLoader ? ( +
      +
      +

      + Taking a bit longer to think, wait a moment please +

      +
      + ) : ( + + Thinking... + + )} +
      + +
      +
      +
      + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolCallMessage/ToolCallMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolCallMessage/ToolCallMessage.tsx new file mode 100644 index 0000000000..97590ae0cf --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolCallMessage/ToolCallMessage.tsx @@ -0,0 +1,24 @@ +import { Text } from "@/components/atoms/Text/Text"; +import { cn } from "@/lib/utils"; +import { WrenchIcon } from "@phosphor-icons/react"; +import { getToolActionPhrase } from "../../helpers"; + +export interface ToolCallMessageProps { + toolName: string; + className?: string; +} + +export function ToolCallMessage({ toolName, className }: ToolCallMessageProps) { + return ( +
      + + + {getToolActionPhrase(toolName)}... + +
      + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolResponseMessage/ToolResponseMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolResponseMessage/ToolResponseMessage.tsx new file mode 100644 index 0000000000..b84204c3ff --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolResponseMessage/ToolResponseMessage.tsx @@ -0,0 +1,260 @@ +import { Text } from "@/components/atoms/Text/Text"; +import "@/components/contextual/OutputRenderers"; +import { + globalRegistry, + OutputItem, +} from "@/components/contextual/OutputRenderers"; +import { cn } from "@/lib/utils"; +import type { ToolResult } from "@/types/chat"; +import { WrenchIcon } from "@phosphor-icons/react"; +import { getToolActionPhrase } from "../../helpers"; + +export interface ToolResponseMessageProps { + toolName: string; + result?: ToolResult; + success?: boolean; + className?: string; +} + +export function ToolResponseMessage({ + toolName, + result, + success: _success = true, + className, +}: ToolResponseMessageProps) { + if (!result) { + return ( +
      + + + {getToolActionPhrase(toolName)}... + +
      + ); + } + + let parsedResult: Record | null = null; + try { + parsedResult = + typeof result === "string" + ? JSON.parse(result) + : (result as Record); + } catch { + parsedResult = null; + } + + if (parsedResult && typeof parsedResult === "object") { + const responseType = parsedResult.type as string | undefined; + + if (responseType === "agent_output") { + const execution = parsedResult.execution as + | { + outputs?: Record; + } + | null + | undefined; + const outputs = execution?.outputs || {}; + const message = parsedResult.message as string | undefined; + + return ( +
      +
      + + + {getToolActionPhrase(toolName)} + +
      + {message && ( +
      + + {message} + +
      + )} + {Object.keys(outputs).length > 0 && ( +
      + {Object.entries(outputs).map(([outputName, values]) => + values.map((value, index) => { + const renderer = globalRegistry.getRenderer(value); + if (renderer) { + return ( + + ); + } + return ( +
      + + {outputName} + +
      +                        {JSON.stringify(value, null, 2)}
      +                      
      +
      + ); + }), + )} +
      + )} +
      + ); + } + + if (responseType === "block_output" && parsedResult.outputs) { + const outputs = parsedResult.outputs as Record; + + return ( +
      +
      + + + {getToolActionPhrase(toolName)} + +
      +
      + {Object.entries(outputs).map(([outputName, values]) => + values.map((value, index) => { + const renderer = globalRegistry.getRenderer(value); + if (renderer) { + return ( + + ); + } + return ( +
      + + {outputName} + +
      +                      {JSON.stringify(value, null, 2)}
      +                    
      +
      + ); + }), + )} +
      +
      + ); + } + + // Handle other response types with a message field (e.g., understanding_updated) + if (parsedResult.message && typeof parsedResult.message === "string") { + // Format tool name from snake_case to Title Case + const formattedToolName = toolName + .split("_") + .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) + .join(" "); + + // Clean up message - remove incomplete user_name references + let cleanedMessage = parsedResult.message; + // Remove "Updated understanding with: user_name" pattern if user_name is just a placeholder + cleanedMessage = cleanedMessage.replace( + /Updated understanding with:\s*user_name\.?\s*/gi, + "", + ); + // Remove standalone user_name references + cleanedMessage = cleanedMessage.replace(/\buser_name\b\.?\s*/gi, ""); + cleanedMessage = cleanedMessage.trim(); + + // Only show message if it has content after cleaning + if (!cleanedMessage) { + return ( +
      + + + {formattedToolName} + +
      + ); + } + + return ( +
      +
      + + + {formattedToolName} + +
      +
      + + {cleanedMessage} + +
      +
      + ); + } + } + + const renderer = globalRegistry.getRenderer(result); + if (renderer) { + return ( +
      +
      + + + {getToolActionPhrase(toolName)} + +
      + +
      + ); + } + + return ( +
      + + + {getToolActionPhrase(toolName)}... + +
      + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/helpers.ts similarity index 92% rename from autogpt_platform/frontend/src/app/(platform)/chat/helpers.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/helpers.ts index 5a1e5eb93f..0fade56b73 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/helpers.ts @@ -64,10 +64,3 @@ export function getToolCompletionPhrase(toolName: string): string { `Finished ${toolName.replace(/_/g, " ").replace("...", "")}` ); } - -/** Validate UUID v4 format */ -export function isValidUUID(value: string): boolean { - const uuidRegex = - /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i; - return uuidRegex.test(value); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/useChatPage.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChat.ts similarity index 78% rename from autogpt_platform/frontend/src/app/(platform)/chat/useChatPage.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChat.ts index 4f1db5471a..8445d68b3f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/useChatPage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChat.ts @@ -1,17 +1,12 @@ "use client"; -import { useEffect, useRef } from "react"; -import { useRouter, useSearchParams } from "next/navigation"; -import { toast } from "sonner"; -import { useChatSession } from "@/app/(platform)/chat/useChatSession"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; -import { useChatStream } from "@/app/(platform)/chat/useChatStream"; +import { useEffect, useRef } from "react"; +import { toast } from "sonner"; +import { useChatSession } from "./useChatSession"; +import { useChatStream } from "./useChatStream"; -export function useChatPage() { - const router = useRouter(); - const searchParams = useSearchParams(); - const urlSessionId = - searchParams.get("session_id") || searchParams.get("session"); +export function useChat() { const hasCreatedSessionRef = useRef(false); const hasClaimedSessionRef = useRef(false); const { user } = useSupabase(); @@ -25,29 +20,24 @@ export function useChatPage() { isCreating, error, createSession, - refreshSession, claimSession, clearSession: clearSessionBase, + loadSession, } = useChatSession({ - urlSessionId, + urlSessionId: null, autoCreate: false, }); useEffect( function autoCreateSession() { - if ( - !urlSessionId && - !hasCreatedSessionRef.current && - !isCreating && - !sessionIdFromHook - ) { + if (!hasCreatedSessionRef.current && !isCreating && !sessionIdFromHook) { hasCreatedSessionRef.current = true; createSession().catch((_err) => { hasCreatedSessionRef.current = false; }); } }, - [urlSessionId, isCreating, sessionIdFromHook, createSession], + [isCreating, sessionIdFromHook, createSession], ); useEffect( @@ -111,7 +101,6 @@ export function useChatPage() { clearSessionBase(); hasCreatedSessionRef.current = false; hasClaimedSessionRef.current = false; - router.push("/chat"); } return { @@ -121,8 +110,8 @@ export function useChatPage() { isCreating, error, createSession, - refreshSession, clearSession, + loadSession, sessionId: sessionIdFromHook, }; } diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatDrawer.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatDrawer.ts new file mode 100644 index 0000000000..62e1a5a569 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatDrawer.ts @@ -0,0 +1,17 @@ +"use client"; + +import { create } from "zustand"; + +interface ChatDrawerState { + isOpen: boolean; + open: () => void; + close: () => void; + toggle: () => void; +} + +export const useChatDrawer = create((set) => ({ + isOpen: false, + open: () => set({ isOpen: true }), + close: () => set({ isOpen: false }), + toggle: () => set((state) => ({ isOpen: !state.isOpen })), +})); diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/useChatSession.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatSession.ts similarity index 89% rename from autogpt_platform/frontend/src/app/(platform)/chat/useChatSession.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatSession.ts index 99f4efc093..a54dc9e32a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/useChatSession.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatSession.ts @@ -1,17 +1,18 @@ -import { useCallback, useEffect, useState, useRef, useMemo } from "react"; -import { useQueryClient } from "@tanstack/react-query"; -import { toast } from "sonner"; import { - usePostV2CreateSession, + getGetV2GetSessionQueryKey, + getGetV2GetSessionQueryOptions, postV2CreateSession, useGetV2GetSession, usePatchV2SessionAssignUser, - getGetV2GetSessionQueryKey, + usePostV2CreateSession, } from "@/app/api/__generated__/endpoints/chat/chat"; import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; -import { storage, Key } from "@/services/storage/local-storage"; -import { isValidUUID } from "@/app/(platform)/chat/helpers"; import { okData } from "@/app/api/helpers"; +import { isValidUUID } from "@/lib/utils"; +import { Key, storage } from "@/services/storage/local-storage"; +import { useQueryClient } from "@tanstack/react-query"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { toast } from "sonner"; interface UseChatSessionArgs { urlSessionId?: string | null; @@ -155,10 +156,22 @@ export function useChatSession({ async function loadSession(id: string) { try { setError(null); + // Invalidate the query cache for this session to force a fresh fetch + await queryClient.invalidateQueries({ + queryKey: getGetV2GetSessionQueryKey(id), + }); + // Set sessionId after invalidation to ensure the hook refetches setSessionId(id); storage.set(Key.CHAT_SESSION_ID, id); - const result = await refetch(); - if (!result.data || result.isError) { + // Force fetch with fresh data (bypass cache) + const queryOptions = getGetV2GetSessionQueryOptions(id, { + query: { + staleTime: 0, // Force fresh fetch + retry: 1, + }, + }); + const result = await queryClient.fetchQuery(queryOptions); + if (!result || ("status" in result && result.status !== 200)) { console.warn("Session not found on server, clearing local state"); storage.clean(Key.CHAT_SESSION_ID); setSessionId(null); @@ -171,7 +184,7 @@ export function useChatSession({ throw error; } }, - [refetch], + [queryClient], ); const refreshSession = useCallback( diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatStream.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatStream.ts new file mode 100644 index 0000000000..1471a13a71 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatStream.ts @@ -0,0 +1,371 @@ +import type { ToolArguments, ToolResult } from "@/types/chat"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { toast } from "sonner"; + +const MAX_RETRIES = 3; +const INITIAL_RETRY_DELAY = 1000; + +export interface StreamChunk { + type: + | "text_chunk" + | "text_ended" + | "tool_call" + | "tool_call_start" + | "tool_response" + | "login_needed" + | "need_login" + | "credentials_needed" + | "error" + | "usage" + | "stream_end"; + timestamp?: string; + content?: string; + message?: string; + tool_id?: string; + tool_name?: string; + arguments?: ToolArguments; + result?: ToolResult; + success?: boolean; + idx?: number; + session_id?: string; + agent_info?: { + graph_id: string; + name: string; + trigger_type: string; + }; + provider?: string; + provider_name?: string; + credential_type?: string; + scopes?: string[]; + title?: string; + [key: string]: unknown; +} + +type VercelStreamChunk = + | { type: "start"; messageId: string } + | { type: "finish" } + | { type: "text-start"; id: string } + | { type: "text-delta"; id: string; delta: string } + | { type: "text-end"; id: string } + | { type: "tool-input-start"; toolCallId: string; toolName: string } + | { + type: "tool-input-available"; + toolCallId: string; + toolName: string; + input: ToolArguments; + } + | { + type: "tool-output-available"; + toolCallId: string; + toolName?: string; + output: ToolResult; + success?: boolean; + } + | { + type: "usage"; + promptTokens: number; + completionTokens: number; + totalTokens: number; + } + | { + type: "error"; + errorText: string; + code?: string; + details?: Record; + }; + +const LEGACY_STREAM_TYPES = new Set([ + "text_chunk", + "text_ended", + "tool_call", + "tool_call_start", + "tool_response", + "login_needed", + "need_login", + "credentials_needed", + "error", + "usage", + "stream_end", +]); + +function isLegacyStreamChunk( + chunk: StreamChunk | VercelStreamChunk, +): chunk is StreamChunk { + return LEGACY_STREAM_TYPES.has(chunk.type as StreamChunk["type"]); +} + +function normalizeStreamChunk( + chunk: StreamChunk | VercelStreamChunk, +): StreamChunk | null { + if (isLegacyStreamChunk(chunk)) { + return chunk; + } + switch (chunk.type) { + case "text-delta": + return { type: "text_chunk", content: chunk.delta }; + case "text-end": + return { type: "text_ended" }; + case "tool-input-available": + return { + type: "tool_call_start", + tool_id: chunk.toolCallId, + tool_name: chunk.toolName, + arguments: chunk.input, + }; + case "tool-output-available": + return { + type: "tool_response", + tool_id: chunk.toolCallId, + tool_name: chunk.toolName, + result: chunk.output, + success: chunk.success ?? true, + }; + case "usage": + return { + type: "usage", + promptTokens: chunk.promptTokens, + completionTokens: chunk.completionTokens, + totalTokens: chunk.totalTokens, + }; + case "error": + return { + type: "error", + message: chunk.errorText, + code: chunk.code, + details: chunk.details, + }; + case "finish": + return { type: "stream_end" }; + case "start": + case "text-start": + case "tool-input-start": + return null; + } +} + +export function useChatStream() { + const [isStreaming, setIsStreaming] = useState(false); + const [error, setError] = useState(null); + const retryCountRef = useRef(0); + const retryTimeoutRef = useRef(null); + const abortControllerRef = useRef(null); + + const stopStreaming = useCallback(() => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + } + if (retryTimeoutRef.current) { + clearTimeout(retryTimeoutRef.current); + retryTimeoutRef.current = null; + } + setIsStreaming(false); + }, []); + + useEffect(() => { + return () => { + stopStreaming(); + }; + }, [stopStreaming]); + + const sendMessage = useCallback( + async ( + sessionId: string, + message: string, + onChunk: (chunk: StreamChunk) => void, + isUserMessage: boolean = true, + context?: { url: string; content: string }, + isRetry: boolean = false, + ) => { + stopStreaming(); + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + if (abortController.signal.aborted) { + return Promise.reject(new Error("Request aborted")); + } + + if (!isRetry) { + retryCountRef.current = 0; + } + setIsStreaming(true); + setError(null); + + try { + const url = `/api/chat/sessions/${sessionId}/stream`; + const body = JSON.stringify({ + message, + is_user_message: isUserMessage, + context: context || null, + }); + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + }, + body, + signal: abortController.signal, + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(errorText || `HTTP ${response.status}`); + } + + if (!response.body) { + throw new Error("Response body is null"); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + return new Promise((resolve, reject) => { + let didDispatchStreamEnd = false; + + function dispatchStreamEnd() { + if (didDispatchStreamEnd) return; + didDispatchStreamEnd = true; + onChunk({ type: "stream_end" }); + } + + const cleanup = () => { + reader.cancel().catch(() => { + // Ignore cancel errors + }); + }; + + async function readStream() { + try { + while (true) { + const { done, value } = await reader.read(); + + if (done) { + cleanup(); + dispatchStreamEnd(); + retryCountRef.current = 0; + stopStreaming(); + resolve(); + return; + } + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + if (line.startsWith("data: ")) { + const data = line.slice(6); + if (data === "[DONE]") { + cleanup(); + dispatchStreamEnd(); + retryCountRef.current = 0; + stopStreaming(); + resolve(); + return; + } + + try { + const rawChunk = JSON.parse(data) as + | StreamChunk + | VercelStreamChunk; + const chunk = normalizeStreamChunk(rawChunk); + if (!chunk) { + continue; + } + + // Call the chunk handler + onChunk(chunk); + + // Handle stream lifecycle + if (chunk.type === "stream_end") { + didDispatchStreamEnd = true; + cleanup(); + retryCountRef.current = 0; + stopStreaming(); + resolve(); + return; + } else if (chunk.type === "error") { + cleanup(); + reject( + new Error( + chunk.message || chunk.content || "Stream error", + ), + ); + return; + } + } catch (err) { + // Skip invalid JSON lines + console.warn("Failed to parse SSE chunk:", err, data); + } + } + } + } + } catch (err) { + if (err instanceof Error && err.name === "AbortError") { + cleanup(); + return; + } + + const streamError = + err instanceof Error ? err : new Error("Failed to read stream"); + + if (retryCountRef.current < MAX_RETRIES) { + retryCountRef.current += 1; + const retryDelay = + INITIAL_RETRY_DELAY * Math.pow(2, retryCountRef.current - 1); + + toast.info("Connection interrupted", { + description: `Retrying in ${retryDelay / 1000} seconds...`, + }); + + retryTimeoutRef.current = setTimeout(() => { + sendMessage( + sessionId, + message, + onChunk, + isUserMessage, + context, + true, + ).catch((_err) => { + // Retry failed + }); + }, retryDelay); + } else { + setError(streamError); + toast.error("Connection Failed", { + description: + "Unable to connect to chat service. Please try again.", + }); + cleanup(); + dispatchStreamEnd(); + retryCountRef.current = 0; + stopStreaming(); + reject(streamError); + } + } + } + + readStream(); + }); + } catch (err) { + const streamError = + err instanceof Error ? err : new Error("Failed to start stream"); + setError(streamError); + setIsStreaming(false); + throw streamError; + } + }, + [stopStreaming], + ); + + return { + isStreaming, + error, + sendMessage, + stopStreaming, + }; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/usePageContext.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/usePageContext.ts new file mode 100644 index 0000000000..c567422a5c --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/usePageContext.ts @@ -0,0 +1,98 @@ +import { useCallback } from "react"; + +export interface PageContext { + url: string; + content: string; +} + +const MAX_CONTENT_CHARS = 10000; + +/** + * Hook to capture the current page context (URL + full page content) + * Privacy-hardened: removes sensitive inputs and enforces content size limits + */ +export function usePageContext() { + const capturePageContext = useCallback((): PageContext => { + if (typeof window === "undefined" || typeof document === "undefined") { + return { url: "", content: "" }; + } + + const url = window.location.href; + + // Clone document to avoid modifying the original + const clone = document.cloneNode(true) as Document; + + // Remove script, style, and noscript elements + const scripts = clone.querySelectorAll("script, style, noscript"); + scripts.forEach((el) => el.remove()); + + // Remove sensitive elements and their content + const sensitiveSelectors = [ + "input", + "textarea", + "[contenteditable]", + 'input[type="password"]', + 'input[type="email"]', + 'input[type="tel"]', + 'input[type="search"]', + 'input[type="hidden"]', + "form", + "[data-sensitive]", + "[data-sensitive='true']", + ]; + + sensitiveSelectors.forEach((selector) => { + const elements = clone.querySelectorAll(selector); + elements.forEach((el) => { + // For form elements, remove the entire element + if (el.tagName === "FORM") { + el.remove(); + } else { + // For inputs and textareas, clear their values but keep the element structure + if ( + el instanceof HTMLInputElement || + el instanceof HTMLTextAreaElement + ) { + el.value = ""; + el.textContent = ""; + } else { + // For other sensitive elements, remove them entirely + el.remove(); + } + } + }); + }); + + // Strip any remaining input values that might have been missed + const allInputs = clone.querySelectorAll("input, textarea"); + allInputs.forEach((el) => { + if (el instanceof HTMLInputElement || el instanceof HTMLTextAreaElement) { + el.value = ""; + el.textContent = ""; + } + }); + + // Get text content from body + const body = clone.body; + const content = body?.textContent || body?.innerText || ""; + + // Clean up whitespace + let cleanedContent = content + .replace(/\s+/g, " ") + .replace(/\n\s*\n/g, "\n") + .trim(); + + // Enforce maximum content size + if (cleanedContent.length > MAX_CONTENT_CHARS) { + cleanedContent = + cleanedContent.substring(0, MAX_CONTENT_CHARS) + "... [truncated]"; + } + + return { + url, + content: cleanedContent, + }; + }, []); + + return { capturePageContext }; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/ChatContainer.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/ChatContainer.tsx deleted file mode 100644 index 32f9d6c6eb..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/ChatContainer.tsx +++ /dev/null @@ -1,68 +0,0 @@ -import { cn } from "@/lib/utils"; -import { ChatInput } from "@/app/(platform)/chat/components/ChatInput/ChatInput"; -import { MessageList } from "@/app/(platform)/chat/components/MessageList/MessageList"; -import { QuickActionsWelcome } from "@/app/(platform)/chat/components/QuickActionsWelcome/QuickActionsWelcome"; -import { useChatContainer } from "./useChatContainer"; -import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; - -export interface ChatContainerProps { - sessionId: string | null; - initialMessages: SessionDetailResponse["messages"]; - onRefreshSession: () => Promise; - className?: string; -} - -export function ChatContainer({ - sessionId, - initialMessages, - onRefreshSession, - className, -}: ChatContainerProps) { - const { messages, streamingChunks, isStreaming, sendMessage } = - useChatContainer({ - sessionId, - initialMessages, - onRefreshSession, - }); - - const quickActions = [ - "Find agents for social media management", - "Show me agents for content creation", - "Help me automate my business", - "What can you help me with?", - ]; - - return ( -
      - {/* Messages or Welcome Screen */} - {messages.length === 0 ? ( - - ) : ( - - )} - - {/* Input - Always visible */} -
      - -
      -
      - ); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.ts deleted file mode 100644 index c75ad587a5..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.ts +++ /dev/null @@ -1,130 +0,0 @@ -import { useState, useCallback, useRef, useMemo } from "react"; -import { toast } from "sonner"; -import { useChatStream } from "@/app/(platform)/chat/useChatStream"; -import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; -import type { ChatMessageData } from "@/app/(platform)/chat/components/ChatMessage/useChatMessage"; -import { - parseToolResponse, - isValidMessage, - isToolCallArray, - createUserMessage, - filterAuthMessages, -} from "./helpers"; -import { createStreamEventDispatcher } from "./createStreamEventDispatcher"; - -interface UseChatContainerArgs { - sessionId: string | null; - initialMessages: SessionDetailResponse["messages"]; - onRefreshSession: () => Promise; -} - -export function useChatContainer({ - sessionId, - initialMessages, -}: UseChatContainerArgs) { - const [messages, setMessages] = useState([]); - const [streamingChunks, setStreamingChunks] = useState([]); - const [hasTextChunks, setHasTextChunks] = useState(false); - const streamingChunksRef = useRef([]); - const { error, sendMessage: sendStreamMessage } = useChatStream(); - const isStreaming = hasTextChunks; - - const allMessages = useMemo(() => { - const processedInitialMessages = initialMessages - .filter((msg: Record) => { - if (!isValidMessage(msg)) { - console.warn("Invalid message structure from backend:", msg); - return false; - } - const content = String(msg.content || "").trim(); - const toolCalls = msg.tool_calls; - return ( - content.length > 0 || - (toolCalls && Array.isArray(toolCalls) && toolCalls.length > 0) - ); - }) - .map((msg: Record) => { - const content = String(msg.content || ""); - const role = String(msg.role || "assistant").toLowerCase(); - const toolCalls = msg.tool_calls; - if ( - role === "assistant" && - toolCalls && - isToolCallArray(toolCalls) && - toolCalls.length > 0 - ) { - return null; - } - if (role === "tool") { - const timestamp = msg.timestamp - ? new Date(msg.timestamp as string) - : undefined; - const toolResponse = parseToolResponse( - content, - (msg.tool_call_id as string) || "", - "unknown", - timestamp, - ); - if (!toolResponse) { - return null; - } - return toolResponse; - } - return { - type: "message", - role: role as "user" | "assistant" | "system", - content, - timestamp: msg.timestamp - ? new Date(msg.timestamp as string) - : undefined, - }; - }) - .filter((msg): msg is ChatMessageData => msg !== null); - - return [...processedInitialMessages, ...messages]; - }, [initialMessages, messages]); - - const sendMessage = useCallback( - async function sendMessage(content: string, isUserMessage: boolean = true) { - if (!sessionId) { - console.error("Cannot send message: no session ID"); - return; - } - if (isUserMessage) { - const userMessage = createUserMessage(content); - setMessages((prev) => [...filterAuthMessages(prev), userMessage]); - } else { - setMessages((prev) => filterAuthMessages(prev)); - } - setStreamingChunks([]); - streamingChunksRef.current = []; - setHasTextChunks(false); - const dispatcher = createStreamEventDispatcher({ - setHasTextChunks, - setStreamingChunks, - streamingChunksRef, - setMessages, - sessionId, - }); - try { - await sendStreamMessage(sessionId, content, dispatcher, isUserMessage); - } catch (err) { - console.error("Failed to send message:", err); - const errorMessage = - err instanceof Error ? err.message : "Failed to send message"; - toast.error("Failed to send message", { - description: errorMessage, - }); - } - }, - [sessionId, sendStreamMessage], - ); - - return { - messages: allMessages, - streamingChunks, - isStreaming, - error, - sendMessage, - }; -} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx deleted file mode 100644 index 3868e17a10..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx +++ /dev/null @@ -1,153 +0,0 @@ -import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput"; -import { Card } from "@/components/atoms/Card/Card"; -import { Text } from "@/components/atoms/Text/Text"; -import type { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api"; -import { cn } from "@/lib/utils"; -import { CheckIcon, KeyIcon, WarningIcon } from "@phosphor-icons/react"; -import { useEffect, useRef } from "react"; -import { useChatCredentialsSetup } from "./useChatCredentialsSetup"; - -export interface CredentialInfo { - provider: string; - providerName: string; - credentialType: "api_key" | "oauth2" | "user_password" | "host_scoped"; - title: string; - scopes?: string[]; -} - -interface Props { - credentials: CredentialInfo[]; - agentName?: string; - message: string; - onAllCredentialsComplete: () => void; - onCancel: () => void; - className?: string; -} - -function createSchemaFromCredentialInfo( - credential: CredentialInfo, -): BlockIOCredentialsSubSchema { - return { - type: "object", - properties: {}, - credentials_provider: [credential.provider], - credentials_types: [credential.credentialType], - credentials_scopes: credential.scopes, - discriminator: undefined, - discriminator_mapping: undefined, - discriminator_values: undefined, - }; -} - -export function ChatCredentialsSetup({ - credentials, - agentName: _agentName, - message, - onAllCredentialsComplete, - onCancel: _onCancel, - className, -}: Props) { - const { selectedCredentials, isAllComplete, handleCredentialSelect } = - useChatCredentialsSetup(credentials); - - // Track if we've already called completion to prevent double calls - const hasCalledCompleteRef = useRef(false); - - // Reset the completion flag when credentials change (new credential setup flow) - useEffect( - function resetCompletionFlag() { - hasCalledCompleteRef.current = false; - }, - [credentials], - ); - - // Auto-call completion when all credentials are configured - useEffect( - function autoCompleteWhenReady() { - if (isAllComplete && !hasCalledCompleteRef.current) { - hasCalledCompleteRef.current = true; - onAllCredentialsComplete(); - } - }, - [isAllComplete, onAllCredentialsComplete], - ); - - return ( - -
      -
      - -
      -
      - - Credentials Required - - - {message} - - -
      - {credentials.map((cred, index) => { - const schema = createSchemaFromCredentialInfo(cred); - const isSelected = !!selectedCredentials[cred.provider]; - - return ( -
      -
      -
      - {isSelected ? ( - - ) : ( - - )} - - {cred.providerName} - -
      -
      - - - handleCredentialSelect(cred.provider, credMeta) - } - /> -
      - ); - })} -
      -
      -
      -
      - ); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/ChatInput.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/ChatInput.tsx deleted file mode 100644 index f1caceef70..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/ChatInput.tsx +++ /dev/null @@ -1,63 +0,0 @@ -import { cn } from "@/lib/utils"; -import { PaperPlaneRightIcon } from "@phosphor-icons/react"; -import { Button } from "@/components/atoms/Button/Button"; -import { useChatInput } from "./useChatInput"; - -export interface ChatInputProps { - onSend: (message: string) => void; - disabled?: boolean; - placeholder?: string; - className?: string; -} - -export function ChatInput({ - onSend, - disabled = false, - placeholder = "Type your message...", - className, -}: ChatInputProps) { - const { value, setValue, handleKeyDown, handleSend, textareaRef } = - useChatInput({ - onSend, - disabled, - maxRows: 5, - }); - - return ( -
      -