From e80e4d9cbb81b8719b12f0b434ed9c19c7e41754 Mon Sep 17 00:00:00 2001 From: Nicholas Tindle Date: Thu, 15 Jan 2026 13:02:48 -0700 Subject: [PATCH 01/32] ci: update dev from gitbook (#11757) gitbook changes via ui --- > [!NOTE] > **Docs sync from GitBook** > > - Updates `docs/home/README.md` with a new Developer Platform landing page (cards, links to Platform, Integrations, Contribute, Discord, GitHub) and metadata/cover settings > - Adds `docs/home/SUMMARY.md` defining the table of contents linking to `README.md` > - No application/runtime code changes > > Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit 446c71fec89fd7d295cdc74378a2644ea65fb4b4. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot). --------- Co-authored-by: Claude Opus 4.5 Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> --- docs/{ => home}/.gitbook/assets/AGPT_Platform.png | Bin docs/{ => home}/.gitbook/assets/Banner_image.png | Bin docs/{ => home}/.gitbook/assets/Contribute.png | Bin docs/{ => home}/.gitbook/assets/Integrations.png | Bin .../assets/Screenshot 2025-08-11 at 12.21.17 PM.png | Bin docs/{ => home}/.gitbook/assets/api-reference.jpg | Bin docs/{ => home}/.gitbook/assets/hosted.jpg | Bin docs/{ => home}/.gitbook/assets/no-code.jpg | Bin ...icate-prediction-yt6p2d3gjhrma0ctdsv8vp1t70.jpeg | Bin docs/{ => home}/README.md | 0 docs/{ => home}/SUMMARY.md | 0 11 files changed, 0 insertions(+), 0 deletions(-) rename docs/{ => home}/.gitbook/assets/AGPT_Platform.png (100%) rename docs/{ => home}/.gitbook/assets/Banner_image.png (100%) rename docs/{ => home}/.gitbook/assets/Contribute.png (100%) rename docs/{ => home}/.gitbook/assets/Integrations.png (100%) rename docs/{ => home}/.gitbook/assets/Screenshot 2025-08-11 at 12.21.17 PM.png (100%) rename docs/{ => home}/.gitbook/assets/api-reference.jpg (100%) rename docs/{ => home}/.gitbook/assets/hosted.jpg (100%) rename docs/{ => home}/.gitbook/assets/no-code.jpg (100%) rename docs/{ => home}/.gitbook/assets/replicate-prediction-yt6p2d3gjhrma0ctdsv8vp1t70.jpeg (100%) rename docs/{ => home}/README.md (100%) rename docs/{ => home}/SUMMARY.md (100%) diff --git a/docs/.gitbook/assets/AGPT_Platform.png b/docs/home/.gitbook/assets/AGPT_Platform.png similarity index 100% rename from docs/.gitbook/assets/AGPT_Platform.png rename to docs/home/.gitbook/assets/AGPT_Platform.png diff --git a/docs/.gitbook/assets/Banner_image.png b/docs/home/.gitbook/assets/Banner_image.png similarity index 100% rename from docs/.gitbook/assets/Banner_image.png rename to docs/home/.gitbook/assets/Banner_image.png diff --git a/docs/.gitbook/assets/Contribute.png b/docs/home/.gitbook/assets/Contribute.png similarity index 100% rename from docs/.gitbook/assets/Contribute.png rename to docs/home/.gitbook/assets/Contribute.png diff --git a/docs/.gitbook/assets/Integrations.png b/docs/home/.gitbook/assets/Integrations.png similarity index 100% rename from docs/.gitbook/assets/Integrations.png rename to docs/home/.gitbook/assets/Integrations.png diff --git a/docs/.gitbook/assets/Screenshot 2025-08-11 at 12.21.17 PM.png b/docs/home/.gitbook/assets/Screenshot 2025-08-11 at 12.21.17 PM.png similarity index 100% rename from docs/.gitbook/assets/Screenshot 2025-08-11 at 12.21.17 PM.png rename to docs/home/.gitbook/assets/Screenshot 2025-08-11 at 12.21.17 PM.png diff --git a/docs/.gitbook/assets/api-reference.jpg b/docs/home/.gitbook/assets/api-reference.jpg similarity index 100% rename from docs/.gitbook/assets/api-reference.jpg rename to docs/home/.gitbook/assets/api-reference.jpg diff --git a/docs/.gitbook/assets/hosted.jpg b/docs/home/.gitbook/assets/hosted.jpg similarity index 100% rename from docs/.gitbook/assets/hosted.jpg rename to docs/home/.gitbook/assets/hosted.jpg diff --git a/docs/.gitbook/assets/no-code.jpg b/docs/home/.gitbook/assets/no-code.jpg similarity index 100% rename from docs/.gitbook/assets/no-code.jpg rename to docs/home/.gitbook/assets/no-code.jpg diff --git a/docs/.gitbook/assets/replicate-prediction-yt6p2d3gjhrma0ctdsv8vp1t70.jpeg b/docs/home/.gitbook/assets/replicate-prediction-yt6p2d3gjhrma0ctdsv8vp1t70.jpeg similarity index 100% rename from docs/.gitbook/assets/replicate-prediction-yt6p2d3gjhrma0ctdsv8vp1t70.jpeg rename to docs/home/.gitbook/assets/replicate-prediction-yt6p2d3gjhrma0ctdsv8vp1t70.jpeg diff --git a/docs/README.md b/docs/home/README.md similarity index 100% rename from docs/README.md rename to docs/home/README.md diff --git a/docs/SUMMARY.md b/docs/home/SUMMARY.md similarity index 100% rename from docs/SUMMARY.md rename to docs/home/SUMMARY.md From 8b83bb8647dfad90f5904ed55a55fa45686ef9d1 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Fri, 16 Jan 2026 02:47:19 -0600 Subject: [PATCH 02/32] feat(backend): unified hybrid search with embedding backfill for all content types (#11767) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR extends the embedding system to support **blocks** and **documentation** content types in addition to store agents, and introduces **unified hybrid search** across all content types using a single `UnifiedContentEmbedding` table. ### Key Changes 1. **Unified Hybrid Search Architecture** - Added `search` tsvector column to `UnifiedContentEmbedding` table - New `unified_hybrid_search()` function searches across all content types (agents, blocks, docs) - Updated `hybrid_search()` for store agents to use `UnifiedContentEmbedding.search` - Removed deprecated `search` column from `StoreListingVersion` table 2. **Pluggable Content Handler Architecture** - Created abstract `ContentHandler` base class for extensibility - Implemented handlers: `StoreAgentHandler`, `BlockHandler`, `DocumentationHandler` - Registry pattern for easy addition of new content types 3. **Block Embeddings** - Discovers all blocks using `get_blocks()` - Extracts searchable text from: name, description, categories, input/output schemas 4. **Documentation Embeddings** - Scans `/docs/` directory for `.md` and `.mdx` files - Extracts title from first `#` heading or uses filename as fallback 5. **Hybrid Search Graceful Degradation** - Falls back to lexical-only search if query embedding generation fails - Redistributes semantic weight proportionally to other components - Logs warning instead of throwing error 6. **Database Migrations** - `20260115200000_add_unified_search_tsvector`: Adds search column to UnifiedContentEmbedding with auto-update trigger - `20260115210000_remove_storelistingversion_search`: Removes deprecated search column and updates StoreAgent view 7. **Orphan Cleanup** - `cleanup_orphaned_embeddings()` removes embeddings for deleted content - Always runs after backfill, even at 100% coverage ### Review Comments Addressed - ✅ SQL parameter index bug when user_id provided (embeddings.py) - ✅ Early return skipping cleanup at 100% coverage (scheduler.py) - ✅ Inconsistent return structure across code paths (scheduler.py) - ✅ SQL UNION syntax error - added parentheses for ORDER BY/LIMIT (hybrid_search.py) - ✅ Version numeric ordering in aggregations (migration) - ✅ Embedding dimension uses EMBEDDING_DIM constant ### Files Changed - `backend/api/features/store/content_handlers.py` (NEW): Handler architecture - `backend/api/features/store/embeddings.py`: Refactored to use handlers - `backend/api/features/store/hybrid_search.py`: Unified search + graceful degradation - `backend/executor/scheduler.py`: Process all content types, consistent returns - `migrations/20260115200000_add_unified_search_tsvector/`: Add tsvector to unified table - `migrations/20260115210000_remove_storelistingversion_search/`: Remove old search column - `schema.prisma`: Updated UnifiedContentEmbedding and StoreListingVersion models - `*_test.py`: Added tests for unified_hybrid_search ## Test Plan 1. ✅ All tests passing on Python 3.11, 3.12, 3.13 2. ✅ Types check passing 3. ✅ CodeRabbit and Sentry reviews addressed 4. Deploy to staging and verify: - Backfill job processes all content types - Search results include blocks and docs - Search works without OpenAI API (graceful degradation) 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Swifty Co-authored-by: Claude Opus 4.5 --- .dockerignore | 3 + autogpt_platform/backend/Dockerfile | 1 + .../api/features/store/content_handlers.py | 431 ++++++++++ .../content_handlers_integration_test.py | 215 +++++ .../features/store/content_handlers_test.py | 324 ++++++++ .../backend/api/features/store/embeddings.py | 584 +++++++++++--- .../api/features/store/embeddings_e2e_test.py | 666 ++++++++++++++++ .../features/store/embeddings_schema_test.py | 126 ++- .../api/features/store/embeddings_test.py | 102 ++- .../api/features/store/hybrid_search.py | 752 ++++++++++++------ .../api/features/store/hybrid_search_test.py | 371 ++++++++- .../backend/api/features/store/model.py | 20 + .../backend/api/features/store/routes.py | 99 +++ .../features/store/semantic_search_test.py | 272 +++++++ .../backend/backend/executor/database.py | 3 + .../backend/backend/executor/scheduler.py | 118 ++- .../migration.sql | 2 + .../migration.sql | 35 + .../migration.sql | 90 +++ autogpt_platform/backend/schema.prisma | 5 +- .../frontend/src/app/api/openapi.json | 115 +++ 21 files changed, 3810 insertions(+), 524 deletions(-) create mode 100644 autogpt_platform/backend/backend/api/features/store/content_handlers.py create mode 100644 autogpt_platform/backend/backend/api/features/store/content_handlers_integration_test.py create mode 100644 autogpt_platform/backend/backend/api/features/store/content_handlers_test.py create mode 100644 autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py create mode 100644 autogpt_platform/backend/backend/api/features/store/semantic_search_test.py create mode 100644 autogpt_platform/backend/migrations/20260115200000_add_unified_search_tsvector/migration.sql create mode 100644 autogpt_platform/backend/migrations/20260115210000_remove_storelistingversion_search/migration.sql diff --git a/.dockerignore b/.dockerignore index c9524ce700..9b744e7f9b 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,6 +1,9 @@ # Ignore everything by default, selectively add things to context * +# Documentation (for embeddings/search) +!docs/ + # Platform - Libs !autogpt_platform/autogpt_libs/autogpt_libs/ !autogpt_platform/autogpt_libs/pyproject.toml diff --git a/autogpt_platform/backend/Dockerfile b/autogpt_platform/backend/Dockerfile index b3389d1787..103226d079 100644 --- a/autogpt_platform/backend/Dockerfile +++ b/autogpt_platform/backend/Dockerfile @@ -100,6 +100,7 @@ COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migration FROM server_dependencies AS server COPY autogpt_platform/backend /app/autogpt_platform/backend +COPY docs /app/docs RUN poetry install --no-ansi --only-root ENV PORT=8000 diff --git a/autogpt_platform/backend/backend/api/features/store/content_handlers.py b/autogpt_platform/backend/backend/api/features/store/content_handlers.py new file mode 100644 index 0000000000..1758ebf067 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/store/content_handlers.py @@ -0,0 +1,431 @@ +""" +Content Type Handlers for Unified Embeddings + +Pluggable system for different content sources (store agents, blocks, docs). +Each handler knows how to fetch and process its content type for embedding. +""" + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from prisma.enums import ContentType + +from backend.data.db import query_raw_with_schema + +logger = logging.getLogger(__name__) + + +@dataclass +class ContentItem: + """Represents a piece of content to be embedded.""" + + content_id: str # Unique identifier (DB ID or file path) + content_type: ContentType + searchable_text: str # Combined text for embedding + metadata: dict[str, Any] # Content-specific metadata + user_id: str | None = None # For user-scoped content + + +class ContentHandler(ABC): + """Base handler for fetching and processing content for embeddings.""" + + @property + @abstractmethod + def content_type(self) -> ContentType: + """The ContentType this handler manages.""" + pass + + @abstractmethod + async def get_missing_items(self, batch_size: int) -> list[ContentItem]: + """ + Fetch items that don't have embeddings yet. + + Args: + batch_size: Maximum number of items to return + + Returns: + List of ContentItem objects ready for embedding + """ + pass + + @abstractmethod + async def get_stats(self) -> dict[str, int]: + """ + Get statistics about embedding coverage. + + Returns: + Dict with keys: total, with_embeddings, without_embeddings + """ + pass + + +class StoreAgentHandler(ContentHandler): + """Handler for marketplace store agent listings.""" + + @property + def content_type(self) -> ContentType: + return ContentType.STORE_AGENT + + async def get_missing_items(self, batch_size: int) -> list[ContentItem]: + """Fetch approved store listings without embeddings.""" + from backend.api.features.store.embeddings import build_searchable_text + + missing = await query_raw_with_schema( + """ + SELECT + slv.id, + slv.name, + slv.description, + slv."subHeading", + slv.categories + FROM {schema_prefix}"StoreListingVersion" slv + LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce + ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType" + WHERE slv."submissionStatus" = 'APPROVED' + AND slv."isDeleted" = false + AND uce."contentId" IS NULL + LIMIT $1 + """, + batch_size, + ) + + return [ + ContentItem( + content_id=row["id"], + content_type=ContentType.STORE_AGENT, + searchable_text=build_searchable_text( + name=row["name"], + description=row["description"], + sub_heading=row["subHeading"], + categories=row["categories"] or [], + ), + metadata={ + "name": row["name"], + "categories": row["categories"] or [], + }, + user_id=None, # Store agents are public + ) + for row in missing + ] + + async def get_stats(self) -> dict[str, int]: + """Get statistics about store agent embedding coverage.""" + # Count approved versions + approved_result = await query_raw_with_schema( + """ + SELECT COUNT(*) as count + FROM {schema_prefix}"StoreListingVersion" + WHERE "submissionStatus" = 'APPROVED' + AND "isDeleted" = false + """ + ) + total_approved = approved_result[0]["count"] if approved_result else 0 + + # Count versions with embeddings + embedded_result = await query_raw_with_schema( + """ + SELECT COUNT(*) as count + FROM {schema_prefix}"StoreListingVersion" slv + JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType" + WHERE slv."submissionStatus" = 'APPROVED' + AND slv."isDeleted" = false + """ + ) + with_embeddings = embedded_result[0]["count"] if embedded_result else 0 + + return { + "total": total_approved, + "with_embeddings": with_embeddings, + "without_embeddings": total_approved - with_embeddings, + } + + +class BlockHandler(ContentHandler): + """Handler for block definitions (Python classes).""" + + @property + def content_type(self) -> ContentType: + return ContentType.BLOCK + + async def get_missing_items(self, batch_size: int) -> list[ContentItem]: + """Fetch blocks without embeddings.""" + from backend.data.block import get_blocks + + # Get all available blocks + all_blocks = get_blocks() + + # Check which ones have embeddings + if not all_blocks: + return [] + + block_ids = list(all_blocks.keys()) + + # Query for existing embeddings + placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))]) + existing_result = await query_raw_with_schema( + f""" + SELECT "contentId" + FROM {{schema_prefix}}"UnifiedContentEmbedding" + WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType" + AND "contentId" = ANY(ARRAY[{placeholders}]) + """, + *block_ids, + ) + + existing_ids = {row["contentId"] for row in existing_result} + missing_blocks = [ + (block_id, block_cls) + for block_id, block_cls in all_blocks.items() + if block_id not in existing_ids + ] + + # Convert to ContentItem + items = [] + for block_id, block_cls in missing_blocks[:batch_size]: + try: + block_instance = block_cls() + + # Build searchable text from block metadata + parts = [] + if hasattr(block_instance, "name") and block_instance.name: + parts.append(block_instance.name) + if ( + hasattr(block_instance, "description") + and block_instance.description + ): + parts.append(block_instance.description) + if hasattr(block_instance, "categories") and block_instance.categories: + # Convert BlockCategory enum to strings + parts.append( + " ".join(str(cat.value) for cat in block_instance.categories) + ) + + # Add input/output schema info + if hasattr(block_instance, "input_schema"): + schema = block_instance.input_schema + if hasattr(schema, "model_json_schema"): + schema_dict = schema.model_json_schema() + if "properties" in schema_dict: + for prop_name, prop_info in schema_dict[ + "properties" + ].items(): + if "description" in prop_info: + parts.append( + f"{prop_name}: {prop_info['description']}" + ) + + searchable_text = " ".join(parts) + + # Convert categories set of enums to list of strings for JSON serialization + categories = getattr(block_instance, "categories", set()) + categories_list = ( + [cat.value for cat in categories] if categories else [] + ) + + items.append( + ContentItem( + content_id=block_id, + content_type=ContentType.BLOCK, + searchable_text=searchable_text, + metadata={ + "name": getattr(block_instance, "name", ""), + "categories": categories_list, + }, + user_id=None, # Blocks are public + ) + ) + except Exception as e: + logger.warning(f"Failed to process block {block_id}: {e}") + continue + + return items + + async def get_stats(self) -> dict[str, int]: + """Get statistics about block embedding coverage.""" + from backend.data.block import get_blocks + + all_blocks = get_blocks() + total_blocks = len(all_blocks) + + if total_blocks == 0: + return {"total": 0, "with_embeddings": 0, "without_embeddings": 0} + + block_ids = list(all_blocks.keys()) + placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))]) + + embedded_result = await query_raw_with_schema( + f""" + SELECT COUNT(*) as count + FROM {{schema_prefix}}"UnifiedContentEmbedding" + WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType" + AND "contentId" = ANY(ARRAY[{placeholders}]) + """, + *block_ids, + ) + + with_embeddings = embedded_result[0]["count"] if embedded_result else 0 + + return { + "total": total_blocks, + "with_embeddings": with_embeddings, + "without_embeddings": total_blocks - with_embeddings, + } + + +class DocumentationHandler(ContentHandler): + """Handler for documentation files (.md/.mdx).""" + + @property + def content_type(self) -> ContentType: + return ContentType.DOCUMENTATION + + def _get_docs_root(self) -> Path: + """Get the documentation root directory.""" + # content_handlers.py is at: backend/backend/api/features/store/content_handlers.py + # Need to go up to project root then into docs/ + # In container: /app/autogpt_platform/backend/backend/api/features/store -> /app/docs + # In development: /repo/autogpt_platform/backend/backend/api/features/store -> /repo/docs + this_file = Path( + __file__ + ) # .../backend/backend/api/features/store/content_handlers.py + project_root = ( + this_file.parent.parent.parent.parent.parent.parent.parent + ) # -> /app or /repo + docs_root = project_root / "docs" + return docs_root + + def _extract_title_and_content(self, file_path: Path) -> tuple[str, str]: + """Extract title and content from markdown file.""" + try: + content = file_path.read_text(encoding="utf-8") + + # Try to extract title from first # heading + lines = content.split("\n") + title = "" + body_lines = [] + + for line in lines: + if line.startswith("# ") and not title: + title = line[2:].strip() + else: + body_lines.append(line) + + # If no title found, use filename + if not title: + title = file_path.stem.replace("-", " ").replace("_", " ").title() + + body = "\n".join(body_lines) + + return title, body + except Exception as e: + logger.warning(f"Failed to read {file_path}: {e}") + return file_path.stem, "" + + async def get_missing_items(self, batch_size: int) -> list[ContentItem]: + """Fetch documentation files without embeddings.""" + docs_root = self._get_docs_root() + + if not docs_root.exists(): + logger.warning(f"Documentation root not found: {docs_root}") + return [] + + # Find all .md and .mdx files + all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx")) + + # Get relative paths for content IDs + doc_paths = [str(doc.relative_to(docs_root)) for doc in all_docs] + + if not doc_paths: + return [] + + # Check which ones have embeddings + placeholders = ",".join([f"${i+1}" for i in range(len(doc_paths))]) + existing_result = await query_raw_with_schema( + f""" + SELECT "contentId" + FROM {{schema_prefix}}"UnifiedContentEmbedding" + WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType" + AND "contentId" = ANY(ARRAY[{placeholders}]) + """, + *doc_paths, + ) + + existing_ids = {row["contentId"] for row in existing_result} + missing_docs = [ + (doc_path, doc_file) + for doc_path, doc_file in zip(doc_paths, all_docs) + if doc_path not in existing_ids + ] + + # Convert to ContentItem + items = [] + for doc_path, doc_file in missing_docs[:batch_size]: + try: + title, content = self._extract_title_and_content(doc_file) + + # Build searchable text + searchable_text = f"{title} {content}" + + items.append( + ContentItem( + content_id=doc_path, + content_type=ContentType.DOCUMENTATION, + searchable_text=searchable_text, + metadata={ + "title": title, + "path": doc_path, + }, + user_id=None, # Documentation is public + ) + ) + except Exception as e: + logger.warning(f"Failed to process doc {doc_path}: {e}") + continue + + return items + + async def get_stats(self) -> dict[str, int]: + """Get statistics about documentation embedding coverage.""" + docs_root = self._get_docs_root() + + if not docs_root.exists(): + return {"total": 0, "with_embeddings": 0, "without_embeddings": 0} + + # Count all .md and .mdx files + all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx")) + total_docs = len(all_docs) + + if total_docs == 0: + return {"total": 0, "with_embeddings": 0, "without_embeddings": 0} + + doc_paths = [str(doc.relative_to(docs_root)) for doc in all_docs] + placeholders = ",".join([f"${i+1}" for i in range(len(doc_paths))]) + + embedded_result = await query_raw_with_schema( + f""" + SELECT COUNT(*) as count + FROM {{schema_prefix}}"UnifiedContentEmbedding" + WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType" + AND "contentId" = ANY(ARRAY[{placeholders}]) + """, + *doc_paths, + ) + + with_embeddings = embedded_result[0]["count"] if embedded_result else 0 + + return { + "total": total_docs, + "with_embeddings": with_embeddings, + "without_embeddings": total_docs - with_embeddings, + } + + +# Content handler registry +CONTENT_HANDLERS: dict[ContentType, ContentHandler] = { + ContentType.STORE_AGENT: StoreAgentHandler(), + ContentType.BLOCK: BlockHandler(), + ContentType.DOCUMENTATION: DocumentationHandler(), +} diff --git a/autogpt_platform/backend/backend/api/features/store/content_handlers_integration_test.py b/autogpt_platform/backend/backend/api/features/store/content_handlers_integration_test.py new file mode 100644 index 0000000000..b53a9e80b0 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/store/content_handlers_integration_test.py @@ -0,0 +1,215 @@ +""" +Integration tests for content handlers using real DB. + +Run with: poetry run pytest backend/api/features/store/content_handlers_integration_test.py -xvs + +These tests use the real database but mock OpenAI calls. +""" + +from unittest.mock import patch + +import pytest + +from backend.api.features.store.content_handlers import ( + CONTENT_HANDLERS, + BlockHandler, + DocumentationHandler, + StoreAgentHandler, +) +from backend.api.features.store.embeddings import ( + EMBEDDING_DIM, + backfill_all_content_types, + ensure_content_embedding, + get_embedding_stats, +) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_store_agent_handler_real_db(): + """Test StoreAgentHandler with real database queries.""" + handler = StoreAgentHandler() + + # Get stats from real DB + stats = await handler.get_stats() + + # Stats should have correct structure + assert "total" in stats + assert "with_embeddings" in stats + assert "without_embeddings" in stats + assert stats["total"] >= 0 + assert stats["with_embeddings"] >= 0 + assert stats["without_embeddings"] >= 0 + + # Get missing items (max 1 to keep test fast) + items = await handler.get_missing_items(batch_size=1) + + # Items should be list (may be empty if all have embeddings) + assert isinstance(items, list) + + if items: + item = items[0] + assert item.content_id is not None + assert item.content_type.value == "STORE_AGENT" + assert item.searchable_text != "" + assert item.user_id is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_block_handler_real_db(): + """Test BlockHandler with real database queries.""" + handler = BlockHandler() + + # Get stats from real DB + stats = await handler.get_stats() + + # Stats should have correct structure + assert "total" in stats + assert "with_embeddings" in stats + assert "without_embeddings" in stats + assert stats["total"] >= 0 # Should have at least some blocks + assert stats["with_embeddings"] >= 0 + assert stats["without_embeddings"] >= 0 + + # Get missing items (max 1 to keep test fast) + items = await handler.get_missing_items(batch_size=1) + + # Items should be list + assert isinstance(items, list) + + if items: + item = items[0] + assert item.content_id is not None # Should be block UUID + assert item.content_type.value == "BLOCK" + assert item.searchable_text != "" + assert item.user_id is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_documentation_handler_real_fs(): + """Test DocumentationHandler with real filesystem.""" + handler = DocumentationHandler() + + # Get stats from real filesystem + stats = await handler.get_stats() + + # Stats should have correct structure + assert "total" in stats + assert "with_embeddings" in stats + assert "without_embeddings" in stats + assert stats["total"] >= 0 + assert stats["with_embeddings"] >= 0 + assert stats["without_embeddings"] >= 0 + + # Get missing items (max 1 to keep test fast) + items = await handler.get_missing_items(batch_size=1) + + # Items should be list + assert isinstance(items, list) + + if items: + item = items[0] + assert item.content_id is not None # Should be relative path + assert item.content_type.value == "DOCUMENTATION" + assert item.searchable_text != "" + assert item.user_id is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_embedding_stats_all_types(): + """Test get_embedding_stats aggregates all content types.""" + stats = await get_embedding_stats() + + # Should have structure with by_type and totals + assert "by_type" in stats + assert "totals" in stats + + # Check each content type is present + by_type = stats["by_type"] + assert "STORE_AGENT" in by_type + assert "BLOCK" in by_type + assert "DOCUMENTATION" in by_type + + # Check totals are aggregated + totals = stats["totals"] + assert totals["total"] >= 0 + assert totals["with_embeddings"] >= 0 + assert totals["without_embeddings"] >= 0 + assert "coverage_percent" in totals + + +@pytest.mark.asyncio(loop_scope="session") +@patch("backend.api.features.store.embeddings.generate_embedding") +async def test_ensure_content_embedding_blocks(mock_generate): + """Test creating embeddings for blocks (mocked OpenAI).""" + # Mock OpenAI to return fake embedding + mock_generate.return_value = [0.1] * EMBEDDING_DIM + + # Get one block without embedding + handler = BlockHandler() + items = await handler.get_missing_items(batch_size=1) + + if not items: + pytest.skip("No blocks without embeddings") + + item = items[0] + + # Try to create embedding (OpenAI mocked) + result = await ensure_content_embedding( + content_type=item.content_type, + content_id=item.content_id, + searchable_text=item.searchable_text, + metadata=item.metadata, + user_id=item.user_id, + ) + + # Should succeed with mocked OpenAI + assert result is True + mock_generate.assert_called_once() + + +@pytest.mark.asyncio(loop_scope="session") +@patch("backend.api.features.store.embeddings.generate_embedding") +async def test_backfill_all_content_types_dry_run(mock_generate): + """Test backfill_all_content_types processes all handlers in order.""" + # Mock OpenAI to return fake embedding + mock_generate.return_value = [0.1] * EMBEDDING_DIM + + # Run backfill with batch_size=1 to process max 1 per type + result = await backfill_all_content_types(batch_size=1) + + # Should have results for all content types + assert "by_type" in result + assert "totals" in result + + by_type = result["by_type"] + assert "BLOCK" in by_type + assert "STORE_AGENT" in by_type + assert "DOCUMENTATION" in by_type + + # Each type should have correct structure + for content_type, type_result in by_type.items(): + assert "processed" in type_result + assert "success" in type_result + assert "failed" in type_result + + # Totals should aggregate + totals = result["totals"] + assert totals["processed"] >= 0 + assert totals["success"] >= 0 + assert totals["failed"] >= 0 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_content_handler_registry(): + """Test all handlers are registered in correct order.""" + from prisma.enums import ContentType + + # All three types should be registered + assert ContentType.STORE_AGENT in CONTENT_HANDLERS + assert ContentType.BLOCK in CONTENT_HANDLERS + assert ContentType.DOCUMENTATION in CONTENT_HANDLERS + + # Check handler types + assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler) + assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler) + assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler) diff --git a/autogpt_platform/backend/backend/api/features/store/content_handlers_test.py b/autogpt_platform/backend/backend/api/features/store/content_handlers_test.py new file mode 100644 index 0000000000..83c8ee3a4a --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/store/content_handlers_test.py @@ -0,0 +1,324 @@ +""" +E2E tests for content handlers (blocks, store agents, documentation). + +Tests the full flow: discovering content → generating embeddings → storing. +""" + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from prisma.enums import ContentType + +from backend.api.features.store.content_handlers import ( + CONTENT_HANDLERS, + BlockHandler, + DocumentationHandler, + StoreAgentHandler, +) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_store_agent_handler_get_missing_items(mocker): + """Test StoreAgentHandler fetches approved agents without embeddings.""" + handler = StoreAgentHandler() + + # Mock database query + mock_missing = [ + { + "id": "agent-1", + "name": "Test Agent", + "description": "A test agent", + "subHeading": "Test heading", + "categories": ["AI", "Testing"], + } + ] + + with patch( + "backend.api.features.store.content_handlers.query_raw_with_schema", + return_value=mock_missing, + ): + items = await handler.get_missing_items(batch_size=10) + + assert len(items) == 1 + assert items[0].content_id == "agent-1" + assert items[0].content_type == ContentType.STORE_AGENT + assert "Test Agent" in items[0].searchable_text + assert "A test agent" in items[0].searchable_text + assert items[0].metadata["name"] == "Test Agent" + assert items[0].user_id is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_store_agent_handler_get_stats(mocker): + """Test StoreAgentHandler returns correct stats.""" + handler = StoreAgentHandler() + + # Mock approved count query + mock_approved = [{"count": 50}] + # Mock embedded count query + mock_embedded = [{"count": 30}] + + with patch( + "backend.api.features.store.content_handlers.query_raw_with_schema", + side_effect=[mock_approved, mock_embedded], + ): + stats = await handler.get_stats() + + assert stats["total"] == 50 + assert stats["with_embeddings"] == 30 + assert stats["without_embeddings"] == 20 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_block_handler_get_missing_items(mocker): + """Test BlockHandler discovers blocks without embeddings.""" + handler = BlockHandler() + + # Mock get_blocks to return test blocks + mock_block_class = MagicMock() + mock_block_instance = MagicMock() + mock_block_instance.name = "Calculator Block" + mock_block_instance.description = "Performs calculations" + mock_block_instance.categories = [MagicMock(value="MATH")] + mock_block_instance.input_schema.model_json_schema.return_value = { + "properties": {"expression": {"description": "Math expression to evaluate"}} + } + mock_block_class.return_value = mock_block_instance + + mock_blocks = {"block-uuid-1": mock_block_class} + + # Mock existing embeddings query (no embeddings exist) + mock_existing = [] + + with patch( + "backend.data.block.get_blocks", + return_value=mock_blocks, + ): + with patch( + "backend.api.features.store.content_handlers.query_raw_with_schema", + return_value=mock_existing, + ): + items = await handler.get_missing_items(batch_size=10) + + assert len(items) == 1 + assert items[0].content_id == "block-uuid-1" + assert items[0].content_type == ContentType.BLOCK + assert "Calculator Block" in items[0].searchable_text + assert "Performs calculations" in items[0].searchable_text + assert "MATH" in items[0].searchable_text + assert "expression: Math expression" in items[0].searchable_text + assert items[0].user_id is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_block_handler_get_stats(mocker): + """Test BlockHandler returns correct stats.""" + handler = BlockHandler() + + # Mock get_blocks + mock_blocks = { + "block-1": MagicMock(), + "block-2": MagicMock(), + "block-3": MagicMock(), + } + + # Mock embedded count query (2 blocks have embeddings) + mock_embedded = [{"count": 2}] + + with patch( + "backend.data.block.get_blocks", + return_value=mock_blocks, + ): + with patch( + "backend.api.features.store.content_handlers.query_raw_with_schema", + return_value=mock_embedded, + ): + stats = await handler.get_stats() + + assert stats["total"] == 3 + assert stats["with_embeddings"] == 2 + assert stats["without_embeddings"] == 1 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_documentation_handler_get_missing_items(tmp_path, mocker): + """Test DocumentationHandler discovers docs without embeddings.""" + handler = DocumentationHandler() + + # Create temporary docs directory with test files + docs_root = tmp_path / "docs" + docs_root.mkdir() + + (docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.") + (docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.") + + # Mock _get_docs_root to return temp dir + with patch.object(handler, "_get_docs_root", return_value=docs_root): + # Mock existing embeddings query (no embeddings exist) + with patch( + "backend.api.features.store.content_handlers.query_raw_with_schema", + return_value=[], + ): + items = await handler.get_missing_items(batch_size=10) + + assert len(items) == 2 + + # Check guide.md + guide_item = next( + (item for item in items if item.content_id == "guide.md"), None + ) + assert guide_item is not None + assert guide_item.content_type == ContentType.DOCUMENTATION + assert "Getting Started" in guide_item.searchable_text + assert "This is a guide" in guide_item.searchable_text + assert guide_item.metadata["title"] == "Getting Started" + assert guide_item.user_id is None + + # Check api.mdx + api_item = next( + (item for item in items if item.content_id == "api.mdx"), None + ) + assert api_item is not None + assert "API Reference" in api_item.searchable_text + + +@pytest.mark.asyncio(loop_scope="session") +async def test_documentation_handler_get_stats(tmp_path, mocker): + """Test DocumentationHandler returns correct stats.""" + handler = DocumentationHandler() + + # Create temporary docs directory + docs_root = tmp_path / "docs" + docs_root.mkdir() + (docs_root / "doc1.md").write_text("# Doc 1") + (docs_root / "doc2.md").write_text("# Doc 2") + (docs_root / "doc3.mdx").write_text("# Doc 3") + + # Mock embedded count query (1 doc has embedding) + mock_embedded = [{"count": 1}] + + with patch.object(handler, "_get_docs_root", return_value=docs_root): + with patch( + "backend.api.features.store.content_handlers.query_raw_with_schema", + return_value=mock_embedded, + ): + stats = await handler.get_stats() + + assert stats["total"] == 3 + assert stats["with_embeddings"] == 1 + assert stats["without_embeddings"] == 2 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_documentation_handler_title_extraction(tmp_path): + """Test DocumentationHandler extracts title from markdown heading.""" + handler = DocumentationHandler() + + # Test with heading + doc_with_heading = tmp_path / "with_heading.md" + doc_with_heading.write_text("# My Title\n\nContent here") + title, content = handler._extract_title_and_content(doc_with_heading) + assert title == "My Title" + assert "# My Title" not in content + assert "Content here" in content + + # Test without heading + doc_without_heading = tmp_path / "no-heading.md" + doc_without_heading.write_text("Just content, no heading") + title, content = handler._extract_title_and_content(doc_without_heading) + assert title == "No Heading" # Uses filename + assert "Just content" in content + + +@pytest.mark.asyncio(loop_scope="session") +async def test_content_handlers_registry(): + """Test all content types are registered.""" + assert ContentType.STORE_AGENT in CONTENT_HANDLERS + assert ContentType.BLOCK in CONTENT_HANDLERS + assert ContentType.DOCUMENTATION in CONTENT_HANDLERS + + assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler) + assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler) + assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_block_handler_handles_missing_attributes(): + """Test BlockHandler gracefully handles blocks with missing attributes.""" + handler = BlockHandler() + + # Mock block with minimal attributes + mock_block_class = MagicMock() + mock_block_instance = MagicMock() + mock_block_instance.name = "Minimal Block" + # No description, categories, or schema + del mock_block_instance.description + del mock_block_instance.categories + del mock_block_instance.input_schema + mock_block_class.return_value = mock_block_instance + + mock_blocks = {"block-minimal": mock_block_class} + + with patch( + "backend.data.block.get_blocks", + return_value=mock_blocks, + ): + with patch( + "backend.api.features.store.content_handlers.query_raw_with_schema", + return_value=[], + ): + items = await handler.get_missing_items(batch_size=10) + + assert len(items) == 1 + assert items[0].searchable_text == "Minimal Block" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_block_handler_skips_failed_blocks(): + """Test BlockHandler skips blocks that fail to instantiate.""" + handler = BlockHandler() + + # Mock one good block and one bad block + good_block = MagicMock() + good_instance = MagicMock() + good_instance.name = "Good Block" + good_instance.description = "Works fine" + good_instance.categories = [] + good_block.return_value = good_instance + + bad_block = MagicMock() + bad_block.side_effect = Exception("Instantiation failed") + + mock_blocks = {"good-block": good_block, "bad-block": bad_block} + + with patch( + "backend.data.block.get_blocks", + return_value=mock_blocks, + ): + with patch( + "backend.api.features.store.content_handlers.query_raw_with_schema", + return_value=[], + ): + items = await handler.get_missing_items(batch_size=10) + + # Should only get the good block + assert len(items) == 1 + assert items[0].content_id == "good-block" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_documentation_handler_missing_docs_directory(): + """Test DocumentationHandler handles missing docs directory gracefully.""" + handler = DocumentationHandler() + + # Mock _get_docs_root to return non-existent path + fake_path = Path("/nonexistent/docs") + with patch.object(handler, "_get_docs_root", return_value=fake_path): + items = await handler.get_missing_items(batch_size=10) + assert items == [] + + stats = await handler.get_stats() + assert stats["total"] == 0 + assert stats["with_embeddings"] == 0 + assert stats["without_embeddings"] == 0 diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings.py b/autogpt_platform/backend/backend/api/features/store/embeddings.py index 70f4360c0c..84742817c7 100644 --- a/autogpt_platform/backend/backend/api/features/store/embeddings.py +++ b/autogpt_platform/backend/backend/api/features/store/embeddings.py @@ -14,6 +14,7 @@ import prisma from prisma.enums import ContentType from tiktoken import encoding_for_model +from backend.api.features.store.content_handlers import CONTENT_HANDLERS from backend.data.db import execute_raw_with_schema, query_raw_with_schema from backend.util.clients import get_openai_client from backend.util.json import dumps @@ -23,6 +24,9 @@ logger = logging.getLogger(__name__) # OpenAI embedding model configuration EMBEDDING_MODEL = "text-embedding-3-small" +# Embedding dimension for the model above +# text-embedding-3-small: 1536, text-embedding-3-large: 3072 +EMBEDDING_DIM = 1536 # OpenAI embedding token limit (8,191 with 1 token buffer for safety) EMBEDDING_MAX_TOKENS = 8191 @@ -369,55 +373,69 @@ async def delete_content_embedding( async def get_embedding_stats() -> dict[str, Any]: """ - Get statistics about embedding coverage. + Get statistics about embedding coverage for all content types. - Returns counts of: - - Total approved listing versions - - Versions with embeddings - - Versions without embeddings + Returns stats per content type and overall totals. """ try: - # Count approved versions - approved_result = await query_raw_with_schema( - """ - SELECT COUNT(*) as count - FROM {schema_prefix}"StoreListingVersion" - WHERE "submissionStatus" = 'APPROVED' - AND "isDeleted" = false - """ - ) - total_approved = approved_result[0]["count"] if approved_result else 0 + stats_by_type = {} + total_items = 0 + total_with_embeddings = 0 + total_without_embeddings = 0 - # Count versions with embeddings - embedded_result = await query_raw_with_schema( - """ - SELECT COUNT(*) as count - FROM {schema_prefix}"StoreListingVersion" slv - JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType" - WHERE slv."submissionStatus" = 'APPROVED' - AND slv."isDeleted" = false - """ - ) - with_embeddings = embedded_result[0]["count"] if embedded_result else 0 + # Aggregate stats from all handlers + for content_type, handler in CONTENT_HANDLERS.items(): + try: + stats = await handler.get_stats() + stats_by_type[content_type.value] = { + "total": stats["total"], + "with_embeddings": stats["with_embeddings"], + "without_embeddings": stats["without_embeddings"], + "coverage_percent": ( + round(stats["with_embeddings"] / stats["total"] * 100, 1) + if stats["total"] > 0 + else 0 + ), + } + + total_items += stats["total"] + total_with_embeddings += stats["with_embeddings"] + total_without_embeddings += stats["without_embeddings"] + + except Exception as e: + logger.error(f"Failed to get stats for {content_type.value}: {e}") + stats_by_type[content_type.value] = { + "total": 0, + "with_embeddings": 0, + "without_embeddings": 0, + "coverage_percent": 0, + "error": str(e), + } return { - "total_approved": total_approved, - "with_embeddings": with_embeddings, - "without_embeddings": total_approved - with_embeddings, - "coverage_percent": ( - round(with_embeddings / total_approved * 100, 1) - if total_approved > 0 - else 0 - ), + "by_type": stats_by_type, + "totals": { + "total": total_items, + "with_embeddings": total_with_embeddings, + "without_embeddings": total_without_embeddings, + "coverage_percent": ( + round(total_with_embeddings / total_items * 100, 1) + if total_items > 0 + else 0 + ), + }, } except Exception as e: logger.error(f"Failed to get embedding stats: {e}") return { - "total_approved": 0, - "with_embeddings": 0, - "without_embeddings": 0, - "coverage_percent": 0, + "by_type": {}, + "totals": { + "total": 0, + "with_embeddings": 0, + "without_embeddings": 0, + "coverage_percent": 0, + }, "error": str(e), } @@ -426,73 +444,118 @@ async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]: """ Generate embeddings for approved listings that don't have them. + BACKWARD COMPATIBILITY: Maintained for existing usage. + This now delegates to backfill_all_content_types() to process all content types. + Args: - batch_size: Number of embeddings to generate in one call + batch_size: Number of embeddings to generate per content type Returns: - Dict with success/failure counts + Dict with success/failure counts aggregated across all content types """ - try: - # Find approved versions without embeddings - missing = await query_raw_with_schema( - """ - SELECT - slv.id, - slv.name, - slv.description, - slv."subHeading", - slv.categories - FROM {schema_prefix}"StoreListingVersion" slv - LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce - ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType" - WHERE slv."submissionStatus" = 'APPROVED' - AND slv."isDeleted" = false - AND uce."contentId" IS NULL - LIMIT $1 - """, - batch_size, - ) + # Delegate to the new generic backfill system + result = await backfill_all_content_types(batch_size) - if not missing: - return { + # Return in the old format for backward compatibility + return result["totals"] + + +async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]: + """ + Generate embeddings for all content types using registered handlers. + + Processes content types in order: BLOCK → STORE_AGENT → DOCUMENTATION. + This ensures foundational content (blocks) are searchable first. + + Args: + batch_size: Number of embeddings to generate per content type + + Returns: + Dict with stats per content type and overall totals + """ + results_by_type = {} + total_processed = 0 + total_success = 0 + total_failed = 0 + + # Process content types in explicit order + processing_order = [ + ContentType.BLOCK, + ContentType.STORE_AGENT, + ContentType.DOCUMENTATION, + ] + + for content_type in processing_order: + handler = CONTENT_HANDLERS.get(content_type) + if not handler: + logger.warning(f"No handler registered for {content_type.value}") + continue + try: + logger.info(f"Processing {content_type.value} content type...") + + # Get missing items from handler + missing_items = await handler.get_missing_items(batch_size) + + if not missing_items: + results_by_type[content_type.value] = { + "processed": 0, + "success": 0, + "failed": 0, + "message": "No missing embeddings", + } + continue + + # Process embeddings concurrently for better performance + embedding_tasks = [ + ensure_content_embedding( + content_type=item.content_type, + content_id=item.content_id, + searchable_text=item.searchable_text, + metadata=item.metadata, + user_id=item.user_id, + ) + for item in missing_items + ] + + results = await asyncio.gather(*embedding_tasks, return_exceptions=True) + + success = sum(1 for result in results if result is True) + failed = len(results) - success + + results_by_type[content_type.value] = { + "processed": len(missing_items), + "success": success, + "failed": failed, + "message": f"Backfilled {success} embeddings, {failed} failed", + } + + total_processed += len(missing_items) + total_success += success + total_failed += failed + + logger.info( + f"{content_type.value}: processed {len(missing_items)}, " + f"success {success}, failed {failed}" + ) + + except Exception as e: + logger.error(f"Failed to process {content_type.value}: {e}") + results_by_type[content_type.value] = { "processed": 0, "success": 0, "failed": 0, - "message": "No missing embeddings", + "error": str(e), } - # Process embeddings concurrently for better performance - embedding_tasks = [ - ensure_embedding( - version_id=row["id"], - name=row["name"], - description=row["description"], - sub_heading=row["subHeading"], - categories=row["categories"] or [], - ) - for row in missing - ] - - results = await asyncio.gather(*embedding_tasks, return_exceptions=True) - - success = sum(1 for result in results if result is True) - failed = len(results) - success - - return { - "processed": len(missing), - "success": success, - "failed": failed, - "message": f"Backfilled {success} embeddings, {failed} failed", - } - - except Exception as e: - logger.error(f"Failed to backfill embeddings: {e}") - return { - "processed": 0, - "success": 0, - "failed": 0, - "error": str(e), - } + return { + "by_type": results_by_type, + "totals": { + "processed": total_processed, + "success": total_success, + "failed": total_failed, + "message": f"Overall: {total_success} succeeded, {total_failed} failed", + }, + } async def embed_query(query: str) -> list[float] | None: @@ -566,3 +629,334 @@ async def ensure_content_embedding( except Exception as e: logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}") return False + + +async def cleanup_orphaned_embeddings() -> dict[str, Any]: + """ + Clean up embeddings for content that no longer exists or is no longer valid. + + Compares current content with embeddings in database and removes orphaned records: + - STORE_AGENT: Removes embeddings for rejected/deleted store listings + - BLOCK: Removes embeddings for blocks no longer registered + - DOCUMENTATION: Removes embeddings for deleted doc files + + Returns: + Dict with cleanup statistics per content type + """ + results_by_type = {} + total_deleted = 0 + + # Cleanup orphaned embeddings for all content types + cleanup_types = [ + ContentType.STORE_AGENT, + ContentType.BLOCK, + ContentType.DOCUMENTATION, + ] + + for content_type in cleanup_types: + try: + handler = CONTENT_HANDLERS.get(content_type) + if not handler: + logger.warning(f"No handler registered for {content_type}") + results_by_type[content_type.value] = { + "deleted": 0, + "error": "No handler registered", + } + continue + + # Get all current content IDs from handler + if content_type == ContentType.STORE_AGENT: + # Get IDs of approved store listing versions from non-deleted listings + valid_agents = await query_raw_with_schema( + """ + SELECT slv.id + FROM {schema_prefix}"StoreListingVersion" slv + JOIN {schema_prefix}"StoreListing" sl ON slv."storeListingId" = sl.id + WHERE slv."submissionStatus" = 'APPROVED' + AND slv."isDeleted" = false + AND sl."isDeleted" = false + """, + ) + current_ids = {row["id"] for row in valid_agents} + elif content_type == ContentType.BLOCK: + from backend.data.block import get_blocks + + current_ids = set(get_blocks().keys()) + elif content_type == ContentType.DOCUMENTATION: + from pathlib import Path + + # embeddings.py is at: backend/backend/api/features/store/embeddings.py + # Need to go up to project root then into docs/ + this_file = Path(__file__) + project_root = ( + this_file.parent.parent.parent.parent.parent.parent.parent + ) + docs_root = project_root / "docs" + if docs_root.exists(): + all_docs = list(docs_root.rglob("*.md")) + list( + docs_root.rglob("*.mdx") + ) + current_ids = {str(doc.relative_to(docs_root)) for doc in all_docs} + else: + current_ids = set() + else: + # Skip unknown content types to avoid accidental deletion + logger.warning( + f"Skipping cleanup for unknown content type: {content_type}" + ) + results_by_type[content_type.value] = { + "deleted": 0, + "error": "Unknown content type - skipped for safety", + } + continue + + # Get all embedding IDs from database + db_embeddings = await query_raw_with_schema( + """ + SELECT "contentId" + FROM {schema_prefix}"UnifiedContentEmbedding" + WHERE "contentType" = $1::{schema_prefix}"ContentType" + """, + content_type, + ) + + db_ids = {row["contentId"] for row in db_embeddings} + + # Find orphaned embeddings (in DB but not in current content) + orphaned_ids = db_ids - current_ids + + if not orphaned_ids: + logger.info(f"{content_type.value}: No orphaned embeddings found") + results_by_type[content_type.value] = { + "deleted": 0, + "message": "No orphaned embeddings", + } + continue + + # Delete orphaned embeddings in batch for better performance + orphaned_list = list(orphaned_ids) + try: + await execute_raw_with_schema( + """ + DELETE FROM {schema_prefix}"UnifiedContentEmbedding" + WHERE "contentType" = $1::{schema_prefix}"ContentType" + AND "contentId" = ANY($2::text[]) + """, + content_type, + orphaned_list, + ) + deleted = len(orphaned_list) + except Exception as e: + logger.error(f"Failed to batch delete orphaned embeddings: {e}") + deleted = 0 + + logger.info( + f"{content_type.value}: Deleted {deleted}/{len(orphaned_ids)} orphaned embeddings" + ) + results_by_type[content_type.value] = { + "deleted": deleted, + "orphaned": len(orphaned_ids), + "message": f"Deleted {deleted} orphaned embeddings", + } + + total_deleted += deleted + + except Exception as e: + logger.error(f"Failed to cleanup {content_type.value}: {e}") + results_by_type[content_type.value] = { + "deleted": 0, + "error": str(e), + } + + return { + "by_type": results_by_type, + "totals": { + "deleted": total_deleted, + "message": f"Deleted {total_deleted} orphaned embeddings", + }, + } + + +async def semantic_search( + query: str, + content_types: list[ContentType] | None = None, + user_id: str | None = None, + limit: int = 20, + min_similarity: float = 0.5, +) -> list[dict[str, Any]]: + """ + Semantic search across content types using embeddings. + + Performs vector similarity search on UnifiedContentEmbedding table. + Used directly for blocks/docs/library agents, or as the semantic component + within hybrid_search for store agents. + + If embedding generation fails, falls back to lexical search on searchableText. + + Args: + query: Search query string + content_types: List of ContentType to search. Defaults to [BLOCK, STORE_AGENT, DOCUMENTATION] + user_id: Optional user ID for searching private content (library agents) + limit: Maximum number of results to return (default: 20) + min_similarity: Minimum cosine similarity threshold (0-1, default: 0.5) + + Returns: + List of search results with the following structure: + [ + { + "content_id": str, + "content_type": str, # "BLOCK", "STORE_AGENT", "DOCUMENTATION", or "LIBRARY_AGENT" + "searchable_text": str, + "metadata": dict, + "similarity": float, # Cosine similarity score (0-1) + }, + ... + ] + + Examples: + # Search blocks only + results = await semantic_search("calculate", content_types=[ContentType.BLOCK]) + + # Search blocks and documentation + results = await semantic_search( + "how to use API", + content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION] + ) + + # Search all public content (default) + results = await semantic_search("AI agent") + + # Search user's library agents + results = await semantic_search( + "my custom agent", + content_types=[ContentType.LIBRARY_AGENT], + user_id="user123" + ) + """ + # Default to searching all public content types + if content_types is None: + content_types = [ + ContentType.BLOCK, + ContentType.STORE_AGENT, + ContentType.DOCUMENTATION, + ] + + # Validate inputs + if not content_types: + return [] # Empty content_types would cause invalid SQL (IN ()) + + query = query.strip() + if not query: + return [] + + if limit < 1: + limit = 1 + if limit > 100: + limit = 100 + + # Generate query embedding + query_embedding = await embed_query(query) + + if query_embedding is not None: + # Semantic search with embeddings + embedding_str = embedding_to_vector_string(query_embedding) + + # Build params in order: limit, then user_id (if provided), then content types + params: list[Any] = [limit] + user_filter = "" + if user_id is not None: + user_filter = 'AND "userId" = ${}'.format(len(params) + 1) + params.append(user_id) + + # Add content type parameters and build placeholders dynamically + content_type_start_idx = len(params) + 1 + content_type_placeholders = ", ".join( + f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"' + for i in range(len(content_types)) + ) + params.extend([ct.value for ct in content_types]) + + sql = f""" + SELECT + "contentId" as content_id, + "contentType" as content_type, + "searchableText" as searchable_text, + metadata, + 1 - (embedding <=> '{embedding_str}'::vector) as similarity + FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding" + WHERE "contentType" IN ({content_type_placeholders}) + {user_filter} + AND 1 - (embedding <=> '{embedding_str}'::vector) >= ${len(params) + 1} + ORDER BY similarity DESC + LIMIT $1 + """ + params.append(min_similarity) + + try: + results = await query_raw_with_schema( + sql, *params, set_public_search_path=True + ) + return [ + { + "content_id": row["content_id"], + "content_type": row["content_type"], + "searchable_text": row["searchable_text"], + "metadata": row["metadata"], + "similarity": float(row["similarity"]), + } + for row in results + ] + except Exception as e: + logger.error(f"Semantic search failed: {e}") + # Fall through to lexical search below + + # Fallback to lexical search if embeddings unavailable + logger.warning("Falling back to lexical search (embeddings unavailable)") + + params_lexical: list[Any] = [limit] + user_filter = "" + if user_id is not None: + user_filter = 'AND "userId" = ${}'.format(len(params_lexical) + 1) + params_lexical.append(user_id) + + # Add content type parameters and build placeholders dynamically + content_type_start_idx = len(params_lexical) + 1 + content_type_placeholders_lexical = ", ".join( + f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"' + for i in range(len(content_types)) + ) + params_lexical.extend([ct.value for ct in content_types]) + + sql_lexical = f""" + SELECT + "contentId" as content_id, + "contentType" as content_type, + "searchableText" as searchable_text, + metadata, + 0.0 as similarity + FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding" + WHERE "contentType" IN ({content_type_placeholders_lexical}) + {user_filter} + AND "searchableText" ILIKE ${len(params_lexical) + 1} + ORDER BY "updatedAt" DESC + LIMIT $1 + """ + params_lexical.append(f"%{query}%") + + try: + results = await query_raw_with_schema( + sql_lexical, *params_lexical, set_public_search_path=True + ) + return [ + { + "content_id": row["content_id"], + "content_type": row["content_type"], + "searchable_text": row["searchable_text"], + "metadata": row["metadata"], + "similarity": 0.0, # Lexical search doesn't provide similarity + } + for row in results + ] + except Exception as e: + logger.error(f"Lexical search failed: {e}") + return [] diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py b/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py new file mode 100644 index 0000000000..bae5b97cd6 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py @@ -0,0 +1,666 @@ +""" +End-to-end database tests for embeddings and hybrid search. + +These tests hit the actual database to verify SQL queries work correctly. +Tests cover: +1. Embedding storage (store_content_embedding) +2. Embedding retrieval (get_content_embedding) +3. Embedding deletion (delete_content_embedding) +4. Unified hybrid search across content types +5. Store agent hybrid search +""" + +import uuid +from typing import AsyncGenerator + +import pytest +from prisma.enums import ContentType + +from backend.api.features.store import embeddings +from backend.api.features.store.embeddings import EMBEDDING_DIM +from backend.api.features.store.hybrid_search import ( + hybrid_search, + unified_hybrid_search, +) + +# ============================================================================ +# Test Fixtures +# ============================================================================ + + +@pytest.fixture +def test_content_id() -> str: + """Generate unique content ID for test isolation.""" + return f"test-content-{uuid.uuid4()}" + + +@pytest.fixture +def test_user_id() -> str: + """Generate unique user ID for test isolation.""" + return f"test-user-{uuid.uuid4()}" + + +@pytest.fixture +def mock_embedding() -> list[float]: + """Generate a mock embedding vector.""" + # Create a normalized embedding vector + import math + + raw = [float(i % 10) / 10.0 for i in range(EMBEDDING_DIM)] + # Normalize to unit length (required for cosine similarity) + magnitude = math.sqrt(sum(x * x for x in raw)) + return [x / magnitude for x in raw] + + +@pytest.fixture +def similar_embedding() -> list[float]: + """Generate an embedding similar to mock_embedding.""" + import math + + # Similar but slightly different values + raw = [float(i % 10) / 10.0 + 0.01 for i in range(EMBEDDING_DIM)] + magnitude = math.sqrt(sum(x * x for x in raw)) + return [x / magnitude for x in raw] + + +@pytest.fixture +def different_embedding() -> list[float]: + """Generate an embedding very different from mock_embedding.""" + import math + + # Reversed pattern to be maximally different + raw = [float((EMBEDDING_DIM - i) % 10) / 10.0 for i in range(EMBEDDING_DIM)] + magnitude = math.sqrt(sum(x * x for x in raw)) + return [x / magnitude for x in raw] + + +@pytest.fixture +async def cleanup_embeddings( + server, +) -> AsyncGenerator[list[tuple[ContentType, str, str | None]], None]: + """ + Fixture that tracks created embeddings and cleans them up after tests. + + Yields a list to which tests can append (content_type, content_id, user_id) tuples. + """ + created_embeddings: list[tuple[ContentType, str, str | None]] = [] + yield created_embeddings + + # Cleanup all created embeddings + for content_type, content_id, user_id in created_embeddings: + try: + await embeddings.delete_content_embedding(content_type, content_id, user_id) + except Exception: + pass # Ignore cleanup errors + + +# ============================================================================ +# store_content_embedding Tests +# ============================================================================ + + +@pytest.mark.asyncio(loop_scope="session") +async def test_store_content_embedding_store_agent( + server, + test_content_id: str, + mock_embedding: list[float], + cleanup_embeddings: list, +): + """Test storing embedding for STORE_AGENT content type.""" + # Track for cleanup + cleanup_embeddings.append((ContentType.STORE_AGENT, test_content_id, None)) + + result = await embeddings.store_content_embedding( + content_type=ContentType.STORE_AGENT, + content_id=test_content_id, + embedding=mock_embedding, + searchable_text="AI assistant for productivity tasks", + metadata={"name": "Test Agent", "categories": ["productivity"]}, + user_id=None, # Store agents are public + ) + + assert result is True + + # Verify it was stored + stored = await embeddings.get_content_embedding( + ContentType.STORE_AGENT, test_content_id, user_id=None + ) + assert stored is not None + assert stored["contentId"] == test_content_id + assert stored["contentType"] == "STORE_AGENT" + assert stored["searchableText"] == "AI assistant for productivity tasks" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_store_content_embedding_block( + server, + test_content_id: str, + mock_embedding: list[float], + cleanup_embeddings: list, +): + """Test storing embedding for BLOCK content type.""" + cleanup_embeddings.append((ContentType.BLOCK, test_content_id, None)) + + result = await embeddings.store_content_embedding( + content_type=ContentType.BLOCK, + content_id=test_content_id, + embedding=mock_embedding, + searchable_text="HTTP request block for API calls", + metadata={"name": "HTTP Request Block"}, + user_id=None, # Blocks are public + ) + + assert result is True + + stored = await embeddings.get_content_embedding( + ContentType.BLOCK, test_content_id, user_id=None + ) + assert stored is not None + assert stored["contentType"] == "BLOCK" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_store_content_embedding_documentation( + server, + test_content_id: str, + mock_embedding: list[float], + cleanup_embeddings: list, +): + """Test storing embedding for DOCUMENTATION content type.""" + cleanup_embeddings.append((ContentType.DOCUMENTATION, test_content_id, None)) + + result = await embeddings.store_content_embedding( + content_type=ContentType.DOCUMENTATION, + content_id=test_content_id, + embedding=mock_embedding, + searchable_text="Getting started guide for AutoGPT platform", + metadata={"title": "Getting Started", "url": "/docs/getting-started"}, + user_id=None, # Docs are public + ) + + assert result is True + + stored = await embeddings.get_content_embedding( + ContentType.DOCUMENTATION, test_content_id, user_id=None + ) + assert stored is not None + assert stored["contentType"] == "DOCUMENTATION" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_store_content_embedding_upsert( + server, + test_content_id: str, + mock_embedding: list[float], + cleanup_embeddings: list, +): + """Test that storing embedding twice updates instead of duplicates.""" + cleanup_embeddings.append((ContentType.BLOCK, test_content_id, None)) + + # Store first time + result1 = await embeddings.store_content_embedding( + content_type=ContentType.BLOCK, + content_id=test_content_id, + embedding=mock_embedding, + searchable_text="Original text", + metadata={"version": 1}, + user_id=None, + ) + assert result1 is True + + # Store again with different text (upsert) + result2 = await embeddings.store_content_embedding( + content_type=ContentType.BLOCK, + content_id=test_content_id, + embedding=mock_embedding, + searchable_text="Updated text", + metadata={"version": 2}, + user_id=None, + ) + assert result2 is True + + # Verify only one record with updated text + stored = await embeddings.get_content_embedding( + ContentType.BLOCK, test_content_id, user_id=None + ) + assert stored is not None + assert stored["searchableText"] == "Updated text" + + +# ============================================================================ +# get_content_embedding Tests +# ============================================================================ + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_content_embedding_not_found(server): + """Test retrieving non-existent embedding returns None.""" + result = await embeddings.get_content_embedding( + ContentType.STORE_AGENT, "non-existent-id", user_id=None + ) + assert result is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_content_embedding_with_metadata( + server, + test_content_id: str, + mock_embedding: list[float], + cleanup_embeddings: list, +): + """Test that metadata is correctly stored and retrieved.""" + cleanup_embeddings.append((ContentType.STORE_AGENT, test_content_id, None)) + + metadata = { + "name": "Test Agent", + "subHeading": "A test agent", + "categories": ["ai", "productivity"], + "customField": 123, + } + + await embeddings.store_content_embedding( + content_type=ContentType.STORE_AGENT, + content_id=test_content_id, + embedding=mock_embedding, + searchable_text="test", + metadata=metadata, + user_id=None, + ) + + stored = await embeddings.get_content_embedding( + ContentType.STORE_AGENT, test_content_id, user_id=None + ) + + assert stored is not None + assert stored["metadata"]["name"] == "Test Agent" + assert stored["metadata"]["categories"] == ["ai", "productivity"] + assert stored["metadata"]["customField"] == 123 + + +# ============================================================================ +# delete_content_embedding Tests +# ============================================================================ + + +@pytest.mark.asyncio(loop_scope="session") +async def test_delete_content_embedding( + server, + test_content_id: str, + mock_embedding: list[float], +): + """Test deleting embedding removes it from database.""" + # Store embedding + await embeddings.store_content_embedding( + content_type=ContentType.BLOCK, + content_id=test_content_id, + embedding=mock_embedding, + searchable_text="To be deleted", + metadata=None, + user_id=None, + ) + + # Verify it exists + stored = await embeddings.get_content_embedding( + ContentType.BLOCK, test_content_id, user_id=None + ) + assert stored is not None + + # Delete it + result = await embeddings.delete_content_embedding( + ContentType.BLOCK, test_content_id, user_id=None + ) + assert result is True + + # Verify it's gone + stored = await embeddings.get_content_embedding( + ContentType.BLOCK, test_content_id, user_id=None + ) + assert stored is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_delete_content_embedding_not_found(server): + """Test deleting non-existent embedding doesn't error.""" + result = await embeddings.delete_content_embedding( + ContentType.BLOCK, "non-existent-id", user_id=None + ) + # Should succeed even if nothing to delete + assert result is True + + +# ============================================================================ +# unified_hybrid_search Tests +# ============================================================================ + + +@pytest.mark.asyncio(loop_scope="session") +async def test_unified_hybrid_search_finds_matching_content( + server, + mock_embedding: list[float], + cleanup_embeddings: list, +): + """Test unified search finds content matching the query.""" + # Create unique content IDs + agent_id = f"test-agent-{uuid.uuid4()}" + block_id = f"test-block-{uuid.uuid4()}" + doc_id = f"test-doc-{uuid.uuid4()}" + + cleanup_embeddings.append((ContentType.STORE_AGENT, agent_id, None)) + cleanup_embeddings.append((ContentType.BLOCK, block_id, None)) + cleanup_embeddings.append((ContentType.DOCUMENTATION, doc_id, None)) + + # Store embeddings for different content types + await embeddings.store_content_embedding( + content_type=ContentType.STORE_AGENT, + content_id=agent_id, + embedding=mock_embedding, + searchable_text="AI writing assistant for blog posts", + metadata={"name": "Writing Assistant"}, + user_id=None, + ) + + await embeddings.store_content_embedding( + content_type=ContentType.BLOCK, + content_id=block_id, + embedding=mock_embedding, + searchable_text="Text generation block for creative writing", + metadata={"name": "Text Generator"}, + user_id=None, + ) + + await embeddings.store_content_embedding( + content_type=ContentType.DOCUMENTATION, + content_id=doc_id, + embedding=mock_embedding, + searchable_text="How to use writing blocks in AutoGPT", + metadata={"title": "Writing Guide"}, + user_id=None, + ) + + # Search for "writing" - should find all three + results, total = await unified_hybrid_search( + query="writing", + page=1, + page_size=20, + ) + + # Should find at least our test content (may find others too) + content_ids = [r["content_id"] for r in results] + assert agent_id in content_ids or total >= 1 # Lexical search should find it + + +@pytest.mark.asyncio(loop_scope="session") +async def test_unified_hybrid_search_filter_by_content_type( + server, + mock_embedding: list[float], + cleanup_embeddings: list, +): + """Test unified search can filter by content type.""" + agent_id = f"test-agent-{uuid.uuid4()}" + block_id = f"test-block-{uuid.uuid4()}" + + cleanup_embeddings.append((ContentType.STORE_AGENT, agent_id, None)) + cleanup_embeddings.append((ContentType.BLOCK, block_id, None)) + + # Store both types with same searchable text + await embeddings.store_content_embedding( + content_type=ContentType.STORE_AGENT, + content_id=agent_id, + embedding=mock_embedding, + searchable_text="unique_search_term_xyz123", + metadata={}, + user_id=None, + ) + + await embeddings.store_content_embedding( + content_type=ContentType.BLOCK, + content_id=block_id, + embedding=mock_embedding, + searchable_text="unique_search_term_xyz123", + metadata={}, + user_id=None, + ) + + # Search only for BLOCK type + results, total = await unified_hybrid_search( + query="unique_search_term_xyz123", + content_types=[ContentType.BLOCK], + page=1, + page_size=20, + ) + + # All results should be BLOCK type + for r in results: + assert r["content_type"] == "BLOCK" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_unified_hybrid_search_empty_query(server): + """Test unified search with empty query returns empty results.""" + results, total = await unified_hybrid_search( + query="", + page=1, + page_size=20, + ) + + assert results == [] + assert total == 0 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_unified_hybrid_search_pagination( + server, + mock_embedding: list[float], + cleanup_embeddings: list, +): + """Test unified search pagination works correctly.""" + # Create multiple items + content_ids = [] + for i in range(5): + content_id = f"test-pagination-{uuid.uuid4()}" + content_ids.append(content_id) + cleanup_embeddings.append((ContentType.BLOCK, content_id, None)) + + await embeddings.store_content_embedding( + content_type=ContentType.BLOCK, + content_id=content_id, + embedding=mock_embedding, + searchable_text=f"pagination test item number {i}", + metadata={"index": i}, + user_id=None, + ) + + # Get first page + page1_results, total1 = await unified_hybrid_search( + query="pagination test", + content_types=[ContentType.BLOCK], + page=1, + page_size=2, + ) + + # Get second page + page2_results, total2 = await unified_hybrid_search( + query="pagination test", + content_types=[ContentType.BLOCK], + page=2, + page_size=2, + ) + + # Total should be consistent + assert total1 == total2 + + # Pages should have different content (if we have enough results) + if len(page1_results) > 0 and len(page2_results) > 0: + page1_ids = {r["content_id"] for r in page1_results} + page2_ids = {r["content_id"] for r in page2_results} + # No overlap between pages + assert page1_ids.isdisjoint(page2_ids) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_unified_hybrid_search_min_score_filtering( + server, + mock_embedding: list[float], + cleanup_embeddings: list, +): + """Test unified search respects min_score threshold.""" + content_id = f"test-minscore-{uuid.uuid4()}" + cleanup_embeddings.append((ContentType.BLOCK, content_id, None)) + + await embeddings.store_content_embedding( + content_type=ContentType.BLOCK, + content_id=content_id, + embedding=mock_embedding, + searchable_text="completely unrelated content about bananas", + metadata={}, + user_id=None, + ) + + # Search with very high min_score - should filter out low relevance + results_high, _ = await unified_hybrid_search( + query="quantum computing algorithms", + content_types=[ContentType.BLOCK], + min_score=0.9, # Very high threshold + page=1, + page_size=20, + ) + + # Search with low min_score + results_low, _ = await unified_hybrid_search( + query="quantum computing algorithms", + content_types=[ContentType.BLOCK], + min_score=0.01, # Very low threshold + page=1, + page_size=20, + ) + + # High threshold should have fewer or equal results + assert len(results_high) <= len(results_low) + + +# ============================================================================ +# hybrid_search (Store Agents) Tests +# ============================================================================ + + +@pytest.mark.asyncio(loop_scope="session") +async def test_hybrid_search_store_agents_sql_valid(server): + """Test that hybrid_search SQL executes without errors.""" + # This test verifies the SQL is syntactically correct + # even if no results are found + results, total = await hybrid_search( + query="test agent", + page=1, + page_size=20, + ) + + # Should not raise - verifies SQL is valid + assert isinstance(results, list) + assert isinstance(total, int) + assert total >= 0 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_hybrid_search_with_filters(server): + """Test hybrid_search with various filter options.""" + # Test with all filter types + results, total = await hybrid_search( + query="productivity", + featured=True, + creators=["test-creator"], + category="productivity", + page=1, + page_size=10, + ) + + # Should not raise - verifies filter SQL is valid + assert isinstance(results, list) + assert isinstance(total, int) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_hybrid_search_pagination(server): + """Test hybrid_search pagination.""" + # Page 1 + results1, total1 = await hybrid_search( + query="agent", + page=1, + page_size=5, + ) + + # Page 2 + results2, total2 = await hybrid_search( + query="agent", + page=2, + page_size=5, + ) + + # Verify SQL executes without error + assert isinstance(results1, list) + assert isinstance(results2, list) + assert isinstance(total1, int) + assert isinstance(total2, int) + + # If page 1 has results, total should be > 0 + # Note: total from page 2 may be 0 if no results on that page (COUNT(*) OVER limitation) + if results1: + assert total1 > 0 + + +# ============================================================================ +# SQL Validity Tests (verify queries don't break) +# ============================================================================ + + +@pytest.mark.asyncio(loop_scope="session") +async def test_all_content_types_searchable(server): + """Test that all content types can be searched without SQL errors.""" + for content_type in [ + ContentType.STORE_AGENT, + ContentType.BLOCK, + ContentType.DOCUMENTATION, + ]: + results, total = await unified_hybrid_search( + query="test", + content_types=[content_type], + page=1, + page_size=10, + ) + + # Should not raise + assert isinstance(results, list) + assert isinstance(total, int) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_multiple_content_types_searchable(server): + """Test searching multiple content types at once.""" + results, total = await unified_hybrid_search( + query="test", + content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION], + page=1, + page_size=20, + ) + + # Should not raise + assert isinstance(results, list) + assert isinstance(total, int) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_search_all_content_types_default(server): + """Test searching all content types (default behavior).""" + results, total = await unified_hybrid_search( + query="test", + content_types=None, # Should search all + page=1, + page_size=20, + ) + + # Should not raise + assert isinstance(results, list) + assert isinstance(total, int) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings_schema_test.py b/autogpt_platform/backend/backend/api/features/store/embeddings_schema_test.py index 441fd9961a..7ba200fda0 100644 --- a/autogpt_platform/backend/backend/api/features/store/embeddings_schema_test.py +++ b/autogpt_platform/backend/backend/api/features/store/embeddings_schema_test.py @@ -4,12 +4,13 @@ Integration tests for embeddings with schema handling. These tests verify that embeddings operations work correctly across different database schemas. """ -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from prisma.enums import ContentType from backend.api.features.store import embeddings +from backend.api.features.store.embeddings import EMBEDDING_DIM # Schema prefix tests removed - functionality moved to db.raw_with_schema() helper @@ -28,7 +29,7 @@ async def test_store_content_embedding_with_schema(): result = await embeddings.store_content_embedding( content_type=ContentType.STORE_AGENT, content_id="test-id", - embedding=[0.1] * 1536, + embedding=[0.1] * EMBEDDING_DIM, searchable_text="test text", metadata={"test": "data"}, user_id=None, @@ -125,84 +126,69 @@ async def test_delete_content_embedding_with_schema(): @pytest.mark.asyncio(loop_scope="session") @pytest.mark.integration async def test_get_embedding_stats_with_schema(): - """Test embedding statistics with proper schema handling.""" - with patch("backend.data.db.get_database_schema") as mock_schema: - mock_schema.return_value = "platform" + """Test embedding statistics with proper schema handling via content handlers.""" + # Mock handler to return stats + mock_handler = MagicMock() + mock_handler.get_stats = AsyncMock( + return_value={ + "total": 100, + "with_embeddings": 80, + "without_embeddings": 20, + } + ) - with patch("prisma.get_client") as mock_get_client: - mock_client = AsyncMock() - # Mock both query results - mock_client.query_raw.side_effect = [ - [{"count": 100}], # total_approved - [{"count": 80}], # with_embeddings - ] - mock_get_client.return_value = mock_client + with patch( + "backend.api.features.store.embeddings.CONTENT_HANDLERS", + {ContentType.STORE_AGENT: mock_handler}, + ): + result = await embeddings.get_embedding_stats() - result = await embeddings.get_embedding_stats() + # Verify handler was called + mock_handler.get_stats.assert_called_once() - # Verify both queries were called - assert mock_client.query_raw.call_count == 2 - - # Get both SQL queries - first_call = mock_client.query_raw.call_args_list[0] - second_call = mock_client.query_raw.call_args_list[1] - - first_sql = first_call[0][0] - second_sql = second_call[0][0] - - # Verify schema prefix in both queries - assert '"platform"."StoreListingVersion"' in first_sql - assert '"platform"."StoreListingVersion"' in second_sql - assert '"platform"."UnifiedContentEmbedding"' in second_sql - - # Verify results - assert result["total_approved"] == 100 - assert result["with_embeddings"] == 80 - assert result["without_embeddings"] == 20 - assert result["coverage_percent"] == 80.0 + # Verify new result structure + assert "by_type" in result + assert "totals" in result + assert result["totals"]["total"] == 100 + assert result["totals"]["with_embeddings"] == 80 + assert result["totals"]["without_embeddings"] == 20 + assert result["totals"]["coverage_percent"] == 80.0 @pytest.mark.asyncio(loop_scope="session") @pytest.mark.integration async def test_backfill_missing_embeddings_with_schema(): - """Test backfilling embeddings with proper schema handling.""" - with patch("backend.data.db.get_database_schema") as mock_schema: - mock_schema.return_value = "platform" + """Test backfilling embeddings via content handlers.""" + from backend.api.features.store.content_handlers import ContentItem - with patch("prisma.get_client") as mock_get_client: - mock_client = AsyncMock() - # Mock missing embeddings query - mock_client.query_raw.return_value = [ - { - "id": "version-1", - "name": "Test Agent", - "description": "Test description", - "subHeading": "Test heading", - "categories": ["test"], - } - ] - mock_get_client.return_value = mock_client + # Create mock content item + mock_item = ContentItem( + content_id="version-1", + content_type=ContentType.STORE_AGENT, + searchable_text="Test Agent Test description", + metadata={"name": "Test Agent"}, + ) + # Mock handler + mock_handler = MagicMock() + mock_handler.get_missing_items = AsyncMock(return_value=[mock_item]) + + with patch( + "backend.api.features.store.embeddings.CONTENT_HANDLERS", + {ContentType.STORE_AGENT: mock_handler}, + ): + with patch( + "backend.api.features.store.embeddings.generate_embedding", + return_value=[0.1] * EMBEDDING_DIM, + ): with patch( - "backend.api.features.store.embeddings.ensure_embedding" - ) as mock_ensure: - mock_ensure.return_value = True - + "backend.api.features.store.embeddings.store_content_embedding", + return_value=True, + ): result = await embeddings.backfill_missing_embeddings(batch_size=10) - # Verify the query was called - assert mock_client.query_raw.called - - # Get the SQL query - call_args = mock_client.query_raw.call_args - sql_query = call_args[0][0] - - # Verify schema prefix in query - assert '"platform"."StoreListingVersion"' in sql_query - assert '"platform"."UnifiedContentEmbedding"' in sql_query - - # Verify ensure_embedding was called - assert mock_ensure.called + # Verify handler was called + mock_handler.get_missing_items.assert_called_once_with(10) # Verify results assert result["processed"] == 1 @@ -226,7 +212,7 @@ async def test_ensure_content_embedding_with_schema(): with patch( "backend.api.features.store.embeddings.generate_embedding" ) as mock_generate: - mock_generate.return_value = [0.1] * 1536 + mock_generate.return_value = [0.1] * EMBEDDING_DIM with patch( "backend.api.features.store.embeddings.store_content_embedding" @@ -260,7 +246,7 @@ async def test_backward_compatibility_store_embedding(): result = await embeddings.store_embedding( version_id="test-version-id", - embedding=[0.1] * 1536, + embedding=[0.1] * EMBEDDING_DIM, tx=None, ) @@ -315,7 +301,7 @@ async def test_schema_handling_error_cases(): result = await embeddings.store_content_embedding( content_type=ContentType.STORE_AGENT, content_id="test-id", - embedding=[0.1] * 1536, + embedding=[0.1] * EMBEDDING_DIM, searchable_text="test", metadata=None, user_id=None, diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings_test.py b/autogpt_platform/backend/backend/api/features/store/embeddings_test.py index 98329abb19..a17e393472 100644 --- a/autogpt_platform/backend/backend/api/features/store/embeddings_test.py +++ b/autogpt_platform/backend/backend/api/features/store/embeddings_test.py @@ -63,7 +63,7 @@ async def test_generate_embedding_success(): result = await embeddings.generate_embedding("test text") assert result is not None - assert len(result) == 1536 + assert len(result) == embeddings.EMBEDDING_DIM assert result[0] == 0.1 mock_client.embeddings.create.assert_called_once_with( @@ -110,7 +110,7 @@ async def test_generate_embedding_text_truncation(): mock_client = MagicMock() mock_response = MagicMock() mock_response.data = [MagicMock()] - mock_response.data[0].embedding = [0.1] * 1536 + mock_response.data[0].embedding = [0.1] * embeddings.EMBEDDING_DIM # Use AsyncMock for async embeddings.create method mock_client.embeddings.create = AsyncMock(return_value=mock_response) @@ -297,72 +297,92 @@ async def test_ensure_embedding_generation_fails(mock_get, mock_generate): @pytest.mark.asyncio(loop_scope="session") async def test_get_embedding_stats(): """Test embedding statistics retrieval.""" - # Mock approved count query and embedded count query - mock_approved_result = [{"count": 100}] - mock_embedded_result = [{"count": 75}] + # Mock handler stats for each content type + mock_handler = MagicMock() + mock_handler.get_stats = AsyncMock( + return_value={ + "total": 100, + "with_embeddings": 75, + "without_embeddings": 25, + } + ) + # Patch the CONTENT_HANDLERS where it's used (in embeddings module) with patch( - "backend.api.features.store.embeddings.query_raw_with_schema", - side_effect=[mock_approved_result, mock_embedded_result], + "backend.api.features.store.embeddings.CONTENT_HANDLERS", + {ContentType.STORE_AGENT: mock_handler}, ): result = await embeddings.get_embedding_stats() - assert result["total_approved"] == 100 - assert result["with_embeddings"] == 75 - assert result["without_embeddings"] == 25 - assert result["coverage_percent"] == 75.0 + assert "by_type" in result + assert "totals" in result + assert result["totals"]["total"] == 100 + assert result["totals"]["with_embeddings"] == 75 + assert result["totals"]["without_embeddings"] == 25 + assert result["totals"]["coverage_percent"] == 75.0 @pytest.mark.asyncio(loop_scope="session") -@patch("backend.api.features.store.embeddings.ensure_embedding") -async def test_backfill_missing_embeddings_success(mock_ensure): +@patch("backend.api.features.store.embeddings.store_content_embedding") +async def test_backfill_missing_embeddings_success(mock_store): """Test backfill with successful embedding generation.""" - # Mock missing embeddings query - mock_missing = [ - { - "id": "version-1", - "name": "Agent 1", - "description": "Description 1", - "subHeading": "Heading 1", - "categories": ["AI"], - }, - { - "id": "version-2", - "name": "Agent 2", - "description": "Description 2", - "subHeading": "Heading 2", - "categories": ["Productivity"], - }, + # Mock ContentItem from handlers + from backend.api.features.store.content_handlers import ContentItem + + mock_items = [ + ContentItem( + content_id="version-1", + content_type=ContentType.STORE_AGENT, + searchable_text="Agent 1 Description 1", + metadata={"name": "Agent 1"}, + ), + ContentItem( + content_id="version-2", + content_type=ContentType.STORE_AGENT, + searchable_text="Agent 2 Description 2", + metadata={"name": "Agent 2"}, + ), ] - # Mock ensure_embedding to succeed for first, fail for second - mock_ensure.side_effect = [True, False] + # Mock handler to return missing items + mock_handler = MagicMock() + mock_handler.get_missing_items = AsyncMock(return_value=mock_items) + + # Mock store_content_embedding to succeed for first, fail for second + mock_store.side_effect = [True, False] with patch( - "backend.api.features.store.embeddings.query_raw_with_schema", - return_value=mock_missing, + "backend.api.features.store.embeddings.CONTENT_HANDLERS", + {ContentType.STORE_AGENT: mock_handler}, ): - result = await embeddings.backfill_missing_embeddings(batch_size=5) + with patch( + "backend.api.features.store.embeddings.generate_embedding", + return_value=[0.1] * embeddings.EMBEDDING_DIM, + ): + result = await embeddings.backfill_missing_embeddings(batch_size=5) - assert result["processed"] == 2 - assert result["success"] == 1 - assert result["failed"] == 1 - assert mock_ensure.call_count == 2 + assert result["processed"] == 2 + assert result["success"] == 1 + assert result["failed"] == 1 + assert mock_store.call_count == 2 @pytest.mark.asyncio(loop_scope="session") async def test_backfill_missing_embeddings_no_missing(): """Test backfill when no embeddings are missing.""" + # Mock handler to return no missing items + mock_handler = MagicMock() + mock_handler.get_missing_items = AsyncMock(return_value=[]) + with patch( - "backend.api.features.store.embeddings.query_raw_with_schema", - return_value=[], + "backend.api.features.store.embeddings.CONTENT_HANDLERS", + {ContentType.STORE_AGENT: mock_handler}, ): result = await embeddings.backfill_missing_embeddings(batch_size=5) assert result["processed"] == 0 assert result["success"] == 0 assert result["failed"] == 0 - assert result["message"] == "No missing embeddings" @pytest.mark.asyncio(loop_scope="session") diff --git a/autogpt_platform/backend/backend/api/features/store/hybrid_search.py b/autogpt_platform/backend/backend/api/features/store/hybrid_search.py index fbbbe62cb3..ff7903ffee 100644 --- a/autogpt_platform/backend/backend/api/features/store/hybrid_search.py +++ b/autogpt_platform/backend/backend/api/features/store/hybrid_search.py @@ -1,16 +1,18 @@ """ -Hybrid Search for Store Agents +Unified Hybrid Search Combines semantic (embedding) search with lexical (tsvector) search -for improved relevance in marketplace agent discovery. +for improved relevance across all content types (agents, blocks, docs). """ import logging from dataclasses import dataclass -from datetime import datetime from typing import Any, Literal +from prisma.enums import ContentType + from backend.api.features.store.embeddings import ( + EMBEDDING_DIM, embed_query, embedding_to_vector_string, ) @@ -20,17 +22,299 @@ logger = logging.getLogger(__name__) @dataclass -class HybridSearchWeights: - """Weights for combining search signals.""" +class UnifiedSearchWeights: + """Weights for unified search (no popularity signal).""" - semantic: float = 0.30 # Embedding cosine similarity - lexical: float = 0.30 # tsvector ts_rank_cd score - category: float = 0.20 # Category match boost - recency: float = 0.10 # Newer agents ranked higher - popularity: float = 0.10 # Agent usage/runs (PageRank-like) + semantic: float = 0.40 # Embedding cosine similarity + lexical: float = 0.40 # tsvector ts_rank_cd score + category: float = 0.10 # Category match boost (for types that have categories) + recency: float = 0.10 # Newer content ranked higher def __post_init__(self): """Validate weights are non-negative and sum to approximately 1.0.""" + total = self.semantic + self.lexical + self.category + self.recency + + if any( + w < 0 for w in [self.semantic, self.lexical, self.category, self.recency] + ): + raise ValueError("All weights must be non-negative") + + if not (0.99 <= total <= 1.01): + raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}") + + +# Default weights for unified search +DEFAULT_UNIFIED_WEIGHTS = UnifiedSearchWeights() + +# Minimum relevance score thresholds +DEFAULT_MIN_SCORE = 0.15 # For unified search (more permissive) +DEFAULT_STORE_AGENT_MIN_SCORE = 0.20 # For store agent search (original threshold) + + +async def unified_hybrid_search( + query: str, + content_types: list[ContentType] | None = None, + category: str | None = None, + page: int = 1, + page_size: int = 20, + weights: UnifiedSearchWeights | None = None, + min_score: float | None = None, + user_id: str | None = None, +) -> tuple[list[dict[str, Any]], int]: + """ + Unified hybrid search across all content types. + + Searches UnifiedContentEmbedding using both semantic (vector) and lexical (tsvector) signals. + + Args: + query: Search query string + content_types: List of content types to search. Defaults to all public types. + category: Filter by category (for content types that support it) + page: Page number (1-indexed) + page_size: Results per page + weights: Custom weights for search signals + min_score: Minimum relevance score threshold (0-1) + user_id: User ID for searching private content (library agents) + + Returns: + Tuple of (results list, total count) + """ + # Validate inputs + query = query.strip() + if not query: + return [], 0 + + if page < 1: + page = 1 + if page_size < 1: + page_size = 1 + if page_size > 100: + page_size = 100 + + if content_types is None: + content_types = [ + ContentType.STORE_AGENT, + ContentType.BLOCK, + ContentType.DOCUMENTATION, + ] + + if weights is None: + weights = DEFAULT_UNIFIED_WEIGHTS + if min_score is None: + min_score = DEFAULT_MIN_SCORE + + offset = (page - 1) * page_size + + # Generate query embedding + query_embedding = await embed_query(query) + + # Graceful degradation if embedding unavailable + if query_embedding is None or not query_embedding: + logger.warning( + "Failed to generate query embedding - falling back to lexical-only search. " + "Check that openai_internal_api_key is configured and OpenAI API is accessible." + ) + query_embedding = [0.0] * EMBEDDING_DIM + # Redistribute semantic weight to lexical + total_non_semantic = weights.lexical + weights.category + weights.recency + if total_non_semantic > 0: + factor = 1.0 / total_non_semantic + weights = UnifiedSearchWeights( + semantic=0.0, + lexical=weights.lexical * factor, + category=weights.category * factor, + recency=weights.recency * factor, + ) + else: + weights = UnifiedSearchWeights( + semantic=0.0, lexical=1.0, category=0.0, recency=0.0 + ) + + # Build parameters + params: list[Any] = [] + param_idx = 1 + + # Query for lexical search + params.append(query) + query_param = f"${param_idx}" + param_idx += 1 + + # Query lowercase for category matching + params.append(query.lower()) + query_lower_param = f"${param_idx}" + param_idx += 1 + + # Embedding + embedding_str = embedding_to_vector_string(query_embedding) + params.append(embedding_str) + embedding_param = f"${param_idx}" + param_idx += 1 + + # Content types + content_type_values = [ct.value for ct in content_types] + params.append(content_type_values) + content_types_param = f"${param_idx}" + param_idx += 1 + + # User ID filter (for private content) + user_filter = "" + if user_id is not None: + params.append(user_id) + user_filter = f'AND (uce."userId" = ${param_idx} OR uce."userId" IS NULL)' + param_idx += 1 + else: + user_filter = 'AND uce."userId" IS NULL' + + # Weights + params.append(weights.semantic) + w_semantic = f"${param_idx}" + param_idx += 1 + + params.append(weights.lexical) + w_lexical = f"${param_idx}" + param_idx += 1 + + params.append(weights.category) + w_category = f"${param_idx}" + param_idx += 1 + + params.append(weights.recency) + w_recency = f"${param_idx}" + param_idx += 1 + + # Min score + params.append(min_score) + min_score_param = f"${param_idx}" + param_idx += 1 + + # Pagination + params.append(page_size) + limit_param = f"${param_idx}" + param_idx += 1 + + params.append(offset) + offset_param = f"${param_idx}" + param_idx += 1 + + # Unified search query on UnifiedContentEmbedding + sql_query = f""" + WITH candidates AS ( + -- Lexical matches (uses GIN index on search column) + SELECT uce.id, uce."contentType", uce."contentId" + FROM {{schema_prefix}}"UnifiedContentEmbedding" uce + WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[]) + {user_filter} + AND uce.search @@ plainto_tsquery('english', {query_param}) + + UNION + + -- Semantic matches (uses HNSW index on embedding) + ( + SELECT uce.id, uce."contentType", uce."contentId" + FROM {{schema_prefix}}"UnifiedContentEmbedding" uce + WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[]) + {user_filter} + ORDER BY uce.embedding <=> {embedding_param}::vector + LIMIT 200 + ) + ), + search_scores AS ( + SELECT + uce."contentType" as content_type, + uce."contentId" as content_id, + uce."searchableText" as searchable_text, + uce.metadata, + uce."updatedAt" as updated_at, + -- Semantic score: cosine similarity (1 - distance) + COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score, + -- Lexical score: ts_rank_cd + COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw, + -- Category match from metadata + CASE + WHEN uce.metadata ? 'categories' AND EXISTS ( + SELECT 1 FROM jsonb_array_elements_text(uce.metadata->'categories') cat + WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%' + ) + THEN 1.0 + ELSE 0.0 + END as category_score, + -- Recency score: linear decay over 90 days + GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - uce."updatedAt")) / (90 * 24 * 3600)) as recency_score + FROM candidates c + INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce ON c.id = uce.id + ), + max_lexical AS ( + SELECT GREATEST(MAX(lexical_raw), 0.001) as max_val FROM search_scores + ), + normalized AS ( + SELECT + ss.*, + ss.lexical_raw / ml.max_val as lexical_score + FROM search_scores ss + CROSS JOIN max_lexical ml + ), + scored AS ( + SELECT + content_type, + content_id, + searchable_text, + metadata, + updated_at, + semantic_score, + lexical_score, + category_score, + recency_score, + ( + {w_semantic} * semantic_score + + {w_lexical} * lexical_score + + {w_category} * category_score + + {w_recency} * recency_score + ) as combined_score + FROM normalized + ), + filtered AS ( + SELECT + *, + COUNT(*) OVER () as total_count + FROM scored + WHERE combined_score >= {min_score_param} + ) + SELECT * FROM filtered + ORDER BY combined_score DESC + LIMIT {limit_param} OFFSET {offset_param} + """ + + results = await query_raw_with_schema( + sql_query, *params, set_public_search_path=True + ) + + total = results[0]["total_count"] if results else 0 + + # Clean up results + for result in results: + result.pop("total_count", None) + + logger.info(f"Unified hybrid search: {len(results)} results, {total} total") + + return results, total + + +# ============================================================================ +# Store Agent specific search (with full metadata) +# ============================================================================ + + +@dataclass +class StoreAgentSearchWeights: + """Weights for store agent search including popularity.""" + + semantic: float = 0.30 + lexical: float = 0.30 + category: float = 0.20 + recency: float = 0.10 + popularity: float = 0.10 + + def __post_init__(self): total = ( self.semantic + self.lexical @@ -38,7 +322,6 @@ class HybridSearchWeights: + self.recency + self.popularity ) - if any( w < 0 for w in [ @@ -50,46 +333,11 @@ class HybridSearchWeights: ] ): raise ValueError("All weights must be non-negative") - if not (0.99 <= total <= 1.01): raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}") -DEFAULT_WEIGHTS = HybridSearchWeights() - -# Minimum relevance score threshold - agents below this are filtered out -# With weights (0.30 semantic + 0.30 lexical + 0.20 category + 0.10 recency + 0.10 popularity): -# - 0.20 means at least ~60% semantic match OR strong lexical match required -# - Ensures only genuinely relevant results are returned -# - Recency/popularity alone (0.10 each) won't pass the threshold -DEFAULT_MIN_SCORE = 0.20 - - -@dataclass -class HybridSearchResult: - """A single search result with score breakdown.""" - - slug: str - agent_name: str - agent_image: str - creator_username: str - creator_avatar: str - sub_heading: str - description: str - runs: int - rating: float - categories: list[str] - featured: bool - is_available: bool - updated_at: datetime - - # Score breakdown (for debugging/tuning) - combined_score: float - semantic_score: float = 0.0 - lexical_score: float = 0.0 - category_score: float = 0.0 - recency_score: float = 0.0 - popularity_score: float = 0.0 +DEFAULT_STORE_AGENT_WEIGHTS = StoreAgentSearchWeights() async def hybrid_search( @@ -102,276 +350,263 @@ async def hybrid_search( ) = None, page: int = 1, page_size: int = 20, - weights: HybridSearchWeights | None = None, + weights: StoreAgentSearchWeights | None = None, min_score: float | None = None, ) -> tuple[list[dict[str, Any]], int]: """ - Perform hybrid search combining semantic and lexical signals. + Hybrid search for store agents with full metadata. - Args: - query: Search query string - featured: Filter for featured agents only - creators: Filter by creator usernames - category: Filter by category - sorted_by: Sort order (relevance uses hybrid scoring) - page: Page number (1-indexed) - page_size: Results per page - weights: Custom weights for search signals - min_score: Minimum relevance score threshold (0-1). Results below - this score are filtered out. Defaults to DEFAULT_MIN_SCORE. - - Returns: - Tuple of (results list, total count). Returns empty list if no - results meet the minimum relevance threshold. + Uses UnifiedContentEmbedding for search, joins to StoreAgent for metadata. """ - # Validate inputs query = query.strip() if not query: - return [], 0 # Empty query returns no results + return [], 0 if page < 1: page = 1 if page_size < 1: page_size = 1 - if page_size > 100: # Cap at reasonable limit to prevent performance issues + if page_size > 100: page_size = 100 if weights is None: - weights = DEFAULT_WEIGHTS + weights = DEFAULT_STORE_AGENT_WEIGHTS if min_score is None: - min_score = DEFAULT_MIN_SCORE + min_score = ( + DEFAULT_STORE_AGENT_MIN_SCORE # Use original threshold for store agents + ) offset = (page - 1) * page_size # Generate query embedding query_embedding = await embed_query(query) - # Build WHERE clause conditions - where_parts: list[str] = ["sa.is_available = true"] + # Graceful degradation + if query_embedding is None or not query_embedding: + logger.warning( + "Failed to generate query embedding - falling back to lexical-only search." + ) + query_embedding = [0.0] * EMBEDDING_DIM + total_non_semantic = ( + weights.lexical + weights.category + weights.recency + weights.popularity + ) + if total_non_semantic > 0: + factor = 1.0 / total_non_semantic + weights = StoreAgentSearchWeights( + semantic=0.0, + lexical=weights.lexical * factor, + category=weights.category * factor, + recency=weights.recency * factor, + popularity=weights.popularity * factor, + ) + else: + weights = StoreAgentSearchWeights( + semantic=0.0, lexical=1.0, category=0.0, recency=0.0, popularity=0.0 + ) + + # Build parameters params: list[Any] = [] - param_index = 1 + param_idx = 1 - # Add search query for lexical matching params.append(query) - query_param = f"${param_index}" - param_index += 1 + query_param = f"${param_idx}" + param_idx += 1 - # Add lowercased query for category matching params.append(query.lower()) - query_lower_param = f"${param_index}" - param_index += 1 + query_lower_param = f"${param_idx}" + param_idx += 1 + + embedding_str = embedding_to_vector_string(query_embedding) + params.append(embedding_str) + embedding_param = f"${param_idx}" + param_idx += 1 + + # Build WHERE clause for StoreAgent filters + where_parts = ["sa.is_available = true"] if featured: where_parts.append("sa.featured = true") if creators: - where_parts.append(f"sa.creator_username = ANY(${param_index})") params.append(creators) - param_index += 1 + where_parts.append(f"sa.creator_username = ANY(${param_idx})") + param_idx += 1 if category: - where_parts.append(f"${param_index} = ANY(sa.categories)") params.append(category) - param_index += 1 + where_parts.append(f"${param_idx} = ANY(sa.categories)") + param_idx += 1 - # Safe: where_parts only contains hardcoded strings with $N parameter placeholders - # No user input is concatenated directly into the SQL string where_clause = " AND ".join(where_parts) - # Embedding is required for hybrid search - fail fast if unavailable - if query_embedding is None or not query_embedding: - # Log detailed error server-side - logger.error( - "Failed to generate query embedding. " - "Check that openai_internal_api_key is configured and OpenAI API is accessible." - ) - # Raise generic error to client - raise ValueError("Search service temporarily unavailable") - - # Add embedding parameter - embedding_str = embedding_to_vector_string(query_embedding) - params.append(embedding_str) - embedding_param = f"${param_index}" - param_index += 1 - - # Add weight parameters for SQL calculation + # Weights params.append(weights.semantic) - weight_semantic_param = f"${param_index}" - param_index += 1 + w_semantic = f"${param_idx}" + param_idx += 1 params.append(weights.lexical) - weight_lexical_param = f"${param_index}" - param_index += 1 + w_lexical = f"${param_idx}" + param_idx += 1 params.append(weights.category) - weight_category_param = f"${param_index}" - param_index += 1 + w_category = f"${param_idx}" + param_idx += 1 params.append(weights.recency) - weight_recency_param = f"${param_index}" - param_index += 1 + w_recency = f"${param_idx}" + param_idx += 1 params.append(weights.popularity) - weight_popularity_param = f"${param_index}" - param_index += 1 + w_popularity = f"${param_idx}" + param_idx += 1 - # Add min_score parameter params.append(min_score) - min_score_param = f"${param_index}" - param_index += 1 + min_score_param = f"${param_idx}" + param_idx += 1 - # Optimized hybrid search query: - # 1. Direct join to UnifiedContentEmbedding via contentId=storeListingVersionId (no redundant JOINs) - # 2. UNION approach (deduplicates agents matching both branches) - # 3. COUNT(*) OVER() to get total count in single query - # 4. Optimized category matching with EXISTS + unnest - # 5. Pre-calculated max values for lexical and popularity normalization - # 6. Simplified recency calculation with linear decay - # 7. Logarithmic popularity scaling to prevent viral agents from dominating + params.append(page_size) + limit_param = f"${param_idx}" + param_idx += 1 + + params.append(offset) + offset_param = f"${param_idx}" + param_idx += 1 + + # Query using UnifiedContentEmbedding for search, StoreAgent for metadata sql_query = f""" - WITH candidates AS ( - -- Lexical matches (uses GIN index on search column) - SELECT sa."storeListingVersionId" - FROM {{schema_prefix}}"StoreAgent" sa - WHERE {where_clause} - AND sa.search @@ plainto_tsquery('english', {query_param}) + WITH candidates AS ( + -- Lexical matches via UnifiedContentEmbedding.search + SELECT uce."contentId" as "storeListingVersionId" + FROM {{schema_prefix}}"UnifiedContentEmbedding" uce + INNER JOIN {{schema_prefix}}"StoreAgent" sa + ON uce."contentId" = sa."storeListingVersionId" + WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType" + AND uce."userId" IS NULL + AND uce.search @@ plainto_tsquery('english', {query_param}) + AND {where_clause} - UNION + UNION - -- Semantic matches (uses HNSW index on embedding with KNN) - SELECT "storeListingVersionId" - FROM ( - SELECT sa."storeListingVersionId", uce.embedding - FROM {{schema_prefix}}"StoreAgent" sa - INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce - ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType" - WHERE {where_clause} - ORDER BY uce.embedding <=> {embedding_param}::vector - LIMIT 200 - ) semantic_results - ), - search_scores AS ( - SELECT - sa.slug, - sa.agent_name, - sa.agent_image, - sa.creator_username, - sa.creator_avatar, - sa.sub_heading, - sa.description, - sa.runs, - sa.rating, - sa.categories, - sa.featured, - sa.is_available, - sa.updated_at, - -- Semantic score: cosine similarity (1 - distance) - COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score, - -- Lexical score: ts_rank_cd (will be normalized later) - COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw, - -- Category match: optimized with unnest for better performance - CASE - WHEN EXISTS ( - SELECT 1 FROM unnest(sa.categories) cat - WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%' - ) - THEN 1.0 - ELSE 0.0 - END as category_score, - -- Recency score: linear decay over 90 days (simpler than exponential) - GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score, - -- Popularity raw: agent runs count (will be normalized with log scaling) - sa.runs as popularity_raw - FROM candidates c + -- Semantic matches via UnifiedContentEmbedding.embedding + SELECT uce."contentId" as "storeListingVersionId" + FROM ( + SELECT uce."contentId", uce.embedding + FROM {{schema_prefix}}"UnifiedContentEmbedding" uce INNER JOIN {{schema_prefix}}"StoreAgent" sa - ON c."storeListingVersionId" = sa."storeListingVersionId" - LEFT JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce - ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType" - ), - max_lexical AS ( - SELECT MAX(lexical_raw) as max_val FROM search_scores - ), - max_popularity AS ( - SELECT MAX(popularity_raw) as max_val FROM search_scores - ), - normalized AS ( - SELECT - ss.*, - -- Normalize lexical score by pre-calculated max - CASE - WHEN ml.max_val > 0 - THEN ss.lexical_raw / ml.max_val - ELSE 0 - END as lexical_score, - -- Normalize popularity with logarithmic scaling to prevent viral agents from dominating - -- LOG(1 + runs) / LOG(1 + max_runs) ensures score is 0-1 range - CASE - WHEN mp.max_val > 0 AND ss.popularity_raw > 0 - THEN LN(1 + ss.popularity_raw) / LN(1 + mp.max_val) - ELSE 0 - END as popularity_score - FROM search_scores ss - CROSS JOIN max_lexical ml - CROSS JOIN max_popularity mp - ), - scored AS ( - SELECT - slug, - agent_name, - agent_image, - creator_username, - creator_avatar, - sub_heading, - description, - runs, - rating, - categories, - featured, - is_available, - updated_at, - semantic_score, - lexical_score, - category_score, - recency_score, - popularity_score, - ( - {weight_semantic_param} * semantic_score + - {weight_lexical_param} * lexical_score + - {weight_category_param} * category_score + - {weight_recency_param} * recency_score + - {weight_popularity_param} * popularity_score - ) as combined_score - FROM normalized - ), - filtered AS ( - SELECT - *, - COUNT(*) OVER () as total_count - FROM scored - WHERE combined_score >= {min_score_param} - ) - SELECT * FROM filtered - ORDER BY combined_score DESC - LIMIT ${param_index} OFFSET ${param_index + 1} + ON uce."contentId" = sa."storeListingVersionId" + WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType" + AND uce."userId" IS NULL + AND {where_clause} + ORDER BY uce.embedding <=> {embedding_param}::vector + LIMIT 200 + ) uce + ), + search_scores AS ( + SELECT + sa.slug, + sa.agent_name, + sa.agent_image, + sa.creator_username, + sa.creator_avatar, + sa.sub_heading, + sa.description, + sa.runs, + sa.rating, + sa.categories, + sa.featured, + sa.is_available, + sa.updated_at, + -- Semantic score + COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score, + -- Lexical score (raw, will normalize) + COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw, + -- Category match + CASE + WHEN EXISTS ( + SELECT 1 FROM unnest(sa.categories) cat + WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%' + ) + THEN 1.0 + ELSE 0.0 + END as category_score, + -- Recency + GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score, + -- Popularity (raw) + sa.runs as popularity_raw + FROM candidates c + INNER JOIN {{schema_prefix}}"StoreAgent" sa + ON c."storeListingVersionId" = sa."storeListingVersionId" + INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce + ON sa."storeListingVersionId" = uce."contentId" + AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType" + ), + max_vals AS ( + SELECT + GREATEST(MAX(lexical_raw), 0.001) as max_lexical, + GREATEST(MAX(popularity_raw), 1) as max_popularity + FROM search_scores + ), + normalized AS ( + SELECT + ss.*, + ss.lexical_raw / mv.max_lexical as lexical_score, + CASE + WHEN ss.popularity_raw > 0 + THEN LN(1 + ss.popularity_raw) / LN(1 + mv.max_popularity) + ELSE 0 + END as popularity_score + FROM search_scores ss + CROSS JOIN max_vals mv + ), + scored AS ( + SELECT + slug, + agent_name, + agent_image, + creator_username, + creator_avatar, + sub_heading, + description, + runs, + rating, + categories, + featured, + is_available, + updated_at, + semantic_score, + lexical_score, + category_score, + recency_score, + popularity_score, + ( + {w_semantic} * semantic_score + + {w_lexical} * lexical_score + + {w_category} * category_score + + {w_recency} * recency_score + + {w_popularity} * popularity_score + ) as combined_score + FROM normalized + ), + filtered AS ( + SELECT *, COUNT(*) OVER () as total_count + FROM scored + WHERE combined_score >= {min_score_param} + ) + SELECT * FROM filtered + ORDER BY combined_score DESC + LIMIT {limit_param} OFFSET {offset_param} """ - # Add pagination params - params.extend([page_size, offset]) - - # Execute search query - includes total_count via window function results = await query_raw_with_schema( sql_query, *params, set_public_search_path=True ) - # Extract total count from first result (all rows have same count) total = results[0]["total_count"] if results else 0 - # Remove total_count from results before returning for result in results: result.pop("total_count", None) - # Log without sensitive query content - logger.info(f"Hybrid search: {len(results)} results, {total} total") + logger.info(f"Hybrid search (store agents): {len(results)} results, {total} total") return results, total @@ -381,13 +616,10 @@ async def hybrid_search_simple( page: int = 1, page_size: int = 20, ) -> tuple[list[dict[str, Any]], int]: - """ - Simplified hybrid search for common use cases. + """Simplified hybrid search for store agents.""" + return await hybrid_search(query=query, page=page, page_size=page_size) - Uses default weights and no filters. - """ - return await hybrid_search( - query=query, - page=page, - page_size=page_size, - ) + +# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights +# for existing code that expects the popularity parameter +HybridSearchWeights = StoreAgentSearchWeights diff --git a/autogpt_platform/backend/backend/api/features/store/hybrid_search_test.py b/autogpt_platform/backend/backend/api/features/store/hybrid_search_test.py index 6a5cd7ad6d..70f692ec02 100644 --- a/autogpt_platform/backend/backend/api/features/store/hybrid_search_test.py +++ b/autogpt_platform/backend/backend/api/features/store/hybrid_search_test.py @@ -7,8 +7,15 @@ These tests verify that hybrid search works correctly across different database from unittest.mock import patch import pytest +from prisma.enums import ContentType -from backend.api.features.store.hybrid_search import HybridSearchWeights, hybrid_search +from backend.api.features.store import embeddings +from backend.api.features.store.hybrid_search import ( + HybridSearchWeights, + UnifiedSearchWeights, + hybrid_search, + unified_hybrid_search, +) @pytest.mark.asyncio(loop_scope="session") @@ -49,7 +56,7 @@ async def test_hybrid_search_with_schema_handling(): with patch( "backend.api.features.store.hybrid_search.embed_query" ) as mock_embed: - mock_embed.return_value = [0.1] * 1536 # Mock embedding + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM # Mock embedding results, total = await hybrid_search( query=query, @@ -85,7 +92,7 @@ async def test_hybrid_search_with_public_schema(): with patch( "backend.api.features.store.hybrid_search.embed_query" ) as mock_embed: - mock_embed.return_value = [0.1] * 1536 + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM results, total = await hybrid_search( query="test", @@ -116,7 +123,7 @@ async def test_hybrid_search_with_custom_schema(): with patch( "backend.api.features.store.hybrid_search.embed_query" ) as mock_embed: - mock_embed.return_value = [0.1] * 1536 + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM results, total = await hybrid_search( query="test", @@ -134,22 +141,52 @@ async def test_hybrid_search_with_custom_schema(): @pytest.mark.asyncio(loop_scope="session") @pytest.mark.integration async def test_hybrid_search_without_embeddings(): - """Test hybrid search fails fast when embeddings are unavailable.""" - # Patch where the function is used, not where it's defined - with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed: - # Simulate embedding failure - mock_embed.return_value = None + """Test hybrid search gracefully degrades when embeddings are unavailable.""" + # Mock database to return some results + mock_results = [ + { + "slug": "test-agent", + "agent_name": "Test Agent", + "agent_image": "test.png", + "creator_username": "creator", + "creator_avatar": "avatar.png", + "sub_heading": "Test heading", + "description": "Test description", + "runs": 100, + "rating": 4.5, + "categories": ["AI"], + "featured": False, + "is_available": True, + "updated_at": "2025-01-01T00:00:00Z", + "semantic_score": 0.0, # Zero because no embedding + "lexical_score": 0.5, + "category_score": 0.0, + "recency_score": 0.1, + "popularity_score": 0.2, + "combined_score": 0.3, + "total_count": 1, + } + ] - # Should raise ValueError with helpful message - with pytest.raises(ValueError) as exc_info: - await hybrid_search( + with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed: + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + # Simulate embedding failure + mock_embed.return_value = None + mock_query.return_value = mock_results + + # Should NOT raise - graceful degradation + results, total = await hybrid_search( query="test", page=1, page_size=20, ) - # Verify error message is generic (doesn't leak implementation details) - assert "Search service temporarily unavailable" in str(exc_info.value) + # Verify it returns results even without embeddings + assert len(results) == 1 + assert results[0]["slug"] == "test-agent" + assert total == 1 @pytest.mark.asyncio(loop_scope="session") @@ -164,7 +201,7 @@ async def test_hybrid_search_with_filters(): with patch( "backend.api.features.store.hybrid_search.embed_query" ) as mock_embed: - mock_embed.return_value = [0.1] * 1536 + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM # Test with featured filter results, total = await hybrid_search( @@ -204,7 +241,7 @@ async def test_hybrid_search_weights(): with patch( "backend.api.features.store.hybrid_search.embed_query" ) as mock_embed: - mock_embed.return_value = [0.1] * 1536 + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM results, total = await hybrid_search( query="test", @@ -248,7 +285,7 @@ async def test_hybrid_search_min_score_filtering(): with patch( "backend.api.features.store.hybrid_search.embed_query" ) as mock_embed: - mock_embed.return_value = [0.1] * 1536 + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM # Test with custom min_score results, total = await hybrid_search( @@ -283,7 +320,7 @@ async def test_hybrid_search_pagination(): with patch( "backend.api.features.store.hybrid_search.embed_query" ) as mock_embed: - mock_embed.return_value = [0.1] * 1536 + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM # Test page 2 with page_size 10 results, total = await hybrid_search( @@ -317,7 +354,7 @@ async def test_hybrid_search_error_handling(): with patch( "backend.api.features.store.hybrid_search.embed_query" ) as mock_embed: - mock_embed.return_value = [0.1] * 1536 + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM # Should raise exception with pytest.raises(Exception) as exc_info: @@ -330,5 +367,301 @@ async def test_hybrid_search_error_handling(): assert "Database connection error" in str(exc_info.value) +# ============================================================================= +# Unified Hybrid Search Tests +# ============================================================================= + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_unified_hybrid_search_basic(): + """Test basic unified hybrid search across all content types.""" + mock_results = [ + { + "content_type": "STORE_AGENT", + "content_id": "agent-1", + "searchable_text": "Test Agent Description", + "metadata": {"name": "Test Agent"}, + "updated_at": "2025-01-01T00:00:00Z", + "semantic_score": 0.7, + "lexical_score": 0.8, + "category_score": 0.5, + "recency_score": 0.3, + "combined_score": 0.6, + "total_count": 2, + }, + { + "content_type": "BLOCK", + "content_id": "block-1", + "searchable_text": "Test Block Description", + "metadata": {"name": "Test Block"}, + "updated_at": "2025-01-01T00:00:00Z", + "semantic_score": 0.6, + "lexical_score": 0.7, + "category_score": 0.4, + "recency_score": 0.2, + "combined_score": 0.5, + "total_count": 2, + }, + ] + + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_query.return_value = mock_results + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM + + results, total = await unified_hybrid_search( + query="test", + page=1, + page_size=20, + ) + + assert len(results) == 2 + assert total == 2 + assert results[0]["content_type"] == "STORE_AGENT" + assert results[1]["content_type"] == "BLOCK" + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_unified_hybrid_search_filter_by_content_type(): + """Test unified search filtering by specific content types.""" + mock_results = [ + { + "content_type": "BLOCK", + "content_id": "block-1", + "searchable_text": "Test Block", + "metadata": {}, + "updated_at": "2025-01-01T00:00:00Z", + "semantic_score": 0.7, + "lexical_score": 0.8, + "category_score": 0.0, + "recency_score": 0.3, + "combined_score": 0.5, + "total_count": 1, + }, + ] + + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_query.return_value = mock_results + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM + + results, total = await unified_hybrid_search( + query="test", + content_types=[ContentType.BLOCK], + page=1, + page_size=20, + ) + + # Verify content_types parameter was passed correctly + call_args = mock_query.call_args + params = call_args[0][1:] + # The content types should be in the params as a list + assert ["BLOCK"] in params + + assert len(results) == 1 + assert total == 1 + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_unified_hybrid_search_with_user_id(): + """Test unified search with user_id for private content.""" + mock_results = [ + { + "content_type": "STORE_AGENT", + "content_id": "agent-1", + "searchable_text": "My Private Agent", + "metadata": {}, + "updated_at": "2025-01-01T00:00:00Z", + "semantic_score": 0.7, + "lexical_score": 0.8, + "category_score": 0.0, + "recency_score": 0.3, + "combined_score": 0.6, + "total_count": 1, + }, + ] + + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_query.return_value = mock_results + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM + + results, total = await unified_hybrid_search( + query="test", + user_id="user-123", + page=1, + page_size=20, + ) + + # Verify SQL contains user_id filter + call_args = mock_query.call_args + sql_template = call_args[0][0] + params = call_args[0][1:] + + assert 'uce."userId"' in sql_template + assert "user-123" in params + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_unified_hybrid_search_custom_weights(): + """Test unified search with custom weights.""" + custom_weights = UnifiedSearchWeights( + semantic=0.6, + lexical=0.2, + category=0.1, + recency=0.1, + ) + + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_query.return_value = [] + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM + + results, total = await unified_hybrid_search( + query="test", + weights=custom_weights, + page=1, + page_size=20, + ) + + # Verify custom weights are in parameters + call_args = mock_query.call_args + params = call_args[0][1:] + + assert 0.6 in params # semantic weight + assert 0.2 in params # lexical weight + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_unified_hybrid_search_graceful_degradation(): + """Test unified search gracefully degrades when embeddings unavailable.""" + mock_results = [ + { + "content_type": "DOCUMENTATION", + "content_id": "doc-1", + "searchable_text": "API Documentation", + "metadata": {}, + "updated_at": "2025-01-01T00:00:00Z", + "semantic_score": 0.0, # Zero because no embedding + "lexical_score": 0.8, + "category_score": 0.0, + "recency_score": 0.2, + "combined_score": 0.5, + "total_count": 1, + }, + ] + + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_query.return_value = mock_results + mock_embed.return_value = None # Embedding failure + + # Should NOT raise - graceful degradation + results, total = await unified_hybrid_search( + query="test", + page=1, + page_size=20, + ) + + assert len(results) == 1 + assert total == 1 + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_unified_hybrid_search_empty_query(): + """Test unified search with empty query returns empty results.""" + results, total = await unified_hybrid_search( + query="", + page=1, + page_size=20, + ) + + assert results == [] + assert total == 0 + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_unified_hybrid_search_pagination(): + """Test unified search pagination.""" + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_query.return_value = [] + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM + + results, total = await unified_hybrid_search( + query="test", + page=3, + page_size=15, + ) + + # Verify pagination parameters (last two params are LIMIT and OFFSET) + call_args = mock_query.call_args + params = call_args[0] + + limit = params[-2] + offset = params[-1] + + assert limit == 15 # page_size + assert offset == 30 # (page - 1) * page_size = (3 - 1) * 15 + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.integration +async def test_unified_hybrid_search_schema_prefix(): + """Test unified search uses schema_prefix placeholder.""" + with patch( + "backend.api.features.store.hybrid_search.query_raw_with_schema" + ) as mock_query: + with patch( + "backend.api.features.store.hybrid_search.embed_query" + ) as mock_embed: + mock_query.return_value = [] + mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM + + await unified_hybrid_search( + query="test", + page=1, + page_size=20, + ) + + call_args = mock_query.call_args + sql_template = call_args[0][0] + + # Verify schema_prefix placeholder is used for table references + assert "{schema_prefix}" in sql_template + assert '"UnifiedContentEmbedding"' in sql_template + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/autogpt_platform/backend/backend/api/features/store/model.py b/autogpt_platform/backend/backend/api/features/store/model.py index 077135217a..a3310b96fc 100644 --- a/autogpt_platform/backend/backend/api/features/store/model.py +++ b/autogpt_platform/backend/backend/api/features/store/model.py @@ -221,3 +221,23 @@ class ReviewSubmissionRequest(pydantic.BaseModel): is_approved: bool comments: str # External comments visible to creator internal_comments: str | None = None # Private admin notes + + +class UnifiedSearchResult(pydantic.BaseModel): + """A single result from unified hybrid search across all content types.""" + + content_type: str # STORE_AGENT, BLOCK, DOCUMENTATION + content_id: str + searchable_text: str + metadata: dict | None = None + updated_at: datetime.datetime | None = None + combined_score: float | None = None + semantic_score: float | None = None + lexical_score: float | None = None + + +class UnifiedSearchResponse(pydantic.BaseModel): + """Response model for unified search across all content types.""" + + results: list[UnifiedSearchResult] + pagination: Pagination diff --git a/autogpt_platform/backend/backend/api/features/store/routes.py b/autogpt_platform/backend/backend/api/features/store/routes.py index 7816b25d5a..2f3c7bfb04 100644 --- a/autogpt_platform/backend/backend/api/features/store/routes.py +++ b/autogpt_platform/backend/backend/api/features/store/routes.py @@ -7,12 +7,15 @@ from typing import Literal import autogpt_libs.auth import fastapi import fastapi.responses +import prisma.enums import backend.data.graph import backend.util.json +from backend.util.models import Pagination from . import cache as store_cache from . import db as store_db +from . import hybrid_search as store_hybrid_search from . import image_gen as store_image_gen from . import media as store_media from . import model as store_model @@ -146,6 +149,102 @@ async def get_agents( return agents +############################################## +############### Search Endpoints ############# +############################################## + + +@router.get( + "/search", + summary="Unified search across all content types", + tags=["store", "public"], + response_model=store_model.UnifiedSearchResponse, +) +async def unified_search( + query: str, + content_types: list[str] | None = fastapi.Query( + default=None, + description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.", + ), + page: int = 1, + page_size: int = 20, + user_id: str | None = fastapi.Security( + autogpt_libs.auth.get_optional_user_id, use_cache=False + ), +): + """ + Search across all content types (store agents, blocks, documentation) using hybrid search. + + Combines semantic (embedding-based) and lexical (text-based) search for best results. + + Args: + query: The search query string + content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION) + page: Page number for pagination (default 1) + page_size: Number of results per page (default 20) + user_id: Optional authenticated user ID (for user-scoped content in future) + + Returns: + UnifiedSearchResponse: Paginated list of search results with relevance scores + """ + if page < 1: + raise fastapi.HTTPException( + status_code=422, detail="Page must be greater than 0" + ) + + if page_size < 1: + raise fastapi.HTTPException( + status_code=422, detail="Page size must be greater than 0" + ) + + # Convert string content types to enum + content_type_enums: list[prisma.enums.ContentType] | None = None + if content_types: + try: + content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types] + except ValueError as e: + raise fastapi.HTTPException( + status_code=422, + detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}", + ) + + # Perform unified hybrid search + results, total = await store_hybrid_search.unified_hybrid_search( + query=query, + content_types=content_type_enums, + user_id=user_id, + page=page, + page_size=page_size, + ) + + # Convert results to response model + search_results = [ + store_model.UnifiedSearchResult( + content_type=r["content_type"], + content_id=r["content_id"], + searchable_text=r.get("searchable_text", ""), + metadata=r.get("metadata"), + updated_at=r.get("updated_at"), + combined_score=r.get("combined_score"), + semantic_score=r.get("semantic_score"), + lexical_score=r.get("lexical_score"), + ) + for r in results + ] + + total_pages = (total + page_size - 1) // page_size if total > 0 else 0 + + return store_model.UnifiedSearchResponse( + results=search_results, + pagination=Pagination( + total_items=total, + total_pages=total_pages, + current_page=page, + page_size=page_size, + ), + ) + + @router.get( "/agents/{username}/{agent_name}", summary="Get specific agent", diff --git a/autogpt_platform/backend/backend/api/features/store/semantic_search_test.py b/autogpt_platform/backend/backend/api/features/store/semantic_search_test.py new file mode 100644 index 0000000000..b52f924ee8 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/store/semantic_search_test.py @@ -0,0 +1,272 @@ +"""Tests for the semantic_search function.""" + +import pytest +from prisma.enums import ContentType + +from backend.api.features.store.embeddings import EMBEDDING_DIM, semantic_search + + +@pytest.mark.asyncio +async def test_search_blocks_only(mocker): + """Test searching only BLOCK content type.""" + # Mock embed_query to return a test embedding + mock_embedding = [0.1] * EMBEDDING_DIM + mocker.patch( + "backend.api.features.store.embeddings.embed_query", + return_value=mock_embedding, + ) + + # Mock query_raw_with_schema to return test results + mock_results = [ + { + "content_id": "block-123", + "content_type": "BLOCK", + "searchable_text": "Calculator Block - Performs arithmetic operations", + "metadata": {"name": "Calculator", "categories": ["Math"]}, + "similarity": 0.85, + } + ] + mocker.patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + return_value=mock_results, + ) + + results = await semantic_search( + query="calculate numbers", + content_types=[ContentType.BLOCK], + ) + + assert len(results) == 1 + assert results[0]["content_type"] == "BLOCK" + assert results[0]["content_id"] == "block-123" + assert results[0]["similarity"] == 0.85 + + +@pytest.mark.asyncio +async def test_search_multiple_content_types(mocker): + """Test searching multiple content types simultaneously.""" + mock_embedding = [0.1] * EMBEDDING_DIM + mocker.patch( + "backend.api.features.store.embeddings.embed_query", + return_value=mock_embedding, + ) + + mock_results = [ + { + "content_id": "block-123", + "content_type": "BLOCK", + "searchable_text": "Calculator Block", + "metadata": {}, + "similarity": 0.85, + }, + { + "content_id": "doc-456", + "content_type": "DOCUMENTATION", + "searchable_text": "How to use Calculator", + "metadata": {}, + "similarity": 0.75, + }, + ] + mocker.patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + return_value=mock_results, + ) + + results = await semantic_search( + query="calculator", + content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION], + ) + + assert len(results) == 2 + assert results[0]["content_type"] == "BLOCK" + assert results[1]["content_type"] == "DOCUMENTATION" + + +@pytest.mark.asyncio +async def test_search_with_min_similarity_threshold(mocker): + """Test that results below min_similarity are filtered out.""" + mock_embedding = [0.1] * EMBEDDING_DIM + mocker.patch( + "backend.api.features.store.embeddings.embed_query", + return_value=mock_embedding, + ) + + # Only return results above 0.7 similarity + mock_results = [ + { + "content_id": "block-123", + "content_type": "BLOCK", + "searchable_text": "Calculator Block", + "metadata": {}, + "similarity": 0.85, + } + ] + mocker.patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + return_value=mock_results, + ) + + results = await semantic_search( + query="calculate", + content_types=[ContentType.BLOCK], + min_similarity=0.7, + ) + + assert len(results) == 1 + assert results[0]["similarity"] >= 0.7 + + +@pytest.mark.asyncio +async def test_search_fallback_to_lexical(mocker): + """Test fallback to lexical search when embeddings fail.""" + # Mock embed_query to return None (embeddings unavailable) + mocker.patch( + "backend.api.features.store.embeddings.embed_query", + return_value=None, + ) + + mock_lexical_results = [ + { + "content_id": "block-123", + "content_type": "BLOCK", + "searchable_text": "Calculator Block performs calculations", + "metadata": {}, + "similarity": 0.0, + } + ] + mocker.patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + return_value=mock_lexical_results, + ) + + results = await semantic_search( + query="calculator", + content_types=[ContentType.BLOCK], + ) + + assert len(results) == 1 + assert results[0]["similarity"] == 0.0 # Lexical search returns 0 similarity + + +@pytest.mark.asyncio +async def test_search_empty_query(): + """Test that empty query returns no results.""" + results = await semantic_search(query="") + assert results == [] + + results = await semantic_search(query=" ") + assert results == [] + + +@pytest.mark.asyncio +async def test_search_with_user_id_filter(mocker): + """Test searching with user_id filter for private content.""" + mock_embedding = [0.1] * EMBEDDING_DIM + mocker.patch( + "backend.api.features.store.embeddings.embed_query", + return_value=mock_embedding, + ) + + mock_results = [ + { + "content_id": "agent-789", + "content_type": "LIBRARY_AGENT", + "searchable_text": "My Custom Agent", + "metadata": {}, + "similarity": 0.9, + } + ] + mocker.patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + return_value=mock_results, + ) + + results = await semantic_search( + query="custom agent", + content_types=[ContentType.LIBRARY_AGENT], + user_id="user-123", + ) + + assert len(results) == 1 + assert results[0]["content_type"] == "LIBRARY_AGENT" + + +@pytest.mark.asyncio +async def test_search_limit_parameter(mocker): + """Test that limit parameter correctly limits results.""" + mock_embedding = [0.1] * EMBEDDING_DIM + mocker.patch( + "backend.api.features.store.embeddings.embed_query", + return_value=mock_embedding, + ) + + # Return 5 results + mock_results = [ + { + "content_id": f"block-{i}", + "content_type": "BLOCK", + "searchable_text": f"Block {i}", + "metadata": {}, + "similarity": 0.8, + } + for i in range(5) + ] + mocker.patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + return_value=mock_results, + ) + + results = await semantic_search( + query="block", + content_types=[ContentType.BLOCK], + limit=5, + ) + + assert len(results) == 5 + + +@pytest.mark.asyncio +async def test_search_default_content_types(mocker): + """Test that default content_types includes BLOCK, STORE_AGENT, and DOCUMENTATION.""" + mock_embedding = [0.1] * EMBEDDING_DIM + mocker.patch( + "backend.api.features.store.embeddings.embed_query", + return_value=mock_embedding, + ) + + mock_query_raw = mocker.patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + return_value=[], + ) + + await semantic_search(query="test") + + # Check that the SQL query includes all three default content types + call_args = mock_query_raw.call_args + assert "BLOCK" in str(call_args) + assert "STORE_AGENT" in str(call_args) + assert "DOCUMENTATION" in str(call_args) + + +@pytest.mark.asyncio +async def test_search_handles_database_error(mocker): + """Test that database errors are handled gracefully.""" + mock_embedding = [0.1] * EMBEDDING_DIM + mocker.patch( + "backend.api.features.store.embeddings.embed_query", + return_value=mock_embedding, + ) + + # Simulate database error + mocker.patch( + "backend.api.features.store.embeddings.query_raw_with_schema", + side_effect=Exception("Database connection failed"), + ) + + results = await semantic_search( + query="test", + content_types=[ContentType.BLOCK], + ) + + # Should return empty list on error + assert results == [] diff --git a/autogpt_platform/backend/backend/executor/database.py b/autogpt_platform/backend/backend/executor/database.py index f10b285450..ac381bbd67 100644 --- a/autogpt_platform/backend/backend/executor/database.py +++ b/autogpt_platform/backend/backend/executor/database.py @@ -9,6 +9,7 @@ from backend.api.features.library.db import ( from backend.api.features.store.db import get_store_agent_details, get_store_agents from backend.api.features.store.embeddings import ( backfill_missing_embeddings, + cleanup_orphaned_embeddings, get_embedding_stats, ) from backend.data import db @@ -221,6 +222,7 @@ class DatabaseManager(AppService): # Store Embeddings get_embedding_stats = _(get_embedding_stats) backfill_missing_embeddings = _(backfill_missing_embeddings) + cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings) # Summary data - async get_user_execution_summary_data = _(get_user_execution_summary_data) @@ -276,6 +278,7 @@ class DatabaseManagerClient(AppServiceClient): # Store Embeddings get_embedding_stats = _(d.get_embedding_stats) backfill_missing_embeddings = _(d.backfill_missing_embeddings) + cleanup_orphaned_embeddings = _(d.cleanup_orphaned_embeddings) class DatabaseManagerAsyncClient(AppServiceClient): diff --git a/autogpt_platform/backend/backend/executor/scheduler.py b/autogpt_platform/backend/backend/executor/scheduler.py index 3845c04ab6..6c50acfc07 100644 --- a/autogpt_platform/backend/backend/executor/scheduler.py +++ b/autogpt_platform/backend/backend/executor/scheduler.py @@ -28,6 +28,7 @@ from backend.data.auth.oauth import cleanup_expired_oauth_tokens from backend.data.block import BlockInput from backend.data.execution import GraphExecutionWithNodes from backend.data.model import CredentialsMetaInput +from backend.data.onboarding import increment_onboarding_runs from backend.executor import utils as execution_utils from backend.monitoring import ( NotificationJobArgs, @@ -156,6 +157,7 @@ async def _execute_graph(**kwargs): inputs=args.input_data, graph_credentials_inputs=args.input_credentials, ) + await increment_onboarding_runs(args.user_id) elapsed = asyncio.get_event_loop().time() - start_time logger.info( f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} " @@ -255,14 +257,14 @@ def execution_accuracy_alerts(): def ensure_embeddings_coverage(): """ - Ensure approved store agents have embeddings for hybrid search. + Ensure all content types (store agents, blocks, docs) have embeddings for search. - Processes ALL missing embeddings in batches of 10 until 100% coverage. - Missing embeddings = agents invisible in hybrid search. + Processes ALL missing embeddings in batches of 10 per content type until 100% coverage. + Missing embeddings = content invisible in hybrid search. Schedule: Runs every 6 hours (balanced between coverage and API costs). - - Catches agents approved between scheduled runs - - Batch size 10: gradual processing to avoid rate limits + - Catches new content added between scheduled runs + - Batch size 10 per content type: gradual processing to avoid rate limits - Manual trigger available via execute_ensure_embeddings_coverage endpoint """ db_client = get_database_manager_client() @@ -273,51 +275,91 @@ def ensure_embeddings_coverage(): logger.error( f"Failed to get embedding stats: {stats['error']} - skipping backfill" ) - return {"processed": 0, "success": 0, "failed": 0, "error": stats["error"]} + return { + "backfill": {"processed": 0, "success": 0, "failed": 0}, + "cleanup": {"deleted": 0}, + "error": stats["error"], + } - if stats["without_embeddings"] == 0: - logger.info("All approved agents have embeddings, skipping backfill") - return {"processed": 0, "success": 0, "failed": 0} - - logger.info( - f"Found {stats['without_embeddings']} agents without embeddings " - f"({stats['coverage_percent']}% coverage) - processing all" - ) + # Extract totals from new stats structure + totals = stats.get("totals", {}) + without_embeddings = totals.get("without_embeddings", 0) + coverage_percent = totals.get("coverage_percent", 0) total_processed = 0 total_success = 0 total_failed = 0 - # Process in batches until no more missing embeddings - while True: - result = db_client.backfill_missing_embeddings(batch_size=10) + if without_embeddings == 0: + logger.info("All content has embeddings, skipping backfill") + else: + # Log per-content-type stats for visibility + by_type = stats.get("by_type", {}) + for content_type, type_stats in by_type.items(): + if type_stats.get("without_embeddings", 0) > 0: + logger.info( + f"{content_type}: {type_stats['without_embeddings']} items without embeddings " + f"({type_stats['coverage_percent']}% coverage)" + ) - total_processed += result["processed"] - total_success += result["success"] - total_failed += result["failed"] + logger.info( + f"Total: {without_embeddings} items without embeddings " + f"({coverage_percent}% coverage) - processing all" + ) - if result["processed"] == 0: - # No more missing embeddings - break + # Process in batches until no more missing embeddings + while True: + result = db_client.backfill_missing_embeddings(batch_size=10) - if result["success"] == 0 and result["processed"] > 0: - # All attempts in this batch failed - stop to avoid infinite loop - logger.error( - f"All {result['processed']} embedding attempts failed - stopping backfill" - ) - break + total_processed += result["processed"] + total_success += result["success"] + total_failed += result["failed"] - # Small delay between batches to avoid rate limits - time.sleep(1) + if result["processed"] == 0: + # No more missing embeddings + break + + if result["success"] == 0 and result["processed"] > 0: + # All attempts in this batch failed - stop to avoid infinite loop + logger.error( + f"All {result['processed']} embedding attempts failed - stopping backfill" + ) + break + + # Small delay between batches to avoid rate limits + time.sleep(1) + + logger.info( + f"Embedding backfill completed: {total_success}/{total_processed} succeeded, " + f"{total_failed} failed" + ) + + # Clean up orphaned embeddings for blocks and docs + logger.info("Running cleanup for orphaned embeddings (blocks/docs)...") + cleanup_result = db_client.cleanup_orphaned_embeddings() + cleanup_totals = cleanup_result.get("totals", {}) + cleanup_deleted = cleanup_totals.get("deleted", 0) + + if cleanup_deleted > 0: + logger.info(f"Cleanup completed: deleted {cleanup_deleted} orphaned embeddings") + by_type = cleanup_result.get("by_type", {}) + for content_type, type_result in by_type.items(): + if type_result.get("deleted", 0) > 0: + logger.info( + f"{content_type}: deleted {type_result['deleted']} orphaned embeddings" + ) + else: + logger.info("Cleanup completed: no orphaned embeddings found") - logger.info( - f"Embedding backfill completed: {total_success}/{total_processed} succeeded, " - f"{total_failed} failed" - ) return { - "processed": total_processed, - "success": total_success, - "failed": total_failed, + "backfill": { + "processed": total_processed, + "success": total_success, + "failed": total_failed, + }, + "cleanup": { + "deleted": cleanup_deleted, + }, } diff --git a/autogpt_platform/backend/migrations/20260109181714_add_docs_embedding/migration.sql b/autogpt_platform/backend/migrations/20260109181714_add_docs_embedding/migration.sql index 9c4bcff5e1..0897e1865a 100644 --- a/autogpt_platform/backend/migrations/20260109181714_add_docs_embedding/migration.sql +++ b/autogpt_platform/backend/migrations/20260109181714_add_docs_embedding/migration.sql @@ -43,4 +43,6 @@ CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" O -- CreateIndex -- HNSW index for fast vector similarity search on embeddings -- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py +-- Note: Drop first in case Prisma created a btree index (Prisma doesn't support HNSW) +DROP INDEX IF EXISTS "UnifiedContentEmbedding_embedding_idx"; CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops); diff --git a/autogpt_platform/backend/migrations/20260115200000_add_unified_search_tsvector/migration.sql b/autogpt_platform/backend/migrations/20260115200000_add_unified_search_tsvector/migration.sql new file mode 100644 index 0000000000..be3062cbfb --- /dev/null +++ b/autogpt_platform/backend/migrations/20260115200000_add_unified_search_tsvector/migration.sql @@ -0,0 +1,35 @@ +-- Add tsvector search column to UnifiedContentEmbedding for unified full-text search +-- This enables hybrid search (semantic + lexical) across all content types + +-- Add search column (IF NOT EXISTS for idempotency) +ALTER TABLE "UnifiedContentEmbedding" ADD COLUMN IF NOT EXISTS "search" tsvector DEFAULT ''::tsvector; + +-- Create GIN index for fast full-text search +-- No @@index in schema.prisma - Prisma may generate DROP INDEX on migrate dev +-- If that happens, just let it drop and this migration will recreate it, or manually re-run: +-- CREATE INDEX IF NOT EXISTS "UnifiedContentEmbedding_search_idx" ON "UnifiedContentEmbedding" USING GIN ("search"); +DROP INDEX IF EXISTS "UnifiedContentEmbedding_search_idx"; +CREATE INDEX "UnifiedContentEmbedding_search_idx" ON "UnifiedContentEmbedding" USING GIN ("search"); + +-- Drop existing trigger/function if exists +DROP TRIGGER IF EXISTS "update_unified_tsvector" ON "UnifiedContentEmbedding"; +DROP FUNCTION IF EXISTS update_unified_tsvector_column(); + +-- Create function to auto-update tsvector from searchableText +CREATE OR REPLACE FUNCTION update_unified_tsvector_column() RETURNS TRIGGER AS $$ +BEGIN + NEW.search := to_tsvector('english', COALESCE(NEW."searchableText", '')); + RETURN NEW; +END; +$$ LANGUAGE plpgsql SECURITY DEFINER SET search_path = platform, pg_temp; + +-- Create trigger to auto-update search column on insert/update +CREATE TRIGGER "update_unified_tsvector" +BEFORE INSERT OR UPDATE ON "UnifiedContentEmbedding" +FOR EACH ROW +EXECUTE FUNCTION update_unified_tsvector_column(); + +-- Backfill existing rows +UPDATE "UnifiedContentEmbedding" +SET search = to_tsvector('english', COALESCE("searchableText", '')) +WHERE search IS NULL OR search = ''::tsvector; diff --git a/autogpt_platform/backend/migrations/20260115210000_remove_storelistingversion_search/migration.sql b/autogpt_platform/backend/migrations/20260115210000_remove_storelistingversion_search/migration.sql new file mode 100644 index 0000000000..4550340330 --- /dev/null +++ b/autogpt_platform/backend/migrations/20260115210000_remove_storelistingversion_search/migration.sql @@ -0,0 +1,90 @@ +-- Remove the old search column from StoreListingVersion +-- This column has been replaced by UnifiedContentEmbedding.search +-- which provides unified hybrid search across all content types + +-- First drop the dependent view +DROP VIEW IF EXISTS "StoreAgent"; + +-- Drop the trigger and function for old search column +-- The original trigger was created in 20251016093049_add_full_text_search +DROP TRIGGER IF EXISTS "update_tsvector" ON "StoreListingVersion"; +DROP FUNCTION IF EXISTS update_tsvector_column(); + +-- Drop the index +DROP INDEX IF EXISTS "StoreListingVersion_search_idx"; + +-- NOTE: Keeping search column for now to allow easy revert if needed +-- Uncomment to fully remove once migration is verified in production: +-- ALTER TABLE "StoreListingVersion" DROP COLUMN IF EXISTS "search"; + +-- Recreate the StoreAgent view WITHOUT the search column +-- (Search now handled by UnifiedContentEmbedding) +CREATE OR REPLACE VIEW "StoreAgent" AS +WITH latest_versions AS ( + SELECT + "storeListingId", + MAX(version) AS max_version + FROM "StoreListingVersion" + WHERE "submissionStatus" = 'APPROVED' + GROUP BY "storeListingId" +), +agent_versions AS ( + SELECT + "storeListingId", + array_agg(DISTINCT version::text ORDER BY version::text) AS versions + FROM "StoreListingVersion" + WHERE "submissionStatus" = 'APPROVED' + GROUP BY "storeListingId" +), +agent_graph_versions AS ( + SELECT + "storeListingId", + array_agg(DISTINCT "agentGraphVersion"::text ORDER BY "agentGraphVersion"::text) AS graph_versions + FROM "StoreListingVersion" + WHERE "submissionStatus" = 'APPROVED' + GROUP BY "storeListingId" +) +SELECT + sl.id AS listing_id, + slv.id AS "storeListingVersionId", + slv."createdAt" AS updated_at, + sl.slug, + COALESCE(slv.name, '') AS agent_name, + slv."videoUrl" AS agent_video, + slv."agentOutputDemoUrl" AS agent_output_demo, + COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image, + slv."isFeatured" AS featured, + p.username AS creator_username, + p."avatarUrl" AS creator_avatar, + slv."subHeading" AS sub_heading, + slv.description, + slv.categories, + COALESCE(ar.run_count, 0::bigint) AS runs, + COALESCE(rs.avg_rating, 0.0)::double precision AS rating, + COALESCE(av.versions, ARRAY[slv.version::text]) AS versions, + COALESCE(agv.graph_versions, ARRAY[slv."agentGraphVersion"::text]) AS "agentGraphVersions", + slv."agentGraphId", + slv."isAvailable" AS is_available, + COALESCE(sl."useForOnboarding", false) AS "useForOnboarding" +FROM "StoreListing" sl +JOIN latest_versions lv + ON sl.id = lv."storeListingId" +JOIN "StoreListingVersion" slv + ON slv."storeListingId" = lv."storeListingId" + AND slv.version = lv.max_version + AND slv."submissionStatus" = 'APPROVED' +JOIN "AgentGraph" a + ON slv."agentGraphId" = a.id + AND slv."agentGraphVersion" = a.version +LEFT JOIN "Profile" p + ON sl."owningUserId" = p."userId" +LEFT JOIN "mv_review_stats" rs + ON sl.id = rs."storeListingId" +LEFT JOIN "mv_agent_run_counts" ar + ON a.id = ar."agentGraphId" +LEFT JOIN agent_versions av + ON sl.id = av."storeListingId" +LEFT JOIN agent_graph_versions agv + ON sl.id = agv."storeListingId" +WHERE sl."isDeleted" = false + AND sl."hasApprovedVersion" = true; diff --git a/autogpt_platform/backend/schema.prisma b/autogpt_platform/backend/schema.prisma index 6e1a02f367..b7dc98524a 100644 --- a/autogpt_platform/backend/schema.prisma +++ b/autogpt_platform/backend/schema.prisma @@ -937,7 +937,7 @@ model StoreListingVersion { // Old versions can be made unavailable by the author if desired isAvailable Boolean @default(true) - search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) + // Note: search column removed - now using UnifiedContentEmbedding.search // Version workflow state submissionStatus SubmissionStatus @default(DRAFT) @@ -1002,6 +1002,7 @@ model UnifiedContentEmbedding { // Search data embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema) searchableText String // Combined text for search and fallback + search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) // Full-text search (auto-populated by trigger) metadata Json @default("{}") // Content-specific metadata @@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key") @@ -1009,6 +1010,8 @@ model UnifiedContentEmbedding { @@index([userId]) @@index([contentType, userId]) @@index([embedding], map: "UnifiedContentEmbedding_embedding_idx") + // NO @@index for search - GIN index "UnifiedContentEmbedding_search_idx" created via SQL migration + // Prisma may generate DROP INDEX on migrate dev - that's okay, migration recreates it } model StoreListingReview { diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 6f9a87216b..776ba2321a 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -5621,6 +5621,69 @@ "security": [{ "HTTPBearerJWT": [] }] } }, + "/api/store/search": { + "get": { + "tags": ["v2", "store", "public"], + "summary": "Unified search across all content types", + "description": "Search across all content types (store agents, blocks, documentation) using hybrid search.\n\nCombines semantic (embedding-based) and lexical (text-based) search for best results.\n\nArgs:\n query: The search query string\n content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)\n page: Page number for pagination (default 1)\n page_size: Number of results per page (default 20)\n user_id: Optional authenticated user ID (for user-scoped content in future)\n\nReturns:\n UnifiedSearchResponse: Paginated list of search results with relevance scores", + "operationId": "getV2Unified search across all content types", + "security": [{ "HTTPBearer": [] }], + "parameters": [ + { + "name": "query", + "in": "query", + "required": true, + "schema": { "type": "string", "title": "Query" } + }, + { + "name": "content_types", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { "type": "array", "items": { "type": "string" } }, + { "type": "null" } + ], + "description": "Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.", + "title": "Content Types" + }, + "description": "Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all." + }, + { + "name": "page", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 1, "title": "Page" } + }, + { + "name": "page_size", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 20, "title": "Page Size" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnifiedSearchResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, "/api/store/submissions": { "get": { "tags": ["v2", "store", "private"], @@ -10899,6 +10962,57 @@ "required": ["name", "graph_id", "graph_version", "trigger_config"], "title": "TriggeredPresetSetupRequest" }, + "UnifiedSearchResponse": { + "properties": { + "results": { + "items": { "$ref": "#/components/schemas/UnifiedSearchResult" }, + "type": "array", + "title": "Results" + }, + "pagination": { "$ref": "#/components/schemas/Pagination" } + }, + "type": "object", + "required": ["results", "pagination"], + "title": "UnifiedSearchResponse", + "description": "Response model for unified search across all content types." + }, + "UnifiedSearchResult": { + "properties": { + "content_type": { "type": "string", "title": "Content Type" }, + "content_id": { "type": "string", "title": "Content Id" }, + "searchable_text": { "type": "string", "title": "Searchable Text" }, + "metadata": { + "anyOf": [ + { "additionalProperties": true, "type": "object" }, + { "type": "null" } + ], + "title": "Metadata" + }, + "updated_at": { + "anyOf": [ + { "type": "string", "format": "date-time" }, + { "type": "null" } + ], + "title": "Updated At" + }, + "combined_score": { + "anyOf": [{ "type": "number" }, { "type": "null" }], + "title": "Combined Score" + }, + "semantic_score": { + "anyOf": [{ "type": "number" }, { "type": "null" }], + "title": "Semantic Score" + }, + "lexical_score": { + "anyOf": [{ "type": "number" }, { "type": "null" }], + "title": "Lexical Score" + } + }, + "type": "object", + "required": ["content_type", "content_id", "searchable_text"], + "title": "UnifiedSearchResult", + "description": "A single result from unified hybrid search across all content types." + }, "UpdateAppLogoRequest": { "properties": { "logo_url": { @@ -11863,6 +11977,7 @@ "in": "header", "name": "X-Postmark-Webhook-Token" }, + "HTTPBearer": { "type": "http", "scheme": "bearer" }, "HTTPBearerJWT": { "type": "http", "scheme": "bearer", From aa5a039c5e5f4b2e61658b1aee86a0e73b8f0ba1 Mon Sep 17 00:00:00 2001 From: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com> Date: Fri, 16 Jan 2026 16:40:21 +0530 Subject: [PATCH 03/32] feat(frontend): add special rendering for NOTE UI type in FieldTemplate (#11771) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Changes 🏗️ Added support for Note blocks in the FieldTemplate component by: - Importing the BlockUIType enum from the build components types - Extracting the uiType from the registry.formContext - Adding a conditional rendering check that returns children directly when the uiType is BlockUIType.NOTE This change allows Note blocks to render without the standard field template wrapper, providing a cleaner display for note-type content. ![Screenshot 2026-01-15 at 1.01.03 PM.png](https://app.graphite.com/user-attachments/assets/7d654eed-abbe-4ec3-9c80-24a77a8373e3.png) ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Created a Note block and verified it renders correctly without field template wrapper - [x] Confirmed other block types still render with proper field template - [x] Verified that Note blocks maintain proper functionality in the node graph --- .../build/components/FlowEditor/nodes/OutputHandler.tsx | 2 -- .../InputRenderer/base/standard/FieldTemplate.tsx | 7 ++++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/OutputHandler.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/OutputHandler.tsx index 417fb9d9c1..665a5bb7be 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/OutputHandler.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/OutputHandler.tsx @@ -31,8 +31,6 @@ export const OutputHandler = ({ const [isOutputVisible, setIsOutputVisible] = useState(true); const brokenOutputs = useBrokenOutputs(nodeId); - console.log("brokenOutputs", brokenOutputs); - const showHandles = uiType !== BlockUIType.OUTPUT; const renderOutputHandles = ( diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/standard/FieldTemplate.tsx b/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/standard/FieldTemplate.tsx index 56dfcefa71..705afbcaaa 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/standard/FieldTemplate.tsx +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/standard/FieldTemplate.tsx @@ -17,6 +17,7 @@ import { import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore"; import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore"; import { FieldError } from "./FieldError"; +import { BlockUIType } from "@/app/(platform)/build/components/types"; export default function FieldTemplate(props: FieldTemplateProps) { const { @@ -39,7 +40,7 @@ export default function FieldTemplate(props: FieldTemplateProps) { onRemoveProperty, readonly, } = props; - const { nodeId } = registry.formContext; + const { nodeId, uiType } = registry.formContext; const { isInputConnected } = useEdgeStore(); const showAdvanced = useNodeStore( @@ -50,6 +51,10 @@ export default function FieldTemplate(props: FieldTemplateProps) { return
{children}
; } + if (uiType === BlockUIType.NOTE) { + return children; + } + const uiOptions = getUiOptions(uiSchema); const TitleFieldTemplate = getTemplate( "TitleFieldTemplate", From 8b1720e61d62013cf3c451f644a235ae8eaf6860 Mon Sep 17 00:00:00 2001 From: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com> Date: Fri, 16 Jan 2026 16:44:00 +0530 Subject: [PATCH 04/32] feat(frontend): improve graph validation error handling and node navigation (#11779) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Changes 🏗️ - Enhanced error handling for graph validation failures with detailed user feedback - Added automatic viewport navigation to the first node with errors when validation fails - Improved node title display to prioritize agent_name from hardcoded values - Removed console.log debugging statement from OutputHandler - Added ApiError import and improved error type handling - Reorganized imports for better code organization ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Create a graph with intentional validation errors and verify error messages display correctly - [x] Verify the viewport automatically navigates to the first node with errors - [x] Check that node titles correctly display customized names or agent names - [x] Test error recovery by fixing validation errors and successfully running the graph --- .../RunInputDialog/useRunInputDialog.ts | 69 +++++++++++++++++-- .../CustomNode/components/NodeHeader.tsx | 10 +-- 2 files changed, 68 insertions(+), 11 deletions(-) diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts index ddd77bae48..f19cb96205 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts @@ -1,7 +1,8 @@ import { useGraphStore } from "@/app/(platform)/build/stores/graphStore"; import { usePostV1ExecuteGraphAgent } from "@/app/api/__generated__/endpoints/graphs/graphs"; -import { useToast } from "@/components/molecules/Toast/use-toast"; + import { + ApiError, CredentialsMetaInput, GraphExecutionMeta, } from "@/lib/autogpt-server-api"; @@ -9,6 +10,9 @@ import { parseAsInteger, parseAsString, useQueryStates } from "nuqs"; import { useMemo, useState } from "react"; import { uiSchema } from "../../../FlowEditor/nodes/uiSchema"; import { isCredentialFieldSchema } from "@/components/renderers/InputRenderer/custom/CredentialField/helpers"; +import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore"; +import { useToast } from "@/components/molecules/Toast/use-toast"; +import { useReactFlow } from "@xyflow/react"; export const useRunInputDialog = ({ setIsOpen, @@ -31,6 +35,7 @@ export const useRunInputDialog = ({ flowVersion: parseAsInteger, }); const { toast } = useToast(); + const { setViewport } = useReactFlow(); const { mutateAsync: executeGraph, isPending: isExecutingGraph } = usePostV1ExecuteGraphAgent({ @@ -42,13 +47,63 @@ export const useRunInputDialog = ({ }); }, onError: (error) => { - // Reset running state on error + if (error instanceof ApiError && error.isGraphValidationError()) { + const errorData = error.response?.detail; + Object.entries(errorData.node_errors).forEach( + ([nodeId, nodeErrors]) => { + useNodeStore + .getState() + .updateNodeErrors( + nodeId, + nodeErrors as { [key: string]: string }, + ); + }, + ); + toast({ + title: errorData?.message || "Graph validation failed", + description: + "Please fix the validation errors on the highlighted nodes and try again.", + variant: "destructive", + }); + setIsOpen(false); + + const firstBackendId = Object.keys(errorData.node_errors)[0]; + + if (firstBackendId) { + const firstErrorNode = useNodeStore + .getState() + .nodes.find( + (n) => + n.data.metadata?.backend_id === firstBackendId || + n.id === firstBackendId, + ); + + if (firstErrorNode) { + setTimeout(() => { + setViewport( + { + x: + -firstErrorNode.position.x * 0.8 + + window.innerWidth / 2 - + 150, + y: -firstErrorNode.position.y * 0.8 + 50, + zoom: 0.8, + }, + { duration: 500 }, + ); + }, 50); + } + } + } else { + toast({ + title: "Error running graph", + description: + (error as Error).message || "An unexpected error occurred.", + variant: "destructive", + }); + setIsOpen(false); + } setIsGraphRunning(false); - toast({ - title: (error.detail as string) ?? "An unexpected error occurred.", - description: "An unexpected error occurred.", - variant: "destructive", - }); }, }, }); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx index e13aa37a31..d9f3d108f4 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx @@ -20,11 +20,13 @@ type Props = { export const NodeHeader = ({ data, nodeId }: Props) => { const updateNodeData = useNodeStore((state) => state.updateNodeData); - const title = (data.metadata?.customized_name as string) || data.title; + const title = + (data.metadata?.customized_name as string) || + data.hardcodedValues.agent_name || + data.title; + const [isEditingTitle, setIsEditingTitle] = useState(false); - const [editedTitle, setEditedTitle] = useState( - beautifyString(title).replace("Block", "").trim(), - ); + const [editedTitle, setEditedTitle] = useState(title); const handleTitleEdit = () => { updateNodeData(nodeId, { From b08851f5d77ab63cc2d55b11968c9ed2389ae09a Mon Sep 17 00:00:00 2001 From: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com> Date: Fri, 16 Jan 2026 18:32:36 +0530 Subject: [PATCH 05/32] feat(frontend): improve GoogleDrivePickerField with input mode support and array field spacing (#11780) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Changes 🏗️ - Added a placeholder UI for Google Drive Picker in INPUT block type - Improved detection of Google Drive file objects in schema validation - Extracted `isGoogleDrivePickerSchema` function for better code organization - Added spacing between array field elements with a gap-2 class - Added debug logging for preprocessed schema in FormRenderer ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Verified Google Drive Picker shows placeholder in INPUT blocks - [x] Confirmed array field elements have proper spacing - [x] Tested that Google Drive file objects are properly detected --- .../renderers/InputRenderer/FormRenderer.tsx | 2 ++ .../base/array/ArrayFieldTemplate.tsx | 2 +- .../GoogleDrivePickerField.tsx | 13 ++++++- .../InputRenderer/custom/custom-registry.ts | 12 +++---- .../InputRenderer/utils/schema-utils.ts | 35 +++++++++++++++++++ 5 files changed, 55 insertions(+), 9 deletions(-) diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/FormRenderer.tsx b/autogpt_platform/frontend/src/components/renderers/InputRenderer/FormRenderer.tsx index c3a20d8cd2..fc388cc343 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/FormRenderer.tsx +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/FormRenderer.tsx @@ -30,6 +30,8 @@ export const FormRenderer = ({ return generateUiSchemaForCustomFields(preprocessedSchema, uiSchema); }, [preprocessedSchema, uiSchema]); + console.log("preprocessedSchema", preprocessedSchema); + return (
{!fromAnyOf && ( -
+
{ - const { schema, uiSchema, onChange, fieldPathId, formData } = props; + const { schema, uiSchema, onChange, fieldPathId, formData, registry } = props; const uiOptions = getUiOptions(uiSchema); const config: GoogleDrivePickerConfig = schema.google_drive_picker_config; + const uiType = registry.formContext?.uiType; + + if (uiType === BlockUIType.INPUT) { + return ( +
+ Select files when you run the graph +
+ ); + } + return (
{ - return ( - "google_drive_picker_config" in schema || - ("format" in schema && schema.format === "google-drive-picker") - ); - }, + matcher: isGoogleDrivePickerSchema, component: GoogleDrivePickerField, }, { diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/utils/schema-utils.ts b/autogpt_platform/frontend/src/components/renderers/InputRenderer/utils/schema-utils.ts index fecf2d77d1..c7dd9680d7 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/utils/schema-utils.ts +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/utils/schema-utils.ts @@ -55,3 +55,38 @@ export function isMultiSelectSchema(schema: RJSFSchema | undefined): boolean { ) ); } + +const isGoogleDriveFileObject = (obj: RJSFSchema): boolean => { + if (obj.type !== "object" || !obj.properties) { + return false; + } + const props = obj.properties; + const hasId = "id" in props; + const hasMimeType = "mimeType" in props || "mime_type" in props; + const hasIconUrl = "iconUrl" in props || "icon_url" in props; + const hasIsFolder = "isFolder" in props || "is_folder" in props; + return hasId && hasMimeType && (hasIconUrl || hasIsFolder); +}; + +export const isGoogleDrivePickerSchema = ( + schema: RJSFSchema | undefined, +): boolean => { + if (!schema) { + return false; + } + + // highest priority + if ( + "google_drive_picker_config" in schema || + ("format" in schema && schema.format === "google-drive-picker") + ) { + return true; + } + + // In the Input type block, we do not add the format for the GoogleFile field, so we need to include this extra check. + if (isGoogleDriveFileObject(schema)) { + return true; + } + + return false; +}; From ec03a13e263c05e4426f63d353613460060f2b56 Mon Sep 17 00:00:00 2001 From: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com> Date: Fri, 16 Jan 2026 19:04:57 +0530 Subject: [PATCH 06/32] fix(frontend): improve history tracking, error handling (#11786) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Changes 🏗️ - **Improved Error Handling**: Enhanced error handling in `useRunInputDialog.ts` to properly handle cases where node errors are empty or undefined - **Fixed Node Collision Resolution**: Updated `Flow.tsx` to use the current state from the store instead of stale props - **Enhanced History Management**: - Added proper state tracking for edge removal operations - Improved undo/redo functionality to prevent duplicate states - Fixed edge case where history wasn't properly tracked during node dragging - **UI Improvements**: - Fixed potential null reference in NodeHeader when accessing agent_name - Added placeholder for GoogleDrivePicker in INPUT mode - Fixed spacing in ArrayFieldTemplate - **Bug Fixes**: - Added proper state tracking before modifying nodes/edges - Fixed history tracking to avoid redundant states - Improved collision detection and resolution ### Checklist ��� #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Test undo/redo functionality after adding, removing, and moving nodes - [x] Test edge creation and deletion with history tracking - [x] Verify error handling when graph validation fails - [x] Test Google Drive picker in different UI modes - [x] Verify node collision resolution works correctly --- .../RunInputDialog/useRunInputDialog.ts | 36 ++++++--- .../build/components/FlowEditor/Flow/Flow.tsx | 6 +- .../FlowEditor/edges/useCustomEdge.ts | 14 ++++ .../CustomNode/components/NodeHeader.tsx | 2 +- .../app/(platform)/build/stores/edgeStore.ts | 39 ++++++---- .../(platform)/build/stores/historyStore.ts | 48 ++++++++++-- .../app/(platform)/build/stores/nodeStore.ts | 77 ++++++++++++------- 7 files changed, 161 insertions(+), 61 deletions(-) diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts index f19cb96205..358fd3ae7e 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts @@ -48,17 +48,29 @@ export const useRunInputDialog = ({ }, onError: (error) => { if (error instanceof ApiError && error.isGraphValidationError()) { - const errorData = error.response?.detail; - Object.entries(errorData.node_errors).forEach( - ([nodeId, nodeErrors]) => { - useNodeStore - .getState() - .updateNodeErrors( - nodeId, - nodeErrors as { [key: string]: string }, - ); - }, - ); + const errorData = error.response?.detail || { + node_errors: {}, + message: undefined, + }; + const nodeErrors = errorData.node_errors || {}; + + if (Object.keys(nodeErrors).length > 0) { + Object.entries(nodeErrors).forEach( + ([nodeId, nodeErrorsForNode]) => { + useNodeStore + .getState() + .updateNodeErrors( + nodeId, + nodeErrorsForNode as { [key: string]: string }, + ); + }, + ); + } else { + useNodeStore.getState().nodes.forEach((node) => { + useNodeStore.getState().updateNodeErrors(node.id, {}); + }); + } + toast({ title: errorData?.message || "Graph validation failed", description: @@ -67,7 +79,7 @@ export const useRunInputDialog = ({ }); setIsOpen(false); - const firstBackendId = Object.keys(errorData.node_errors)[0]; + const firstBackendId = Object.keys(nodeErrors)[0]; if (firstBackendId) { const firstErrorNode = useNodeStore diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx index 29fd984b1d..87ae4300b8 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx @@ -55,14 +55,16 @@ export const Flow = () => { const edgeTypes = useMemo(() => ({ custom: CustomEdge }), []); const onNodeDragStop = useCallback(() => { + const currentNodes = useNodeStore.getState().nodes; setNodes( - resolveCollisions(nodes, { + resolveCollisions(currentNodes, { maxIterations: Infinity, overlapThreshold: 0.5, margin: 15, }), ); - }, [setNodes, nodes]); + }, [setNodes]); + const { edges, onConnect, onEdgesChange } = useCustomEdge(); // for loading purpose diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/edges/useCustomEdge.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/edges/useCustomEdge.ts index bf4ba3a418..d8571749d3 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/edges/useCustomEdge.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/edges/useCustomEdge.ts @@ -6,6 +6,7 @@ import { import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore"; import { useCallback } from "react"; import { useNodeStore } from "../../../stores/nodeStore"; +import { useHistoryStore } from "../../../stores/historyStore"; import { CustomEdge } from "./CustomEdge"; export const useCustomEdge = () => { @@ -51,7 +52,20 @@ export const useCustomEdge = () => { const onEdgesChange = useCallback( (changes: EdgeChange[]) => { + const hasRemoval = changes.some((change) => change.type === "remove"); + + const prevState = hasRemoval + ? { + nodes: useNodeStore.getState().nodes, + edges: edges, + } + : null; + setEdges(applyEdgeChanges(changes, edges)); + + if (prevState) { + useHistoryStore.getState().pushState(prevState); + } }, [edges, setEdges], ); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx index d9f3d108f4..c4659b8dcf 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx @@ -22,7 +22,7 @@ export const NodeHeader = ({ data, nodeId }: Props) => { const updateNodeData = useNodeStore((state) => state.updateNodeData); const title = (data.metadata?.customized_name as string) || - data.hardcodedValues.agent_name || + data.hardcodedValues?.agent_name || data.title; const [isEditingTitle, setIsEditingTitle] = useState(false); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/edgeStore.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/edgeStore.ts index 7b17eecfb3..6a45b9e1e2 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/stores/edgeStore.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/edgeStore.ts @@ -5,6 +5,8 @@ import { customEdgeToLink, linkToCustomEdge } from "../components/helper"; import { MarkerType } from "@xyflow/react"; import { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult"; import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers"; +import { useHistoryStore } from "./historyStore"; +import { useNodeStore } from "./nodeStore"; type EdgeStore = { edges: CustomEdge[]; @@ -53,25 +55,36 @@ export const useEdgeStore = create((set, get) => ({ id, }; - set((state) => { - const exists = state.edges.some( - (e) => - e.source === newEdge.source && - e.target === newEdge.target && - e.sourceHandle === newEdge.sourceHandle && - e.targetHandle === newEdge.targetHandle, - ); - if (exists) return state; - return { edges: [...state.edges, newEdge] }; - }); + const exists = get().edges.some( + (e) => + e.source === newEdge.source && + e.target === newEdge.target && + e.sourceHandle === newEdge.sourceHandle && + e.targetHandle === newEdge.targetHandle, + ); + if (exists) return newEdge; + const prevState = { + nodes: useNodeStore.getState().nodes, + edges: get().edges, + }; + + set((state) => ({ edges: [...state.edges, newEdge] })); + useHistoryStore.getState().pushState(prevState); return newEdge; }, - removeEdge: (edgeId) => + removeEdge: (edgeId) => { + const prevState = { + nodes: useNodeStore.getState().nodes, + edges: get().edges, + }; + set((state) => ({ edges: state.edges.filter((e) => e.id !== edgeId), - })), + })); + useHistoryStore.getState().pushState(prevState); + }, upsertMany: (edges) => set((state) => { diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/historyStore.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/historyStore.ts index 4eea5741a4..3a67bb8dcd 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/stores/historyStore.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/historyStore.ts @@ -37,6 +37,15 @@ export const useHistoryStore = create((set, get) => ({ return; } + const actualCurrentState = { + nodes: useNodeStore.getState().nodes, + edges: useEdgeStore.getState().edges, + }; + + if (isEqual(state, actualCurrentState)) { + return; + } + set((prev) => ({ past: [...prev.past.slice(-MAX_HISTORY + 1), state], future: [], @@ -55,18 +64,25 @@ export const useHistoryStore = create((set, get) => ({ undo: () => { const { past, future } = get(); - if (past.length <= 1) return; + if (past.length === 0) return; - const currentState = past[past.length - 1]; + const actualCurrentState = { + nodes: useNodeStore.getState().nodes, + edges: useEdgeStore.getState().edges, + }; - const previousState = past[past.length - 2]; + const previousState = past[past.length - 1]; + + if (isEqual(actualCurrentState, previousState)) { + return; + } useNodeStore.getState().setNodes(previousState.nodes); useEdgeStore.getState().setEdges(previousState.edges); set({ - past: past.slice(0, -1), - future: [currentState, ...future], + past: past.length > 1 ? past.slice(0, -1) : past, + future: [actualCurrentState, ...future], }); }, @@ -74,18 +90,36 @@ export const useHistoryStore = create((set, get) => ({ const { past, future } = get(); if (future.length === 0) return; + const actualCurrentState = { + nodes: useNodeStore.getState().nodes, + edges: useEdgeStore.getState().edges, + }; + const nextState = future[0]; useNodeStore.getState().setNodes(nextState.nodes); useEdgeStore.getState().setEdges(nextState.edges); + const lastPast = past[past.length - 1]; + const shouldPushToPast = + !lastPast || !isEqual(actualCurrentState, lastPast); + set({ - past: [...past, nextState], + past: shouldPushToPast ? [...past, actualCurrentState] : past, future: future.slice(1), }); }, - canUndo: () => get().past.length > 1, + canUndo: () => { + const { past } = get(); + if (past.length === 0) return false; + + const actualCurrentState = { + nodes: useNodeStore.getState().nodes, + edges: useEdgeStore.getState().edges, + }; + return !isEqual(actualCurrentState, past[past.length - 1]); + }, canRedo: () => get().future.length > 0, clear: () => set({ past: [{ nodes: [], edges: [] }], future: [] }), diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts index cb41da9463..5502a8780d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts @@ -1,6 +1,7 @@ import { create } from "zustand"; import { NodeChange, XYPosition, applyNodeChanges } from "@xyflow/react"; import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode"; +import { CustomEdge } from "../components/FlowEditor/edges/CustomEdge"; import { BlockInfo } from "@/app/api/__generated__/models/blockInfo"; import { convertBlockInfoIntoCustomNodeData, @@ -44,6 +45,8 @@ const MINIMUM_MOVE_BEFORE_LOG = 50; // Track initial positions when drag starts (outside store to avoid re-renders) const dragStartPositions: Record = {}; +let dragStartState: { nodes: CustomNode[]; edges: CustomEdge[] } | null = null; + type NodeStore = { nodes: CustomNode[]; nodeCounter: number; @@ -124,14 +127,20 @@ export const useNodeStore = create((set, get) => ({ nodeCounter: state.nodeCounter + 1, })), onNodesChange: (changes) => { - const prevState = { - nodes: get().nodes, - edges: useEdgeStore.getState().edges, - }; - - // Track initial positions when drag starts changes.forEach((change) => { if (change.type === "position" && change.dragging === true) { + if (!dragStartState) { + const currentNodes = get().nodes; + const currentEdges = useEdgeStore.getState().edges; + dragStartState = { + nodes: currentNodes.map((n) => ({ + ...n, + position: { ...n.position }, + data: { ...n.data }, + })), + edges: currentEdges.map((e) => ({ ...e })), + }; + } if (!dragStartPositions[change.id]) { const node = get().nodes.find((n) => n.id === change.id); if (node) { @@ -141,12 +150,17 @@ export const useNodeStore = create((set, get) => ({ } }); - // Check if we should track this change in history - let shouldTrack = changes.some( - (change) => change.type === "remove" || change.type === "add", - ); + let shouldTrack = changes.some((change) => change.type === "remove"); + let stateToTrack: { nodes: CustomNode[]; edges: CustomEdge[] } | null = + null; + + if (shouldTrack) { + stateToTrack = { + nodes: get().nodes, + edges: useEdgeStore.getState().edges, + }; + } - // For position changes, only track if movement exceeds threshold if (!shouldTrack) { changes.forEach((change) => { if (change.type === "position" && change.dragging === false) { @@ -158,20 +172,23 @@ export const useNodeStore = create((set, get) => ({ ); if (distanceMoved > MINIMUM_MOVE_BEFORE_LOG) { shouldTrack = true; + stateToTrack = dragStartState; } } - // Clean up tracked position after drag ends delete dragStartPositions[change.id]; } }); + if (Object.keys(dragStartPositions).length === 0) { + dragStartState = null; + } } set((state) => ({ nodes: applyNodeChanges(changes, state.nodes), })); - if (shouldTrack) { - useHistoryStore.getState().pushState(prevState); + if (shouldTrack && stateToTrack) { + useHistoryStore.getState().pushState(stateToTrack); } }, @@ -185,6 +202,11 @@ export const useNodeStore = create((set, get) => ({ hardcodedValues?: Record, position?: XYPosition, ) => { + const prevState = { + nodes: get().nodes, + edges: useEdgeStore.getState().edges, + }; + const customNodeData = convertBlockInfoIntoCustomNodeData( block, hardcodedValues, @@ -218,21 +240,24 @@ export const useNodeStore = create((set, get) => ({ set((state) => ({ nodes: [...state.nodes, customNode], })); + + useHistoryStore.getState().pushState(prevState); + return customNode; }, updateNodeData: (nodeId, data) => { + const prevState = { + nodes: get().nodes, + edges: useEdgeStore.getState().edges, + }; + set((state) => ({ nodes: state.nodes.map((n) => n.id === nodeId ? { ...n, data: { ...n.data, ...data } } : n, ), })); - const newState = { - nodes: get().nodes, - edges: useEdgeStore.getState().edges, - }; - - useHistoryStore.getState().pushState(newState); + useHistoryStore.getState().pushState(prevState); }, toggleAdvanced: (nodeId: string) => set((state) => ({ @@ -391,6 +416,11 @@ export const useNodeStore = create((set, get) => ({ }, setCredentialsOptional: (nodeId: string, optional: boolean) => { + const prevState = { + nodes: get().nodes, + edges: useEdgeStore.getState().edges, + }; + set((state) => ({ nodes: state.nodes.map((n) => n.id === nodeId @@ -408,12 +438,7 @@ export const useNodeStore = create((set, get) => ({ ), })); - const newState = { - nodes: get().nodes, - edges: useEdgeStore.getState().edges, - }; - - useHistoryStore.getState().pushState(newState); + useHistoryStore.getState().pushState(prevState); }, // Sub-agent resolution mode state From 5ff669e9996c3825fd8929db9521839575b35396 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Fri, 16 Jan 2026 08:28:36 -0600 Subject: [PATCH 07/32] fix(backend): Make Redis connection lazy in cache module (#11775) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Makes Redis connection lazy in the cache module - connection is only established when `shared_cache=True` is actually used - Fixes DatabaseManager failing to start because it imports `onboarding.py` which imports `cache.py`, triggering Redis connection at module load time even though it only uses in-memory caching ## Root Cause Commit `b01ea3fcb` (merged today) added `increment_onboarding_runs` to DatabaseManager, which imports from `onboarding.py`. That module imports `@cached` decorator from `cache.py`, which was creating a Redis connection at module import time: ```python # Old code - ran at import time! redis = Redis(connection_pool=_get_cache_pool()) ``` Since `onboarding.py` only uses `@cached(shared_cache=False)` (in-memory caching), it doesn't actually need Redis. But the import triggered the connection attempt. ## Changes - Wrapped Redis connection in a singleton class with lazy initialization - Connection is only established when `_get_redis()` is first called (i.e., when `shared_cache=True` is used) - Services using only in-memory caching can now import `cache.py` without Redis configuration ## Test plan - [ ] Services using `shared_cache=False` work without Redis configured - [ ] Services using `shared_cache=True` still work correctly with Redis - [ ] Existing cache tests pass 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.5 --- .../backend/backend/util/cache.py | 71 +++++++++++-------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/autogpt_platform/backend/backend/util/cache.py b/autogpt_platform/backend/backend/util/cache.py index 757ba45b42..38c3667951 100644 --- a/autogpt_platform/backend/backend/util/cache.py +++ b/autogpt_platform/backend/backend/util/cache.py @@ -16,7 +16,7 @@ import pickle import threading import time from dataclasses import dataclass -from functools import wraps +from functools import cache, wraps from typing import Any, Callable, ParamSpec, Protocol, TypeVar, cast, runtime_checkable from redis import ConnectionPool, Redis @@ -38,29 +38,34 @@ settings = Settings() # maxmemory 2gb # Set memory limit (adjust based on your needs) # save "" # Disable persistence if using Redis purely for caching -# Create a dedicated Redis connection pool for caching (binary mode for pickle) -_cache_pool: ConnectionPool | None = None - -@conn_retry("Redis", "Acquiring cache connection pool") +@cache def _get_cache_pool() -> ConnectionPool: - """Get or create a connection pool for cache operations.""" - global _cache_pool - if _cache_pool is None: - _cache_pool = ConnectionPool( - host=settings.config.redis_host, - port=settings.config.redis_port, - password=settings.config.redis_password or None, - decode_responses=False, # Binary mode for pickle - max_connections=50, - socket_keepalive=True, - socket_connect_timeout=5, - retry_on_timeout=True, - ) - return _cache_pool + """Get or create a connection pool for cache operations (lazy, thread-safe).""" + return ConnectionPool( + host=settings.config.redis_host, + port=settings.config.redis_port, + password=settings.config.redis_password or None, + decode_responses=False, # Binary mode for pickle + max_connections=50, + socket_keepalive=True, + socket_connect_timeout=5, + retry_on_timeout=True, + ) -redis = Redis(connection_pool=_get_cache_pool()) +@cache +@conn_retry("Redis", "Acquiring cache connection") +def _get_redis() -> Redis: + """ + Get the lazily-initialized Redis client for shared cache operations. + Uses @cache for thread-safe singleton behavior - connection is only + established when first accessed, allowing services that only use + in-memory caching to work without Redis configuration. + """ + r = Redis(connection_pool=_get_cache_pool()) + r.ping() # Verify connection + return r @dataclass @@ -179,9 +184,9 @@ def cached( try: if refresh_ttl_on_get: # Use GETEX to get value and refresh expiry atomically - cached_bytes = redis.getex(redis_key, ex=ttl_seconds) + cached_bytes = _get_redis().getex(redis_key, ex=ttl_seconds) else: - cached_bytes = redis.get(redis_key) + cached_bytes = _get_redis().get(redis_key) if cached_bytes and isinstance(cached_bytes, bytes): return pickle.loads(cached_bytes) @@ -195,7 +200,7 @@ def cached( """Set value in Redis with TTL.""" try: pickled_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL) - redis.setex(redis_key, ttl_seconds, pickled_value) + _get_redis().setex(redis_key, ttl_seconds, pickled_value) except Exception as e: logger.error( f"Redis error storing cache for {target_func.__name__}: {e}" @@ -333,14 +338,18 @@ def cached( if pattern: # Clear entries matching pattern keys = list( - redis.scan_iter(f"cache:{target_func.__name__}:{pattern}") + _get_redis().scan_iter( + f"cache:{target_func.__name__}:{pattern}" + ) ) else: # Clear all cache keys - keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*")) + keys = list( + _get_redis().scan_iter(f"cache:{target_func.__name__}:*") + ) if keys: - pipeline = redis.pipeline() + pipeline = _get_redis().pipeline() for key in keys: pipeline.delete(key) pipeline.execute() @@ -355,7 +364,9 @@ def cached( def cache_info() -> dict[str, int | None]: if shared_cache: - cache_keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*")) + cache_keys = list( + _get_redis().scan_iter(f"cache:{target_func.__name__}:*") + ) return { "size": len(cache_keys), "maxsize": None, # Redis manages its own size @@ -373,10 +384,8 @@ def cached( key = _make_hashable_key(args, kwargs) if shared_cache: redis_key = _make_redis_key(key, target_func.__name__) - if redis.exists(redis_key): - redis.delete(redis_key) - return True - return False + deleted_count = cast(int, _get_redis().delete(redis_key)) + return deleted_count > 0 else: if key in cache_storage: del cache_storage[key] From 4a9b13acb6bce3e94be0f23938dadb78bb8c47ef Mon Sep 17 00:00:00 2001 From: Swifty Date: Fri, 16 Jan 2026 16:15:39 +0100 Subject: [PATCH 08/32] feat(frontend): extract frontend changes from hackathon/copilot branch (#11717) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Frontend changes extracted from the hackathon/copilot branch for the copilot feature development. ### Changes 🏗️ - New Chat system with contextual components (`Chat`, `ChatDrawer`, `ChatContainer`, `ChatMessage`, etc.) - Form renderer system with RJSF v6 integration and new input renderers - Enhanced credentials management with improved OAuth flow and credential selection - New output renderers for various content types (Code, Image, JSON, Markdown, Text, Video) - Scrollable tabs component for better UI organization - Marketplace update notifications and publishing workflow improvements - Draft recovery feature with IndexedDB persistence - Safe mode toggle functionality - Various UI/UX improvements across the platform ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [ ] Test new Chat components functionality - [ ] Verify form renderer with various input types - [ ] Test credential management flows - [ ] Verify output renderers display correctly - [ ] Test draft recovery feature #### For configuration changes: - [x] `.env.default` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes - [x] I have included a list of my configuration changes in the PR description (under **Changes**) --------- Co-authored-by: Lluis Agusti --- .../AgentOnboardingCredentials.tsx | 2 +- .../app/(no-navbar)/onboarding/5-run/page.tsx | 2 +- .../auth/integrations/setup-wizard/page.tsx | 2 +- .../components/AgentOutputs/AgentOutputs.tsx | 10 +- .../NodeOutput/components/ContentRenderer.tsx | 4 +- .../NodeDataViewer/NodeDataViewer.tsx | 8 +- .../NodeDataViewer/useNodeDataViewer.ts | 6 +- .../components/WebhookDisclaimer.tsx | 10 +- .../legacy-builder/ExpandableOutputDialog.tsx | 4 +- .../components/legacy-builder/NodeInputs.tsx | 2 +- .../(platform)/chat/components/Chat/Chat.tsx | 134 +++++ .../AgentCarouselMessage.tsx | 56 ++- .../AgentInputsSetup/AgentInputsSetup.tsx | 246 ++++++++++ .../AgentInputsSetup/useAgentInputsSetup.ts | 38 ++ .../AuthPromptWidget/AuthPromptWidget.tsx | 23 +- .../ChatContainer/ChatContainer.tsx | 88 ++++ .../createStreamEventDispatcher.ts | 8 +- .../components}/ChatContainer/helpers.ts | 119 ++++- .../useChatContainer.handlers.ts | 31 +- .../ChatContainer/useChatContainer.ts | 206 ++++++++ .../ChatCredentialsSetup.tsx | 149 ++++++ .../useChatCredentialsSetup.ts | 0 .../ChatErrorState/ChatErrorState.tsx | 0 .../Chat/components/ChatInput/ChatInput.tsx | 64 +++ .../components}/ChatInput/useChatInput.ts | 20 +- .../ChatLoadingState/ChatLoadingState.tsx | 19 + .../components/ChatMessage/ChatMessage.tsx | 341 +++++++++++++ .../components}/ChatMessage/useChatMessage.ts | 15 +- .../ExecutionStartedMessage.tsx | 38 +- .../MarkdownContent/MarkdownContent.tsx | 55 +-- .../MessageBubble/MessageBubble.tsx | 56 +++ .../components/MessageList/MessageList.tsx | 121 +++++ .../components}/MessageList/useMessageList.ts | 0 .../NoResultsMessage/NoResultsMessage.tsx | 22 +- .../QuickActionsWelcome.tsx | 94 ++++ .../SessionsDrawer/SessionsDrawer.tsx | 136 ++++++ .../StreamingMessage/StreamingMessage.tsx | 42 ++ .../StreamingMessage/useStreamingMessage.ts | 0 .../ThinkingMessage/ThinkingMessage.tsx | 70 +++ .../ToolCallMessage/ToolCallMessage.tsx | 24 + .../ToolResponseMessage.tsx | 260 ++++++++++ .../chat/{ => components/Chat}/helpers.ts | 7 - .../Chat/useChat.ts} | 31 +- .../chat/components/Chat/useChatDrawer.ts | 17 + .../{ => components/Chat}/useChatSession.ts | 33 +- .../chat/components/Chat/useChatStream.ts | 371 ++++++++++++++ .../chat/components/Chat/usePageContext.ts | 98 ++++ .../ChatContainer/ChatContainer.tsx | 68 --- .../ChatContainer/useChatContainer.ts | 130 ----- .../ChatCredentialsSetup.tsx | 153 ------ .../chat/components/ChatInput/ChatInput.tsx | 63 --- .../ChatLoadingState/ChatLoadingState.tsx | 31 -- .../components/ChatMessage/ChatMessage.tsx | 194 -------- .../MessageBubble/MessageBubble.tsx | 28 -- .../components/MessageList/MessageList.tsx | 61 --- .../QuickActionsWelcome.tsx | 51 -- .../StreamingMessage/StreamingMessage.tsx | 42 -- .../ToolCallMessage/ToolCallMessage.tsx | 49 -- .../ToolResponseMessage.tsx | 52 -- .../frontend/src/app/(platform)/chat/page.tsx | 64 +-- .../src/app/(platform)/chat/useChatStream.ts | 204 -------- .../frontend/src/app/(platform)/layout.tsx | 2 +- .../NewAgentLibraryView.tsx | 2 +- .../AgentInputsReadOnly.tsx | 4 +- .../modals/RunAgentInputs/RunAgentInputs.tsx | 35 +- .../CredentialsGroupedView.tsx | 2 +- .../RunAgentModal/components/helpers.ts | 2 +- .../modals/RunAgentModal/useAgentRunModal.tsx | 2 +- .../SelectedRunView/components/RunOutputs.tsx | 4 +- .../SelectedTemplateView.tsx | 2 +- .../SelectedTriggerView.tsx | 2 +- .../components/agent-run-draft-view.tsx | 2 +- .../components/agent-run-output-view.tsx | 4 +- .../chat/sessions/[sessionId]/stream/route.ts | 87 +++- autogpt_platform/frontend/src/app/globals.css | 46 ++ autogpt_platform/frontend/src/app/layout.tsx | 2 +- .../src/components/atoms/Input/Input.tsx | 7 +- .../CredentialsInput}/CredentialsInput.tsx | 2 + .../APIKeyCredentialsModal.tsx | 0 .../useAPIKeyCredentialsModal.ts | 0 .../CredentialRow/CredentialRow.tsx | 0 .../CredentialsAccordionView.tsx | 0 .../CredentialsFlatView.tsx | 0 .../CredentialsSelect/CredentialsSelect.tsx | 0 .../DeleteConfirmationModal.tsx | 0 .../HotScopedCredentialsModal.tsx | 0 .../OAuthWaitingModal/OAuthWaitingModal.tsx | 0 .../PasswordCredentialsModal.tsx | 0 .../contextual/CredentialsInput}/helpers.ts | 0 .../CredentialsInput}/useCredentialsInput.ts | 0 .../GoogleDrivePicker/GoogleDrivePicker.tsx | 2 +- .../components/OutputActions.tsx | 106 ++++ .../OutputRenderers/components/OutputItem.tsx | 30 ++ .../contextual/OutputRenderers/index.ts | 20 + .../renderers/CodeRenderer.tsx | 135 ++++++ .../renderers/ImageRenderer.tsx | 209 ++++++++ .../renderers/JSONRenderer.tsx | 204 ++++++++ .../renderers/MarkdownRenderer.tsx | 456 ++++++++++++++++++ .../renderers/TextRenderer.tsx | 71 +++ .../renderers/VideoRenderer.tsx | 169 +++++++ .../contextual/OutputRenderers/types.ts | 60 +++ .../contextual/OutputRenderers/utils/copy.ts | 115 +++++ .../OutputRenderers/utils/download.ts | 74 +++ .../RunAgentInputs/RunAgentInputs.tsx | 389 +++++++++++++++ .../RunAgentInputs/useRunAgentInputs.ts | 19 + .../CredentialField/CredentialField.tsx | 2 +- autogpt_platform/frontend/src/lib/utils.ts | 6 + 107 files changed, 5130 insertions(+), 1416 deletions(-) create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/Chat.tsx rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/AgentCarouselMessage/AgentCarouselMessage.tsx (65%) create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/AgentInputsSetup.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/useAgentInputsSetup.ts rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/AuthPromptWidget/AuthPromptWidget.tsx (84%) create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/ChatContainer.tsx rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/ChatContainer/createStreamEventDispatcher.ts (95%) rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/ChatContainer/helpers.ts (68%) rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/ChatContainer/useChatContainer.handlers.ts (85%) create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/ChatCredentialsSetup/useChatCredentialsSetup.ts (100%) rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/ChatErrorState/ChatErrorState.tsx (100%) create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/ChatInput.tsx rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/ChatInput/useChatInput.ts (65%) create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatLoadingState/ChatLoadingState.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/ChatMessage.tsx rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/ChatMessage/useChatMessage.ts (88%) rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/ExecutionStartedMessage/ExecutionStartedMessage.tsx (67%) rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/MarkdownContent/MarkdownContent.tsx (80%) create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageBubble/MessageBubble.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageList/MessageList.tsx rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/MessageList/useMessageList.ts (100%) rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/NoResultsMessage/NoResultsMessage.tsx (72%) create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/QuickActionsWelcome/QuickActionsWelcome.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/SessionsDrawer/SessionsDrawer.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/StreamingMessage/StreamingMessage.tsx rename autogpt_platform/frontend/src/app/(platform)/chat/components/{ => Chat/components}/StreamingMessage/useStreamingMessage.ts (100%) create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ThinkingMessage/ThinkingMessage.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolCallMessage/ToolCallMessage.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolResponseMessage/ToolResponseMessage.tsx rename autogpt_platform/frontend/src/app/(platform)/chat/{ => components/Chat}/helpers.ts (92%) rename autogpt_platform/frontend/src/app/(platform)/chat/{useChatPage.ts => components/Chat/useChat.ts} (78%) create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatDrawer.ts rename autogpt_platform/frontend/src/app/(platform)/chat/{ => components/Chat}/useChatSession.ts (89%) create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatStream.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/usePageContext.ts delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/ChatContainer.tsx delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.ts delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/ChatInput.tsx delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/ChatLoadingState/ChatLoadingState.tsx delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/ChatMessage/ChatMessage.tsx delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/MessageBubble/MessageBubble.tsx delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/MessageList/MessageList.tsx delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/QuickActionsWelcome/QuickActionsWelcome.tsx delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/StreamingMessage/StreamingMessage.tsx delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/ToolCallMessage/ToolCallMessage.tsx delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/components/ToolResponseMessage/ToolResponseMessage.tsx delete mode 100644 autogpt_platform/frontend/src/app/(platform)/chat/useChatStream.ts rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/CredentialsInput.tsx (99%) rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/components/APIKeyCredentialsModal/APIKeyCredentialsModal.tsx (100%) rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/components/APIKeyCredentialsModal/useAPIKeyCredentialsModal.ts (100%) rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/components/CredentialRow/CredentialRow.tsx (100%) rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/components/CredentialsAccordionView/CredentialsAccordionView.tsx (100%) rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/components/CredentialsFlatView/CredentialsFlatView.tsx (100%) rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/components/CredentialsSelect/CredentialsSelect.tsx (100%) rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/components/DeleteConfirmationModal/DeleteConfirmationModal.tsx (100%) rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/components/HotScopedCredentialsModal/HotScopedCredentialsModal.tsx (100%) rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/components/OAuthWaitingModal/OAuthWaitingModal.tsx (100%) rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/components/PasswordCredentialsModal/PasswordCredentialsModal.tsx (100%) rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/helpers.ts (100%) rename autogpt_platform/frontend/src/{app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs => components/contextual/CredentialsInput}/useCredentialsInput.ts (100%) create mode 100644 autogpt_platform/frontend/src/components/contextual/OutputRenderers/components/OutputActions.tsx create mode 100644 autogpt_platform/frontend/src/components/contextual/OutputRenderers/components/OutputItem.tsx create mode 100644 autogpt_platform/frontend/src/components/contextual/OutputRenderers/index.ts create mode 100644 autogpt_platform/frontend/src/components/contextual/OutputRenderers/renderers/CodeRenderer.tsx create mode 100644 autogpt_platform/frontend/src/components/contextual/OutputRenderers/renderers/ImageRenderer.tsx create mode 100644 autogpt_platform/frontend/src/components/contextual/OutputRenderers/renderers/JSONRenderer.tsx create mode 100644 autogpt_platform/frontend/src/components/contextual/OutputRenderers/renderers/MarkdownRenderer.tsx create mode 100644 autogpt_platform/frontend/src/components/contextual/OutputRenderers/renderers/TextRenderer.tsx create mode 100644 autogpt_platform/frontend/src/components/contextual/OutputRenderers/renderers/VideoRenderer.tsx create mode 100644 autogpt_platform/frontend/src/components/contextual/OutputRenderers/types.ts create mode 100644 autogpt_platform/frontend/src/components/contextual/OutputRenderers/utils/copy.ts create mode 100644 autogpt_platform/frontend/src/components/contextual/OutputRenderers/utils/download.ts create mode 100644 autogpt_platform/frontend/src/components/contextual/RunAgentInputs/RunAgentInputs.tsx create mode 100644 autogpt_platform/frontend/src/components/contextual/RunAgentInputs/useRunAgentInputs.ts diff --git a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/components/AgentOnboardingCredentials/AgentOnboardingCredentials.tsx b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/components/AgentOnboardingCredentials/AgentOnboardingCredentials.tsx index 72e296fd88..f0bb652a06 100644 --- a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/components/AgentOnboardingCredentials/AgentOnboardingCredentials.tsx +++ b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/components/AgentOnboardingCredentials/AgentOnboardingCredentials.tsx @@ -1,6 +1,6 @@ -import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput"; import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput"; import { GraphMeta } from "@/app/api/__generated__/models/graphMeta"; +import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput"; import { useState } from "react"; import { getSchemaDefaultCredentials } from "../../helpers"; import { areAllCredentialsSet, getCredentialFields } from "./helpers"; diff --git a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/page.tsx b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/page.tsx index 30e1b67090..db04278d80 100644 --- a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/page.tsx +++ b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/page.tsx @@ -1,12 +1,12 @@ "use client"; -import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentInputs/RunAgentInputs"; import { Card, CardContent, CardHeader, CardTitle, } from "@/components/__legacy__/ui/card"; +import { RunAgentInputs } from "@/components/contextual/RunAgentInputs/RunAgentInputs"; import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard"; import { CircleNotchIcon } from "@phosphor-icons/react/dist/ssr"; import { Play } from "lucide-react"; diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/setup-wizard/page.tsx b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/setup-wizard/page.tsx index f7d4935907..9e2e637ef6 100644 --- a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/setup-wizard/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/setup-wizard/page.tsx @@ -1,11 +1,11 @@ "use client"; -import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput"; import { useGetOauthGetOauthAppInfo } from "@/app/api/__generated__/endpoints/oauth/oauth"; import { okData } from "@/app/api/helpers"; import { Button } from "@/components/atoms/Button/Button"; import { Text } from "@/components/atoms/Text/Text"; import { AuthCard } from "@/components/auth/AuthCard"; +import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput"; import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard"; import type { BlockIOCredentialsSubSchema, diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/AgentOutputs/AgentOutputs.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/AgentOutputs/AgentOutputs.tsx index 20493b2ca0..cfea5d9452 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/AgentOutputs/AgentOutputs.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/AgentOutputs/AgentOutputs.tsx @@ -1,11 +1,6 @@ import { BlockUIType } from "@/app/(platform)/build/components/types"; import { useGraphStore } from "@/app/(platform)/build/stores/graphStore"; import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore"; -import { - globalRegistry, - OutputActions, - OutputItem, -} from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers"; import { Label } from "@/components/__legacy__/ui/label"; import { ScrollArea } from "@/components/__legacy__/ui/scroll-area"; import { @@ -23,6 +18,11 @@ import { TooltipProvider, TooltipTrigger, } from "@/components/atoms/Tooltip/BaseTooltip"; +import { + globalRegistry, + OutputActions, + OutputItem, +} from "@/components/contextual/OutputRenderers"; import { BookOpenIcon } from "@phosphor-icons/react"; import { useMemo } from "react"; import { useShallow } from "zustand/react/shallow"; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/ContentRenderer.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/ContentRenderer.tsx index 9cb1a62e3d..6571bc7b6f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/ContentRenderer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/ContentRenderer.tsx @@ -1,7 +1,7 @@ "use client"; -import type { OutputMetadata } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers"; -import { globalRegistry } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers"; +import type { OutputMetadata } from "@/components/contextual/OutputRenderers"; +import { globalRegistry } from "@/components/contextual/OutputRenderers"; export const TextRenderer: React.FC<{ value: any; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx index 31b89315d6..0858db8f0e 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx @@ -1,7 +1,3 @@ -import { - OutputActions, - OutputItem, -} from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers"; import { ScrollArea } from "@/components/__legacy__/ui/scroll-area"; import { Button } from "@/components/atoms/Button/Button"; import { Text } from "@/components/atoms/Text/Text"; @@ -11,6 +7,10 @@ import { TooltipProvider, TooltipTrigger, } from "@/components/atoms/Tooltip/BaseTooltip"; +import { + OutputActions, + OutputItem, +} from "@/components/contextual/OutputRenderers"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; import { beautifyString } from "@/lib/utils"; import { diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/useNodeDataViewer.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/useNodeDataViewer.ts index 1adec625a0..d3c555970c 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/useNodeDataViewer.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/useNodeDataViewer.ts @@ -1,6 +1,6 @@ -import type { OutputMetadata } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers"; -import { globalRegistry } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers"; -import { downloadOutputs } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers/utils/download"; +import type { OutputMetadata } from "@/components/contextual/OutputRenderers"; +import { globalRegistry } from "@/components/contextual/OutputRenderers"; +import { downloadOutputs } from "@/components/contextual/OutputRenderers/utils/download"; import { useToast } from "@/components/molecules/Toast/use-toast"; import { beautifyString } from "@/lib/utils"; import React, { useMemo, useState } from "react"; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/WebhookDisclaimer.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/WebhookDisclaimer.tsx index 044bf994ad..1fdee05a2a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/WebhookDisclaimer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/WebhookDisclaimer.tsx @@ -1,10 +1,10 @@ -import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert"; -import { Text } from "@/components/atoms/Text/Text"; -import Link from "next/link"; import { useGetV2GetLibraryAgentByGraphId } from "@/app/api/__generated__/endpoints/library/library"; import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; -import { useQueryStates, parseAsString } from "nuqs"; -import { isValidUUID } from "@/app/(platform)/chat/helpers"; +import { Text } from "@/components/atoms/Text/Text"; +import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert"; +import { isValidUUID } from "@/lib/utils"; +import Link from "next/link"; +import { parseAsString, useQueryStates } from "nuqs"; export const WebhookDisclaimer = ({ nodeId }: { nodeId: string }) => { const [{ flowID }] = useQueryStates({ diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/ExpandableOutputDialog.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/ExpandableOutputDialog.tsx index 98edbca2fb..1ccb3d1261 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/ExpandableOutputDialog.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/ExpandableOutputDialog.tsx @@ -1,9 +1,9 @@ -import type { OutputMetadata } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers"; +import type { OutputMetadata } from "@/components/contextual/OutputRenderers"; import { globalRegistry, OutputActions, OutputItem, -} from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers"; +} from "@/components/contextual/OutputRenderers"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; import { beautifyString } from "@/lib/utils"; import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/NodeInputs.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/NodeInputs.tsx index 36df180c8c..51fed5bef1 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/NodeInputs.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/NodeInputs.tsx @@ -3,7 +3,6 @@ import { CustomNodeData, } from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode"; import { NodeTableInput } from "@/app/(platform)/build/components/legacy-builder/NodeTableInput"; -import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput"; import { Button } from "@/components/__legacy__/ui/button"; import { Calendar } from "@/components/__legacy__/ui/calendar"; import { LocalValuedInput } from "@/components/__legacy__/ui/input"; @@ -28,6 +27,7 @@ import { SelectValue, } from "@/components/__legacy__/ui/select"; import { Switch } from "@/components/atoms/Switch/Switch"; +import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput"; import { GoogleDrivePickerInput } from "@/components/contextual/GoogleDrivePicker/GoogleDrivePickerInput"; import { BlockIOArraySubSchema, diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/Chat.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/Chat.tsx new file mode 100644 index 0000000000..461c885dc3 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/Chat.tsx @@ -0,0 +1,134 @@ +"use client"; + +import { Button } from "@/components/atoms/Button/Button"; +import { Text } from "@/components/atoms/Text/Text"; +import { cn } from "@/lib/utils"; +import { List } from "@phosphor-icons/react"; +import React, { useState } from "react"; +import { ChatContainer } from "./components/ChatContainer/ChatContainer"; +import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState"; +import { ChatLoadingState } from "./components/ChatLoadingState/ChatLoadingState"; +import { SessionsDrawer } from "./components/SessionsDrawer/SessionsDrawer"; +import { useChat } from "./useChat"; + +export interface ChatProps { + className?: string; + headerTitle?: React.ReactNode; + showHeader?: boolean; + showSessionInfo?: boolean; + showNewChatButton?: boolean; + onNewChat?: () => void; + headerActions?: React.ReactNode; +} + +export function Chat({ + className, + headerTitle = "AutoGPT Copilot", + showHeader = true, + showSessionInfo = true, + showNewChatButton = true, + onNewChat, + headerActions, +}: ChatProps) { + const { + messages, + isLoading, + isCreating, + error, + sessionId, + createSession, + clearSession, + loadSession, + } = useChat(); + + const [isSessionsDrawerOpen, setIsSessionsDrawerOpen] = useState(false); + + const handleNewChat = () => { + clearSession(); + onNewChat?.(); + }; + + const handleSelectSession = async (sessionId: string) => { + try { + await loadSession(sessionId); + } catch (err) { + console.error("Failed to load session:", err); + } + }; + + return ( +
+ {/* Header */} + {showHeader && ( +
+
+
+ + {typeof headerTitle === "string" ? ( + + {headerTitle} + + ) : ( + headerTitle + )} +
+
+ {showSessionInfo && sessionId && ( + <> + {showNewChatButton && ( + + )} + + )} + {headerActions} +
+
+
+ )} + + {/* Main Content */} +
+ {/* Loading State - show when explicitly loading/creating OR when we don't have a session yet and no error */} + {(isLoading || isCreating || (!sessionId && !error)) && ( + + )} + + {/* Error State */} + {error && !isLoading && ( + + )} + + {/* Session Content */} + {sessionId && !isLoading && !error && ( + + )} +
+ + {/* Sessions Drawer */} + setIsSessionsDrawerOpen(false)} + onSelectSession={handleSelectSession} + currentSessionId={sessionId} + /> +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/AgentCarouselMessage/AgentCarouselMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentCarouselMessage/AgentCarouselMessage.tsx similarity index 65% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/AgentCarouselMessage/AgentCarouselMessage.tsx rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentCarouselMessage/AgentCarouselMessage.tsx index 125f834e05..582b24de5e 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/AgentCarouselMessage/AgentCarouselMessage.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentCarouselMessage/AgentCarouselMessage.tsx @@ -1,15 +1,16 @@ -import React from "react"; -import { Text } from "@/components/atoms/Text/Text"; import { Button } from "@/components/atoms/Button/Button"; import { Card } from "@/components/atoms/Card/Card"; -import { List, Robot, ArrowRight } from "@phosphor-icons/react"; +import { Text } from "@/components/atoms/Text/Text"; import { cn } from "@/lib/utils"; +import { ArrowRight, List, Robot } from "@phosphor-icons/react"; +import Image from "next/image"; export interface Agent { id: string; name: string; description: string; version?: number; + image_url?: string; } export interface AgentCarouselMessageProps { @@ -30,7 +31,7 @@ export function AgentCarouselMessage({ return (
@@ -40,13 +41,10 @@ export function AgentCarouselMessage({
- + Found {displayCount} {displayCount === 1 ? "Agent" : "Agents"} - + Select an agent to view details or run it
@@ -57,40 +55,49 @@ export function AgentCarouselMessage({ {agents.map((agent) => (
-
- +
+ {agent.image_url ? ( + {`${agent.name} + ) : ( +
+ +
+ )}
{agent.name} {agent.version && ( - + v{agent.version} )}
- + {agent.description} {onSelectAgent && (
{totalCount && totalCount > agents.length && ( - + Showing {agents.length} of {totalCount} results )} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/AgentInputsSetup.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/AgentInputsSetup.tsx new file mode 100644 index 0000000000..3ef71eca09 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/AgentInputsSetup.tsx @@ -0,0 +1,246 @@ +"use client"; + +import { Button } from "@/components/atoms/Button/Button"; +import { Card } from "@/components/atoms/Card/Card"; +import { Text } from "@/components/atoms/Text/Text"; +import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput"; +import { RunAgentInputs } from "@/components/contextual/RunAgentInputs/RunAgentInputs"; + +import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; +import { + BlockIOCredentialsSubSchema, + BlockIOSubSchema, +} from "@/lib/autogpt-server-api/types"; +import { cn, isEmpty } from "@/lib/utils"; +import { PlayIcon, WarningIcon } from "@phosphor-icons/react"; +import { useMemo } from "react"; +import { useAgentInputsSetup } from "./useAgentInputsSetup"; + +type LibraryAgentInputSchemaProperties = LibraryAgent["input_schema"] extends { + properties: infer P; +} + ? P extends Record + ? P + : Record + : Record; + +type LibraryAgentCredentialsInputSchemaProperties = + LibraryAgent["credentials_input_schema"] extends { + properties: infer P; + } + ? P extends Record + ? P + : Record + : Record; + +interface Props { + agentName?: string; + inputSchema: LibraryAgentInputSchemaProperties | Record; + credentialsSchema?: + | LibraryAgentCredentialsInputSchemaProperties + | Record; + message: string; + requiredFields?: string[]; + onRun: ( + inputs: Record, + credentials: Record, + ) => void; + onCancel?: () => void; + className?: string; +} + +export function AgentInputsSetup({ + agentName, + inputSchema, + credentialsSchema, + message, + requiredFields, + onRun, + onCancel, + className, +}: Props) { + const { inputValues, setInputValue, credentialsValues, setCredentialsValue } = + useAgentInputsSetup(); + + const inputSchemaObj = useMemo(() => { + if (!inputSchema) return { properties: {}, required: [] }; + if ("properties" in inputSchema && "type" in inputSchema) { + return inputSchema as { + properties: Record; + required?: string[]; + }; + } + return { properties: inputSchema as Record, required: [] }; + }, [inputSchema]); + + const credentialsSchemaObj = useMemo(() => { + if (!credentialsSchema) return { properties: {}, required: [] }; + if ("properties" in credentialsSchema && "type" in credentialsSchema) { + return credentialsSchema as { + properties: Record; + required?: string[]; + }; + } + return { + properties: credentialsSchema as Record, + required: [], + }; + }, [credentialsSchema]); + + const agentInputFields = useMemo(() => { + const properties = inputSchemaObj.properties || {}; + return Object.fromEntries( + Object.entries(properties).filter( + ([_, subSchema]: [string, any]) => !subSchema.hidden, + ), + ); + }, [inputSchemaObj]); + + const agentCredentialsInputFields = useMemo(() => { + return credentialsSchemaObj.properties || {}; + }, [credentialsSchemaObj]); + + const inputFields = Object.entries(agentInputFields); + const credentialFields = Object.entries(agentCredentialsInputFields); + + const defaultsFromSchema = useMemo(() => { + const defaults: Record = {}; + Object.entries(agentInputFields).forEach(([key, schema]) => { + if ("default" in schema && schema.default !== undefined) { + defaults[key] = schema.default; + } + }); + return defaults; + }, [agentInputFields]); + + const defaultsFromCredentialsSchema = useMemo(() => { + const defaults: Record = {}; + Object.entries(agentCredentialsInputFields).forEach(([key, schema]) => { + if ("default" in schema && schema.default !== undefined) { + defaults[key] = schema.default; + } + }); + return defaults; + }, [agentCredentialsInputFields]); + + const mergedInputValues = useMemo(() => { + return { ...defaultsFromSchema, ...inputValues }; + }, [defaultsFromSchema, inputValues]); + + const mergedCredentialsValues = useMemo(() => { + return { ...defaultsFromCredentialsSchema, ...credentialsValues }; + }, [defaultsFromCredentialsSchema, credentialsValues]); + + const allRequiredInputsAreSet = useMemo(() => { + const requiredInputs = new Set( + requiredFields || (inputSchemaObj.required as string[]) || [], + ); + const nonEmptyInputs = new Set( + Object.keys(mergedInputValues).filter( + (k) => !isEmpty(mergedInputValues[k]), + ), + ); + const missing = [...requiredInputs].filter( + (input) => !nonEmptyInputs.has(input), + ); + return missing.length === 0; + }, [inputSchemaObj.required, mergedInputValues, requiredFields]); + + const allCredentialsAreSet = useMemo(() => { + const requiredCredentials = new Set( + (credentialsSchemaObj.required as string[]) || [], + ); + if (requiredCredentials.size === 0) { + return true; + } + const missing = [...requiredCredentials].filter((key) => { + const cred = mergedCredentialsValues[key]; + return !cred || !cred.id; + }); + return missing.length === 0; + }, [credentialsSchemaObj.required, mergedCredentialsValues]); + + const canRun = allRequiredInputsAreSet && allCredentialsAreSet; + + function handleRun() { + if (canRun) { + onRun(mergedInputValues, mergedCredentialsValues); + } + } + + return ( + +
+
+ +
+
+ + {agentName ? `Configure ${agentName}` : "Agent Configuration"} + + + {message} + + + {inputFields.length > 0 && ( +
+ {inputFields.map(([key, inputSubSchema]) => ( + setInputValue(key, value)} + /> + ))} +
+ )} + + {credentialFields.length > 0 && ( +
+ {credentialFields.map(([key, schema]) => { + const requiredCredentials = new Set( + (credentialsSchemaObj.required as string[]) || [], + ); + return ( + + setCredentialsValue(key, value) + } + siblingInputs={mergedInputValues} + isOptional={!requiredCredentials.has(key)} + /> + ); + })} +
+ )} + +
+ + {onCancel && ( + + )} +
+
+
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/useAgentInputsSetup.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/useAgentInputsSetup.ts new file mode 100644 index 0000000000..e36a3f3c5d --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AgentInputsSetup/useAgentInputsSetup.ts @@ -0,0 +1,38 @@ +import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types"; +import { useState } from "react"; + +export function useAgentInputsSetup() { + const [inputValues, setInputValues] = useState>({}); + const [credentialsValues, setCredentialsValues] = useState< + Record + >({}); + + function setInputValue(key: string, value: any) { + setInputValues((prev) => ({ + ...prev, + [key]: value, + })); + } + + function setCredentialsValue(key: string, value?: CredentialsMetaInput) { + if (value) { + setCredentialsValues((prev) => ({ + ...prev, + [key]: value, + })); + } else { + setCredentialsValues((prev) => { + const next = { ...prev }; + delete next[key]; + return next; + }); + } + } + + return { + inputValues, + setInputValue, + credentialsValues, + setCredentialsValue, + }; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/AuthPromptWidget/AuthPromptWidget.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AuthPromptWidget/AuthPromptWidget.tsx similarity index 84% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/AuthPromptWidget/AuthPromptWidget.tsx rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AuthPromptWidget/AuthPromptWidget.tsx index 885a06e92a..33f02e660f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/AuthPromptWidget/AuthPromptWidget.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/AuthPromptWidget/AuthPromptWidget.tsx @@ -1,10 +1,9 @@ "use client"; -import React from "react"; -import { useRouter } from "next/navigation"; import { Button } from "@/components/atoms/Button/Button"; -import { SignInIcon, UserPlusIcon, ShieldIcon } from "@phosphor-icons/react"; import { cn } from "@/lib/utils"; +import { ShieldIcon, SignInIcon, UserPlusIcon } from "@phosphor-icons/react"; +import { useRouter } from "next/navigation"; export interface AuthPromptWidgetProps { message: string; @@ -54,8 +53,8 @@ export function AuthPromptWidget({ return (
-

+

Authentication Required

-

+

Sign in to set up and manage agents

-
-

- {message} -

+
+

{message}

{agentInfo && ( -
+

Ready to set up:{" "} {agentInfo.name} @@ -114,7 +111,7 @@ export function AuthPromptWidget({

-
+
Your chat session will be preserved after signing in
diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/ChatContainer.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/ChatContainer.tsx new file mode 100644 index 0000000000..6f7a0e8f51 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/ChatContainer.tsx @@ -0,0 +1,88 @@ +import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; +import { cn } from "@/lib/utils"; +import { useCallback } from "react"; +import { usePageContext } from "../../usePageContext"; +import { ChatInput } from "../ChatInput/ChatInput"; +import { MessageList } from "../MessageList/MessageList"; +import { QuickActionsWelcome } from "../QuickActionsWelcome/QuickActionsWelcome"; +import { useChatContainer } from "./useChatContainer"; + +export interface ChatContainerProps { + sessionId: string | null; + initialMessages: SessionDetailResponse["messages"]; + className?: string; +} + +export function ChatContainer({ + sessionId, + initialMessages, + className, +}: ChatContainerProps) { + const { messages, streamingChunks, isStreaming, sendMessage } = + useChatContainer({ + sessionId, + initialMessages, + }); + const { capturePageContext } = usePageContext(); + + // Wrap sendMessage to automatically capture page context + const sendMessageWithContext = useCallback( + async (content: string, isUserMessage: boolean = true) => { + const context = capturePageContext(); + await sendMessage(content, isUserMessage, context); + }, + [sendMessage, capturePageContext], + ); + + const quickActions = [ + "Find agents for social media management", + "Show me agents for content creation", + "Help me automate my business", + "What can you help me with?", + ]; + + return ( +
+ {/* Messages or Welcome Screen */} +
+ {messages.length === 0 ? ( + + ) : ( + + )} +
+ + {/* Input - Always visible */} +
+ +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/createStreamEventDispatcher.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/createStreamEventDispatcher.ts similarity index 95% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/createStreamEventDispatcher.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/createStreamEventDispatcher.ts index b8421c3386..844f126d49 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/createStreamEventDispatcher.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/createStreamEventDispatcher.ts @@ -1,14 +1,14 @@ import { toast } from "sonner"; -import type { StreamChunk } from "@/app/(platform)/chat/useChatStream"; +import { StreamChunk } from "../../useChatStream"; import type { HandlerDependencies } from "./useChatContainer.handlers"; import { + handleError, + handleLoginNeeded, + handleStreamEnd, handleTextChunk, handleTextEnded, handleToolCallStart, handleToolResponse, - handleLoginNeeded, - handleStreamEnd, - handleError, } from "./useChatContainer.handlers"; export function createStreamEventDispatcher( diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/helpers.ts similarity index 68% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/helpers.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/helpers.ts index 3a94dab1ea..cd05563369 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/helpers.ts @@ -1,5 +1,24 @@ -import type { ChatMessageData } from "@/app/(platform)/chat/components/ChatMessage/useChatMessage"; import type { ToolResult } from "@/types/chat"; +import type { ChatMessageData } from "../ChatMessage/useChatMessage"; + +export function removePageContext(content: string): string { + // Remove "Page URL: ..." pattern at start of line (case insensitive, handles various formats) + let cleaned = content.replace(/^\s*Page URL:\s*[^\n\r]*/gim, ""); + + // Find "User Message:" marker at start of line to preserve the actual user message + const userMessageMatch = cleaned.match(/^\s*User Message:\s*([\s\S]*)$/im); + if (userMessageMatch) { + // If we found "User Message:", extract everything after it + cleaned = userMessageMatch[1]; + } else { + // If no "User Message:" marker, remove "Page Content:" and everything after it at start of line + cleaned = cleaned.replace(/^\s*Page Content:[\s\S]*$/gim, ""); + } + + // Clean up extra whitespace and newlines + cleaned = cleaned.replace(/\n\s*\n\s*\n+/g, "\n\n").trim(); + return cleaned; +} export function createUserMessage(content: string): ChatMessageData { return { @@ -63,6 +82,7 @@ export function isAgentArray(value: unknown): value is Array<{ name: string; description: string; version?: number; + image_url?: string; }> { if (!Array.isArray(value)) { return false; @@ -77,7 +97,8 @@ export function isAgentArray(value: unknown): value is Array<{ typeof item.name === "string" && "description" in item && typeof item.description === "string" && - (!("version" in item) || typeof item.version === "number"), + (!("version" in item) || typeof item.version === "number") && + (!("image_url" in item) || typeof item.image_url === "string"), ); } @@ -232,6 +253,7 @@ export function isSetupInfo(value: unknown): value is { export function extractCredentialsNeeded( parsedResult: Record, + toolName: string = "run_agent", ): ChatMessageData | null { try { const setupInfo = parsedResult?.setup_info as @@ -244,7 +266,7 @@ export function extractCredentialsNeeded( | Record> | undefined; if (missingCreds && Object.keys(missingCreds).length > 0) { - const agentName = (setupInfo?.agent_name as string) || "this agent"; + const agentName = (setupInfo?.agent_name as string) || "this block"; const credentials = Object.values(missingCreds).map((credInfo) => ({ provider: (credInfo.provider as string) || "unknown", providerName: @@ -264,7 +286,7 @@ export function extractCredentialsNeeded( })); return { type: "credentials_needed", - toolName: "run_agent", + toolName, credentials, message: `To run ${agentName}, you need to add ${credentials.length === 1 ? "credentials" : `${credentials.length} credentials`}.`, agentName, @@ -277,3 +299,92 @@ export function extractCredentialsNeeded( return null; } } + +export function extractInputsNeeded( + parsedResult: Record, + toolName: string = "run_agent", +): ChatMessageData | null { + try { + const setupInfo = parsedResult?.setup_info as + | Record + | undefined; + const requirements = setupInfo?.requirements as + | Record + | undefined; + const inputs = requirements?.inputs as + | Array> + | undefined; + const credentials = requirements?.credentials as + | Array> + | undefined; + + if (!inputs || inputs.length === 0) { + return null; + } + + const agentName = (setupInfo?.agent_name as string) || "this agent"; + const agentId = parsedResult?.graph_id as string | undefined; + const graphVersion = parsedResult?.graph_version as number | undefined; + + const properties: Record = {}; + const requiredProps: string[] = []; + inputs.forEach((input) => { + const name = input.name as string; + if (name) { + properties[name] = { + title: input.name as string, + description: (input.description as string) || "", + type: (input.type as string) || "string", + default: input.default, + enum: input.options, + format: input.format, + }; + if ((input.required as boolean) === true) { + requiredProps.push(name); + } + } + }); + + const inputSchema: Record = { + type: "object", + properties, + }; + if (requiredProps.length > 0) { + inputSchema.required = requiredProps; + } + + const credentialsSchema: Record = {}; + if (credentials && credentials.length > 0) { + credentials.forEach((cred) => { + const id = cred.id as string; + if (id) { + credentialsSchema[id] = { + type: "object", + properties: {}, + credentials_provider: [cred.provider as string], + credentials_types: [(cred.type as string) || "api_key"], + credentials_scopes: cred.scopes as string[] | undefined, + }; + } + }); + } + + return { + type: "inputs_needed", + toolName, + agentName, + agentId, + graphVersion, + inputSchema, + credentialsSchema: + Object.keys(credentialsSchema).length > 0 + ? credentialsSchema + : undefined, + message: `Please provide the required inputs to run ${agentName}.`, + timestamp: new Date(), + }; + } catch (err) { + console.error("Failed to extract inputs from setup info:", err); + return null; + } +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.handlers.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.handlers.ts similarity index 85% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.handlers.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.handlers.ts index fdbecb5d61..064b847064 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.handlers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.handlers.ts @@ -1,13 +1,18 @@ -import type { Dispatch, SetStateAction, MutableRefObject } from "react"; -import type { StreamChunk } from "@/app/(platform)/chat/useChatStream"; -import type { ChatMessageData } from "@/app/(platform)/chat/components/ChatMessage/useChatMessage"; -import { parseToolResponse, extractCredentialsNeeded } from "./helpers"; +import type { Dispatch, MutableRefObject, SetStateAction } from "react"; +import { StreamChunk } from "../../useChatStream"; +import type { ChatMessageData } from "../ChatMessage/useChatMessage"; +import { + extractCredentialsNeeded, + extractInputsNeeded, + parseToolResponse, +} from "./helpers"; export interface HandlerDependencies { setHasTextChunks: Dispatch>; setStreamingChunks: Dispatch>; streamingChunksRef: MutableRefObject; setMessages: Dispatch>; + setIsStreamingInitiated: Dispatch>; sessionId: string; } @@ -100,11 +105,18 @@ export function handleToolResponse( parsedResult = null; } if ( - chunk.tool_name === "run_agent" && + (chunk.tool_name === "run_agent" || chunk.tool_name === "run_block") && chunk.success && parsedResult?.type === "setup_requirements" ) { - const credentialsMessage = extractCredentialsNeeded(parsedResult); + const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name); + if (inputsMessage) { + deps.setMessages((prev) => [...prev, inputsMessage]); + } + const credentialsMessage = extractCredentialsNeeded( + parsedResult, + chunk.tool_name, + ); if (credentialsMessage) { deps.setMessages((prev) => [...prev, credentialsMessage]); } @@ -197,10 +209,15 @@ export function handleStreamEnd( deps.setStreamingChunks([]); deps.streamingChunksRef.current = []; deps.setHasTextChunks(false); + deps.setIsStreamingInitiated(false); console.log("[Stream End] Stream complete, messages in local state"); } -export function handleError(chunk: StreamChunk, _deps: HandlerDependencies) { +export function handleError(chunk: StreamChunk, deps: HandlerDependencies) { const errorMessage = chunk.message || chunk.content || "An error occurred"; console.error("Stream error:", errorMessage); + deps.setIsStreamingInitiated(false); + deps.setHasTextChunks(false); + deps.setStreamingChunks([]); + deps.streamingChunksRef.current = []; } diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.ts new file mode 100644 index 0000000000..8e7dee7718 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatContainer/useChatContainer.ts @@ -0,0 +1,206 @@ +import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; +import { useCallback, useMemo, useRef, useState } from "react"; +import { toast } from "sonner"; +import { useChatStream } from "../../useChatStream"; +import type { ChatMessageData } from "../ChatMessage/useChatMessage"; +import { createStreamEventDispatcher } from "./createStreamEventDispatcher"; +import { + createUserMessage, + filterAuthMessages, + isToolCallArray, + isValidMessage, + parseToolResponse, + removePageContext, +} from "./helpers"; + +interface Args { + sessionId: string | null; + initialMessages: SessionDetailResponse["messages"]; +} + +export function useChatContainer({ sessionId, initialMessages }: Args) { + const [messages, setMessages] = useState([]); + const [streamingChunks, setStreamingChunks] = useState([]); + const [hasTextChunks, setHasTextChunks] = useState(false); + const [isStreamingInitiated, setIsStreamingInitiated] = useState(false); + const streamingChunksRef = useRef([]); + const { error, sendMessage: sendStreamMessage } = useChatStream(); + const isStreaming = isStreamingInitiated || hasTextChunks; + + const allMessages = useMemo(() => { + const processedInitialMessages: ChatMessageData[] = []; + // Map to track tool calls by their ID so we can look up tool names for tool responses + const toolCallMap = new Map(); + + for (const msg of initialMessages) { + if (!isValidMessage(msg)) { + console.warn("Invalid message structure from backend:", msg); + continue; + } + + let content = String(msg.content || ""); + const role = String(msg.role || "assistant").toLowerCase(); + const toolCalls = msg.tool_calls; + const timestamp = msg.timestamp + ? new Date(msg.timestamp as string) + : undefined; + + // Remove page context from user messages when loading existing sessions + if (role === "user") { + content = removePageContext(content); + // Skip user messages that become empty after removing page context + if (!content.trim()) { + continue; + } + processedInitialMessages.push({ + type: "message", + role: "user", + content, + timestamp, + }); + continue; + } + + // Handle assistant messages first (before tool messages) to build tool call map + if (role === "assistant") { + // Strip tags from content + content = content + .replace(/[\s\S]*?<\/thinking>/gi, "") + .trim(); + + // If assistant has tool calls, create tool_call messages for each + if (toolCalls && isToolCallArray(toolCalls) && toolCalls.length > 0) { + for (const toolCall of toolCalls) { + const toolName = toolCall.function.name; + const toolId = toolCall.id; + // Store tool name for later lookup + toolCallMap.set(toolId, toolName); + + try { + const args = JSON.parse(toolCall.function.arguments || "{}"); + processedInitialMessages.push({ + type: "tool_call", + toolId, + toolName, + arguments: args, + timestamp, + }); + } catch (err) { + console.warn("Failed to parse tool call arguments:", err); + processedInitialMessages.push({ + type: "tool_call", + toolId, + toolName, + arguments: {}, + timestamp, + }); + } + } + // Only add assistant message if there's content after stripping thinking tags + if (content.trim()) { + processedInitialMessages.push({ + type: "message", + role: "assistant", + content, + timestamp, + }); + } + } else if (content.trim()) { + // Assistant message without tool calls, but with content + processedInitialMessages.push({ + type: "message", + role: "assistant", + content, + timestamp, + }); + } + continue; + } + + // Handle tool messages - look up tool name from tool call map + if (role === "tool") { + const toolCallId = (msg.tool_call_id as string) || ""; + const toolName = toolCallMap.get(toolCallId) || "unknown"; + const toolResponse = parseToolResponse( + content, + toolCallId, + toolName, + timestamp, + ); + if (toolResponse) { + processedInitialMessages.push(toolResponse); + } + continue; + } + + // Handle other message types (system, etc.) + if (content.trim()) { + processedInitialMessages.push({ + type: "message", + role: role as "user" | "assistant" | "system", + content, + timestamp, + }); + } + } + + return [...processedInitialMessages, ...messages]; + }, [initialMessages, messages]); + + const sendMessage = useCallback( + async function sendMessage( + content: string, + isUserMessage: boolean = true, + context?: { url: string; content: string }, + ) { + if (!sessionId) { + console.error("Cannot send message: no session ID"); + return; + } + if (isUserMessage) { + const userMessage = createUserMessage(content); + setMessages((prev) => [...filterAuthMessages(prev), userMessage]); + } else { + setMessages((prev) => filterAuthMessages(prev)); + } + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + setIsStreamingInitiated(true); + const dispatcher = createStreamEventDispatcher({ + setHasTextChunks, + setStreamingChunks, + streamingChunksRef, + setMessages, + sessionId, + setIsStreamingInitiated, + }); + try { + await sendStreamMessage( + sessionId, + content, + dispatcher, + isUserMessage, + context, + ); + } catch (err) { + console.error("Failed to send message:", err); + setIsStreamingInitiated(false); + const errorMessage = + err instanceof Error ? err.message : "Failed to send message"; + toast.error("Failed to send message", { + description: errorMessage, + }); + } + }, + [sessionId, sendStreamMessage], + ); + + return { + messages: allMessages, + streamingChunks, + isStreaming, + error, + sendMessage, + }; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx new file mode 100644 index 0000000000..4b9da57286 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx @@ -0,0 +1,149 @@ +import { Text } from "@/components/atoms/Text/Text"; +import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput"; +import type { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api"; +import { cn } from "@/lib/utils"; +import { CheckIcon, RobotIcon, WarningIcon } from "@phosphor-icons/react"; +import { useEffect, useRef } from "react"; +import { useChatCredentialsSetup } from "./useChatCredentialsSetup"; + +export interface CredentialInfo { + provider: string; + providerName: string; + credentialType: "api_key" | "oauth2" | "user_password" | "host_scoped"; + title: string; + scopes?: string[]; +} + +interface Props { + credentials: CredentialInfo[]; + agentName?: string; + message: string; + onAllCredentialsComplete: () => void; + onCancel: () => void; + className?: string; +} + +function createSchemaFromCredentialInfo( + credential: CredentialInfo, +): BlockIOCredentialsSubSchema { + return { + type: "object", + properties: {}, + credentials_provider: [credential.provider], + credentials_types: [credential.credentialType], + credentials_scopes: credential.scopes, + discriminator: undefined, + discriminator_mapping: undefined, + discriminator_values: undefined, + }; +} + +export function ChatCredentialsSetup({ + credentials, + agentName: _agentName, + message, + onAllCredentialsComplete, + onCancel: _onCancel, +}: Props) { + const { selectedCredentials, isAllComplete, handleCredentialSelect } = + useChatCredentialsSetup(credentials); + + // Track if we've already called completion to prevent double calls + const hasCalledCompleteRef = useRef(false); + + // Reset the completion flag when credentials change (new credential setup flow) + useEffect( + function resetCompletionFlag() { + hasCalledCompleteRef.current = false; + }, + [credentials], + ); + + // Auto-call completion when all credentials are configured + useEffect( + function autoCompleteWhenReady() { + if (isAllComplete && !hasCalledCompleteRef.current) { + hasCalledCompleteRef.current = true; + onAllCredentialsComplete(); + } + }, + [isAllComplete, onAllCredentialsComplete], + ); + + return ( +
+
+
+
+ +
+
+ +
+
+
+
+
+ + Credentials Required + + + {message} + +
+ +
+ {credentials.map((cred, index) => { + const schema = createSchemaFromCredentialInfo(cred); + const isSelected = !!selectedCredentials[cred.provider]; + + return ( +
+
+ {isSelected ? ( + + ) : ( + + )} + + {cred.providerName} + +
+ + + handleCredentialSelect(cred.provider, credMeta) + } + /> +
+ ); + })} +
+
+
+
+
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatCredentialsSetup/useChatCredentialsSetup.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatCredentialsSetup/useChatCredentialsSetup.ts similarity index 100% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatCredentialsSetup/useChatCredentialsSetup.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatCredentialsSetup/useChatCredentialsSetup.ts diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatErrorState/ChatErrorState.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatErrorState/ChatErrorState.tsx similarity index 100% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatErrorState/ChatErrorState.tsx rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatErrorState/ChatErrorState.tsx diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/ChatInput.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/ChatInput.tsx new file mode 100644 index 0000000000..3101174a11 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/ChatInput.tsx @@ -0,0 +1,64 @@ +import { Input } from "@/components/atoms/Input/Input"; +import { cn } from "@/lib/utils"; +import { ArrowUpIcon } from "@phosphor-icons/react"; +import { useChatInput } from "./useChatInput"; + +export interface ChatInputProps { + onSend: (message: string) => void; + disabled?: boolean; + placeholder?: string; + className?: string; +} + +export function ChatInput({ + onSend, + disabled = false, + placeholder = "Type your message...", + className, +}: ChatInputProps) { + const inputId = "chat-input"; + const { value, setValue, handleKeyDown, handleSend } = useChatInput({ + onSend, + disabled, + maxRows: 5, + inputId, + }); + + return ( +
+ setValue(e.target.value)} + onKeyDown={handleKeyDown} + placeholder={placeholder} + disabled={disabled} + rows={1} + wrapperClassName="mb-0 relative" + className="pr-12" + /> + + Press Enter to send, Shift+Enter for new line + + + +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/useChatInput.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/useChatInput.ts similarity index 65% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/useChatInput.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/useChatInput.ts index 2efae95483..08cf565daa 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/useChatInput.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatInput/useChatInput.ts @@ -1,21 +1,22 @@ -import { KeyboardEvent, useCallback, useState, useRef, useEffect } from "react"; +import { KeyboardEvent, useCallback, useEffect, useState } from "react"; interface UseChatInputArgs { onSend: (message: string) => void; disabled?: boolean; maxRows?: number; + inputId?: string; } export function useChatInput({ onSend, disabled = false, maxRows = 5, + inputId = "chat-input", }: UseChatInputArgs) { const [value, setValue] = useState(""); - const textareaRef = useRef(null); useEffect(() => { - const textarea = textareaRef.current; + const textarea = document.getElementById(inputId) as HTMLTextAreaElement; if (!textarea) return; textarea.style.height = "auto"; const lineHeight = parseInt( @@ -27,23 +28,25 @@ export function useChatInput({ textarea.style.height = `${newHeight}px`; textarea.style.overflowY = textarea.scrollHeight > maxHeight ? "auto" : "hidden"; - }, [value, maxRows]); + }, [value, maxRows, inputId]); const handleSend = useCallback(() => { if (disabled || !value.trim()) return; onSend(value.trim()); setValue(""); - if (textareaRef.current) { - textareaRef.current.style.height = "auto"; + const textarea = document.getElementById(inputId) as HTMLTextAreaElement; + if (textarea) { + textarea.style.height = "auto"; } - }, [value, onSend, disabled]); + }, [value, onSend, disabled, inputId]); const handleKeyDown = useCallback( - (event: KeyboardEvent) => { + (event: KeyboardEvent) => { if (event.key === "Enter" && !event.shiftKey) { event.preventDefault(); handleSend(); } + // Shift+Enter allows default behavior (new line) - no need to handle explicitly }, [handleSend], ); @@ -53,6 +56,5 @@ export function useChatInput({ setValue, handleKeyDown, handleSend, - textareaRef, }; } diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatLoadingState/ChatLoadingState.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatLoadingState/ChatLoadingState.tsx new file mode 100644 index 0000000000..c0cdb33c50 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatLoadingState/ChatLoadingState.tsx @@ -0,0 +1,19 @@ +import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; +import { cn } from "@/lib/utils"; + +export interface ChatLoadingStateProps { + message?: string; + className?: string; +} + +export function ChatLoadingState({ className }: ChatLoadingStateProps) { + return ( +
+
+ +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/ChatMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/ChatMessage.tsx new file mode 100644 index 0000000000..69a1ab63fb --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/ChatMessage.tsx @@ -0,0 +1,341 @@ +"use client"; + +import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/store"; +import Avatar, { + AvatarFallback, + AvatarImage, +} from "@/components/atoms/Avatar/Avatar"; +import { Button } from "@/components/atoms/Button/Button"; +import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; +import { cn } from "@/lib/utils"; +import { + ArrowClockwise, + CheckCircleIcon, + CheckIcon, + CopyIcon, + RobotIcon, +} from "@phosphor-icons/react"; +import { useRouter } from "next/navigation"; +import { useCallback, useState } from "react"; +import { getToolActionPhrase } from "../../helpers"; +import { AgentCarouselMessage } from "../AgentCarouselMessage/AgentCarouselMessage"; +import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget"; +import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup"; +import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage"; +import { MarkdownContent } from "../MarkdownContent/MarkdownContent"; +import { MessageBubble } from "../MessageBubble/MessageBubble"; +import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage"; +import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage"; +import { ToolResponseMessage } from "../ToolResponseMessage/ToolResponseMessage"; +import { useChatMessage, type ChatMessageData } from "./useChatMessage"; +export interface ChatMessageProps { + message: ChatMessageData; + className?: string; + onDismissLogin?: () => void; + onDismissCredentials?: () => void; + onSendMessage?: (content: string, isUserMessage?: boolean) => void; + agentOutput?: ChatMessageData; +} + +export function ChatMessage({ + message, + className, + onDismissCredentials, + onSendMessage, + agentOutput, +}: ChatMessageProps) { + const { user } = useSupabase(); + const router = useRouter(); + const [copied, setCopied] = useState(false); + const { + isUser, + isToolCall, + isToolResponse, + isLoginNeeded, + isCredentialsNeeded, + } = useChatMessage(message); + + const { data: profile } = useGetV2GetUserProfile({ + query: { + select: (res) => (res.status === 200 ? res.data : null), + enabled: isUser && !!user, + queryKey: ["/api/store/profile", user?.id], + }, + }); + + const handleAllCredentialsComplete = useCallback( + function handleAllCredentialsComplete() { + // Send a user message that explicitly asks to retry the setup + // This ensures the LLM calls get_required_setup_info again and proceeds with execution + if (onSendMessage) { + onSendMessage( + "I've configured the required credentials. Please check if everything is ready and proceed with setting up the agent.", + ); + } + // Optionally dismiss the credentials prompt + if (onDismissCredentials) { + onDismissCredentials(); + } + }, + [onSendMessage, onDismissCredentials], + ); + + function handleCancelCredentials() { + // Dismiss the credentials prompt + if (onDismissCredentials) { + onDismissCredentials(); + } + } + + const handleCopy = useCallback(async () => { + if (message.type !== "message") return; + + try { + await navigator.clipboard.writeText(message.content); + setCopied(true); + setTimeout(() => setCopied(false), 2000); + } catch (error) { + console.error("Failed to copy:", error); + } + }, [message]); + + const handleTryAgain = useCallback(() => { + if (message.type !== "message" || !onSendMessage) return; + onSendMessage(message.content, message.role === "user"); + }, [message, onSendMessage]); + + const handleViewExecution = useCallback(() => { + if (message.type === "execution_started" && message.libraryAgentLink) { + router.push(message.libraryAgentLink); + } + }, [message, router]); + + // Render credentials needed messages + if (isCredentialsNeeded && message.type === "credentials_needed") { + return ( + + ); + } + + // Render login needed messages + if (isLoginNeeded && message.type === "login_needed") { + // If user is already logged in, show success message instead of auth prompt + if (user) { + return ( +
+
+
+
+
+ +
+
+

+ Successfully Authenticated +

+

+ You're now signed in and ready to continue +

+
+
+
+
+
+ ); + } + + // Show auth prompt if not logged in + return ( +
+ +
+ ); + } + + // Render tool call messages + if (isToolCall && message.type === "tool_call") { + return ( +
+ +
+ ); + } + + // Render no_results messages - use dedicated component, not ToolResponseMessage + if (message.type === "no_results") { + return ( +
+ +
+ ); + } + + // Render agent_carousel messages - use dedicated component, not ToolResponseMessage + if (message.type === "agent_carousel") { + return ( +
+ +
+ ); + } + + // Render execution_started messages - use dedicated component, not ToolResponseMessage + if (message.type === "execution_started") { + return ( +
+ +
+ ); + } + + // Render tool response messages (but skip agent_output if it's being rendered inside assistant message) + if (isToolResponse && message.type === "tool_response") { + // Check if this is an agent_output that should be rendered inside assistant message + if (message.result) { + let parsedResult: Record | null = null; + try { + parsedResult = + typeof message.result === "string" + ? JSON.parse(message.result) + : (message.result as Record); + } catch { + parsedResult = null; + } + if (parsedResult?.type === "agent_output") { + // Skip rendering - this will be rendered inside the assistant message + return null; + } + } + + return ( +
+ +
+ ); + } + + // Render regular chat messages + if (message.type === "message") { + return ( +
+
+ {!isUser && ( +
+
+ +
+
+ )} + +
+ + + {agentOutput && + agentOutput.type === "tool_response" && + !isUser && ( +
+ +
+ )} +
+
+ {isUser && onSendMessage && ( + + )} + +
+
+ + {isUser && ( +
+ + + + {profile?.username?.charAt(0)?.toUpperCase() || "U"} + + +
+ )} +
+
+ ); + } + + // Fallback for unknown message types + return null; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatMessage/useChatMessage.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/useChatMessage.ts similarity index 88% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ChatMessage/useChatMessage.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/useChatMessage.ts index ae4f48f35b..9a597d4b26 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatMessage/useChatMessage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ChatMessage/useChatMessage.ts @@ -1,5 +1,5 @@ -import { formatDistanceToNow } from "date-fns"; import type { ToolArguments, ToolResult } from "@/types/chat"; +import { formatDistanceToNow } from "date-fns"; export type ChatMessageData = | { @@ -65,6 +65,7 @@ export type ChatMessageData = name: string; description: string; version?: number; + image_url?: string; }>; totalCount?: number; timestamp?: string | Date; @@ -77,6 +78,17 @@ export type ChatMessageData = message?: string; libraryAgentLink?: string; timestamp?: string | Date; + } + | { + type: "inputs_needed"; + toolName: string; + agentName?: string; + agentId?: string; + graphVersion?: number; + inputSchema: Record; + credentialsSchema?: Record; + message: string; + timestamp?: string | Date; }; export function useChatMessage(message: ChatMessageData) { @@ -96,5 +108,6 @@ export function useChatMessage(message: ChatMessageData) { isNoResults: message.type === "no_results", isAgentCarousel: message.type === "agent_carousel", isExecutionStarted: message.type === "execution_started", + isInputsNeeded: message.type === "inputs_needed", }; } diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ExecutionStartedMessage/ExecutionStartedMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ExecutionStartedMessage/ExecutionStartedMessage.tsx similarity index 67% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/ExecutionStartedMessage/ExecutionStartedMessage.tsx rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ExecutionStartedMessage/ExecutionStartedMessage.tsx index 77c7f8fe9b..1ac3b440e0 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ExecutionStartedMessage/ExecutionStartedMessage.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ExecutionStartedMessage/ExecutionStartedMessage.tsx @@ -1,8 +1,7 @@ -import React from "react"; -import { Text } from "@/components/atoms/Text/Text"; import { Button } from "@/components/atoms/Button/Button"; -import { CheckCircle, Play, ArrowSquareOut } from "@phosphor-icons/react"; +import { Text } from "@/components/atoms/Text/Text"; import { cn } from "@/lib/utils"; +import { ArrowSquareOut, CheckCircle, Play } from "@phosphor-icons/react"; export interface ExecutionStartedMessageProps { executionId: string; @@ -22,7 +21,7 @@ export function ExecutionStartedMessage({ return (
@@ -32,48 +31,33 @@ export function ExecutionStartedMessage({
- + Execution Started - + {message}
{/* Details */} -
+
{agentName && (
- + Agent: - + {agentName}
)}
- + Execution ID: - + {executionId.slice(0, 16)}...
@@ -94,7 +78,7 @@ export function ExecutionStartedMessage({
)} -
+
Your agent is now running. You can monitor its progress in the monitor diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/MarkdownContent/MarkdownContent.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MarkdownContent/MarkdownContent.tsx similarity index 80% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/MarkdownContent/MarkdownContent.tsx rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MarkdownContent/MarkdownContent.tsx index e82e29c438..51a0794090 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/MarkdownContent/MarkdownContent.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MarkdownContent/MarkdownContent.tsx @@ -1,9 +1,9 @@ "use client"; +import { cn } from "@/lib/utils"; import React from "react"; import ReactMarkdown from "react-markdown"; import remarkGfm from "remark-gfm"; -import { cn } from "@/lib/utils"; interface MarkdownContentProps { content: string; @@ -41,7 +41,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) { if (isInline) { return ( {children} @@ -49,17 +49,14 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) { ); } return ( - + {children} ); }, pre: ({ children, ...props }) => (
               {children}
@@ -70,7 +67,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
               href={href}
               target="_blank"
               rel="noopener noreferrer"
-              className="text-purple-600 underline decoration-1 underline-offset-2 hover:text-purple-700 dark:text-purple-400 dark:hover:text-purple-300"
+              className="text-purple-600 underline decoration-1 underline-offset-2 hover:text-purple-700"
               {...props}
             >
               {children}
@@ -126,7 +123,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
               return (
                 
@@ -136,57 +133,42 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
           },
           blockquote: ({ children, ...props }) => (
             
{children}
), h1: ({ children, ...props }) => ( -

+

{children}

), h2: ({ children, ...props }) => ( -

+

{children}

), h3: ({ children, ...props }) => (

{children}

), h4: ({ children, ...props }) => ( -

+

{children}

), h5: ({ children, ...props }) => ( -
+
{children}
), h6: ({ children, ...props }) => ( -
+
{children}
), @@ -196,15 +178,12 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {

), hr: ({ ...props }) => ( -
+
), table: ({ children, ...props }) => (
{children} @@ -213,7 +192,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) { ), th: ({ children, ...props }) => (
{children} @@ -221,7 +200,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) { ), td: ({ children, ...props }) => ( {children} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageBubble/MessageBubble.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageBubble/MessageBubble.tsx new file mode 100644 index 0000000000..98b50f3d28 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageBubble/MessageBubble.tsx @@ -0,0 +1,56 @@ +import { cn } from "@/lib/utils"; +import { ReactNode } from "react"; + +export interface MessageBubbleProps { + children: ReactNode; + variant: "user" | "assistant"; + className?: string; +} + +export function MessageBubble({ + children, + variant, + className, +}: MessageBubbleProps) { + const userTheme = { + bg: "bg-slate-900", + border: "border-slate-800", + gradient: "from-slate-900/30 via-slate-800/20 to-transparent", + text: "text-slate-50", + }; + + const assistantTheme = { + bg: "bg-slate-50/20", + border: "border-slate-100", + gradient: "from-slate-200/20 via-slate-300/10 to-transparent", + text: "text-slate-900", + }; + + const theme = variant === "user" ? userTheme : assistantTheme; + + return ( +
+ {/* Gradient flare background */} +
+
+ {children} +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageList/MessageList.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageList/MessageList.tsx new file mode 100644 index 0000000000..22b51c0a92 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageList/MessageList.tsx @@ -0,0 +1,121 @@ +"use client"; + +import { cn } from "@/lib/utils"; +import { ChatMessage } from "../ChatMessage/ChatMessage"; +import type { ChatMessageData } from "../ChatMessage/useChatMessage"; +import { StreamingMessage } from "../StreamingMessage/StreamingMessage"; +import { ThinkingMessage } from "../ThinkingMessage/ThinkingMessage"; +import { useMessageList } from "./useMessageList"; + +export interface MessageListProps { + messages: ChatMessageData[]; + streamingChunks?: string[]; + isStreaming?: boolean; + className?: string; + onStreamComplete?: () => void; + onSendMessage?: (content: string) => void; +} + +export function MessageList({ + messages, + streamingChunks = [], + isStreaming = false, + className, + onStreamComplete, + onSendMessage, +}: MessageListProps) { + const { messagesEndRef, messagesContainerRef } = useMessageList({ + messageCount: messages.length, + isStreaming, + }); + + return ( +
+
+ {/* Render all persisted messages */} + {messages.map((message, index) => { + // Check if current message is an agent_output tool_response + // and if previous message is an assistant message + let agentOutput: ChatMessageData | undefined; + + if (message.type === "tool_response" && message.result) { + let parsedResult: Record | null = null; + try { + parsedResult = + typeof message.result === "string" + ? JSON.parse(message.result) + : (message.result as Record); + } catch { + parsedResult = null; + } + if (parsedResult?.type === "agent_output") { + const prevMessage = messages[index - 1]; + if ( + prevMessage && + prevMessage.type === "message" && + prevMessage.role === "assistant" + ) { + // This agent output will be rendered inside the previous assistant message + // Skip rendering this message separately + return null; + } + } + } + + // Check if next message is an agent_output tool_response to include in current assistant message + if (message.type === "message" && message.role === "assistant") { + const nextMessage = messages[index + 1]; + if ( + nextMessage && + nextMessage.type === "tool_response" && + nextMessage.result + ) { + let parsedResult: Record | null = null; + try { + parsedResult = + typeof nextMessage.result === "string" + ? JSON.parse(nextMessage.result) + : (nextMessage.result as Record); + } catch { + parsedResult = null; + } + if (parsedResult?.type === "agent_output") { + agentOutput = nextMessage; + } + } + } + + return ( + + ); + })} + + {/* Render thinking message when streaming but no chunks yet */} + {isStreaming && streamingChunks.length === 0 && } + + {/* Render streaming message if active */} + {isStreaming && streamingChunks.length > 0 && ( + + )} + + {/* Invisible div to scroll to */} +
+
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/MessageList/useMessageList.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageList/useMessageList.ts similarity index 100% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/MessageList/useMessageList.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/MessageList/useMessageList.ts diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/NoResultsMessage/NoResultsMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/NoResultsMessage/NoResultsMessage.tsx similarity index 72% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/NoResultsMessage/NoResultsMessage.tsx rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/NoResultsMessage/NoResultsMessage.tsx index 38eac24bce..b6adc8b93c 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/NoResultsMessage/NoResultsMessage.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/NoResultsMessage/NoResultsMessage.tsx @@ -1,7 +1,6 @@ -import React from "react"; import { Text } from "@/components/atoms/Text/Text"; -import { MagnifyingGlass, X } from "@phosphor-icons/react"; import { cn } from "@/lib/utils"; +import { MagnifyingGlass, X } from "@phosphor-icons/react"; export interface NoResultsMessageProps { message: string; @@ -17,26 +16,26 @@ export function NoResultsMessage({ return (
{/* Icon */}
-
+
-
+
{/* Content */}
- + No Results Found - + {message}
@@ -44,17 +43,14 @@ export function NoResultsMessage({ {/* Suggestions */} {suggestions.length > 0 && (
- + Try these suggestions: -
    +
      {suggestions.map((suggestion, index) => (
    • {suggestion} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/QuickActionsWelcome/QuickActionsWelcome.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/QuickActionsWelcome/QuickActionsWelcome.tsx new file mode 100644 index 0000000000..dd76fd9fb6 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/QuickActionsWelcome/QuickActionsWelcome.tsx @@ -0,0 +1,94 @@ +"use client"; + +import { Text } from "@/components/atoms/Text/Text"; +import { cn } from "@/lib/utils"; + +export interface QuickActionsWelcomeProps { + title: string; + description: string; + actions: string[]; + onActionClick: (action: string) => void; + disabled?: boolean; + className?: string; +} + +export function QuickActionsWelcome({ + title, + description, + actions, + onActionClick, + disabled = false, + className, +}: QuickActionsWelcomeProps) { + return ( +
      +
      +
      + + {title} + + + {description} + +
      +
      + {actions.map((action) => { + // Use slate theme for all cards + const theme = { + bg: "bg-slate-50/10", + border: "border-slate-100", + hoverBg: "hover:bg-slate-50/20", + hoverBorder: "hover:border-slate-200", + gradient: "from-slate-200/20 via-slate-300/10 to-transparent", + text: "text-slate-900", + hoverText: "group-hover:text-slate-900", + }; + + return ( + + ); + })} +
      +
      +
      + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/SessionsDrawer/SessionsDrawer.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/SessionsDrawer/SessionsDrawer.tsx new file mode 100644 index 0000000000..74aa709a46 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/SessionsDrawer/SessionsDrawer.tsx @@ -0,0 +1,136 @@ +"use client"; + +import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat"; +import { Text } from "@/components/atoms/Text/Text"; +import { scrollbarStyles } from "@/components/styles/scrollbars"; +import { cn } from "@/lib/utils"; +import { X } from "@phosphor-icons/react"; +import { formatDistanceToNow } from "date-fns"; +import { Drawer } from "vaul"; + +interface SessionsDrawerProps { + isOpen: boolean; + onClose: () => void; + onSelectSession: (sessionId: string) => void; + currentSessionId?: string | null; +} + +export function SessionsDrawer({ + isOpen, + onClose, + onSelectSession, + currentSessionId, +}: SessionsDrawerProps) { + const { data, isLoading } = useGetV2ListSessions( + { limit: 100 }, + { + query: { + enabled: isOpen, + }, + }, + ); + + const sessions = + data?.status === 200 + ? data.data.sessions.filter((session) => { + // Filter out sessions without messages (sessions that were never updated) + // If updated_at equals created_at, the session was created but never had messages + return session.updated_at !== session.created_at; + }) + : []; + + function handleSelectSession(sessionId: string) { + onSelectSession(sessionId); + onClose(); + } + + return ( + !open && onClose()} + direction="right" + > + + + +
      +
      + + Chat Sessions + + +
      +
      + +
      + {isLoading ? ( +
      + + Loading sessions... + +
      + ) : sessions.length === 0 ? ( +
      + + No sessions found + +
      + ) : ( +
      + {sessions.map((session) => { + const isActive = session.id === currentSessionId; + const updatedAt = session.updated_at + ? formatDistanceToNow(new Date(session.updated_at), { + addSuffix: true, + }) + : ""; + + return ( + + ); + })} +
      + )} +
      +
      +
      +
      + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/StreamingMessage/StreamingMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/StreamingMessage/StreamingMessage.tsx new file mode 100644 index 0000000000..2a6e3d5822 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/StreamingMessage/StreamingMessage.tsx @@ -0,0 +1,42 @@ +import { cn } from "@/lib/utils"; +import { RobotIcon } from "@phosphor-icons/react"; +import { MarkdownContent } from "../MarkdownContent/MarkdownContent"; +import { MessageBubble } from "../MessageBubble/MessageBubble"; +import { useStreamingMessage } from "./useStreamingMessage"; + +export interface StreamingMessageProps { + chunks: string[]; + className?: string; + onComplete?: () => void; +} + +export function StreamingMessage({ + chunks, + className, + onComplete, +}: StreamingMessageProps) { + const { displayText } = useStreamingMessage({ chunks, onComplete }); + + return ( +
      +
      +
      +
      + +
      +
      + +
      + + + +
      +
      +
      + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/StreamingMessage/useStreamingMessage.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/StreamingMessage/useStreamingMessage.ts similarity index 100% rename from autogpt_platform/frontend/src/app/(platform)/chat/components/StreamingMessage/useStreamingMessage.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/StreamingMessage/useStreamingMessage.ts diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ThinkingMessage/ThinkingMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ThinkingMessage/ThinkingMessage.tsx new file mode 100644 index 0000000000..d8adddf416 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ThinkingMessage/ThinkingMessage.tsx @@ -0,0 +1,70 @@ +import { cn } from "@/lib/utils"; +import { RobotIcon } from "@phosphor-icons/react"; +import { useEffect, useRef, useState } from "react"; +import { MessageBubble } from "../MessageBubble/MessageBubble"; + +export interface ThinkingMessageProps { + className?: string; +} + +export function ThinkingMessage({ className }: ThinkingMessageProps) { + const [showSlowLoader, setShowSlowLoader] = useState(false); + const timerRef = useRef(null); + + useEffect(() => { + if (timerRef.current === null) { + timerRef.current = setTimeout(() => { + setShowSlowLoader(true); + }, 8000); + } + + return () => { + if (timerRef.current) { + clearTimeout(timerRef.current); + timerRef.current = null; + } + }; + }, []); + + return ( +
      +
      +
      +
      + +
      +
      + +
      + +
      + {showSlowLoader ? ( +
      +
      +

      + Taking a bit longer to think, wait a moment please +

      +
      + ) : ( + + Thinking... + + )} +
      + +
      +
      +
      + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolCallMessage/ToolCallMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolCallMessage/ToolCallMessage.tsx new file mode 100644 index 0000000000..97590ae0cf --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolCallMessage/ToolCallMessage.tsx @@ -0,0 +1,24 @@ +import { Text } from "@/components/atoms/Text/Text"; +import { cn } from "@/lib/utils"; +import { WrenchIcon } from "@phosphor-icons/react"; +import { getToolActionPhrase } from "../../helpers"; + +export interface ToolCallMessageProps { + toolName: string; + className?: string; +} + +export function ToolCallMessage({ toolName, className }: ToolCallMessageProps) { + return ( +
      + + + {getToolActionPhrase(toolName)}... + +
      + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolResponseMessage/ToolResponseMessage.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolResponseMessage/ToolResponseMessage.tsx new file mode 100644 index 0000000000..b84204c3ff --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/components/ToolResponseMessage/ToolResponseMessage.tsx @@ -0,0 +1,260 @@ +import { Text } from "@/components/atoms/Text/Text"; +import "@/components/contextual/OutputRenderers"; +import { + globalRegistry, + OutputItem, +} from "@/components/contextual/OutputRenderers"; +import { cn } from "@/lib/utils"; +import type { ToolResult } from "@/types/chat"; +import { WrenchIcon } from "@phosphor-icons/react"; +import { getToolActionPhrase } from "../../helpers"; + +export interface ToolResponseMessageProps { + toolName: string; + result?: ToolResult; + success?: boolean; + className?: string; +} + +export function ToolResponseMessage({ + toolName, + result, + success: _success = true, + className, +}: ToolResponseMessageProps) { + if (!result) { + return ( +
      + + + {getToolActionPhrase(toolName)}... + +
      + ); + } + + let parsedResult: Record | null = null; + try { + parsedResult = + typeof result === "string" + ? JSON.parse(result) + : (result as Record); + } catch { + parsedResult = null; + } + + if (parsedResult && typeof parsedResult === "object") { + const responseType = parsedResult.type as string | undefined; + + if (responseType === "agent_output") { + const execution = parsedResult.execution as + | { + outputs?: Record; + } + | null + | undefined; + const outputs = execution?.outputs || {}; + const message = parsedResult.message as string | undefined; + + return ( +
      +
      + + + {getToolActionPhrase(toolName)} + +
      + {message && ( +
      + + {message} + +
      + )} + {Object.keys(outputs).length > 0 && ( +
      + {Object.entries(outputs).map(([outputName, values]) => + values.map((value, index) => { + const renderer = globalRegistry.getRenderer(value); + if (renderer) { + return ( + + ); + } + return ( +
      + + {outputName} + +
      +                        {JSON.stringify(value, null, 2)}
      +                      
      +
      + ); + }), + )} +
      + )} +
      + ); + } + + if (responseType === "block_output" && parsedResult.outputs) { + const outputs = parsedResult.outputs as Record; + + return ( +
      +
      + + + {getToolActionPhrase(toolName)} + +
      +
      + {Object.entries(outputs).map(([outputName, values]) => + values.map((value, index) => { + const renderer = globalRegistry.getRenderer(value); + if (renderer) { + return ( + + ); + } + return ( +
      + + {outputName} + +
      +                      {JSON.stringify(value, null, 2)}
      +                    
      +
      + ); + }), + )} +
      +
      + ); + } + + // Handle other response types with a message field (e.g., understanding_updated) + if (parsedResult.message && typeof parsedResult.message === "string") { + // Format tool name from snake_case to Title Case + const formattedToolName = toolName + .split("_") + .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) + .join(" "); + + // Clean up message - remove incomplete user_name references + let cleanedMessage = parsedResult.message; + // Remove "Updated understanding with: user_name" pattern if user_name is just a placeholder + cleanedMessage = cleanedMessage.replace( + /Updated understanding with:\s*user_name\.?\s*/gi, + "", + ); + // Remove standalone user_name references + cleanedMessage = cleanedMessage.replace(/\buser_name\b\.?\s*/gi, ""); + cleanedMessage = cleanedMessage.trim(); + + // Only show message if it has content after cleaning + if (!cleanedMessage) { + return ( +
      + + + {formattedToolName} + +
      + ); + } + + return ( +
      +
      + + + {formattedToolName} + +
      +
      + + {cleanedMessage} + +
      +
      + ); + } + } + + const renderer = globalRegistry.getRenderer(result); + if (renderer) { + return ( +
      +
      + + + {getToolActionPhrase(toolName)} + +
      + +
      + ); + } + + return ( +
      + + + {getToolActionPhrase(toolName)}... + +
      + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/helpers.ts similarity index 92% rename from autogpt_platform/frontend/src/app/(platform)/chat/helpers.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/helpers.ts index 5a1e5eb93f..0fade56b73 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/helpers.ts @@ -64,10 +64,3 @@ export function getToolCompletionPhrase(toolName: string): string { `Finished ${toolName.replace(/_/g, " ").replace("...", "")}` ); } - -/** Validate UUID v4 format */ -export function isValidUUID(value: string): boolean { - const uuidRegex = - /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i; - return uuidRegex.test(value); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/useChatPage.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChat.ts similarity index 78% rename from autogpt_platform/frontend/src/app/(platform)/chat/useChatPage.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChat.ts index 4f1db5471a..8445d68b3f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/useChatPage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChat.ts @@ -1,17 +1,12 @@ "use client"; -import { useEffect, useRef } from "react"; -import { useRouter, useSearchParams } from "next/navigation"; -import { toast } from "sonner"; -import { useChatSession } from "@/app/(platform)/chat/useChatSession"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; -import { useChatStream } from "@/app/(platform)/chat/useChatStream"; +import { useEffect, useRef } from "react"; +import { toast } from "sonner"; +import { useChatSession } from "./useChatSession"; +import { useChatStream } from "./useChatStream"; -export function useChatPage() { - const router = useRouter(); - const searchParams = useSearchParams(); - const urlSessionId = - searchParams.get("session_id") || searchParams.get("session"); +export function useChat() { const hasCreatedSessionRef = useRef(false); const hasClaimedSessionRef = useRef(false); const { user } = useSupabase(); @@ -25,29 +20,24 @@ export function useChatPage() { isCreating, error, createSession, - refreshSession, claimSession, clearSession: clearSessionBase, + loadSession, } = useChatSession({ - urlSessionId, + urlSessionId: null, autoCreate: false, }); useEffect( function autoCreateSession() { - if ( - !urlSessionId && - !hasCreatedSessionRef.current && - !isCreating && - !sessionIdFromHook - ) { + if (!hasCreatedSessionRef.current && !isCreating && !sessionIdFromHook) { hasCreatedSessionRef.current = true; createSession().catch((_err) => { hasCreatedSessionRef.current = false; }); } }, - [urlSessionId, isCreating, sessionIdFromHook, createSession], + [isCreating, sessionIdFromHook, createSession], ); useEffect( @@ -111,7 +101,6 @@ export function useChatPage() { clearSessionBase(); hasCreatedSessionRef.current = false; hasClaimedSessionRef.current = false; - router.push("/chat"); } return { @@ -121,8 +110,8 @@ export function useChatPage() { isCreating, error, createSession, - refreshSession, clearSession, + loadSession, sessionId: sessionIdFromHook, }; } diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatDrawer.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatDrawer.ts new file mode 100644 index 0000000000..62e1a5a569 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatDrawer.ts @@ -0,0 +1,17 @@ +"use client"; + +import { create } from "zustand"; + +interface ChatDrawerState { + isOpen: boolean; + open: () => void; + close: () => void; + toggle: () => void; +} + +export const useChatDrawer = create((set) => ({ + isOpen: false, + open: () => set({ isOpen: true }), + close: () => set({ isOpen: false }), + toggle: () => set((state) => ({ isOpen: !state.isOpen })), +})); diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/useChatSession.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatSession.ts similarity index 89% rename from autogpt_platform/frontend/src/app/(platform)/chat/useChatSession.ts rename to autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatSession.ts index 99f4efc093..a54dc9e32a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/chat/useChatSession.ts +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatSession.ts @@ -1,17 +1,18 @@ -import { useCallback, useEffect, useState, useRef, useMemo } from "react"; -import { useQueryClient } from "@tanstack/react-query"; -import { toast } from "sonner"; import { - usePostV2CreateSession, + getGetV2GetSessionQueryKey, + getGetV2GetSessionQueryOptions, postV2CreateSession, useGetV2GetSession, usePatchV2SessionAssignUser, - getGetV2GetSessionQueryKey, + usePostV2CreateSession, } from "@/app/api/__generated__/endpoints/chat/chat"; import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; -import { storage, Key } from "@/services/storage/local-storage"; -import { isValidUUID } from "@/app/(platform)/chat/helpers"; import { okData } from "@/app/api/helpers"; +import { isValidUUID } from "@/lib/utils"; +import { Key, storage } from "@/services/storage/local-storage"; +import { useQueryClient } from "@tanstack/react-query"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { toast } from "sonner"; interface UseChatSessionArgs { urlSessionId?: string | null; @@ -155,10 +156,22 @@ export function useChatSession({ async function loadSession(id: string) { try { setError(null); + // Invalidate the query cache for this session to force a fresh fetch + await queryClient.invalidateQueries({ + queryKey: getGetV2GetSessionQueryKey(id), + }); + // Set sessionId after invalidation to ensure the hook refetches setSessionId(id); storage.set(Key.CHAT_SESSION_ID, id); - const result = await refetch(); - if (!result.data || result.isError) { + // Force fetch with fresh data (bypass cache) + const queryOptions = getGetV2GetSessionQueryOptions(id, { + query: { + staleTime: 0, // Force fresh fetch + retry: 1, + }, + }); + const result = await queryClient.fetchQuery(queryOptions); + if (!result || ("status" in result && result.status !== 200)) { console.warn("Session not found on server, clearing local state"); storage.clean(Key.CHAT_SESSION_ID); setSessionId(null); @@ -171,7 +184,7 @@ export function useChatSession({ throw error; } }, - [refetch], + [queryClient], ); const refreshSession = useCallback( diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatStream.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatStream.ts new file mode 100644 index 0000000000..1471a13a71 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/useChatStream.ts @@ -0,0 +1,371 @@ +import type { ToolArguments, ToolResult } from "@/types/chat"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { toast } from "sonner"; + +const MAX_RETRIES = 3; +const INITIAL_RETRY_DELAY = 1000; + +export interface StreamChunk { + type: + | "text_chunk" + | "text_ended" + | "tool_call" + | "tool_call_start" + | "tool_response" + | "login_needed" + | "need_login" + | "credentials_needed" + | "error" + | "usage" + | "stream_end"; + timestamp?: string; + content?: string; + message?: string; + tool_id?: string; + tool_name?: string; + arguments?: ToolArguments; + result?: ToolResult; + success?: boolean; + idx?: number; + session_id?: string; + agent_info?: { + graph_id: string; + name: string; + trigger_type: string; + }; + provider?: string; + provider_name?: string; + credential_type?: string; + scopes?: string[]; + title?: string; + [key: string]: unknown; +} + +type VercelStreamChunk = + | { type: "start"; messageId: string } + | { type: "finish" } + | { type: "text-start"; id: string } + | { type: "text-delta"; id: string; delta: string } + | { type: "text-end"; id: string } + | { type: "tool-input-start"; toolCallId: string; toolName: string } + | { + type: "tool-input-available"; + toolCallId: string; + toolName: string; + input: ToolArguments; + } + | { + type: "tool-output-available"; + toolCallId: string; + toolName?: string; + output: ToolResult; + success?: boolean; + } + | { + type: "usage"; + promptTokens: number; + completionTokens: number; + totalTokens: number; + } + | { + type: "error"; + errorText: string; + code?: string; + details?: Record; + }; + +const LEGACY_STREAM_TYPES = new Set([ + "text_chunk", + "text_ended", + "tool_call", + "tool_call_start", + "tool_response", + "login_needed", + "need_login", + "credentials_needed", + "error", + "usage", + "stream_end", +]); + +function isLegacyStreamChunk( + chunk: StreamChunk | VercelStreamChunk, +): chunk is StreamChunk { + return LEGACY_STREAM_TYPES.has(chunk.type as StreamChunk["type"]); +} + +function normalizeStreamChunk( + chunk: StreamChunk | VercelStreamChunk, +): StreamChunk | null { + if (isLegacyStreamChunk(chunk)) { + return chunk; + } + switch (chunk.type) { + case "text-delta": + return { type: "text_chunk", content: chunk.delta }; + case "text-end": + return { type: "text_ended" }; + case "tool-input-available": + return { + type: "tool_call_start", + tool_id: chunk.toolCallId, + tool_name: chunk.toolName, + arguments: chunk.input, + }; + case "tool-output-available": + return { + type: "tool_response", + tool_id: chunk.toolCallId, + tool_name: chunk.toolName, + result: chunk.output, + success: chunk.success ?? true, + }; + case "usage": + return { + type: "usage", + promptTokens: chunk.promptTokens, + completionTokens: chunk.completionTokens, + totalTokens: chunk.totalTokens, + }; + case "error": + return { + type: "error", + message: chunk.errorText, + code: chunk.code, + details: chunk.details, + }; + case "finish": + return { type: "stream_end" }; + case "start": + case "text-start": + case "tool-input-start": + return null; + } +} + +export function useChatStream() { + const [isStreaming, setIsStreaming] = useState(false); + const [error, setError] = useState(null); + const retryCountRef = useRef(0); + const retryTimeoutRef = useRef(null); + const abortControllerRef = useRef(null); + + const stopStreaming = useCallback(() => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + } + if (retryTimeoutRef.current) { + clearTimeout(retryTimeoutRef.current); + retryTimeoutRef.current = null; + } + setIsStreaming(false); + }, []); + + useEffect(() => { + return () => { + stopStreaming(); + }; + }, [stopStreaming]); + + const sendMessage = useCallback( + async ( + sessionId: string, + message: string, + onChunk: (chunk: StreamChunk) => void, + isUserMessage: boolean = true, + context?: { url: string; content: string }, + isRetry: boolean = false, + ) => { + stopStreaming(); + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + if (abortController.signal.aborted) { + return Promise.reject(new Error("Request aborted")); + } + + if (!isRetry) { + retryCountRef.current = 0; + } + setIsStreaming(true); + setError(null); + + try { + const url = `/api/chat/sessions/${sessionId}/stream`; + const body = JSON.stringify({ + message, + is_user_message: isUserMessage, + context: context || null, + }); + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + }, + body, + signal: abortController.signal, + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(errorText || `HTTP ${response.status}`); + } + + if (!response.body) { + throw new Error("Response body is null"); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + return new Promise((resolve, reject) => { + let didDispatchStreamEnd = false; + + function dispatchStreamEnd() { + if (didDispatchStreamEnd) return; + didDispatchStreamEnd = true; + onChunk({ type: "stream_end" }); + } + + const cleanup = () => { + reader.cancel().catch(() => { + // Ignore cancel errors + }); + }; + + async function readStream() { + try { + while (true) { + const { done, value } = await reader.read(); + + if (done) { + cleanup(); + dispatchStreamEnd(); + retryCountRef.current = 0; + stopStreaming(); + resolve(); + return; + } + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + if (line.startsWith("data: ")) { + const data = line.slice(6); + if (data === "[DONE]") { + cleanup(); + dispatchStreamEnd(); + retryCountRef.current = 0; + stopStreaming(); + resolve(); + return; + } + + try { + const rawChunk = JSON.parse(data) as + | StreamChunk + | VercelStreamChunk; + const chunk = normalizeStreamChunk(rawChunk); + if (!chunk) { + continue; + } + + // Call the chunk handler + onChunk(chunk); + + // Handle stream lifecycle + if (chunk.type === "stream_end") { + didDispatchStreamEnd = true; + cleanup(); + retryCountRef.current = 0; + stopStreaming(); + resolve(); + return; + } else if (chunk.type === "error") { + cleanup(); + reject( + new Error( + chunk.message || chunk.content || "Stream error", + ), + ); + return; + } + } catch (err) { + // Skip invalid JSON lines + console.warn("Failed to parse SSE chunk:", err, data); + } + } + } + } + } catch (err) { + if (err instanceof Error && err.name === "AbortError") { + cleanup(); + return; + } + + const streamError = + err instanceof Error ? err : new Error("Failed to read stream"); + + if (retryCountRef.current < MAX_RETRIES) { + retryCountRef.current += 1; + const retryDelay = + INITIAL_RETRY_DELAY * Math.pow(2, retryCountRef.current - 1); + + toast.info("Connection interrupted", { + description: `Retrying in ${retryDelay / 1000} seconds...`, + }); + + retryTimeoutRef.current = setTimeout(() => { + sendMessage( + sessionId, + message, + onChunk, + isUserMessage, + context, + true, + ).catch((_err) => { + // Retry failed + }); + }, retryDelay); + } else { + setError(streamError); + toast.error("Connection Failed", { + description: + "Unable to connect to chat service. Please try again.", + }); + cleanup(); + dispatchStreamEnd(); + retryCountRef.current = 0; + stopStreaming(); + reject(streamError); + } + } + } + + readStream(); + }); + } catch (err) { + const streamError = + err instanceof Error ? err : new Error("Failed to start stream"); + setError(streamError); + setIsStreaming(false); + throw streamError; + } + }, + [stopStreaming], + ); + + return { + isStreaming, + error, + sendMessage, + stopStreaming, + }; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/usePageContext.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/usePageContext.ts new file mode 100644 index 0000000000..c567422a5c --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/chat/components/Chat/usePageContext.ts @@ -0,0 +1,98 @@ +import { useCallback } from "react"; + +export interface PageContext { + url: string; + content: string; +} + +const MAX_CONTENT_CHARS = 10000; + +/** + * Hook to capture the current page context (URL + full page content) + * Privacy-hardened: removes sensitive inputs and enforces content size limits + */ +export function usePageContext() { + const capturePageContext = useCallback((): PageContext => { + if (typeof window === "undefined" || typeof document === "undefined") { + return { url: "", content: "" }; + } + + const url = window.location.href; + + // Clone document to avoid modifying the original + const clone = document.cloneNode(true) as Document; + + // Remove script, style, and noscript elements + const scripts = clone.querySelectorAll("script, style, noscript"); + scripts.forEach((el) => el.remove()); + + // Remove sensitive elements and their content + const sensitiveSelectors = [ + "input", + "textarea", + "[contenteditable]", + 'input[type="password"]', + 'input[type="email"]', + 'input[type="tel"]', + 'input[type="search"]', + 'input[type="hidden"]', + "form", + "[data-sensitive]", + "[data-sensitive='true']", + ]; + + sensitiveSelectors.forEach((selector) => { + const elements = clone.querySelectorAll(selector); + elements.forEach((el) => { + // For form elements, remove the entire element + if (el.tagName === "FORM") { + el.remove(); + } else { + // For inputs and textareas, clear their values but keep the element structure + if ( + el instanceof HTMLInputElement || + el instanceof HTMLTextAreaElement + ) { + el.value = ""; + el.textContent = ""; + } else { + // For other sensitive elements, remove them entirely + el.remove(); + } + } + }); + }); + + // Strip any remaining input values that might have been missed + const allInputs = clone.querySelectorAll("input, textarea"); + allInputs.forEach((el) => { + if (el instanceof HTMLInputElement || el instanceof HTMLTextAreaElement) { + el.value = ""; + el.textContent = ""; + } + }); + + // Get text content from body + const body = clone.body; + const content = body?.textContent || body?.innerText || ""; + + // Clean up whitespace + let cleanedContent = content + .replace(/\s+/g, " ") + .replace(/\n\s*\n/g, "\n") + .trim(); + + // Enforce maximum content size + if (cleanedContent.length > MAX_CONTENT_CHARS) { + cleanedContent = + cleanedContent.substring(0, MAX_CONTENT_CHARS) + "... [truncated]"; + } + + return { + url, + content: cleanedContent, + }; + }, []); + + return { capturePageContext }; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/ChatContainer.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/ChatContainer.tsx deleted file mode 100644 index 32f9d6c6eb..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/ChatContainer.tsx +++ /dev/null @@ -1,68 +0,0 @@ -import { cn } from "@/lib/utils"; -import { ChatInput } from "@/app/(platform)/chat/components/ChatInput/ChatInput"; -import { MessageList } from "@/app/(platform)/chat/components/MessageList/MessageList"; -import { QuickActionsWelcome } from "@/app/(platform)/chat/components/QuickActionsWelcome/QuickActionsWelcome"; -import { useChatContainer } from "./useChatContainer"; -import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; - -export interface ChatContainerProps { - sessionId: string | null; - initialMessages: SessionDetailResponse["messages"]; - onRefreshSession: () => Promise; - className?: string; -} - -export function ChatContainer({ - sessionId, - initialMessages, - onRefreshSession, - className, -}: ChatContainerProps) { - const { messages, streamingChunks, isStreaming, sendMessage } = - useChatContainer({ - sessionId, - initialMessages, - onRefreshSession, - }); - - const quickActions = [ - "Find agents for social media management", - "Show me agents for content creation", - "Help me automate my business", - "What can you help me with?", - ]; - - return ( -
      - {/* Messages or Welcome Screen */} - {messages.length === 0 ? ( - - ) : ( - - )} - - {/* Input - Always visible */} -
      - -
      -
      - ); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.ts b/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.ts deleted file mode 100644 index c75ad587a5..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatContainer/useChatContainer.ts +++ /dev/null @@ -1,130 +0,0 @@ -import { useState, useCallback, useRef, useMemo } from "react"; -import { toast } from "sonner"; -import { useChatStream } from "@/app/(platform)/chat/useChatStream"; -import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; -import type { ChatMessageData } from "@/app/(platform)/chat/components/ChatMessage/useChatMessage"; -import { - parseToolResponse, - isValidMessage, - isToolCallArray, - createUserMessage, - filterAuthMessages, -} from "./helpers"; -import { createStreamEventDispatcher } from "./createStreamEventDispatcher"; - -interface UseChatContainerArgs { - sessionId: string | null; - initialMessages: SessionDetailResponse["messages"]; - onRefreshSession: () => Promise; -} - -export function useChatContainer({ - sessionId, - initialMessages, -}: UseChatContainerArgs) { - const [messages, setMessages] = useState([]); - const [streamingChunks, setStreamingChunks] = useState([]); - const [hasTextChunks, setHasTextChunks] = useState(false); - const streamingChunksRef = useRef([]); - const { error, sendMessage: sendStreamMessage } = useChatStream(); - const isStreaming = hasTextChunks; - - const allMessages = useMemo(() => { - const processedInitialMessages = initialMessages - .filter((msg: Record) => { - if (!isValidMessage(msg)) { - console.warn("Invalid message structure from backend:", msg); - return false; - } - const content = String(msg.content || "").trim(); - const toolCalls = msg.tool_calls; - return ( - content.length > 0 || - (toolCalls && Array.isArray(toolCalls) && toolCalls.length > 0) - ); - }) - .map((msg: Record) => { - const content = String(msg.content || ""); - const role = String(msg.role || "assistant").toLowerCase(); - const toolCalls = msg.tool_calls; - if ( - role === "assistant" && - toolCalls && - isToolCallArray(toolCalls) && - toolCalls.length > 0 - ) { - return null; - } - if (role === "tool") { - const timestamp = msg.timestamp - ? new Date(msg.timestamp as string) - : undefined; - const toolResponse = parseToolResponse( - content, - (msg.tool_call_id as string) || "", - "unknown", - timestamp, - ); - if (!toolResponse) { - return null; - } - return toolResponse; - } - return { - type: "message", - role: role as "user" | "assistant" | "system", - content, - timestamp: msg.timestamp - ? new Date(msg.timestamp as string) - : undefined, - }; - }) - .filter((msg): msg is ChatMessageData => msg !== null); - - return [...processedInitialMessages, ...messages]; - }, [initialMessages, messages]); - - const sendMessage = useCallback( - async function sendMessage(content: string, isUserMessage: boolean = true) { - if (!sessionId) { - console.error("Cannot send message: no session ID"); - return; - } - if (isUserMessage) { - const userMessage = createUserMessage(content); - setMessages((prev) => [...filterAuthMessages(prev), userMessage]); - } else { - setMessages((prev) => filterAuthMessages(prev)); - } - setStreamingChunks([]); - streamingChunksRef.current = []; - setHasTextChunks(false); - const dispatcher = createStreamEventDispatcher({ - setHasTextChunks, - setStreamingChunks, - streamingChunksRef, - setMessages, - sessionId, - }); - try { - await sendStreamMessage(sessionId, content, dispatcher, isUserMessage); - } catch (err) { - console.error("Failed to send message:", err); - const errorMessage = - err instanceof Error ? err.message : "Failed to send message"; - toast.error("Failed to send message", { - description: errorMessage, - }); - } - }, - [sessionId, sendStreamMessage], - ); - - return { - messages: allMessages, - streamingChunks, - isStreaming, - error, - sendMessage, - }; -} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx deleted file mode 100644 index 3868e17a10..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatCredentialsSetup/ChatCredentialsSetup.tsx +++ /dev/null @@ -1,153 +0,0 @@ -import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput"; -import { Card } from "@/components/atoms/Card/Card"; -import { Text } from "@/components/atoms/Text/Text"; -import type { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api"; -import { cn } from "@/lib/utils"; -import { CheckIcon, KeyIcon, WarningIcon } from "@phosphor-icons/react"; -import { useEffect, useRef } from "react"; -import { useChatCredentialsSetup } from "./useChatCredentialsSetup"; - -export interface CredentialInfo { - provider: string; - providerName: string; - credentialType: "api_key" | "oauth2" | "user_password" | "host_scoped"; - title: string; - scopes?: string[]; -} - -interface Props { - credentials: CredentialInfo[]; - agentName?: string; - message: string; - onAllCredentialsComplete: () => void; - onCancel: () => void; - className?: string; -} - -function createSchemaFromCredentialInfo( - credential: CredentialInfo, -): BlockIOCredentialsSubSchema { - return { - type: "object", - properties: {}, - credentials_provider: [credential.provider], - credentials_types: [credential.credentialType], - credentials_scopes: credential.scopes, - discriminator: undefined, - discriminator_mapping: undefined, - discriminator_values: undefined, - }; -} - -export function ChatCredentialsSetup({ - credentials, - agentName: _agentName, - message, - onAllCredentialsComplete, - onCancel: _onCancel, - className, -}: Props) { - const { selectedCredentials, isAllComplete, handleCredentialSelect } = - useChatCredentialsSetup(credentials); - - // Track if we've already called completion to prevent double calls - const hasCalledCompleteRef = useRef(false); - - // Reset the completion flag when credentials change (new credential setup flow) - useEffect( - function resetCompletionFlag() { - hasCalledCompleteRef.current = false; - }, - [credentials], - ); - - // Auto-call completion when all credentials are configured - useEffect( - function autoCompleteWhenReady() { - if (isAllComplete && !hasCalledCompleteRef.current) { - hasCalledCompleteRef.current = true; - onAllCredentialsComplete(); - } - }, - [isAllComplete, onAllCredentialsComplete], - ); - - return ( - -
      -
      - -
      -
      - - Credentials Required - - - {message} - - -
      - {credentials.map((cred, index) => { - const schema = createSchemaFromCredentialInfo(cred); - const isSelected = !!selectedCredentials[cred.provider]; - - return ( -
      -
      -
      - {isSelected ? ( - - ) : ( - - )} - - {cred.providerName} - -
      -
      - - - handleCredentialSelect(cred.provider, credMeta) - } - /> -
      - ); - })} -
      -
      -
      -
      - ); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/ChatInput.tsx b/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/ChatInput.tsx deleted file mode 100644 index f1caceef70..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/chat/components/ChatInput/ChatInput.tsx +++ /dev/null @@ -1,63 +0,0 @@ -import { cn } from "@/lib/utils"; -import { PaperPlaneRightIcon } from "@phosphor-icons/react"; -import { Button } from "@/components/atoms/Button/Button"; -import { useChatInput } from "./useChatInput"; - -export interface ChatInputProps { - onSend: (message: string) => void; - disabled?: boolean; - placeholder?: string; - className?: string; -} - -export function ChatInput({ - onSend, - disabled = false, - placeholder = "Type your message...", - className, -}: ChatInputProps) { - const { value, setValue, handleKeyDown, handleSend, textareaRef } = - useChatInput({ - onSend, - disabled, - maxRows: 5, - }); - - return ( -
      -