mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-16 10:38:01 -05:00
Compare commits
9 Commits
feat/copil
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e55f05c7a8 | ||
|
|
4a9b13acb6 | ||
|
|
5ff669e999 | ||
|
|
ec03a13e26 | ||
|
|
b08851f5d7 | ||
|
|
8b1720e61d | ||
|
|
aa5a039c5e | ||
|
|
8b83bb8647 | ||
|
|
e80e4d9cbb |
@@ -1,6 +1,9 @@
|
||||
# Ignore everything by default, selectively add things to context
|
||||
*
|
||||
|
||||
# Documentation (for embeddings/search)
|
||||
!docs/
|
||||
|
||||
# Platform - Libs
|
||||
!autogpt_platform/autogpt_libs/autogpt_libs/
|
||||
!autogpt_platform/autogpt_libs/pyproject.toml
|
||||
|
||||
@@ -100,6 +100,7 @@ COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migration
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
COPY docs /app/docs
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
@@ -299,9 +299,6 @@ async def stream_chat_completion(
|
||||
f"new message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
if len(session.messages) > config.max_context_messages:
|
||||
raise ValueError(f"Max messages exceeded: {config.max_context_messages}")
|
||||
|
||||
logger.info(
|
||||
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
||||
f"message_count={len(session.messages)}"
|
||||
|
||||
@@ -8,8 +8,12 @@ from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .get_doc_page import GetDocPageTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .search_docs import SearchDocsTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
@@ -18,9 +22,13 @@ if TYPE_CHECKING:
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"find_agent": FindAgentTool(),
|
||||
"find_block": FindBlockTool(),
|
||||
"find_library_agent": FindLibraryAgentTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"agent_output": AgentOutputTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
}
|
||||
|
||||
# Export individual tool instances for backwards compatibility
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.api.features.chat.tools.models import (
|
||||
BlockInfoSummary,
|
||||
BlockInputFieldInfo,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.data.block import get_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindBlockTool(BaseTool):
|
||||
"""Tool for searching available blocks."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for available blocks by name or description. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||
"The response includes each block's id, required_inputs, and input_schema."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find blocks by name or description. "
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for blocks matching the query.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
BlockListResponse: List of matching blocks
|
||||
NoResultsResponse: No blocks found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
"Check spelling of technical terms",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Enrich results with full block information
|
||||
blocks: list[BlockInfoSummary] = []
|
||||
for result in results:
|
||||
block_id = result["content_id"]
|
||||
block = get_block(block_id)
|
||||
|
||||
if block:
|
||||
# Get input/output schemas
|
||||
input_schema = {}
|
||||
output_schema = {}
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get categories from block instance
|
||||
categories = []
|
||||
if hasattr(block, "categories") and block.categories:
|
||||
categories = [cat.value for cat in block.categories]
|
||||
|
||||
# Extract required inputs for easier use
|
||||
required_inputs: list[BlockInputFieldInfo] = []
|
||||
if input_schema:
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = set(input_schema.get("required", []))
|
||||
# Get credential field names to exclude from required inputs
|
||||
credentials_fields = set(
|
||||
block.input_schema.get_credentials_fields().keys()
|
||||
)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields - they're handled separately
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
required_inputs.append(
|
||||
BlockInputFieldInfo(
|
||||
name=field_name,
|
||||
type=field_schema.get("type", "string"),
|
||||
description=field_schema.get("description", ""),
|
||||
required=field_name in required_fields,
|
||||
default=field_schema.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=categories,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
required_inputs=required_inputs,
|
||||
)
|
||||
)
|
||||
|
||||
if not blocks:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block(s) matching '{query}'. "
|
||||
"To execute a block, use run_block with the block's 'id' field "
|
||||
"and provide 'input_data' matching the block's input_schema."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching blocks: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search blocks",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,148 @@
|
||||
"""GetDocPageTool - Fetch full content of a documentation page."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocPageResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Base URL for documentation (can be configured)
|
||||
DOCS_BASE_URL = "https://docs.agpt.co"
|
||||
|
||||
|
||||
class GetDocPageTool(BaseTool):
|
||||
"""Tool for fetching full content of a documentation page."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_doc_page"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Get the full content of a documentation page by its path. "
|
||||
"Use this after search_docs to read the complete content of a relevant page."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The path to the documentation file, as returned by search_docs. "
|
||||
"Example: 'platform/block-sdk-guide.md'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False # Documentation is public
|
||||
|
||||
def _get_docs_root(self) -> Path:
|
||||
"""Get the documentation root directory."""
|
||||
this_file = Path(__file__)
|
||||
project_root = this_file.parent.parent.parent.parent.parent.parent.parent.parent
|
||||
return project_root / "docs"
|
||||
|
||||
def _extract_title(self, content: str, fallback: str) -> str:
|
||||
"""Extract title from markdown content."""
|
||||
lines = content.split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("# "):
|
||||
return line[2:].strip()
|
||||
return fallback
|
||||
|
||||
def _make_doc_url(self, path: str) -> str:
|
||||
"""Create a URL for a documentation page."""
|
||||
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
||||
return f"{DOCS_BASE_URL}/{url_path}"
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Fetch full content of a documentation page.
|
||||
|
||||
Args:
|
||||
user_id: User ID (not required for docs)
|
||||
session: Chat session
|
||||
path: Path to the documentation file
|
||||
|
||||
Returns:
|
||||
DocPageResponse: Full document content
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
path = kwargs.get("path", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide a documentation path.",
|
||||
error="Missing path parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Sanitize path to prevent directory traversal
|
||||
if ".." in path or path.startswith("/"):
|
||||
return ErrorResponse(
|
||||
message="Invalid documentation path.",
|
||||
error="invalid_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
docs_root = self._get_docs_root()
|
||||
full_path = docs_root / path
|
||||
|
||||
if not full_path.exists():
|
||||
return ErrorResponse(
|
||||
message=f"Documentation page not found: {path}",
|
||||
error="not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure the path is within docs root
|
||||
try:
|
||||
full_path.resolve().relative_to(docs_root.resolve())
|
||||
except ValueError:
|
||||
return ErrorResponse(
|
||||
message="Invalid documentation path.",
|
||||
error="invalid_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
content = full_path.read_text(encoding="utf-8")
|
||||
title = self._extract_title(content, path)
|
||||
|
||||
return DocPageResponse(
|
||||
message=f"Retrieved documentation page: {title}",
|
||||
title=title,
|
||||
path=path,
|
||||
content=content,
|
||||
doc_url=self._make_doc_url(path),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read documentation page {path}: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read documentation page: {str(e)}",
|
||||
error="read_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -21,6 +21,10 @@ class ResponseType(str, Enum):
|
||||
NO_RESULTS = "no_results"
|
||||
AGENT_OUTPUT = "agent_output"
|
||||
UNDERSTANDING_UPDATED = "understanding_updated"
|
||||
BLOCK_LIST = "block_list"
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||
DOC_PAGE = "doc_page"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -209,3 +213,83 @@ class UnderstandingUpdatedResponse(ToolResponseBase):
|
||||
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
|
||||
updated_fields: list[str] = Field(default_factory=list)
|
||||
current_understanding: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# Documentation search models
|
||||
class DocSearchResult(BaseModel):
|
||||
"""A single documentation search result."""
|
||||
|
||||
title: str
|
||||
path: str
|
||||
section: str
|
||||
snippet: str # Short excerpt for UI display
|
||||
score: float
|
||||
doc_url: str | None = None
|
||||
|
||||
|
||||
class DocSearchResultsResponse(ToolResponseBase):
|
||||
"""Response for search_docs tool."""
|
||||
|
||||
type: ResponseType = ResponseType.DOC_SEARCH_RESULTS
|
||||
results: list[DocSearchResult]
|
||||
count: int
|
||||
query: str
|
||||
|
||||
|
||||
class DocPageResponse(ToolResponseBase):
|
||||
"""Response for get_doc_page tool."""
|
||||
|
||||
type: ResponseType = ResponseType.DOC_PAGE
|
||||
title: str
|
||||
path: str
|
||||
content: str # Full document content
|
||||
doc_url: str | None = None
|
||||
|
||||
|
||||
# Block models
|
||||
class BlockInputFieldInfo(BaseModel):
|
||||
"""Information about a block input field."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
description: str = ""
|
||||
required: bool = False
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
class BlockInfoSummary(BaseModel):
|
||||
"""Summary of a block for search results."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||
default_factory=list,
|
||||
description="List of required input fields for this block",
|
||||
)
|
||||
|
||||
|
||||
class BlockListResponse(ToolResponseBase):
|
||||
"""Response for find_block tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_LIST
|
||||
blocks: list[BlockInfoSummary]
|
||||
count: int
|
||||
query: str
|
||||
usage_hint: str = Field(
|
||||
default="To execute a block, call run_block with block_id set to the block's "
|
||||
"'id' field and input_data containing the required fields from input_schema."
|
||||
)
|
||||
|
||||
|
||||
class BlockOutputResponse(ToolResponseBase):
|
||||
"""Response for run_block tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_OUTPUT
|
||||
block_id: str
|
||||
block_name: str
|
||||
outputs: dict[str, list[Any]]
|
||||
success: bool = True
|
||||
|
||||
@@ -0,0 +1,297 @@
|
||||
"""Tool for executing blocks directly."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunBlockTool(BaseTool):
|
||||
"""Tool for executing a block and returning its outputs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a specific block with the provided input data. "
|
||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||
"do NOT guess or make up block IDs. "
|
||||
"Use the 'id' from find_block results and provide input_data "
|
||||
"matching the block's required_inputs."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"block_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The block's 'id' field from find_block results. "
|
||||
"NEVER guess this - always get it from find_block first."
|
||||
),
|
||||
},
|
||||
"input_data": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Input values for the block. Use the 'required_inputs' field "
|
||||
"from find_block to see what fields are needed."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
if not credentials_fields_info:
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
# Get user's available credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
# field_info.provider is a frozenset of acceptable providers
|
||||
# field_info.supported_types is a frozenset of acceptable types
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in field_info.provider
|
||||
and cred.type in field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if matching_cred:
|
||||
matched_credentials[field_name] = CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with the given input data.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
block_id: Block UUID to execute
|
||||
input_data: Input values for the block
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse: Block execution outputs
|
||||
SetupRequirementsResponse: Missing credentials
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
block_id = kwargs.get("block_id", "").strip()
|
||||
input_data = kwargs.get("input_data", {})
|
||||
session_id = session.session_id
|
||||
|
||||
if not block_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a block_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not isinstance(input_data, dict):
|
||||
return ErrorResponse(
|
||||
message="input_data must be an object",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get the block
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
# Check credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
missing_creds_dict = {c.id: c.model_dump() for c in missing_credentials}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires credentials that are not configured. "
|
||||
"Please set up the required credentials before running this block."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=block_id,
|
||||
agent_name=block.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [c.model_dump() for c in missing_credentials],
|
||||
"inputs": self._get_inputs_list(block),
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
),
|
||||
graph_id=None,
|
||||
graph_version=None,
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch actual credentials and prepare kwargs for block execution
|
||||
# Create execution context with defaults (blocks may require it)
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": ExecutionContext(),
|
||||
}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
# Fetch actual credentials and pass as kwargs (for execution)
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
inputs_list = []
|
||||
schema = block.input_schema.jsonschema()
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
# Get credential field names to exclude
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in required_fields,
|
||||
}
|
||||
)
|
||||
|
||||
return inputs_list
|
||||
@@ -0,0 +1,208 @@
|
||||
"""SearchDocsTool - Search documentation using hybrid search."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Base URL for documentation (can be configured)
|
||||
DOCS_BASE_URL = "https://docs.agpt.co"
|
||||
|
||||
# Maximum number of results to return
|
||||
MAX_RESULTS = 5
|
||||
|
||||
# Snippet length for preview
|
||||
SNIPPET_LENGTH = 200
|
||||
|
||||
|
||||
class SearchDocsTool(BaseTool):
|
||||
"""Tool for searching AutoGPT platform documentation."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "search_docs"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search the AutoGPT platform documentation for information about "
|
||||
"how to use the platform, build agents, configure blocks, and more. "
|
||||
"Returns relevant documentation sections. Use get_doc_page to read full content."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find relevant documentation. "
|
||||
"Use natural language to describe what you're looking for."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False # Documentation is public
|
||||
|
||||
def _create_snippet(self, content: str, max_length: int = SNIPPET_LENGTH) -> str:
|
||||
"""Create a short snippet from content for preview."""
|
||||
# Remove markdown formatting for cleaner snippet
|
||||
clean_content = content.replace("#", "").replace("*", "").replace("`", "")
|
||||
# Remove extra whitespace
|
||||
clean_content = " ".join(clean_content.split())
|
||||
|
||||
if len(clean_content) <= max_length:
|
||||
return clean_content
|
||||
|
||||
# Truncate at word boundary
|
||||
truncated = clean_content[:max_length]
|
||||
last_space = truncated.rfind(" ")
|
||||
if last_space > max_length // 2:
|
||||
truncated = truncated[:last_space]
|
||||
|
||||
return truncated + "..."
|
||||
|
||||
def _make_doc_url(self, path: str) -> str:
|
||||
"""Create a URL for a documentation page."""
|
||||
# Remove file extension for URL
|
||||
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
||||
return f"{DOCS_BASE_URL}/{url_path}"
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search documentation and return relevant sections.
|
||||
|
||||
Args:
|
||||
user_id: User ID (not required for docs)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
DocSearchResultsResponse: List of matching documentation sections
|
||||
NoResultsResponse: No results found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query.",
|
||||
error="Missing query parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Search using hybrid search for DOCUMENTATION content type only
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
|
||||
min_score=0.1, # Lower threshold for docs
|
||||
)
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'.",
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use more general terms",
|
||||
"Check for typos in your query",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Deduplicate by document path (keep highest scoring section per doc)
|
||||
seen_docs: dict[str, dict[str, Any]] = {}
|
||||
for result in results:
|
||||
metadata = result.get("metadata", {})
|
||||
doc_path = metadata.get("path", "")
|
||||
|
||||
if not doc_path:
|
||||
continue
|
||||
|
||||
# Keep the highest scoring result for each document
|
||||
if doc_path not in seen_docs:
|
||||
seen_docs[doc_path] = result
|
||||
elif result.get("combined_score", 0) > seen_docs[doc_path].get(
|
||||
"combined_score", 0
|
||||
):
|
||||
seen_docs[doc_path] = result
|
||||
|
||||
# Sort by score and take top MAX_RESULTS
|
||||
deduplicated = sorted(
|
||||
seen_docs.values(),
|
||||
key=lambda x: x.get("combined_score", 0),
|
||||
reverse=True,
|
||||
)[:MAX_RESULTS]
|
||||
|
||||
if not deduplicated:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'.",
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use more general terms",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build response
|
||||
doc_results: list[DocSearchResult] = []
|
||||
for result in deduplicated:
|
||||
metadata = result.get("metadata", {})
|
||||
doc_path = metadata.get("path", "")
|
||||
doc_title = metadata.get("doc_title", "")
|
||||
section_title = metadata.get("section_title", "")
|
||||
searchable_text = result.get("searchable_text", "")
|
||||
score = result.get("combined_score", 0)
|
||||
|
||||
doc_results.append(
|
||||
DocSearchResult(
|
||||
title=doc_title or section_title or doc_path,
|
||||
path=doc_path,
|
||||
section=section_title,
|
||||
snippet=self._create_snippet(searchable_text),
|
||||
score=round(score, 3),
|
||||
doc_url=self._make_doc_url(doc_path),
|
||||
)
|
||||
)
|
||||
|
||||
return DocSearchResultsResponse(
|
||||
message=f"Found {len(doc_results)} relevant documentation sections.",
|
||||
results=doc_results,
|
||||
count=len(doc_results),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Documentation search failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to search documentation: {str(e)}",
|
||||
error="search_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,610 @@
|
||||
"""
|
||||
Content Type Handlers for Unified Embeddings
|
||||
|
||||
Pluggable system for different content sources (store agents, blocks, docs).
|
||||
Each handler knows how to fetch and process its content type for embedding.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentItem:
|
||||
"""Represents a piece of content to be embedded."""
|
||||
|
||||
content_id: str # Unique identifier (DB ID or file path)
|
||||
content_type: ContentType
|
||||
searchable_text: str # Combined text for embedding
|
||||
metadata: dict[str, Any] # Content-specific metadata
|
||||
user_id: str | None = None # For user-scoped content
|
||||
|
||||
|
||||
class ContentHandler(ABC):
|
||||
"""Base handler for fetching and processing content for embeddings."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def content_type(self) -> ContentType:
|
||||
"""The ContentType this handler manages."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""
|
||||
Fetch items that don't have embeddings yet.
|
||||
|
||||
Args:
|
||||
batch_size: Maximum number of items to return
|
||||
|
||||
Returns:
|
||||
List of ContentItem objects ready for embedding
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
|
||||
Returns:
|
||||
Dict with keys: total, with_embeddings, without_embeddings
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StoreAgentHandler(ContentHandler):
|
||||
"""Handler for marketplace store agent listings."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.STORE_AGENT
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch approved store listings without embeddings."""
|
||||
from backend.api.features.store.embeddings import build_searchable_text
|
||||
|
||||
missing = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
|
||||
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND uce."contentId" IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return [
|
||||
ContentItem(
|
||||
content_id=row["id"],
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text=build_searchable_text(
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
),
|
||||
metadata={
|
||||
"name": row["name"],
|
||||
"categories": row["categories"] or [],
|
||||
},
|
||||
user_id=None, # Store agents are public
|
||||
)
|
||||
for row in missing
|
||||
]
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about store agent embedding coverage."""
|
||||
# Count approved versions
|
||||
approved_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
class BlockHandler(ContentHandler):
|
||||
"""Handler for block definitions (Python classes)."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.BLOCK
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch blocks without embeddings."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# Get all available blocks
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Check which ones have embeddings
|
||||
if not all_blocks:
|
||||
return []
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
|
||||
# Query for existing embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_blocks = [
|
||||
(block_id, block_cls)
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if block_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
items = []
|
||||
for block_id, block_cls in missing_blocks[:batch_size]:
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
# Build searchable text from block metadata
|
||||
parts = []
|
||||
if hasattr(block_instance, "name") and block_instance.name:
|
||||
parts.append(block_instance.name)
|
||||
if (
|
||||
hasattr(block_instance, "description")
|
||||
and block_instance.description
|
||||
):
|
||||
parts.append(block_instance.description)
|
||||
if hasattr(block_instance, "categories") and block_instance.categories:
|
||||
# Convert BlockCategory enum to strings
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
)
|
||||
|
||||
# Add input/output schema info
|
||||
if hasattr(block_instance, "input_schema"):
|
||||
schema = block_instance.input_schema
|
||||
if hasattr(schema, "model_json_schema"):
|
||||
schema_dict = schema.model_json_schema()
|
||||
if "properties" in schema_dict:
|
||||
for prop_name, prop_info in schema_dict[
|
||||
"properties"
|
||||
].items():
|
||||
if "description" in prop_info:
|
||||
parts.append(
|
||||
f"{prop_name}: {prop_info['description']}"
|
||||
)
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
# Convert categories set of enums to list of strings for JSON serialization
|
||||
categories = getattr(block_instance, "categories", set())
|
||||
categories_list = (
|
||||
[cat.value for cat in categories] if categories else []
|
||||
)
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=block_id,
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": getattr(block_instance, "name", ""),
|
||||
"categories": categories_list,
|
||||
},
|
||||
user_id=None, # Blocks are public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process block {block_id}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about block embedding coverage."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
all_blocks = get_blocks()
|
||||
total_blocks = len(all_blocks)
|
||||
|
||||
if total_blocks == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_blocks,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_blocks - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarkdownSection:
|
||||
"""Represents a section of a markdown document."""
|
||||
|
||||
title: str # Section heading text
|
||||
content: str # Section content (including the heading line)
|
||||
level: int # Heading level (1 for #, 2 for ##, etc.)
|
||||
index: int # Section index within the document
|
||||
|
||||
|
||||
class DocumentationHandler(ContentHandler):
|
||||
"""Handler for documentation files (.md/.mdx).
|
||||
|
||||
Chunks documents by markdown headings to create multiple embeddings per file.
|
||||
Each section (## heading) becomes a separate embedding for better retrieval.
|
||||
"""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.DOCUMENTATION
|
||||
|
||||
def _get_docs_root(self) -> Path:
|
||||
"""Get the documentation root directory."""
|
||||
# content_handlers.py is at: backend/backend/api/features/store/content_handlers.py
|
||||
# Need to go up to project root then into docs/
|
||||
# In container: /app/autogpt_platform/backend/backend/api/features/store -> /app/docs
|
||||
# In development: /repo/autogpt_platform/backend/backend/api/features/store -> /repo/docs
|
||||
this_file = Path(
|
||||
__file__
|
||||
) # .../backend/backend/api/features/store/content_handlers.py
|
||||
project_root = (
|
||||
this_file.parent.parent.parent.parent.parent.parent.parent
|
||||
) # -> /app or /repo
|
||||
docs_root = project_root / "docs"
|
||||
return docs_root
|
||||
|
||||
def _extract_doc_title(self, file_path: Path) -> str:
|
||||
"""Extract the document title from a markdown file."""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
lines = content.split("\n")
|
||||
|
||||
# Try to extract title from first # heading
|
||||
for line in lines:
|
||||
if line.startswith("# "):
|
||||
return line[2:].strip()
|
||||
|
||||
# If no title found, use filename
|
||||
return file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read title from {file_path}: {e}")
|
||||
return file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
|
||||
def _chunk_markdown_by_headings(
|
||||
self, file_path: Path, min_heading_level: int = 2
|
||||
) -> list[MarkdownSection]:
|
||||
"""
|
||||
Split a markdown file into sections based on headings.
|
||||
|
||||
Args:
|
||||
file_path: Path to the markdown file
|
||||
min_heading_level: Minimum heading level to split on (default: 2 for ##)
|
||||
|
||||
Returns:
|
||||
List of MarkdownSection objects, one per section.
|
||||
If no headings found, returns a single section with all content.
|
||||
"""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {file_path}: {e}")
|
||||
return []
|
||||
|
||||
lines = content.split("\n")
|
||||
sections: list[MarkdownSection] = []
|
||||
current_section_lines: list[str] = []
|
||||
current_title = ""
|
||||
current_level = 0
|
||||
section_index = 0
|
||||
doc_title = ""
|
||||
|
||||
for line in lines:
|
||||
# Check if line is a heading
|
||||
if line.startswith("#"):
|
||||
# Count heading level
|
||||
level = 0
|
||||
for char in line:
|
||||
if char == "#":
|
||||
level += 1
|
||||
else:
|
||||
break
|
||||
|
||||
heading_text = line[level:].strip()
|
||||
|
||||
# Track document title (level 1 heading)
|
||||
if level == 1 and not doc_title:
|
||||
doc_title = heading_text
|
||||
# Don't create a section for just the title - add it to first section
|
||||
current_section_lines.append(line)
|
||||
continue
|
||||
|
||||
# Check if this heading should start a new section
|
||||
if level >= min_heading_level:
|
||||
# Save previous section if it has content
|
||||
if current_section_lines:
|
||||
section_content = "\n".join(current_section_lines).strip()
|
||||
if section_content:
|
||||
# Use doc title for first section if no specific title
|
||||
title = current_title if current_title else doc_title
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace(
|
||||
"_", " "
|
||||
)
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=section_content,
|
||||
level=current_level if current_level else 1,
|
||||
index=section_index,
|
||||
)
|
||||
)
|
||||
section_index += 1
|
||||
|
||||
# Start new section
|
||||
current_section_lines = [line]
|
||||
current_title = heading_text
|
||||
current_level = level
|
||||
else:
|
||||
# Lower level heading (e.g., # when splitting on ##)
|
||||
current_section_lines.append(line)
|
||||
else:
|
||||
current_section_lines.append(line)
|
||||
|
||||
# Don't forget the last section
|
||||
if current_section_lines:
|
||||
section_content = "\n".join(current_section_lines).strip()
|
||||
if section_content:
|
||||
title = current_title if current_title else doc_title
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace("_", " ")
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=section_content,
|
||||
level=current_level if current_level else 1,
|
||||
index=section_index,
|
||||
)
|
||||
)
|
||||
|
||||
# If no sections were created (no headings found), create one section with all content
|
||||
if not sections and content.strip():
|
||||
title = (
|
||||
doc_title
|
||||
if doc_title
|
||||
else file_path.stem.replace("-", " ").replace("_", " ")
|
||||
)
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=content.strip(),
|
||||
level=1,
|
||||
index=0,
|
||||
)
|
||||
)
|
||||
|
||||
return sections
|
||||
|
||||
def _make_section_content_id(self, doc_path: str, section_index: int) -> str:
|
||||
"""Create a unique content ID for a document section.
|
||||
|
||||
Format: doc_path::section_index
|
||||
Example: 'platform/getting-started.md::0'
|
||||
"""
|
||||
return f"{doc_path}::{section_index}"
|
||||
|
||||
def _parse_section_content_id(self, content_id: str) -> tuple[str, int]:
|
||||
"""Parse a section content ID back into doc_path and section_index.
|
||||
|
||||
Returns: (doc_path, section_index)
|
||||
"""
|
||||
if "::" in content_id:
|
||||
parts = content_id.rsplit("::", 1)
|
||||
return parts[0], int(parts[1])
|
||||
# Legacy format (whole document)
|
||||
return content_id, 0
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch documentation sections without embeddings.
|
||||
|
||||
Chunks each document by markdown headings and creates embeddings for each section.
|
||||
Content IDs use the format: 'path/to/doc.md::section_index'
|
||||
"""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
logger.warning(f"Documentation root not found: {docs_root}")
|
||||
return []
|
||||
|
||||
# Find all .md and .mdx files
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
|
||||
if not all_docs:
|
||||
return []
|
||||
|
||||
# Build list of all sections from all documents
|
||||
all_sections: list[tuple[str, Path, MarkdownSection]] = []
|
||||
for doc_file in all_docs:
|
||||
doc_path = str(doc_file.relative_to(docs_root))
|
||||
sections = self._chunk_markdown_by_headings(doc_file)
|
||||
for section in sections:
|
||||
all_sections.append((doc_path, doc_file, section))
|
||||
|
||||
if not all_sections:
|
||||
return []
|
||||
|
||||
# Generate content IDs for all sections
|
||||
section_content_ids = [
|
||||
self._make_section_content_id(doc_path, section.index)
|
||||
for doc_path, _, section in all_sections
|
||||
]
|
||||
|
||||
# Check which ones have embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(section_content_ids))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*section_content_ids,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
|
||||
# Filter to missing sections
|
||||
missing_sections = [
|
||||
(doc_path, doc_file, section, content_id)
|
||||
for (doc_path, doc_file, section), content_id in zip(
|
||||
all_sections, section_content_ids
|
||||
)
|
||||
if content_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem (up to batch_size)
|
||||
items = []
|
||||
for doc_path, doc_file, section, content_id in missing_sections[:batch_size]:
|
||||
try:
|
||||
# Get document title for context
|
||||
doc_title = self._extract_doc_title(doc_file)
|
||||
|
||||
# Build searchable text with context
|
||||
# Include doc title and section title for better search relevance
|
||||
searchable_text = f"{doc_title} - {section.title}\n\n{section.content}"
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=content_id,
|
||||
content_type=ContentType.DOCUMENTATION,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"doc_title": doc_title,
|
||||
"section_title": section.title,
|
||||
"section_index": section.index,
|
||||
"heading_level": section.level,
|
||||
"path": doc_path,
|
||||
},
|
||||
user_id=None, # Documentation is public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process section {content_id}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
def _get_all_section_content_ids(self, docs_root: Path) -> set[str]:
|
||||
"""Get all current section content IDs from the docs directory.
|
||||
|
||||
Used for stats and cleanup to know what sections should exist.
|
||||
"""
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
content_ids = set()
|
||||
|
||||
for doc_file in all_docs:
|
||||
doc_path = str(doc_file.relative_to(docs_root))
|
||||
sections = self._chunk_markdown_by_headings(doc_file)
|
||||
for section in sections:
|
||||
content_ids.add(self._make_section_content_id(doc_path, section.index))
|
||||
|
||||
return content_ids
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about documentation embedding coverage.
|
||||
|
||||
Counts sections (not documents) since each section gets its own embedding.
|
||||
"""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
# Get all section content IDs
|
||||
all_section_ids = self._get_all_section_content_ids(docs_root)
|
||||
total_sections = len(all_section_ids)
|
||||
|
||||
if total_sections == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
# Count embeddings in database for DOCUMENTATION type
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{schema_prefix}"ContentType"
|
||||
"""
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_sections,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_sections - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
# Content handler registry
|
||||
CONTENT_HANDLERS: dict[ContentType, ContentHandler] = {
|
||||
ContentType.STORE_AGENT: StoreAgentHandler(),
|
||||
ContentType.BLOCK: BlockHandler(),
|
||||
ContentType.DOCUMENTATION: DocumentationHandler(),
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Integration tests for content handlers using real DB.
|
||||
|
||||
Run with: poetry run pytest backend/api/features/store/content_handlers_integration_test.py -xvs
|
||||
|
||||
These tests use the real database but mock OpenAI calls.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
from backend.api.features.store.embeddings import (
|
||||
EMBEDDING_DIM,
|
||||
backfill_all_content_types,
|
||||
ensure_content_embedding,
|
||||
get_embedding_stats,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_real_db():
|
||||
"""Test StoreAgentHandler with real database queries."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list (may be empty if all have embeddings)
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None
|
||||
assert item.content_type.value == "STORE_AGENT"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_real_db():
|
||||
"""Test BlockHandler with real database queries."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0 # Should have at least some blocks
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be block UUID
|
||||
assert item.content_type.value == "BLOCK"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_real_fs():
|
||||
"""Test DocumentationHandler with real filesystem."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Get stats from real filesystem
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be relative path
|
||||
assert item.content_type.value == "DOCUMENTATION"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats_all_types():
|
||||
"""Test get_embedding_stats aggregates all content types."""
|
||||
stats = await get_embedding_stats()
|
||||
|
||||
# Should have structure with by_type and totals
|
||||
assert "by_type" in stats
|
||||
assert "totals" in stats
|
||||
|
||||
# Check each content type is present
|
||||
by_type = stats["by_type"]
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "BLOCK" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Check totals are aggregated
|
||||
totals = stats["totals"]
|
||||
assert totals["total"] >= 0
|
||||
assert totals["with_embeddings"] >= 0
|
||||
assert totals["without_embeddings"] >= 0
|
||||
assert "coverage_percent" in totals
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_ensure_content_embedding_blocks(mock_generate):
|
||||
"""Test creating embeddings for blocks (mocked OpenAI)."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * EMBEDDING_DIM
|
||||
|
||||
# Get one block without embedding
|
||||
handler = BlockHandler()
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
if not items:
|
||||
pytest.skip("No blocks without embeddings")
|
||||
|
||||
item = items[0]
|
||||
|
||||
# Try to create embedding (OpenAI mocked)
|
||||
result = await ensure_content_embedding(
|
||||
content_type=item.content_type,
|
||||
content_id=item.content_id,
|
||||
searchable_text=item.searchable_text,
|
||||
metadata=item.metadata,
|
||||
user_id=item.user_id,
|
||||
)
|
||||
|
||||
# Should succeed with mocked OpenAI
|
||||
assert result is True
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_backfill_all_content_types_dry_run(mock_generate):
|
||||
"""Test backfill_all_content_types processes all handlers in order."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * EMBEDDING_DIM
|
||||
|
||||
# Run backfill with batch_size=1 to process max 1 per type
|
||||
result = await backfill_all_content_types(batch_size=1)
|
||||
|
||||
# Should have results for all content types
|
||||
assert "by_type" in result
|
||||
assert "totals" in result
|
||||
|
||||
by_type = result["by_type"]
|
||||
assert "BLOCK" in by_type
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Each type should have correct structure
|
||||
for content_type, type_result in by_type.items():
|
||||
assert "processed" in type_result
|
||||
assert "success" in type_result
|
||||
assert "failed" in type_result
|
||||
|
||||
# Totals should aggregate
|
||||
totals = result["totals"]
|
||||
assert totals["processed"] >= 0
|
||||
assert totals["success"] >= 0
|
||||
assert totals["failed"] >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handler_registry():
|
||||
"""Test all handlers are registered in correct order."""
|
||||
from prisma.enums import ContentType
|
||||
|
||||
# All three types should be registered
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
# Check handler types
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
E2E tests for content handlers (blocks, store agents, documentation).
|
||||
|
||||
Tests the full flow: discovering content → generating embeddings → storing.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_missing_items(mocker):
|
||||
"""Test StoreAgentHandler fetches approved agents without embeddings."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock database query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "agent-1",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"subHeading": "Test heading",
|
||||
"categories": ["AI", "Testing"],
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_missing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "agent-1"
|
||||
assert items[0].content_type == ContentType.STORE_AGENT
|
||||
assert "Test Agent" in items[0].searchable_text
|
||||
assert "A test agent" in items[0].searchable_text
|
||||
assert items[0].metadata["name"] == "Test Agent"
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_stats(mocker):
|
||||
"""Test StoreAgentHandler returns correct stats."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock approved count query
|
||||
mock_approved = [{"count": 50}]
|
||||
# Mock embedded count query
|
||||
mock_embedded = [{"count": 30}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
side_effect=[mock_approved, mock_embedded],
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 50
|
||||
assert stats["with_embeddings"] == 30
|
||||
assert stats["without_embeddings"] == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items(mocker):
|
||||
"""Test BlockHandler discovers blocks without embeddings."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks to return test blocks
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Calculator Block"
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.input_schema.model_json_schema.return_value = {
|
||||
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
||||
}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
mock_existing = []
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_existing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "block-uuid-1"
|
||||
assert items[0].content_type == ContentType.BLOCK
|
||||
assert "Calculator Block" in items[0].searchable_text
|
||||
assert "Performs calculations" in items[0].searchable_text
|
||||
assert "MATH" in items[0].searchable_text
|
||||
assert "expression: Math expression" in items[0].searchable_text
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats(mocker):
|
||||
"""Test BlockHandler returns correct stats."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks
|
||||
mock_blocks = {
|
||||
"block-1": MagicMock(),
|
||||
"block-2": MagicMock(),
|
||||
"block-3": MagicMock(),
|
||||
}
|
||||
|
||||
# Mock embedded count query (2 blocks have embeddings)
|
||||
mock_embedded = [{"count": 2}]
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 2
|
||||
assert stats["without_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
"""Test DocumentationHandler discovers docs without embeddings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory with test files
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
|
||||
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
|
||||
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
|
||||
|
||||
# Mock _get_docs_root to return temp dir
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
# Check guide.md (content_id format: doc_path::section_index)
|
||||
guide_item = next(
|
||||
(item for item in items if item.content_id == "guide.md::0"), None
|
||||
)
|
||||
assert guide_item is not None
|
||||
assert guide_item.content_type == ContentType.DOCUMENTATION
|
||||
assert "Getting Started" in guide_item.searchable_text
|
||||
assert "This is a guide" in guide_item.searchable_text
|
||||
assert guide_item.metadata["doc_title"] == "Getting Started"
|
||||
assert guide_item.user_id is None
|
||||
|
||||
# Check api.mdx (content_id format: doc_path::section_index)
|
||||
api_item = next(
|
||||
(item for item in items if item.content_id == "api.mdx::0"), None
|
||||
)
|
||||
assert api_item is not None
|
||||
assert "API Reference" in api_item.searchable_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_stats(tmp_path, mocker):
|
||||
"""Test DocumentationHandler returns correct stats."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
(docs_root / "doc1.md").write_text("# Doc 1")
|
||||
(docs_root / "doc2.md").write_text("# Doc 2")
|
||||
(docs_root / "doc3.mdx").write_text("# Doc 3")
|
||||
|
||||
# Mock embedded count query (1 doc has embedding)
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 1
|
||||
assert stats["without_embeddings"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_title_extraction(tmp_path):
|
||||
"""Test DocumentationHandler extracts title from markdown heading."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test with heading
|
||||
doc_with_heading = tmp_path / "with_heading.md"
|
||||
doc_with_heading.write_text("# My Title\n\nContent here")
|
||||
title = handler._extract_doc_title(doc_with_heading)
|
||||
assert title == "My Title"
|
||||
|
||||
# Test without heading
|
||||
doc_without_heading = tmp_path / "no-heading.md"
|
||||
doc_without_heading.write_text("Just content, no heading")
|
||||
title = handler._extract_doc_title(doc_without_heading)
|
||||
assert title == "No Heading" # Uses filename
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
"""Test DocumentationHandler chunks markdown by headings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test document with multiple sections
|
||||
doc_with_sections = tmp_path / "sections.md"
|
||||
doc_with_sections.write_text(
|
||||
"# Document Title\n\n"
|
||||
"Intro paragraph.\n\n"
|
||||
"## Section One\n\n"
|
||||
"Content for section one.\n\n"
|
||||
"## Section Two\n\n"
|
||||
"Content for section two.\n"
|
||||
)
|
||||
sections = handler._chunk_markdown_by_headings(doc_with_sections)
|
||||
|
||||
# Should have 3 sections: intro (with doc title), section one, section two
|
||||
assert len(sections) == 3
|
||||
assert sections[0].title == "Document Title"
|
||||
assert sections[0].index == 0
|
||||
assert "Intro paragraph" in sections[0].content
|
||||
|
||||
assert sections[1].title == "Section One"
|
||||
assert sections[1].index == 1
|
||||
assert "Content for section one" in sections[1].content
|
||||
|
||||
assert sections[2].title == "Section Two"
|
||||
assert sections[2].index == 2
|
||||
assert "Content for section two" in sections[2].content
|
||||
|
||||
# Test document without headings
|
||||
doc_no_sections = tmp_path / "no-sections.md"
|
||||
doc_no_sections.write_text("Just plain content without any headings.")
|
||||
sections = handler._chunk_markdown_by_headings(doc_no_sections)
|
||||
assert len(sections) == 1
|
||||
assert sections[0].index == 0
|
||||
assert "Just plain content" in sections[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_section_content_ids():
|
||||
"""Test DocumentationHandler creates and parses section content IDs."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test making content ID
|
||||
content_id = handler._make_section_content_id("docs/guide.md", 2)
|
||||
assert content_id == "docs/guide.md::2"
|
||||
|
||||
# Test parsing content ID
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/guide.md::2")
|
||||
assert doc_path == "docs/guide.md"
|
||||
assert section_index == 2
|
||||
|
||||
# Test parsing legacy format (no section index)
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/old-format.md")
|
||||
assert doc_path == "docs/old-format.md"
|
||||
assert section_index == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handlers_registry():
|
||||
"""Test all content types are registered."""
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_missing_attributes():
|
||||
"""Test BlockHandler gracefully handles blocks with missing attributes."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with minimal attributes
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
# No description, categories, or schema
|
||||
del mock_block_instance.description
|
||||
del mock_block_instance.categories
|
||||
del mock_block_instance.input_schema
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock one good block and one bad block
|
||||
good_block = MagicMock()
|
||||
good_instance = MagicMock()
|
||||
good_instance.name = "Good Block"
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
bad_block.side_effect = Exception("Instantiation failed")
|
||||
|
||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
# Should only get the good block
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Mock _get_docs_root to return non-existent path
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
@@ -14,6 +14,7 @@ import prisma
|
||||
from prisma.enums import ContentType
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.api.features.store.content_handlers import CONTENT_HANDLERS
|
||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.json import dumps
|
||||
@@ -23,6 +24,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
# Embedding dimension for the model above
|
||||
# text-embedding-3-small: 1536, text-embedding-3-large: 3072
|
||||
EMBEDDING_DIM = 1536
|
||||
# OpenAI embedding token limit (8,191 with 1 token buffer for safety)
|
||||
EMBEDDING_MAX_TOKENS = 8191
|
||||
|
||||
@@ -369,55 +373,69 @@ async def delete_content_embedding(
|
||||
|
||||
async def get_embedding_stats() -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
Get statistics about embedding coverage for all content types.
|
||||
|
||||
Returns counts of:
|
||||
- Total approved listing versions
|
||||
- Versions with embeddings
|
||||
- Versions without embeddings
|
||||
Returns stats per content type and overall totals.
|
||||
"""
|
||||
try:
|
||||
# Count approved versions
|
||||
approved_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
stats_by_type = {}
|
||||
total_items = 0
|
||||
total_with_embeddings = 0
|
||||
total_without_embeddings = 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
# Aggregate stats from all handlers
|
||||
for content_type, handler in CONTENT_HANDLERS.items():
|
||||
try:
|
||||
stats = await handler.get_stats()
|
||||
stats_by_type[content_type.value] = {
|
||||
"total": stats["total"],
|
||||
"with_embeddings": stats["with_embeddings"],
|
||||
"without_embeddings": stats["without_embeddings"],
|
||||
"coverage_percent": (
|
||||
round(stats["with_embeddings"] / stats["total"] * 100, 1)
|
||||
if stats["total"] > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
total_items += stats["total"]
|
||||
total_with_embeddings += stats["with_embeddings"]
|
||||
total_without_embeddings += stats["without_embeddings"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get stats for {content_type.value}: {e}")
|
||||
stats_by_type[content_type.value] = {
|
||||
"total": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
return {
|
||||
"total_approved": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
"coverage_percent": (
|
||||
round(with_embeddings / total_approved * 100, 1)
|
||||
if total_approved > 0
|
||||
else 0
|
||||
),
|
||||
"by_type": stats_by_type,
|
||||
"totals": {
|
||||
"total": total_items,
|
||||
"with_embeddings": total_with_embeddings,
|
||||
"without_embeddings": total_without_embeddings,
|
||||
"coverage_percent": (
|
||||
round(total_with_embeddings / total_items * 100, 1)
|
||||
if total_items > 0
|
||||
else 0
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding stats: {e}")
|
||||
return {
|
||||
"total_approved": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
"by_type": {},
|
||||
"totals": {
|
||||
"total": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
},
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
@@ -426,73 +444,118 @@ async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for approved listings that don't have them.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing usage.
|
||||
This now delegates to backfill_all_content_types() to process all content types.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate in one call
|
||||
batch_size: Number of embeddings to generate per content type
|
||||
|
||||
Returns:
|
||||
Dict with success/failure counts
|
||||
Dict with success/failure counts aggregated across all content types
|
||||
"""
|
||||
try:
|
||||
# Find approved versions without embeddings
|
||||
missing = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
|
||||
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND uce."contentId" IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
# Delegate to the new generic backfill system
|
||||
result = await backfill_all_content_types(batch_size)
|
||||
|
||||
if not missing:
|
||||
return {
|
||||
# Return in the old format for backward compatibility
|
||||
return result["totals"]
|
||||
|
||||
|
||||
async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for all content types using registered handlers.
|
||||
|
||||
Processes content types in order: BLOCK → STORE_AGENT → DOCUMENTATION.
|
||||
This ensures foundational content (blocks) are searchable first.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate per content type
|
||||
|
||||
Returns:
|
||||
Dict with stats per content type and overall totals
|
||||
"""
|
||||
results_by_type = {}
|
||||
total_processed = 0
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
|
||||
# Process content types in explicit order
|
||||
processing_order = [
|
||||
ContentType.BLOCK,
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
for content_type in processing_order:
|
||||
handler = CONTENT_HANDLERS.get(content_type)
|
||||
if not handler:
|
||||
logger.warning(f"No handler registered for {content_type.value}")
|
||||
continue
|
||||
try:
|
||||
logger.info(f"Processing {content_type.value} content type...")
|
||||
|
||||
# Get missing items from handler
|
||||
missing_items = await handler.get_missing_items(batch_size)
|
||||
|
||||
if not missing_items:
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"message": "No missing embeddings",
|
||||
}
|
||||
continue
|
||||
|
||||
# Process embeddings concurrently for better performance
|
||||
embedding_tasks = [
|
||||
ensure_content_embedding(
|
||||
content_type=item.content_type,
|
||||
content_id=item.content_id,
|
||||
searchable_text=item.searchable_text,
|
||||
metadata=item.metadata,
|
||||
user_id=item.user_id,
|
||||
)
|
||||
for item in missing_items
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
|
||||
|
||||
success = sum(1 for result in results if result is True)
|
||||
failed = len(results) - success
|
||||
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": len(missing_items),
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"message": f"Backfilled {success} embeddings, {failed} failed",
|
||||
}
|
||||
|
||||
total_processed += len(missing_items)
|
||||
total_success += success
|
||||
total_failed += failed
|
||||
|
||||
logger.info(
|
||||
f"{content_type.value}: processed {len(missing_items)}, "
|
||||
f"success {success}, failed {failed}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process {content_type.value}: {e}")
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"message": "No missing embeddings",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
# Process embeddings concurrently for better performance
|
||||
embedding_tasks = [
|
||||
ensure_embedding(
|
||||
version_id=row["id"],
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
)
|
||||
for row in missing
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
|
||||
|
||||
success = sum(1 for result in results if result is True)
|
||||
failed = len(results) - success
|
||||
|
||||
return {
|
||||
"processed": len(missing),
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"message": f"Backfilled {success} embeddings, {failed} failed",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill embeddings: {e}")
|
||||
return {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
return {
|
||||
"by_type": results_by_type,
|
||||
"totals": {
|
||||
"processed": total_processed,
|
||||
"success": total_success,
|
||||
"failed": total_failed,
|
||||
"message": f"Overall: {total_success} succeeded, {total_failed} failed",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def embed_query(query: str) -> list[float] | None:
|
||||
@@ -566,3 +629,334 @@ async def ensure_content_embedding(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
||||
"""
|
||||
Clean up embeddings for content that no longer exists or is no longer valid.
|
||||
|
||||
Compares current content with embeddings in database and removes orphaned records:
|
||||
- STORE_AGENT: Removes embeddings for rejected/deleted store listings
|
||||
- BLOCK: Removes embeddings for blocks no longer registered
|
||||
- DOCUMENTATION: Removes embeddings for deleted doc files
|
||||
|
||||
Returns:
|
||||
Dict with cleanup statistics per content type
|
||||
"""
|
||||
results_by_type = {}
|
||||
total_deleted = 0
|
||||
|
||||
# Cleanup orphaned embeddings for all content types
|
||||
cleanup_types = [
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.BLOCK,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
for content_type in cleanup_types:
|
||||
try:
|
||||
handler = CONTENT_HANDLERS.get(content_type)
|
||||
if not handler:
|
||||
logger.warning(f"No handler registered for {content_type}")
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": 0,
|
||||
"error": "No handler registered",
|
||||
}
|
||||
continue
|
||||
|
||||
# Get all current content IDs from handler
|
||||
if content_type == ContentType.STORE_AGENT:
|
||||
# Get IDs of approved store listing versions from non-deleted listings
|
||||
valid_agents = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT slv.id
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
JOIN {schema_prefix}"StoreListing" sl ON slv."storeListingId" = sl.id
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND sl."isDeleted" = false
|
||||
""",
|
||||
)
|
||||
current_ids = {row["id"] for row in valid_agents}
|
||||
elif content_type == ContentType.BLOCK:
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
current_ids = set(get_blocks().keys())
|
||||
elif content_type == ContentType.DOCUMENTATION:
|
||||
# Use DocumentationHandler to get section-based content IDs
|
||||
from backend.api.features.store.content_handlers import (
|
||||
DocumentationHandler,
|
||||
)
|
||||
|
||||
doc_handler = CONTENT_HANDLERS.get(ContentType.DOCUMENTATION)
|
||||
if isinstance(doc_handler, DocumentationHandler):
|
||||
docs_root = doc_handler._get_docs_root()
|
||||
if docs_root.exists():
|
||||
current_ids = doc_handler._get_all_section_content_ids(
|
||||
docs_root
|
||||
)
|
||||
else:
|
||||
current_ids = set()
|
||||
else:
|
||||
current_ids = set()
|
||||
else:
|
||||
# Skip unknown content types to avoid accidental deletion
|
||||
logger.warning(
|
||||
f"Skipping cleanup for unknown content type: {content_type}"
|
||||
)
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": 0,
|
||||
"error": "Unknown content type - skipped for safety",
|
||||
}
|
||||
continue
|
||||
|
||||
# Get all embedding IDs from database
|
||||
db_embeddings = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT "contentId"
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
||||
""",
|
||||
content_type,
|
||||
)
|
||||
|
||||
db_ids = {row["contentId"] for row in db_embeddings}
|
||||
|
||||
# Find orphaned embeddings (in DB but not in current content)
|
||||
orphaned_ids = db_ids - current_ids
|
||||
|
||||
if not orphaned_ids:
|
||||
logger.info(f"{content_type.value}: No orphaned embeddings found")
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": 0,
|
||||
"message": "No orphaned embeddings",
|
||||
}
|
||||
continue
|
||||
|
||||
# Delete orphaned embeddings in batch for better performance
|
||||
orphaned_list = list(orphaned_ids)
|
||||
try:
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND "contentId" = ANY($2::text[])
|
||||
""",
|
||||
content_type,
|
||||
orphaned_list,
|
||||
)
|
||||
deleted = len(orphaned_list)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to batch delete orphaned embeddings: {e}")
|
||||
deleted = 0
|
||||
|
||||
logger.info(
|
||||
f"{content_type.value}: Deleted {deleted}/{len(orphaned_ids)} orphaned embeddings"
|
||||
)
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": deleted,
|
||||
"orphaned": len(orphaned_ids),
|
||||
"message": f"Deleted {deleted} orphaned embeddings",
|
||||
}
|
||||
|
||||
total_deleted += deleted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup {content_type.value}: {e}")
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
return {
|
||||
"by_type": results_by_type,
|
||||
"totals": {
|
||||
"deleted": total_deleted,
|
||||
"message": f"Deleted {total_deleted} orphaned embeddings",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def semantic_search(
|
||||
query: str,
|
||||
content_types: list[ContentType] | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int = 20,
|
||||
min_similarity: float = 0.5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Semantic search across content types using embeddings.
|
||||
|
||||
Performs vector similarity search on UnifiedContentEmbedding table.
|
||||
Used directly for blocks/docs/library agents, or as the semantic component
|
||||
within hybrid_search for store agents.
|
||||
|
||||
If embedding generation fails, falls back to lexical search on searchableText.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
content_types: List of ContentType to search. Defaults to [BLOCK, STORE_AGENT, DOCUMENTATION]
|
||||
user_id: Optional user ID for searching private content (library agents)
|
||||
limit: Maximum number of results to return (default: 20)
|
||||
min_similarity: Minimum cosine similarity threshold (0-1, default: 0.5)
|
||||
|
||||
Returns:
|
||||
List of search results with the following structure:
|
||||
[
|
||||
{
|
||||
"content_id": str,
|
||||
"content_type": str, # "BLOCK", "STORE_AGENT", "DOCUMENTATION", or "LIBRARY_AGENT"
|
||||
"searchable_text": str,
|
||||
"metadata": dict,
|
||||
"similarity": float, # Cosine similarity score (0-1)
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Examples:
|
||||
# Search blocks only
|
||||
results = await semantic_search("calculate", content_types=[ContentType.BLOCK])
|
||||
|
||||
# Search blocks and documentation
|
||||
results = await semantic_search(
|
||||
"how to use API",
|
||||
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION]
|
||||
)
|
||||
|
||||
# Search all public content (default)
|
||||
results = await semantic_search("AI agent")
|
||||
|
||||
# Search user's library agents
|
||||
results = await semantic_search(
|
||||
"my custom agent",
|
||||
content_types=[ContentType.LIBRARY_AGENT],
|
||||
user_id="user123"
|
||||
)
|
||||
"""
|
||||
# Default to searching all public content types
|
||||
if content_types is None:
|
||||
content_types = [
|
||||
ContentType.BLOCK,
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
# Validate inputs
|
||||
if not content_types:
|
||||
return [] # Empty content_types would cause invalid SQL (IN ())
|
||||
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return []
|
||||
|
||||
if limit < 1:
|
||||
limit = 1
|
||||
if limit > 100:
|
||||
limit = 100
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
if query_embedding is not None:
|
||||
# Semantic search with embeddings
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
|
||||
# Build params in order: limit, then user_id (if provided), then content types
|
||||
params: list[Any] = [limit]
|
||||
user_filter = ""
|
||||
if user_id is not None:
|
||||
user_filter = 'AND "userId" = ${}'.format(len(params) + 1)
|
||||
params.append(user_id)
|
||||
|
||||
# Add content type parameters and build placeholders dynamically
|
||||
content_type_start_idx = len(params) + 1
|
||||
content_type_placeholders = ", ".join(
|
||||
f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"'
|
||||
for i in range(len(content_types))
|
||||
)
|
||||
params.extend([ct.value for ct in content_types])
|
||||
|
||||
sql = f"""
|
||||
SELECT
|
||||
"contentId" as content_id,
|
||||
"contentType" as content_type,
|
||||
"searchableText" as searchable_text,
|
||||
metadata,
|
||||
1 - (embedding <=> '{embedding_str}'::vector) as similarity
|
||||
FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ({content_type_placeholders})
|
||||
{user_filter}
|
||||
AND 1 - (embedding <=> '{embedding_str}'::vector) >= ${len(params) + 1}
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $1
|
||||
"""
|
||||
params.append(min_similarity)
|
||||
|
||||
try:
|
||||
results = await query_raw_with_schema(
|
||||
sql, *params, set_public_search_path=True
|
||||
)
|
||||
return [
|
||||
{
|
||||
"content_id": row["content_id"],
|
||||
"content_type": row["content_type"],
|
||||
"searchable_text": row["searchable_text"],
|
||||
"metadata": row["metadata"],
|
||||
"similarity": float(row["similarity"]),
|
||||
}
|
||||
for row in results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Semantic search failed: {e}")
|
||||
# Fall through to lexical search below
|
||||
|
||||
# Fallback to lexical search if embeddings unavailable
|
||||
logger.warning("Falling back to lexical search (embeddings unavailable)")
|
||||
|
||||
params_lexical: list[Any] = [limit]
|
||||
user_filter = ""
|
||||
if user_id is not None:
|
||||
user_filter = 'AND "userId" = ${}'.format(len(params_lexical) + 1)
|
||||
params_lexical.append(user_id)
|
||||
|
||||
# Add content type parameters and build placeholders dynamically
|
||||
content_type_start_idx = len(params_lexical) + 1
|
||||
content_type_placeholders_lexical = ", ".join(
|
||||
f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"'
|
||||
for i in range(len(content_types))
|
||||
)
|
||||
params_lexical.extend([ct.value for ct in content_types])
|
||||
|
||||
sql_lexical = f"""
|
||||
SELECT
|
||||
"contentId" as content_id,
|
||||
"contentType" as content_type,
|
||||
"searchableText" as searchable_text,
|
||||
metadata,
|
||||
0.0 as similarity
|
||||
FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ({content_type_placeholders_lexical})
|
||||
{user_filter}
|
||||
AND "searchableText" ILIKE ${len(params_lexical) + 1}
|
||||
ORDER BY "updatedAt" DESC
|
||||
LIMIT $1
|
||||
"""
|
||||
params_lexical.append(f"%{query}%")
|
||||
|
||||
try:
|
||||
results = await query_raw_with_schema(
|
||||
sql_lexical, *params_lexical, set_public_search_path=True
|
||||
)
|
||||
return [
|
||||
{
|
||||
"content_id": row["content_id"],
|
||||
"content_type": row["content_type"],
|
||||
"searchable_text": row["searchable_text"],
|
||||
"metadata": row["metadata"],
|
||||
"similarity": 0.0, # Lexical search doesn't provide similarity
|
||||
}
|
||||
for row in results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Lexical search failed: {e}")
|
||||
return []
|
||||
|
||||
@@ -0,0 +1,666 @@
|
||||
"""
|
||||
End-to-end database tests for embeddings and hybrid search.
|
||||
|
||||
These tests hit the actual database to verify SQL queries work correctly.
|
||||
Tests cover:
|
||||
1. Embedding storage (store_content_embedding)
|
||||
2. Embedding retrieval (get_content_embedding)
|
||||
3. Embedding deletion (delete_content_embedding)
|
||||
4. Unified hybrid search across content types
|
||||
5. Store agent hybrid search
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
from backend.api.features.store.embeddings import EMBEDDING_DIM
|
||||
from backend.api.features.store.hybrid_search import (
|
||||
hybrid_search,
|
||||
unified_hybrid_search,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Test Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_content_id() -> str:
|
||||
"""Generate unique content ID for test isolation."""
|
||||
return f"test-content-{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Generate unique user ID for test isolation."""
|
||||
return f"test-user-{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding() -> list[float]:
|
||||
"""Generate a mock embedding vector."""
|
||||
# Create a normalized embedding vector
|
||||
import math
|
||||
|
||||
raw = [float(i % 10) / 10.0 for i in range(EMBEDDING_DIM)]
|
||||
# Normalize to unit length (required for cosine similarity)
|
||||
magnitude = math.sqrt(sum(x * x for x in raw))
|
||||
return [x / magnitude for x in raw]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def similar_embedding() -> list[float]:
|
||||
"""Generate an embedding similar to mock_embedding."""
|
||||
import math
|
||||
|
||||
# Similar but slightly different values
|
||||
raw = [float(i % 10) / 10.0 + 0.01 for i in range(EMBEDDING_DIM)]
|
||||
magnitude = math.sqrt(sum(x * x for x in raw))
|
||||
return [x / magnitude for x in raw]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def different_embedding() -> list[float]:
|
||||
"""Generate an embedding very different from mock_embedding."""
|
||||
import math
|
||||
|
||||
# Reversed pattern to be maximally different
|
||||
raw = [float((EMBEDDING_DIM - i) % 10) / 10.0 for i in range(EMBEDDING_DIM)]
|
||||
magnitude = math.sqrt(sum(x * x for x in raw))
|
||||
return [x / magnitude for x in raw]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def cleanup_embeddings(
|
||||
server,
|
||||
) -> AsyncGenerator[list[tuple[ContentType, str, str | None]], None]:
|
||||
"""
|
||||
Fixture that tracks created embeddings and cleans them up after tests.
|
||||
|
||||
Yields a list to which tests can append (content_type, content_id, user_id) tuples.
|
||||
"""
|
||||
created_embeddings: list[tuple[ContentType, str, str | None]] = []
|
||||
yield created_embeddings
|
||||
|
||||
# Cleanup all created embeddings
|
||||
for content_type, content_id, user_id in created_embeddings:
|
||||
try:
|
||||
await embeddings.delete_content_embedding(content_type, content_id, user_id)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# store_content_embedding Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_content_embedding_store_agent(
|
||||
server,
|
||||
test_content_id: str,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test storing embedding for STORE_AGENT content type."""
|
||||
# Track for cleanup
|
||||
cleanup_embeddings.append((ContentType.STORE_AGENT, test_content_id, None))
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="AI assistant for productivity tasks",
|
||||
metadata={"name": "Test Agent", "categories": ["productivity"]},
|
||||
user_id=None, # Store agents are public
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify it was stored
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.STORE_AGENT, test_content_id, user_id=None
|
||||
)
|
||||
assert stored is not None
|
||||
assert stored["contentId"] == test_content_id
|
||||
assert stored["contentType"] == "STORE_AGENT"
|
||||
assert stored["searchableText"] == "AI assistant for productivity tasks"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_content_embedding_block(
|
||||
server,
|
||||
test_content_id: str,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test storing embedding for BLOCK content type."""
|
||||
cleanup_embeddings.append((ContentType.BLOCK, test_content_id, None))
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="HTTP request block for API calls",
|
||||
metadata={"name": "HTTP Request Block"},
|
||||
user_id=None, # Blocks are public
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.BLOCK, test_content_id, user_id=None
|
||||
)
|
||||
assert stored is not None
|
||||
assert stored["contentType"] == "BLOCK"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_content_embedding_documentation(
|
||||
server,
|
||||
test_content_id: str,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test storing embedding for DOCUMENTATION content type."""
|
||||
cleanup_embeddings.append((ContentType.DOCUMENTATION, test_content_id, None))
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.DOCUMENTATION,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="Getting started guide for AutoGPT platform",
|
||||
metadata={"title": "Getting Started", "url": "/docs/getting-started"},
|
||||
user_id=None, # Docs are public
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.DOCUMENTATION, test_content_id, user_id=None
|
||||
)
|
||||
assert stored is not None
|
||||
assert stored["contentType"] == "DOCUMENTATION"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_content_embedding_upsert(
|
||||
server,
|
||||
test_content_id: str,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test that storing embedding twice updates instead of duplicates."""
|
||||
cleanup_embeddings.append((ContentType.BLOCK, test_content_id, None))
|
||||
|
||||
# Store first time
|
||||
result1 = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="Original text",
|
||||
metadata={"version": 1},
|
||||
user_id=None,
|
||||
)
|
||||
assert result1 is True
|
||||
|
||||
# Store again with different text (upsert)
|
||||
result2 = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="Updated text",
|
||||
metadata={"version": 2},
|
||||
user_id=None,
|
||||
)
|
||||
assert result2 is True
|
||||
|
||||
# Verify only one record with updated text
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.BLOCK, test_content_id, user_id=None
|
||||
)
|
||||
assert stored is not None
|
||||
assert stored["searchableText"] == "Updated text"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# get_content_embedding Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_content_embedding_not_found(server):
|
||||
"""Test retrieving non-existent embedding returns None."""
|
||||
result = await embeddings.get_content_embedding(
|
||||
ContentType.STORE_AGENT, "non-existent-id", user_id=None
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_content_embedding_with_metadata(
|
||||
server,
|
||||
test_content_id: str,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test that metadata is correctly stored and retrieved."""
|
||||
cleanup_embeddings.append((ContentType.STORE_AGENT, test_content_id, None))
|
||||
|
||||
metadata = {
|
||||
"name": "Test Agent",
|
||||
"subHeading": "A test agent",
|
||||
"categories": ["ai", "productivity"],
|
||||
"customField": 123,
|
||||
}
|
||||
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="test",
|
||||
metadata=metadata,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.STORE_AGENT, test_content_id, user_id=None
|
||||
)
|
||||
|
||||
assert stored is not None
|
||||
assert stored["metadata"]["name"] == "Test Agent"
|
||||
assert stored["metadata"]["categories"] == ["ai", "productivity"]
|
||||
assert stored["metadata"]["customField"] == 123
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# delete_content_embedding Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_delete_content_embedding(
|
||||
server,
|
||||
test_content_id: str,
|
||||
mock_embedding: list[float],
|
||||
):
|
||||
"""Test deleting embedding removes it from database."""
|
||||
# Store embedding
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="To be deleted",
|
||||
metadata=None,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Verify it exists
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.BLOCK, test_content_id, user_id=None
|
||||
)
|
||||
assert stored is not None
|
||||
|
||||
# Delete it
|
||||
result = await embeddings.delete_content_embedding(
|
||||
ContentType.BLOCK, test_content_id, user_id=None
|
||||
)
|
||||
assert result is True
|
||||
|
||||
# Verify it's gone
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.BLOCK, test_content_id, user_id=None
|
||||
)
|
||||
assert stored is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_delete_content_embedding_not_found(server):
|
||||
"""Test deleting non-existent embedding doesn't error."""
|
||||
result = await embeddings.delete_content_embedding(
|
||||
ContentType.BLOCK, "non-existent-id", user_id=None
|
||||
)
|
||||
# Should succeed even if nothing to delete
|
||||
assert result is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# unified_hybrid_search Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unified_hybrid_search_finds_matching_content(
|
||||
server,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test unified search finds content matching the query."""
|
||||
# Create unique content IDs
|
||||
agent_id = f"test-agent-{uuid.uuid4()}"
|
||||
block_id = f"test-block-{uuid.uuid4()}"
|
||||
doc_id = f"test-doc-{uuid.uuid4()}"
|
||||
|
||||
cleanup_embeddings.append((ContentType.STORE_AGENT, agent_id, None))
|
||||
cleanup_embeddings.append((ContentType.BLOCK, block_id, None))
|
||||
cleanup_embeddings.append((ContentType.DOCUMENTATION, doc_id, None))
|
||||
|
||||
# Store embeddings for different content types
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=agent_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="AI writing assistant for blog posts",
|
||||
metadata={"name": "Writing Assistant"},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=block_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="Text generation block for creative writing",
|
||||
metadata={"name": "Text Generator"},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.DOCUMENTATION,
|
||||
content_id=doc_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="How to use writing blocks in AutoGPT",
|
||||
metadata={"title": "Writing Guide"},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Search for "writing" - should find all three
|
||||
results, total = await unified_hybrid_search(
|
||||
query="writing",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Should find at least our test content (may find others too)
|
||||
content_ids = [r["content_id"] for r in results]
|
||||
assert agent_id in content_ids or total >= 1 # Lexical search should find it
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unified_hybrid_search_filter_by_content_type(
|
||||
server,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test unified search can filter by content type."""
|
||||
agent_id = f"test-agent-{uuid.uuid4()}"
|
||||
block_id = f"test-block-{uuid.uuid4()}"
|
||||
|
||||
cleanup_embeddings.append((ContentType.STORE_AGENT, agent_id, None))
|
||||
cleanup_embeddings.append((ContentType.BLOCK, block_id, None))
|
||||
|
||||
# Store both types with same searchable text
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=agent_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="unique_search_term_xyz123",
|
||||
metadata={},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=block_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="unique_search_term_xyz123",
|
||||
metadata={},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Search only for BLOCK type
|
||||
results, total = await unified_hybrid_search(
|
||||
query="unique_search_term_xyz123",
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# All results should be BLOCK type
|
||||
for r in results:
|
||||
assert r["content_type"] == "BLOCK"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unified_hybrid_search_empty_query(server):
|
||||
"""Test unified search with empty query returns empty results."""
|
||||
results, total = await unified_hybrid_search(
|
||||
query="",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unified_hybrid_search_pagination(
|
||||
server,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test unified search pagination works correctly."""
|
||||
# Create multiple items
|
||||
content_ids = []
|
||||
for i in range(5):
|
||||
content_id = f"test-pagination-{uuid.uuid4()}"
|
||||
content_ids.append(content_id)
|
||||
cleanup_embeddings.append((ContentType.BLOCK, content_id, None))
|
||||
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text=f"pagination test item number {i}",
|
||||
metadata={"index": i},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Get first page
|
||||
page1_results, total1 = await unified_hybrid_search(
|
||||
query="pagination test",
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=2,
|
||||
)
|
||||
|
||||
# Get second page
|
||||
page2_results, total2 = await unified_hybrid_search(
|
||||
query="pagination test",
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=2,
|
||||
page_size=2,
|
||||
)
|
||||
|
||||
# Total should be consistent
|
||||
assert total1 == total2
|
||||
|
||||
# Pages should have different content (if we have enough results)
|
||||
if len(page1_results) > 0 and len(page2_results) > 0:
|
||||
page1_ids = {r["content_id"] for r in page1_results}
|
||||
page2_ids = {r["content_id"] for r in page2_results}
|
||||
# No overlap between pages
|
||||
assert page1_ids.isdisjoint(page2_ids)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unified_hybrid_search_min_score_filtering(
|
||||
server,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test unified search respects min_score threshold."""
|
||||
content_id = f"test-minscore-{uuid.uuid4()}"
|
||||
cleanup_embeddings.append((ContentType.BLOCK, content_id, None))
|
||||
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="completely unrelated content about bananas",
|
||||
metadata={},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Search with very high min_score - should filter out low relevance
|
||||
results_high, _ = await unified_hybrid_search(
|
||||
query="quantum computing algorithms",
|
||||
content_types=[ContentType.BLOCK],
|
||||
min_score=0.9, # Very high threshold
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Search with low min_score
|
||||
results_low, _ = await unified_hybrid_search(
|
||||
query="quantum computing algorithms",
|
||||
content_types=[ContentType.BLOCK],
|
||||
min_score=0.01, # Very low threshold
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# High threshold should have fewer or equal results
|
||||
assert len(results_high) <= len(results_low)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# hybrid_search (Store Agents) Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_hybrid_search_store_agents_sql_valid(server):
|
||||
"""Test that hybrid_search SQL executes without errors."""
|
||||
# This test verifies the SQL is syntactically correct
|
||||
# even if no results are found
|
||||
results, total = await hybrid_search(
|
||||
query="test agent",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Should not raise - verifies SQL is valid
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
assert total >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_hybrid_search_with_filters(server):
|
||||
"""Test hybrid_search with various filter options."""
|
||||
# Test with all filter types
|
||||
results, total = await hybrid_search(
|
||||
query="productivity",
|
||||
featured=True,
|
||||
creators=["test-creator"],
|
||||
category="productivity",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Should not raise - verifies filter SQL is valid
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_hybrid_search_pagination(server):
|
||||
"""Test hybrid_search pagination."""
|
||||
# Page 1
|
||||
results1, total1 = await hybrid_search(
|
||||
query="agent",
|
||||
page=1,
|
||||
page_size=5,
|
||||
)
|
||||
|
||||
# Page 2
|
||||
results2, total2 = await hybrid_search(
|
||||
query="agent",
|
||||
page=2,
|
||||
page_size=5,
|
||||
)
|
||||
|
||||
# Verify SQL executes without error
|
||||
assert isinstance(results1, list)
|
||||
assert isinstance(results2, list)
|
||||
assert isinstance(total1, int)
|
||||
assert isinstance(total2, int)
|
||||
|
||||
# If page 1 has results, total should be > 0
|
||||
# Note: total from page 2 may be 0 if no results on that page (COUNT(*) OVER limitation)
|
||||
if results1:
|
||||
assert total1 > 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SQL Validity Tests (verify queries don't break)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_all_content_types_searchable(server):
|
||||
"""Test that all content types can be searched without SQL errors."""
|
||||
for content_type in [
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.BLOCK,
|
||||
ContentType.DOCUMENTATION,
|
||||
]:
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
content_types=[content_type],
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_multiple_content_types_searchable(server):
|
||||
"""Test searching multiple content types at once."""
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_search_all_content_types_default(server):
|
||||
"""Test searching all content types (default behavior)."""
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
content_types=None, # Should search all
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -4,12 +4,13 @@ Integration tests for embeddings with schema handling.
|
||||
These tests verify that embeddings operations work correctly across different database schemas.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
from backend.api.features.store.embeddings import EMBEDDING_DIM
|
||||
|
||||
# Schema prefix tests removed - functionality moved to db.raw_with_schema() helper
|
||||
|
||||
@@ -28,7 +29,7 @@ async def test_store_content_embedding_with_schema():
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1] * 1536,
|
||||
embedding=[0.1] * EMBEDDING_DIM,
|
||||
searchable_text="test text",
|
||||
metadata={"test": "data"},
|
||||
user_id=None,
|
||||
@@ -125,84 +126,69 @@ async def test_delete_content_embedding_with_schema():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_get_embedding_stats_with_schema():
|
||||
"""Test embedding statistics with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
"""Test embedding statistics with proper schema handling via content handlers."""
|
||||
# Mock handler to return stats
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.get_stats = AsyncMock(
|
||||
return_value={
|
||||
"total": 100,
|
||||
"with_embeddings": 80,
|
||||
"without_embeddings": 20,
|
||||
}
|
||||
)
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
# Mock both query results
|
||||
mock_client.query_raw.side_effect = [
|
||||
[{"count": 100}], # total_approved
|
||||
[{"count": 80}], # with_embeddings
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
|
||||
{ContentType.STORE_AGENT: mock_handler},
|
||||
):
|
||||
result = await embeddings.get_embedding_stats()
|
||||
|
||||
result = await embeddings.get_embedding_stats()
|
||||
# Verify handler was called
|
||||
mock_handler.get_stats.assert_called_once()
|
||||
|
||||
# Verify both queries were called
|
||||
assert mock_client.query_raw.call_count == 2
|
||||
|
||||
# Get both SQL queries
|
||||
first_call = mock_client.query_raw.call_args_list[0]
|
||||
second_call = mock_client.query_raw.call_args_list[1]
|
||||
|
||||
first_sql = first_call[0][0]
|
||||
second_sql = second_call[0][0]
|
||||
|
||||
# Verify schema prefix in both queries
|
||||
assert '"platform"."StoreListingVersion"' in first_sql
|
||||
assert '"platform"."StoreListingVersion"' in second_sql
|
||||
assert '"platform"."UnifiedContentEmbedding"' in second_sql
|
||||
|
||||
# Verify results
|
||||
assert result["total_approved"] == 100
|
||||
assert result["with_embeddings"] == 80
|
||||
assert result["without_embeddings"] == 20
|
||||
assert result["coverage_percent"] == 80.0
|
||||
# Verify new result structure
|
||||
assert "by_type" in result
|
||||
assert "totals" in result
|
||||
assert result["totals"]["total"] == 100
|
||||
assert result["totals"]["with_embeddings"] == 80
|
||||
assert result["totals"]["without_embeddings"] == 20
|
||||
assert result["totals"]["coverage_percent"] == 80.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backfill_missing_embeddings_with_schema():
|
||||
"""Test backfilling embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
"""Test backfilling embeddings via content handlers."""
|
||||
from backend.api.features.store.content_handlers import ContentItem
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
# Mock missing embeddings query
|
||||
mock_client.query_raw.return_value = [
|
||||
{
|
||||
"id": "version-1",
|
||||
"name": "Test Agent",
|
||||
"description": "Test description",
|
||||
"subHeading": "Test heading",
|
||||
"categories": ["test"],
|
||||
}
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
# Create mock content item
|
||||
mock_item = ContentItem(
|
||||
content_id="version-1",
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text="Test Agent Test description",
|
||||
metadata={"name": "Test Agent"},
|
||||
)
|
||||
|
||||
# Mock handler
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.get_missing_items = AsyncMock(return_value=[mock_item])
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
|
||||
{ContentType.STORE_AGENT: mock_handler},
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding",
|
||||
return_value=[0.1] * EMBEDDING_DIM,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.ensure_embedding"
|
||||
) as mock_ensure:
|
||||
mock_ensure.return_value = True
|
||||
|
||||
"backend.api.features.store.embeddings.store_content_embedding",
|
||||
return_value=True,
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=10)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.query_raw.called
|
||||
|
||||
# Get the SQL query
|
||||
call_args = mock_client.query_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix in query
|
||||
assert '"platform"."StoreListingVersion"' in sql_query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify ensure_embedding was called
|
||||
assert mock_ensure.called
|
||||
# Verify handler was called
|
||||
mock_handler.get_missing_items.assert_called_once_with(10)
|
||||
|
||||
# Verify results
|
||||
assert result["processed"] == 1
|
||||
@@ -226,7 +212,7 @@ async def test_ensure_content_embedding_with_schema():
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding"
|
||||
) as mock_generate:
|
||||
mock_generate.return_value = [0.1] * 1536
|
||||
mock_generate.return_value = [0.1] * EMBEDDING_DIM
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.store_content_embedding"
|
||||
@@ -260,7 +246,7 @@ async def test_backward_compatibility_store_embedding():
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id",
|
||||
embedding=[0.1] * 1536,
|
||||
embedding=[0.1] * EMBEDDING_DIM,
|
||||
tx=None,
|
||||
)
|
||||
|
||||
@@ -315,7 +301,7 @@ async def test_schema_handling_error_cases():
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1] * 1536,
|
||||
embedding=[0.1] * EMBEDDING_DIM,
|
||||
searchable_text="test",
|
||||
metadata=None,
|
||||
user_id=None,
|
||||
|
||||
@@ -63,7 +63,7 @@ async def test_generate_embedding_success():
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1536
|
||||
assert len(result) == embeddings.EMBEDDING_DIM
|
||||
assert result[0] == 0.1
|
||||
|
||||
mock_client.embeddings.create.assert_called_once_with(
|
||||
@@ -110,7 +110,7 @@ async def test_generate_embedding_text_truncation():
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock()]
|
||||
mock_response.data[0].embedding = [0.1] * 1536
|
||||
mock_response.data[0].embedding = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
# Use AsyncMock for async embeddings.create method
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
@@ -297,72 +297,92 @@ async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats():
|
||||
"""Test embedding statistics retrieval."""
|
||||
# Mock approved count query and embedded count query
|
||||
mock_approved_result = [{"count": 100}]
|
||||
mock_embedded_result = [{"count": 75}]
|
||||
# Mock handler stats for each content type
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.get_stats = AsyncMock(
|
||||
return_value={
|
||||
"total": 100,
|
||||
"with_embeddings": 75,
|
||||
"without_embeddings": 25,
|
||||
}
|
||||
)
|
||||
|
||||
# Patch the CONTENT_HANDLERS where it's used (in embeddings module)
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
side_effect=[mock_approved_result, mock_embedded_result],
|
||||
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
|
||||
{ContentType.STORE_AGENT: mock_handler},
|
||||
):
|
||||
result = await embeddings.get_embedding_stats()
|
||||
|
||||
assert result["total_approved"] == 100
|
||||
assert result["with_embeddings"] == 75
|
||||
assert result["without_embeddings"] == 25
|
||||
assert result["coverage_percent"] == 75.0
|
||||
assert "by_type" in result
|
||||
assert "totals" in result
|
||||
assert result["totals"]["total"] == 100
|
||||
assert result["totals"]["with_embeddings"] == 75
|
||||
assert result["totals"]["without_embeddings"] == 25
|
||||
assert result["totals"]["coverage_percent"] == 75.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.ensure_embedding")
|
||||
async def test_backfill_missing_embeddings_success(mock_ensure):
|
||||
@patch("backend.api.features.store.embeddings.store_content_embedding")
|
||||
async def test_backfill_missing_embeddings_success(mock_store):
|
||||
"""Test backfill with successful embedding generation."""
|
||||
# Mock missing embeddings query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "version-1",
|
||||
"name": "Agent 1",
|
||||
"description": "Description 1",
|
||||
"subHeading": "Heading 1",
|
||||
"categories": ["AI"],
|
||||
},
|
||||
{
|
||||
"id": "version-2",
|
||||
"name": "Agent 2",
|
||||
"description": "Description 2",
|
||||
"subHeading": "Heading 2",
|
||||
"categories": ["Productivity"],
|
||||
},
|
||||
# Mock ContentItem from handlers
|
||||
from backend.api.features.store.content_handlers import ContentItem
|
||||
|
||||
mock_items = [
|
||||
ContentItem(
|
||||
content_id="version-1",
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text="Agent 1 Description 1",
|
||||
metadata={"name": "Agent 1"},
|
||||
),
|
||||
ContentItem(
|
||||
content_id="version-2",
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text="Agent 2 Description 2",
|
||||
metadata={"name": "Agent 2"},
|
||||
),
|
||||
]
|
||||
|
||||
# Mock ensure_embedding to succeed for first, fail for second
|
||||
mock_ensure.side_effect = [True, False]
|
||||
# Mock handler to return missing items
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.get_missing_items = AsyncMock(return_value=mock_items)
|
||||
|
||||
# Mock store_content_embedding to succeed for first, fail for second
|
||||
mock_store.side_effect = [True, False]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_missing,
|
||||
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
|
||||
{ContentType.STORE_AGENT: mock_handler},
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding",
|
||||
return_value=[0.1] * embeddings.EMBEDDING_DIM,
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
|
||||
assert result["processed"] == 2
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 1
|
||||
assert mock_ensure.call_count == 2
|
||||
assert result["processed"] == 2
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 1
|
||||
assert mock_store.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_backfill_missing_embeddings_no_missing():
|
||||
"""Test backfill when no embeddings are missing."""
|
||||
# Mock handler to return no missing items
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.get_missing_items = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=[],
|
||||
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
|
||||
{ContentType.STORE_AGENT: mock_handler},
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
|
||||
assert result["processed"] == 0
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["message"] == "No missing embeddings"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
"""
|
||||
Hybrid Search for Store Agents
|
||||
Unified Hybrid Search
|
||||
|
||||
Combines semantic (embedding) search with lexical (tsvector) search
|
||||
for improved relevance in marketplace agent discovery.
|
||||
for improved relevance across all content types (agents, blocks, docs).
|
||||
Includes BM25 reranking for improved lexical relevance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from prisma.enums import ContentType
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
from backend.api.features.store.embeddings import (
|
||||
EMBEDDING_DIM,
|
||||
embed_query,
|
||||
embedding_to_vector_string,
|
||||
)
|
||||
@@ -19,18 +24,385 @@ from backend.data.db import query_raw_with_schema
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchWeights:
|
||||
"""Weights for combining search signals."""
|
||||
# ============================================================================
|
||||
# BM25 Reranking
|
||||
# ============================================================================
|
||||
|
||||
semantic: float = 0.30 # Embedding cosine similarity
|
||||
lexical: float = 0.30 # tsvector ts_rank_cd score
|
||||
category: float = 0.20 # Category match boost
|
||||
recency: float = 0.10 # Newer agents ranked higher
|
||||
popularity: float = 0.10 # Agent usage/runs (PageRank-like)
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25 - lowercase and split on non-alphanumeric."""
|
||||
if not text:
|
||||
return []
|
||||
# Lowercase and split on non-alphanumeric characters
|
||||
tokens = re.findall(r"\b\w+\b", text.lower())
|
||||
return tokens
|
||||
|
||||
|
||||
def bm25_rerank(
|
||||
query: str,
|
||||
results: list[dict[str, Any]],
|
||||
text_field: str = "searchable_text",
|
||||
bm25_weight: float = 0.3,
|
||||
original_score_field: str = "combined_score",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Rerank search results using BM25.
|
||||
|
||||
Combines the original combined_score with BM25 score for improved
|
||||
lexical relevance, especially for exact term matches.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
results: List of result dicts with text_field and original_score_field
|
||||
text_field: Field name containing the text to score
|
||||
bm25_weight: Weight for BM25 score (0-1). Original score gets (1 - bm25_weight)
|
||||
original_score_field: Field name containing the original score
|
||||
|
||||
Returns:
|
||||
Results list sorted by combined score (BM25 + original)
|
||||
"""
|
||||
if not results or not query:
|
||||
return results
|
||||
|
||||
# Extract texts and tokenize
|
||||
corpus = [tokenize(r.get(text_field, "") or "") for r in results]
|
||||
|
||||
# Handle edge case where all documents are empty
|
||||
if all(len(doc) == 0 for doc in corpus):
|
||||
return results
|
||||
|
||||
# Build BM25 index
|
||||
bm25 = BM25Okapi(corpus)
|
||||
|
||||
# Score query against corpus
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
return results
|
||||
|
||||
bm25_scores = bm25.get_scores(query_tokens)
|
||||
|
||||
# Normalize BM25 scores to 0-1 range
|
||||
max_bm25 = max(bm25_scores) if max(bm25_scores) > 0 else 1.0
|
||||
normalized_bm25 = [s / max_bm25 for s in bm25_scores]
|
||||
|
||||
# Combine scores
|
||||
original_weight = 1.0 - bm25_weight
|
||||
for i, result in enumerate(results):
|
||||
original_score = result.get(original_score_field, 0) or 0
|
||||
result["bm25_score"] = normalized_bm25[i]
|
||||
final_score = (
|
||||
original_weight * original_score + bm25_weight * normalized_bm25[i]
|
||||
)
|
||||
result["final_score"] = final_score
|
||||
result["relevance"] = final_score
|
||||
|
||||
# Sort by relevance descending
|
||||
results.sort(key=lambda x: x.get("relevance", 0), reverse=True)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedSearchWeights:
|
||||
"""Weights for unified search (no popularity signal)."""
|
||||
|
||||
semantic: float = 0.40 # Embedding cosine similarity
|
||||
lexical: float = 0.40 # tsvector ts_rank_cd score
|
||||
category: float = 0.10 # Category match boost (for types that have categories)
|
||||
recency: float = 0.10 # Newer content ranked higher
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate weights are non-negative and sum to approximately 1.0."""
|
||||
total = self.semantic + self.lexical + self.category + self.recency
|
||||
|
||||
if any(
|
||||
w < 0 for w in [self.semantic, self.lexical, self.category, self.recency]
|
||||
):
|
||||
raise ValueError("All weights must be non-negative")
|
||||
|
||||
if not (0.99 <= total <= 1.01):
|
||||
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
|
||||
|
||||
|
||||
# Default weights for unified search
|
||||
DEFAULT_UNIFIED_WEIGHTS = UnifiedSearchWeights()
|
||||
|
||||
# Minimum relevance score thresholds
|
||||
DEFAULT_MIN_SCORE = 0.15 # For unified search (more permissive)
|
||||
DEFAULT_STORE_AGENT_MIN_SCORE = 0.20 # For store agent search (original threshold)
|
||||
|
||||
|
||||
async def unified_hybrid_search(
|
||||
query: str,
|
||||
content_types: list[ContentType] | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
weights: UnifiedSearchWeights | None = None,
|
||||
min_score: float | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Unified hybrid search across all content types.
|
||||
|
||||
Searches UnifiedContentEmbedding using both semantic (vector) and lexical (tsvector) signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
content_types: List of content types to search. Defaults to all public types.
|
||||
category: Filter by category (for content types that support it)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Results per page
|
||||
weights: Custom weights for search signals
|
||||
min_score: Minimum relevance score threshold (0-1)
|
||||
user_id: User ID for searching private content (library agents)
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, total count)
|
||||
"""
|
||||
# Validate inputs
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return [], 0
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = 1
|
||||
if page_size > 100:
|
||||
page_size = 100
|
||||
|
||||
if content_types is None:
|
||||
content_types = [
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.BLOCK,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
if weights is None:
|
||||
weights = DEFAULT_UNIFIED_WEIGHTS
|
||||
if min_score is None:
|
||||
min_score = DEFAULT_MIN_SCORE
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Graceful degradation if embedding unavailable
|
||||
if query_embedding is None or not query_embedding:
|
||||
logger.warning(
|
||||
"Failed to generate query embedding - falling back to lexical-only search. "
|
||||
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
||||
)
|
||||
query_embedding = [0.0] * EMBEDDING_DIM
|
||||
# Redistribute semantic weight to lexical
|
||||
total_non_semantic = weights.lexical + weights.category + weights.recency
|
||||
if total_non_semantic > 0:
|
||||
factor = 1.0 / total_non_semantic
|
||||
weights = UnifiedSearchWeights(
|
||||
semantic=0.0,
|
||||
lexical=weights.lexical * factor,
|
||||
category=weights.category * factor,
|
||||
recency=weights.recency * factor,
|
||||
)
|
||||
else:
|
||||
weights = UnifiedSearchWeights(
|
||||
semantic=0.0, lexical=1.0, category=0.0, recency=0.0
|
||||
)
|
||||
|
||||
# Build parameters
|
||||
params: list[Any] = []
|
||||
param_idx = 1
|
||||
|
||||
# Query for lexical search
|
||||
params.append(query)
|
||||
query_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Query lowercase for category matching
|
||||
params.append(query.lower())
|
||||
query_lower_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Embedding
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
params.append(embedding_str)
|
||||
embedding_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Content types
|
||||
content_type_values = [ct.value for ct in content_types]
|
||||
params.append(content_type_values)
|
||||
content_types_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# User ID filter (for private content)
|
||||
user_filter = ""
|
||||
if user_id is not None:
|
||||
params.append(user_id)
|
||||
user_filter = f'AND (uce."userId" = ${param_idx} OR uce."userId" IS NULL)'
|
||||
param_idx += 1
|
||||
else:
|
||||
user_filter = 'AND uce."userId" IS NULL'
|
||||
|
||||
# Weights
|
||||
params.append(weights.semantic)
|
||||
w_semantic = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.lexical)
|
||||
w_lexical = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.category)
|
||||
w_category = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.recency)
|
||||
w_recency = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Min score
|
||||
params.append(min_score)
|
||||
min_score_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Pagination
|
||||
params.append(page_size)
|
||||
limit_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(offset)
|
||||
offset_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Unified search query on UnifiedContentEmbedding
|
||||
sql_query = f"""
|
||||
WITH candidates AS (
|
||||
-- Lexical matches (uses GIN index on search column)
|
||||
SELECT uce.id, uce."contentType", uce."contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
|
||||
{user_filter}
|
||||
AND uce.search @@ plainto_tsquery('english', {query_param})
|
||||
|
||||
UNION
|
||||
|
||||
-- Semantic matches (uses HNSW index on embedding)
|
||||
(
|
||||
SELECT uce.id, uce."contentType", uce."contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
|
||||
{user_filter}
|
||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||
LIMIT 200
|
||||
)
|
||||
),
|
||||
search_scores AS (
|
||||
SELECT
|
||||
uce."contentType" as content_type,
|
||||
uce."contentId" as content_id,
|
||||
uce."searchableText" as searchable_text,
|
||||
uce.metadata,
|
||||
uce."updatedAt" as updated_at,
|
||||
-- Semantic score: cosine similarity (1 - distance)
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score: ts_rank_cd
|
||||
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match from metadata
|
||||
CASE
|
||||
WHEN uce.metadata ? 'categories' AND EXISTS (
|
||||
SELECT 1 FROM jsonb_array_elements_text(uce.metadata->'categories') cat
|
||||
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
|
||||
)
|
||||
THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
-- Recency score: linear decay over 90 days
|
||||
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - uce."updatedAt")) / (90 * 24 * 3600)) as recency_score
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce ON c.id = uce.id
|
||||
),
|
||||
max_lexical AS (
|
||||
SELECT GREATEST(MAX(lexical_raw), 0.001) as max_val FROM search_scores
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
ss.*,
|
||||
ss.lexical_raw / ml.max_val as lexical_score
|
||||
FROM search_scores ss
|
||||
CROSS JOIN max_lexical ml
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
content_type,
|
||||
content_id,
|
||||
searchable_text,
|
||||
metadata,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
(
|
||||
{w_semantic} * semantic_score +
|
||||
{w_lexical} * lexical_score +
|
||||
{w_category} * category_score +
|
||||
{w_recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
),
|
||||
filtered AS (
|
||||
SELECT *, COUNT(*) OVER () as total_count
|
||||
FROM scored
|
||||
WHERE combined_score >= {min_score_param}
|
||||
)
|
||||
SELECT * FROM filtered
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(
|
||||
sql_query, *params, set_public_search_path=True
|
||||
)
|
||||
|
||||
total = results[0]["total_count"] if results else 0
|
||||
# Apply BM25 reranking
|
||||
if results:
|
||||
results = bm25_rerank(
|
||||
query=query,
|
||||
results=results,
|
||||
text_field="searchable_text",
|
||||
bm25_weight=0.3,
|
||||
original_score_field="combined_score",
|
||||
)
|
||||
|
||||
# Clean up results
|
||||
for result in results:
|
||||
result.pop("total_count", None)
|
||||
|
||||
logger.info(f"Unified hybrid search: {len(results)} results, {total} total")
|
||||
|
||||
return results, total
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Store Agent specific search (with full metadata)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoreAgentSearchWeights:
|
||||
"""Weights for store agent search including popularity."""
|
||||
|
||||
semantic: float = 0.30
|
||||
lexical: float = 0.30
|
||||
category: float = 0.20
|
||||
recency: float = 0.10
|
||||
popularity: float = 0.10
|
||||
|
||||
def __post_init__(self):
|
||||
total = (
|
||||
self.semantic
|
||||
+ self.lexical
|
||||
@@ -38,7 +410,6 @@ class HybridSearchWeights:
|
||||
+ self.recency
|
||||
+ self.popularity
|
||||
)
|
||||
|
||||
if any(
|
||||
w < 0
|
||||
for w in [
|
||||
@@ -50,46 +421,11 @@ class HybridSearchWeights:
|
||||
]
|
||||
):
|
||||
raise ValueError("All weights must be non-negative")
|
||||
|
||||
if not (0.99 <= total <= 1.01):
|
||||
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
|
||||
|
||||
|
||||
DEFAULT_WEIGHTS = HybridSearchWeights()
|
||||
|
||||
# Minimum relevance score threshold - agents below this are filtered out
|
||||
# With weights (0.30 semantic + 0.30 lexical + 0.20 category + 0.10 recency + 0.10 popularity):
|
||||
# - 0.20 means at least ~60% semantic match OR strong lexical match required
|
||||
# - Ensures only genuinely relevant results are returned
|
||||
# - Recency/popularity alone (0.10 each) won't pass the threshold
|
||||
DEFAULT_MIN_SCORE = 0.20
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchResult:
|
||||
"""A single search result with score breakdown."""
|
||||
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_image: str
|
||||
creator_username: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
runs: int
|
||||
rating: float
|
||||
categories: list[str]
|
||||
featured: bool
|
||||
is_available: bool
|
||||
updated_at: datetime
|
||||
|
||||
# Score breakdown (for debugging/tuning)
|
||||
combined_score: float
|
||||
semantic_score: float = 0.0
|
||||
lexical_score: float = 0.0
|
||||
category_score: float = 0.0
|
||||
recency_score: float = 0.0
|
||||
popularity_score: float = 0.0
|
||||
DEFAULT_STORE_AGENT_WEIGHTS = StoreAgentSearchWeights()
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
@@ -102,276 +438,277 @@ async def hybrid_search(
|
||||
) = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
weights: HybridSearchWeights | None = None,
|
||||
weights: StoreAgentSearchWeights | None = None,
|
||||
min_score: float | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Perform hybrid search combining semantic and lexical signals.
|
||||
Hybrid search for store agents with full metadata.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
featured: Filter for featured agents only
|
||||
creators: Filter by creator usernames
|
||||
category: Filter by category
|
||||
sorted_by: Sort order (relevance uses hybrid scoring)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Results per page
|
||||
weights: Custom weights for search signals
|
||||
min_score: Minimum relevance score threshold (0-1). Results below
|
||||
this score are filtered out. Defaults to DEFAULT_MIN_SCORE.
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, total count). Returns empty list if no
|
||||
results meet the minimum relevance threshold.
|
||||
Uses UnifiedContentEmbedding for search, joins to StoreAgent for metadata.
|
||||
"""
|
||||
# Validate inputs
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return [], 0 # Empty query returns no results
|
||||
return [], 0
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = 1
|
||||
if page_size > 100: # Cap at reasonable limit to prevent performance issues
|
||||
if page_size > 100:
|
||||
page_size = 100
|
||||
|
||||
if weights is None:
|
||||
weights = DEFAULT_WEIGHTS
|
||||
weights = DEFAULT_STORE_AGENT_WEIGHTS
|
||||
if min_score is None:
|
||||
min_score = DEFAULT_MIN_SCORE
|
||||
min_score = (
|
||||
DEFAULT_STORE_AGENT_MIN_SCORE # Use original threshold for store agents
|
||||
)
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Build WHERE clause conditions
|
||||
where_parts: list[str] = ["sa.is_available = true"]
|
||||
# Graceful degradation
|
||||
if query_embedding is None or not query_embedding:
|
||||
logger.warning(
|
||||
"Failed to generate query embedding - falling back to lexical-only search."
|
||||
)
|
||||
query_embedding = [0.0] * EMBEDDING_DIM
|
||||
total_non_semantic = (
|
||||
weights.lexical + weights.category + weights.recency + weights.popularity
|
||||
)
|
||||
if total_non_semantic > 0:
|
||||
factor = 1.0 / total_non_semantic
|
||||
weights = StoreAgentSearchWeights(
|
||||
semantic=0.0,
|
||||
lexical=weights.lexical * factor,
|
||||
category=weights.category * factor,
|
||||
recency=weights.recency * factor,
|
||||
popularity=weights.popularity * factor,
|
||||
)
|
||||
else:
|
||||
weights = StoreAgentSearchWeights(
|
||||
semantic=0.0, lexical=1.0, category=0.0, recency=0.0, popularity=0.0
|
||||
)
|
||||
|
||||
# Build parameters
|
||||
params: list[Any] = []
|
||||
param_index = 1
|
||||
param_idx = 1
|
||||
|
||||
# Add search query for lexical matching
|
||||
params.append(query)
|
||||
query_param = f"${param_index}"
|
||||
param_index += 1
|
||||
query_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Add lowercased query for category matching
|
||||
params.append(query.lower())
|
||||
query_lower_param = f"${param_index}"
|
||||
param_index += 1
|
||||
query_lower_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
params.append(embedding_str)
|
||||
embedding_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Build WHERE clause for StoreAgent filters
|
||||
where_parts = ["sa.is_available = true"]
|
||||
|
||||
if featured:
|
||||
where_parts.append("sa.featured = true")
|
||||
|
||||
if creators:
|
||||
where_parts.append(f"sa.creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
where_parts.append(f"sa.creator_username = ANY(${param_idx})")
|
||||
param_idx += 1
|
||||
|
||||
if category:
|
||||
where_parts.append(f"${param_index} = ANY(sa.categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
where_parts.append(f"${param_idx} = ANY(sa.categories)")
|
||||
param_idx += 1
|
||||
|
||||
# Safe: where_parts only contains hardcoded strings with $N parameter placeholders
|
||||
# No user input is concatenated directly into the SQL string
|
||||
where_clause = " AND ".join(where_parts)
|
||||
|
||||
# Embedding is required for hybrid search - fail fast if unavailable
|
||||
if query_embedding is None or not query_embedding:
|
||||
# Log detailed error server-side
|
||||
logger.error(
|
||||
"Failed to generate query embedding. "
|
||||
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
||||
)
|
||||
# Raise generic error to client
|
||||
raise ValueError("Search service temporarily unavailable")
|
||||
|
||||
# Add embedding parameter
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
params.append(embedding_str)
|
||||
embedding_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Add weight parameters for SQL calculation
|
||||
# Weights
|
||||
params.append(weights.semantic)
|
||||
weight_semantic_param = f"${param_index}"
|
||||
param_index += 1
|
||||
w_semantic = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.lexical)
|
||||
weight_lexical_param = f"${param_index}"
|
||||
param_index += 1
|
||||
w_lexical = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.category)
|
||||
weight_category_param = f"${param_index}"
|
||||
param_index += 1
|
||||
w_category = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.recency)
|
||||
weight_recency_param = f"${param_index}"
|
||||
param_index += 1
|
||||
w_recency = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.popularity)
|
||||
weight_popularity_param = f"${param_index}"
|
||||
param_index += 1
|
||||
w_popularity = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Add min_score parameter
|
||||
params.append(min_score)
|
||||
min_score_param = f"${param_index}"
|
||||
param_index += 1
|
||||
min_score_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Optimized hybrid search query:
|
||||
# 1. Direct join to UnifiedContentEmbedding via contentId=storeListingVersionId (no redundant JOINs)
|
||||
# 2. UNION approach (deduplicates agents matching both branches)
|
||||
# 3. COUNT(*) OVER() to get total count in single query
|
||||
# 4. Optimized category matching with EXISTS + unnest
|
||||
# 5. Pre-calculated max values for lexical and popularity normalization
|
||||
# 6. Simplified recency calculation with linear decay
|
||||
# 7. Logarithmic popularity scaling to prevent viral agents from dominating
|
||||
params.append(page_size)
|
||||
limit_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(offset)
|
||||
offset_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Query using UnifiedContentEmbedding for search, StoreAgent for metadata
|
||||
sql_query = f"""
|
||||
WITH candidates AS (
|
||||
-- Lexical matches (uses GIN index on search column)
|
||||
SELECT sa."storeListingVersionId"
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
WHERE {where_clause}
|
||||
AND sa.search @@ plainto_tsquery('english', {query_param})
|
||||
WITH candidates AS (
|
||||
-- Lexical matches via UnifiedContentEmbedding.search
|
||||
SELECT uce."contentId" as "storeListingVersionId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND uce.search @@ plainto_tsquery('english', {query_param})
|
||||
AND {where_clause}
|
||||
|
||||
UNION
|
||||
UNION
|
||||
|
||||
-- Semantic matches (uses HNSW index on embedding with KNN)
|
||||
SELECT "storeListingVersionId"
|
||||
FROM (
|
||||
SELECT sa."storeListingVersionId", uce.embedding
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
WHERE {where_clause}
|
||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||
LIMIT 200
|
||||
) semantic_results
|
||||
),
|
||||
search_scores AS (
|
||||
SELECT
|
||||
sa.slug,
|
||||
sa.agent_name,
|
||||
sa.agent_image,
|
||||
sa.creator_username,
|
||||
sa.creator_avatar,
|
||||
sa.sub_heading,
|
||||
sa.description,
|
||||
sa.runs,
|
||||
sa.rating,
|
||||
sa.categories,
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
-- Semantic score: cosine similarity (1 - distance)
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score: ts_rank_cd (will be normalized later)
|
||||
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match: optimized with unnest for better performance
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(sa.categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
|
||||
)
|
||||
THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
-- Recency score: linear decay over 90 days (simpler than exponential)
|
||||
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score,
|
||||
-- Popularity raw: agent runs count (will be normalized with log scaling)
|
||||
sa.runs as popularity_raw
|
||||
FROM candidates c
|
||||
-- Semantic matches via UnifiedContentEmbedding.embedding
|
||||
SELECT uce."contentId" as "storeListingVersionId"
|
||||
FROM (
|
||||
SELECT uce."contentId", uce.embedding
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||
LEFT JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
),
|
||||
max_lexical AS (
|
||||
SELECT MAX(lexical_raw) as max_val FROM search_scores
|
||||
),
|
||||
max_popularity AS (
|
||||
SELECT MAX(popularity_raw) as max_val FROM search_scores
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
ss.*,
|
||||
-- Normalize lexical score by pre-calculated max
|
||||
CASE
|
||||
WHEN ml.max_val > 0
|
||||
THEN ss.lexical_raw / ml.max_val
|
||||
ELSE 0
|
||||
END as lexical_score,
|
||||
-- Normalize popularity with logarithmic scaling to prevent viral agents from dominating
|
||||
-- LOG(1 + runs) / LOG(1 + max_runs) ensures score is 0-1 range
|
||||
CASE
|
||||
WHEN mp.max_val > 0 AND ss.popularity_raw > 0
|
||||
THEN LN(1 + ss.popularity_raw) / LN(1 + mp.max_val)
|
||||
ELSE 0
|
||||
END as popularity_score
|
||||
FROM search_scores ss
|
||||
CROSS JOIN max_lexical ml
|
||||
CROSS JOIN max_popularity mp
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
popularity_score,
|
||||
(
|
||||
{weight_semantic_param} * semantic_score +
|
||||
{weight_lexical_param} * lexical_score +
|
||||
{weight_category_param} * category_score +
|
||||
{weight_recency_param} * recency_score +
|
||||
{weight_popularity_param} * popularity_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
),
|
||||
filtered AS (
|
||||
SELECT
|
||||
*,
|
||||
COUNT(*) OVER () as total_count
|
||||
FROM scored
|
||||
WHERE combined_score >= {min_score_param}
|
||||
)
|
||||
SELECT * FROM filtered
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND {where_clause}
|
||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||
LIMIT 200
|
||||
) uce
|
||||
),
|
||||
search_scores AS (
|
||||
SELECT
|
||||
sa.slug,
|
||||
sa.agent_name,
|
||||
sa.agent_image,
|
||||
sa.creator_username,
|
||||
sa.creator_avatar,
|
||||
sa.sub_heading,
|
||||
sa.description,
|
||||
sa.runs,
|
||||
sa.rating,
|
||||
sa.categories,
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
-- Searchable text for BM25 reranking
|
||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||
-- Semantic score
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score (raw, will normalize)
|
||||
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(sa.categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
|
||||
)
|
||||
THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
-- Recency
|
||||
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score,
|
||||
-- Popularity (raw)
|
||||
sa.runs as popularity_raw
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId"
|
||||
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
),
|
||||
max_vals AS (
|
||||
SELECT
|
||||
GREATEST(MAX(lexical_raw), 0.001) as max_lexical,
|
||||
GREATEST(MAX(popularity_raw), 1) as max_popularity
|
||||
FROM search_scores
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
ss.*,
|
||||
ss.lexical_raw / mv.max_lexical as lexical_score,
|
||||
CASE
|
||||
WHEN ss.popularity_raw > 0
|
||||
THEN LN(1 + ss.popularity_raw) / LN(1 + mv.max_popularity)
|
||||
ELSE 0
|
||||
END as popularity_score
|
||||
FROM search_scores ss
|
||||
CROSS JOIN max_vals mv
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
searchable_text,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
popularity_score,
|
||||
(
|
||||
{w_semantic} * semantic_score +
|
||||
{w_lexical} * lexical_score +
|
||||
{w_category} * category_score +
|
||||
{w_recency} * recency_score +
|
||||
{w_popularity} * popularity_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
),
|
||||
filtered AS (
|
||||
SELECT *, COUNT(*) OVER () as total_count
|
||||
FROM scored
|
||||
WHERE combined_score >= {min_score_param}
|
||||
)
|
||||
SELECT * FROM filtered
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
|
||||
# Execute search query - includes total_count via window function
|
||||
results = await query_raw_with_schema(
|
||||
sql_query, *params, set_public_search_path=True
|
||||
)
|
||||
|
||||
# Extract total count from first result (all rows have same count)
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
# Remove total_count from results before returning
|
||||
# Apply BM25 reranking
|
||||
if results:
|
||||
results = bm25_rerank(
|
||||
query=query,
|
||||
results=results,
|
||||
text_field="searchable_text",
|
||||
bm25_weight=0.3,
|
||||
original_score_field="combined_score",
|
||||
)
|
||||
|
||||
for result in results:
|
||||
result.pop("total_count", None)
|
||||
result.pop("searchable_text", None)
|
||||
|
||||
# Log without sensitive query content
|
||||
logger.info(f"Hybrid search: {len(results)} results, {total} total")
|
||||
logger.info(f"Hybrid search (store agents): {len(results)} results, {total} total")
|
||||
|
||||
return results, total
|
||||
|
||||
@@ -381,13 +718,10 @@ async def hybrid_search_simple(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Simplified hybrid search for common use cases.
|
||||
"""Simplified hybrid search for store agents."""
|
||||
return await hybrid_search(query=query, page=page, page_size=page_size)
|
||||
|
||||
Uses default weights and no filters.
|
||||
"""
|
||||
return await hybrid_search(
|
||||
query=query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
||||
# for existing code that expects the popularity parameter
|
||||
HybridSearchWeights = StoreAgentSearchWeights
|
||||
|
||||
@@ -7,8 +7,15 @@ These tests verify that hybrid search works correctly across different database
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store.hybrid_search import HybridSearchWeights, hybrid_search
|
||||
from backend.api.features.store import embeddings
|
||||
from backend.api.features.store.hybrid_search import (
|
||||
HybridSearchWeights,
|
||||
UnifiedSearchWeights,
|
||||
hybrid_search,
|
||||
unified_hybrid_search,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -49,7 +56,7 @@ async def test_hybrid_search_with_schema_handling():
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536 # Mock embedding
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM # Mock embedding
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query=query,
|
||||
@@ -85,7 +92,7 @@ async def test_hybrid_search_with_public_schema():
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
@@ -116,7 +123,7 @@ async def test_hybrid_search_with_custom_schema():
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
@@ -134,22 +141,52 @@ async def test_hybrid_search_with_custom_schema():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_without_embeddings():
|
||||
"""Test hybrid search fails fast when embeddings are unavailable."""
|
||||
# Patch where the function is used, not where it's defined
|
||||
with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed:
|
||||
# Simulate embedding failure
|
||||
mock_embed.return_value = None
|
||||
"""Test hybrid search gracefully degrades when embeddings are unavailable."""
|
||||
# Mock database to return some results
|
||||
mock_results = [
|
||||
{
|
||||
"slug": "test-agent",
|
||||
"agent_name": "Test Agent",
|
||||
"agent_image": "test.png",
|
||||
"creator_username": "creator",
|
||||
"creator_avatar": "avatar.png",
|
||||
"sub_heading": "Test heading",
|
||||
"description": "Test description",
|
||||
"runs": 100,
|
||||
"rating": 4.5,
|
||||
"categories": ["AI"],
|
||||
"featured": False,
|
||||
"is_available": True,
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.0, # Zero because no embedding
|
||||
"lexical_score": 0.5,
|
||||
"category_score": 0.0,
|
||||
"recency_score": 0.1,
|
||||
"popularity_score": 0.2,
|
||||
"combined_score": 0.3,
|
||||
"total_count": 1,
|
||||
}
|
||||
]
|
||||
|
||||
# Should raise ValueError with helpful message
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await hybrid_search(
|
||||
with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Simulate embedding failure
|
||||
mock_embed.return_value = None
|
||||
mock_query.return_value = mock_results
|
||||
|
||||
# Should NOT raise - graceful degradation
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify error message is generic (doesn't leak implementation details)
|
||||
assert "Search service temporarily unavailable" in str(exc_info.value)
|
||||
# Verify it returns results even without embeddings
|
||||
assert len(results) == 1
|
||||
assert results[0]["slug"] == "test-agent"
|
||||
assert total == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -164,7 +201,7 @@ async def test_hybrid_search_with_filters():
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
# Test with featured filter
|
||||
results, total = await hybrid_search(
|
||||
@@ -204,7 +241,7 @@ async def test_hybrid_search_weights():
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
@@ -248,7 +285,7 @@ async def test_hybrid_search_min_score_filtering():
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
# Test with custom min_score
|
||||
results, total = await hybrid_search(
|
||||
@@ -274,16 +311,48 @@ async def test_hybrid_search_min_score_filtering():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_pagination():
|
||||
"""Test hybrid search pagination."""
|
||||
"""Test hybrid search pagination.
|
||||
|
||||
Pagination happens in SQL (LIMIT/OFFSET), then BM25 reranking is applied
|
||||
to the paginated results.
|
||||
"""
|
||||
# Create mock results that SQL would return for a page
|
||||
mock_results = [
|
||||
{
|
||||
"slug": f"agent-{i}",
|
||||
"agent_name": f"Agent {i}",
|
||||
"agent_image": "test.png",
|
||||
"creator_username": "test",
|
||||
"creator_avatar": "avatar.png",
|
||||
"sub_heading": "Test",
|
||||
"description": "Test description",
|
||||
"runs": 100 - i,
|
||||
"rating": 4.5,
|
||||
"categories": ["test"],
|
||||
"featured": False,
|
||||
"is_available": True,
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"searchable_text": f"Agent {i} test description",
|
||||
"combined_score": 0.9 - (i * 0.01),
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.6,
|
||||
"category_score": 0.5,
|
||||
"recency_score": 0.4,
|
||||
"popularity_score": 0.3,
|
||||
"total_count": 25,
|
||||
}
|
||||
for i in range(10) # SQL returns page_size results
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
mock_query.return_value = mock_results
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
# Test page 2 with page_size 10
|
||||
results, total = await hybrid_search(
|
||||
@@ -292,16 +361,18 @@ async def test_hybrid_search_pagination():
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Verify pagination parameters
|
||||
# Verify results returned
|
||||
assert len(results) == 10
|
||||
assert total == 25 # Total from SQL COUNT(*) OVER()
|
||||
|
||||
# Verify the SQL query uses page_size and offset
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0]
|
||||
|
||||
# Last two params should be LIMIT and OFFSET
|
||||
limit = params[-2]
|
||||
offset = params[-1]
|
||||
|
||||
assert limit == 10 # page_size
|
||||
assert offset == 10 # (page - 1) * page_size = (2 - 1) * 10
|
||||
# Last two params are page_size and offset
|
||||
page_size_param = params[-2]
|
||||
offset_param = params[-1]
|
||||
assert page_size_param == 10
|
||||
assert offset_param == 10 # (page 2 - 1) * 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -317,7 +388,7 @@ async def test_hybrid_search_error_handling():
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
# Should raise exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
@@ -330,5 +401,326 @@ async def test_hybrid_search_error_handling():
|
||||
assert "Database connection error" in str(exc_info.value)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Unified Hybrid Search Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_basic():
|
||||
"""Test basic unified hybrid search across all content types."""
|
||||
mock_results = [
|
||||
{
|
||||
"content_type": "STORE_AGENT",
|
||||
"content_id": "agent-1",
|
||||
"searchable_text": "Test Agent Description",
|
||||
"metadata": {"name": "Test Agent"},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.8,
|
||||
"category_score": 0.5,
|
||||
"recency_score": 0.3,
|
||||
"combined_score": 0.6,
|
||||
"total_count": 2,
|
||||
},
|
||||
{
|
||||
"content_type": "BLOCK",
|
||||
"content_id": "block-1",
|
||||
"searchable_text": "Test Block Description",
|
||||
"metadata": {"name": "Test Block"},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.6,
|
||||
"lexical_score": 0.7,
|
||||
"category_score": 0.4,
|
||||
"recency_score": 0.2,
|
||||
"combined_score": 0.5,
|
||||
"total_count": 2,
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = mock_results
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert total == 2
|
||||
assert results[0]["content_type"] == "STORE_AGENT"
|
||||
assert results[1]["content_type"] == "BLOCK"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_filter_by_content_type():
|
||||
"""Test unified search filtering by specific content types."""
|
||||
mock_results = [
|
||||
{
|
||||
"content_type": "BLOCK",
|
||||
"content_id": "block-1",
|
||||
"searchable_text": "Test Block",
|
||||
"metadata": {},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.8,
|
||||
"category_score": 0.0,
|
||||
"recency_score": 0.3,
|
||||
"combined_score": 0.5,
|
||||
"total_count": 1,
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = mock_results
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify content_types parameter was passed correctly
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0][1:]
|
||||
# The content types should be in the params as a list
|
||||
assert ["BLOCK"] in params
|
||||
|
||||
assert len(results) == 1
|
||||
assert total == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_with_user_id():
|
||||
"""Test unified search with user_id for private content."""
|
||||
mock_results = [
|
||||
{
|
||||
"content_type": "STORE_AGENT",
|
||||
"content_id": "agent-1",
|
||||
"searchable_text": "My Private Agent",
|
||||
"metadata": {},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.8,
|
||||
"category_score": 0.0,
|
||||
"recency_score": 0.3,
|
||||
"combined_score": 0.6,
|
||||
"total_count": 1,
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = mock_results
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
user_id="user-123",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify SQL contains user_id filter
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
params = call_args[0][1:]
|
||||
|
||||
assert 'uce."userId"' in sql_template
|
||||
assert "user-123" in params
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_custom_weights():
|
||||
"""Test unified search with custom weights."""
|
||||
custom_weights = UnifiedSearchWeights(
|
||||
semantic=0.6,
|
||||
lexical=0.2,
|
||||
category=0.1,
|
||||
recency=0.1,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = []
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
weights=custom_weights,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify custom weights are in parameters
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0][1:]
|
||||
|
||||
assert 0.6 in params # semantic weight
|
||||
assert 0.2 in params # lexical weight
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_graceful_degradation():
|
||||
"""Test unified search gracefully degrades when embeddings unavailable."""
|
||||
mock_results = [
|
||||
{
|
||||
"content_type": "DOCUMENTATION",
|
||||
"content_id": "doc-1",
|
||||
"searchable_text": "API Documentation",
|
||||
"metadata": {},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.0, # Zero because no embedding
|
||||
"lexical_score": 0.8,
|
||||
"category_score": 0.0,
|
||||
"recency_score": 0.2,
|
||||
"combined_score": 0.5,
|
||||
"total_count": 1,
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = mock_results
|
||||
mock_embed.return_value = None # Embedding failure
|
||||
|
||||
# Should NOT raise - graceful degradation
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert total == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_empty_query():
|
||||
"""Test unified search with empty query returns empty results."""
|
||||
results, total = await unified_hybrid_search(
|
||||
query="",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_pagination():
|
||||
"""Test unified search pagination with BM25 reranking.
|
||||
|
||||
Pagination happens in SQL (LIMIT/OFFSET), then BM25 reranking is applied
|
||||
to the paginated results.
|
||||
"""
|
||||
# Create mock results that SQL would return for a page
|
||||
mock_results = [
|
||||
{
|
||||
"content_type": "STORE_AGENT",
|
||||
"content_id": f"agent-{i}",
|
||||
"searchable_text": f"Agent {i} description",
|
||||
"metadata": {"name": f"Agent {i}"},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.8 - (i * 0.01),
|
||||
"category_score": 0.5,
|
||||
"recency_score": 0.3,
|
||||
"combined_score": 0.6 - (i * 0.01),
|
||||
"total_count": 50,
|
||||
}
|
||||
for i in range(15) # SQL returns page_size results
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = mock_results
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
page=3,
|
||||
page_size=15,
|
||||
)
|
||||
|
||||
# Verify results returned
|
||||
assert len(results) == 15
|
||||
assert total == 50 # Total from SQL COUNT(*) OVER()
|
||||
|
||||
# Verify the SQL query uses page_size and offset
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0]
|
||||
# Last two params are page_size and offset
|
||||
page_size_param = params[-2]
|
||||
offset_param = params[-1]
|
||||
assert page_size_param == 15
|
||||
assert offset_param == 30 # (page 3 - 1) * 15
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_schema_prefix():
|
||||
"""Test unified search uses schema_prefix placeholder."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = []
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
await unified_hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
|
||||
# Verify schema_prefix placeholder is used for table references
|
||||
assert "{schema_prefix}" in sql_template
|
||||
assert '"UnifiedContentEmbedding"' in sql_template
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -221,3 +221,23 @@ class ReviewSubmissionRequest(pydantic.BaseModel):
|
||||
is_approved: bool
|
||||
comments: str # External comments visible to creator
|
||||
internal_comments: str | None = None # Private admin notes
|
||||
|
||||
|
||||
class UnifiedSearchResult(pydantic.BaseModel):
|
||||
"""A single result from unified hybrid search across all content types."""
|
||||
|
||||
content_type: str # STORE_AGENT, BLOCK, DOCUMENTATION
|
||||
content_id: str
|
||||
searchable_text: str
|
||||
metadata: dict | None = None
|
||||
updated_at: datetime.datetime | None = None
|
||||
combined_score: float | None = None
|
||||
semantic_score: float | None = None
|
||||
lexical_score: float | None = None
|
||||
|
||||
|
||||
class UnifiedSearchResponse(pydantic.BaseModel):
|
||||
"""Response model for unified search across all content types."""
|
||||
|
||||
results: list[UnifiedSearchResult]
|
||||
pagination: Pagination
|
||||
|
||||
@@ -7,12 +7,15 @@ from typing import Literal
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.data.graph
|
||||
import backend.util.json
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
from . import db as store_db
|
||||
from . import hybrid_search as store_hybrid_search
|
||||
from . import image_gen as store_image_gen
|
||||
from . import media as store_media
|
||||
from . import model as store_model
|
||||
@@ -146,6 +149,102 @@ async def get_agents(
|
||||
return agents
|
||||
|
||||
|
||||
##############################################
|
||||
############### Search Endpoints #############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search",
|
||||
summary="Unified search across all content types",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.UnifiedSearchResponse,
|
||||
)
|
||||
async def unified_search(
|
||||
query: str,
|
||||
content_types: list[str] | None = fastapi.Query(
|
||||
default=None,
|
||||
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
|
||||
),
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
user_id: str | None = fastapi.Security(
|
||||
autogpt_libs.auth.get_optional_user_id, use_cache=False
|
||||
),
|
||||
):
|
||||
"""
|
||||
Search across all content types (store agents, blocks, documentation) using hybrid search.
|
||||
|
||||
Combines semantic (embedding-based) and lexical (text-based) search for best results.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
|
||||
page: Page number for pagination (default 1)
|
||||
page_size: Number of results per page (default 20)
|
||||
user_id: Optional authenticated user ID (for user-scoped content in future)
|
||||
|
||||
Returns:
|
||||
UnifiedSearchResponse: Paginated list of search results with relevance scores
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
# Convert string content types to enum
|
||||
content_type_enums: list[prisma.enums.ContentType] | None = None
|
||||
if content_types:
|
||||
try:
|
||||
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
|
||||
)
|
||||
|
||||
# Perform unified hybrid search
|
||||
results, total = await store_hybrid_search.unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=content_type_enums,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
# Convert results to response model
|
||||
search_results = [
|
||||
store_model.UnifiedSearchResult(
|
||||
content_type=r["content_type"],
|
||||
content_id=r["content_id"],
|
||||
searchable_text=r.get("searchable_text", ""),
|
||||
metadata=r.get("metadata"),
|
||||
updated_at=r.get("updated_at"),
|
||||
combined_score=r.get("combined_score"),
|
||||
semantic_score=r.get("semantic_score"),
|
||||
lexical_score=r.get("lexical_score"),
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
|
||||
|
||||
return store_model.UnifiedSearchResponse(
|
||||
results=search_results,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{username}/{agent_name}",
|
||||
summary="Get specific agent",
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
"""Tests for the semantic_search function."""
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store.embeddings import EMBEDDING_DIM, semantic_search
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_blocks_only(mocker):
|
||||
"""Test searching only BLOCK content type."""
|
||||
# Mock embed_query to return a test embedding
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
# Mock query_raw_with_schema to return test results
|
||||
mock_results = [
|
||||
{
|
||||
"content_id": "block-123",
|
||||
"content_type": "BLOCK",
|
||||
"searchable_text": "Calculator Block - Performs arithmetic operations",
|
||||
"metadata": {"name": "Calculator", "categories": ["Math"]},
|
||||
"similarity": 0.85,
|
||||
}
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_results,
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="calculate numbers",
|
||||
content_types=[ContentType.BLOCK],
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["content_type"] == "BLOCK"
|
||||
assert results[0]["content_id"] == "block-123"
|
||||
assert results[0]["similarity"] == 0.85
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_multiple_content_types(mocker):
|
||||
"""Test searching multiple content types simultaneously."""
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
mock_results = [
|
||||
{
|
||||
"content_id": "block-123",
|
||||
"content_type": "BLOCK",
|
||||
"searchable_text": "Calculator Block",
|
||||
"metadata": {},
|
||||
"similarity": 0.85,
|
||||
},
|
||||
{
|
||||
"content_id": "doc-456",
|
||||
"content_type": "DOCUMENTATION",
|
||||
"searchable_text": "How to use Calculator",
|
||||
"metadata": {},
|
||||
"similarity": 0.75,
|
||||
},
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_results,
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="calculator",
|
||||
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content_type"] == "BLOCK"
|
||||
assert results[1]["content_type"] == "DOCUMENTATION"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_min_similarity_threshold(mocker):
|
||||
"""Test that results below min_similarity are filtered out."""
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
# Only return results above 0.7 similarity
|
||||
mock_results = [
|
||||
{
|
||||
"content_id": "block-123",
|
||||
"content_type": "BLOCK",
|
||||
"searchable_text": "Calculator Block",
|
||||
"metadata": {},
|
||||
"similarity": 0.85,
|
||||
}
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_results,
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="calculate",
|
||||
content_types=[ContentType.BLOCK],
|
||||
min_similarity=0.7,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["similarity"] >= 0.7
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fallback_to_lexical(mocker):
|
||||
"""Test fallback to lexical search when embeddings fail."""
|
||||
# Mock embed_query to return None (embeddings unavailable)
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
mock_lexical_results = [
|
||||
{
|
||||
"content_id": "block-123",
|
||||
"content_type": "BLOCK",
|
||||
"searchable_text": "Calculator Block performs calculations",
|
||||
"metadata": {},
|
||||
"similarity": 0.0,
|
||||
}
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_lexical_results,
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="calculator",
|
||||
content_types=[ContentType.BLOCK],
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["similarity"] == 0.0 # Lexical search returns 0 similarity
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_empty_query():
|
||||
"""Test that empty query returns no results."""
|
||||
results = await semantic_search(query="")
|
||||
assert results == []
|
||||
|
||||
results = await semantic_search(query=" ")
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_user_id_filter(mocker):
|
||||
"""Test searching with user_id filter for private content."""
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
mock_results = [
|
||||
{
|
||||
"content_id": "agent-789",
|
||||
"content_type": "LIBRARY_AGENT",
|
||||
"searchable_text": "My Custom Agent",
|
||||
"metadata": {},
|
||||
"similarity": 0.9,
|
||||
}
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_results,
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="custom agent",
|
||||
content_types=[ContentType.LIBRARY_AGENT],
|
||||
user_id="user-123",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["content_type"] == "LIBRARY_AGENT"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_limit_parameter(mocker):
|
||||
"""Test that limit parameter correctly limits results."""
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
# Return 5 results
|
||||
mock_results = [
|
||||
{
|
||||
"content_id": f"block-{i}",
|
||||
"content_type": "BLOCK",
|
||||
"searchable_text": f"Block {i}",
|
||||
"metadata": {},
|
||||
"similarity": 0.8,
|
||||
}
|
||||
for i in range(5)
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_results,
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="block",
|
||||
content_types=[ContentType.BLOCK],
|
||||
limit=5,
|
||||
)
|
||||
|
||||
assert len(results) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_default_content_types(mocker):
|
||||
"""Test that default content_types includes BLOCK, STORE_AGENT, and DOCUMENTATION."""
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
mock_query_raw = mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
await semantic_search(query="test")
|
||||
|
||||
# Check that the SQL query includes all three default content types
|
||||
call_args = mock_query_raw.call_args
|
||||
assert "BLOCK" in str(call_args)
|
||||
assert "STORE_AGENT" in str(call_args)
|
||||
assert "DOCUMENTATION" in str(call_args)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_handles_database_error(mocker):
|
||||
"""Test that database errors are handled gracefully."""
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
# Simulate database error
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
side_effect=Exception("Database connection failed"),
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="test",
|
||||
content_types=[ContentType.BLOCK],
|
||||
)
|
||||
|
||||
# Should return empty list on error
|
||||
assert results == []
|
||||
@@ -680,12 +680,23 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
return False, reviewed_data
|
||||
|
||||
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
# Check for review requirement and get potentially modified input data
|
||||
should_pause, input_data = await self.is_block_exec_need_review(
|
||||
input_data, **kwargs
|
||||
# Check for review requirement only if running within a graph execution context
|
||||
# Direct block execution (e.g., from chat) skips the review process
|
||||
has_graph_context = all(
|
||||
key in kwargs
|
||||
for key in (
|
||||
"node_exec_id",
|
||||
"graph_exec_id",
|
||||
"graph_id",
|
||||
"execution_context",
|
||||
)
|
||||
)
|
||||
if should_pause:
|
||||
return
|
||||
if has_graph_context:
|
||||
should_pause, input_data = await self.is_block_exec_need_review(
|
||||
input_data, **kwargs
|
||||
)
|
||||
if should_pause:
|
||||
return
|
||||
|
||||
# Validate the input data (original or reviewer-modified) once
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
|
||||
@@ -9,6 +9,7 @@ from backend.api.features.library.db import (
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
cleanup_orphaned_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
from backend.data import db
|
||||
@@ -221,6 +222,7 @@ class DatabaseManager(AppService):
|
||||
# Store Embeddings
|
||||
get_embedding_stats = _(get_embedding_stats)
|
||||
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
||||
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
|
||||
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
@@ -276,6 +278,7 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
# Store Embeddings
|
||||
get_embedding_stats = _(d.get_embedding_stats)
|
||||
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
|
||||
cleanup_orphaned_embeddings = _(d.cleanup_orphaned_embeddings)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
@@ -28,6 +28,7 @@ from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.monitoring import (
|
||||
NotificationJobArgs,
|
||||
@@ -156,6 +157,7 @@ async def _execute_graph(**kwargs):
|
||||
inputs=args.input_data,
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
)
|
||||
await increment_onboarding_runs(args.user_id)
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
logger.info(
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||
@@ -255,14 +257,14 @@ def execution_accuracy_alerts():
|
||||
|
||||
def ensure_embeddings_coverage():
|
||||
"""
|
||||
Ensure approved store agents have embeddings for hybrid search.
|
||||
Ensure all content types (store agents, blocks, docs) have embeddings for search.
|
||||
|
||||
Processes ALL missing embeddings in batches of 10 until 100% coverage.
|
||||
Missing embeddings = agents invisible in hybrid search.
|
||||
Processes ALL missing embeddings in batches of 10 per content type until 100% coverage.
|
||||
Missing embeddings = content invisible in hybrid search.
|
||||
|
||||
Schedule: Runs every 6 hours (balanced between coverage and API costs).
|
||||
- Catches agents approved between scheduled runs
|
||||
- Batch size 10: gradual processing to avoid rate limits
|
||||
- Catches new content added between scheduled runs
|
||||
- Batch size 10 per content type: gradual processing to avoid rate limits
|
||||
- Manual trigger available via execute_ensure_embeddings_coverage endpoint
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
@@ -273,51 +275,91 @@ def ensure_embeddings_coverage():
|
||||
logger.error(
|
||||
f"Failed to get embedding stats: {stats['error']} - skipping backfill"
|
||||
)
|
||||
return {"processed": 0, "success": 0, "failed": 0, "error": stats["error"]}
|
||||
return {
|
||||
"backfill": {"processed": 0, "success": 0, "failed": 0},
|
||||
"cleanup": {"deleted": 0},
|
||||
"error": stats["error"],
|
||||
}
|
||||
|
||||
if stats["without_embeddings"] == 0:
|
||||
logger.info("All approved agents have embeddings, skipping backfill")
|
||||
return {"processed": 0, "success": 0, "failed": 0}
|
||||
|
||||
logger.info(
|
||||
f"Found {stats['without_embeddings']} agents without embeddings "
|
||||
f"({stats['coverage_percent']}% coverage) - processing all"
|
||||
)
|
||||
# Extract totals from new stats structure
|
||||
totals = stats.get("totals", {})
|
||||
without_embeddings = totals.get("without_embeddings", 0)
|
||||
coverage_percent = totals.get("coverage_percent", 0)
|
||||
|
||||
total_processed = 0
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
|
||||
# Process in batches until no more missing embeddings
|
||||
while True:
|
||||
result = db_client.backfill_missing_embeddings(batch_size=10)
|
||||
if without_embeddings == 0:
|
||||
logger.info("All content has embeddings, skipping backfill")
|
||||
else:
|
||||
# Log per-content-type stats for visibility
|
||||
by_type = stats.get("by_type", {})
|
||||
for content_type, type_stats in by_type.items():
|
||||
if type_stats.get("without_embeddings", 0) > 0:
|
||||
logger.info(
|
||||
f"{content_type}: {type_stats['without_embeddings']} items without embeddings "
|
||||
f"({type_stats['coverage_percent']}% coverage)"
|
||||
)
|
||||
|
||||
total_processed += result["processed"]
|
||||
total_success += result["success"]
|
||||
total_failed += result["failed"]
|
||||
logger.info(
|
||||
f"Total: {without_embeddings} items without embeddings "
|
||||
f"({coverage_percent}% coverage) - processing all"
|
||||
)
|
||||
|
||||
if result["processed"] == 0:
|
||||
# No more missing embeddings
|
||||
break
|
||||
# Process in batches until no more missing embeddings
|
||||
while True:
|
||||
result = db_client.backfill_missing_embeddings(batch_size=10)
|
||||
|
||||
if result["success"] == 0 and result["processed"] > 0:
|
||||
# All attempts in this batch failed - stop to avoid infinite loop
|
||||
logger.error(
|
||||
f"All {result['processed']} embedding attempts failed - stopping backfill"
|
||||
)
|
||||
break
|
||||
total_processed += result["processed"]
|
||||
total_success += result["success"]
|
||||
total_failed += result["failed"]
|
||||
|
||||
# Small delay between batches to avoid rate limits
|
||||
time.sleep(1)
|
||||
if result["processed"] == 0:
|
||||
# No more missing embeddings
|
||||
break
|
||||
|
||||
if result["success"] == 0 and result["processed"] > 0:
|
||||
# All attempts in this batch failed - stop to avoid infinite loop
|
||||
logger.error(
|
||||
f"All {result['processed']} embedding attempts failed - stopping backfill"
|
||||
)
|
||||
break
|
||||
|
||||
# Small delay between batches to avoid rate limits
|
||||
time.sleep(1)
|
||||
|
||||
logger.info(
|
||||
f"Embedding backfill completed: {total_success}/{total_processed} succeeded, "
|
||||
f"{total_failed} failed"
|
||||
)
|
||||
|
||||
# Clean up orphaned embeddings for blocks and docs
|
||||
logger.info("Running cleanup for orphaned embeddings (blocks/docs)...")
|
||||
cleanup_result = db_client.cleanup_orphaned_embeddings()
|
||||
cleanup_totals = cleanup_result.get("totals", {})
|
||||
cleanup_deleted = cleanup_totals.get("deleted", 0)
|
||||
|
||||
if cleanup_deleted > 0:
|
||||
logger.info(f"Cleanup completed: deleted {cleanup_deleted} orphaned embeddings")
|
||||
by_type = cleanup_result.get("by_type", {})
|
||||
for content_type, type_result in by_type.items():
|
||||
if type_result.get("deleted", 0) > 0:
|
||||
logger.info(
|
||||
f"{content_type}: deleted {type_result['deleted']} orphaned embeddings"
|
||||
)
|
||||
else:
|
||||
logger.info("Cleanup completed: no orphaned embeddings found")
|
||||
|
||||
logger.info(
|
||||
f"Embedding backfill completed: {total_success}/{total_processed} succeeded, "
|
||||
f"{total_failed} failed"
|
||||
)
|
||||
return {
|
||||
"processed": total_processed,
|
||||
"success": total_success,
|
||||
"failed": total_failed,
|
||||
"backfill": {
|
||||
"processed": total_processed,
|
||||
"success": total_success,
|
||||
"failed": total_failed,
|
||||
},
|
||||
"cleanup": {
|
||||
"deleted": cleanup_deleted,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -560,6 +602,18 @@ class Scheduler(AppService):
|
||||
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
||||
self.scheduler.start()
|
||||
|
||||
# Run embedding backfill immediately on startup
|
||||
# This ensures blocks/docs are searchable right away, not after 6 hours
|
||||
# Safe to run on multiple pods - uses upserts and checks for existing embeddings
|
||||
if self.register_system_tasks:
|
||||
logger.info("Running embedding backfill on startup...")
|
||||
try:
|
||||
result = ensure_embeddings_coverage()
|
||||
logger.info(f"Startup embedding backfill complete: {result}")
|
||||
except Exception as e:
|
||||
logger.error(f"Startup embedding backfill failed: {e}")
|
||||
# Don't fail startup - the scheduled job will retry later
|
||||
|
||||
# Keep the service running since BackgroundScheduler doesn't block
|
||||
super().run_service()
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import pickle
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from functools import cache, wraps
|
||||
from typing import Any, Callable, ParamSpec, Protocol, TypeVar, cast, runtime_checkable
|
||||
|
||||
from redis import ConnectionPool, Redis
|
||||
@@ -38,29 +38,34 @@ settings = Settings()
|
||||
# maxmemory 2gb # Set memory limit (adjust based on your needs)
|
||||
# save "" # Disable persistence if using Redis purely for caching
|
||||
|
||||
# Create a dedicated Redis connection pool for caching (binary mode for pickle)
|
||||
_cache_pool: ConnectionPool | None = None
|
||||
|
||||
|
||||
@conn_retry("Redis", "Acquiring cache connection pool")
|
||||
@cache
|
||||
def _get_cache_pool() -> ConnectionPool:
|
||||
"""Get or create a connection pool for cache operations."""
|
||||
global _cache_pool
|
||||
if _cache_pool is None:
|
||||
_cache_pool = ConnectionPool(
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
decode_responses=False, # Binary mode for pickle
|
||||
max_connections=50,
|
||||
socket_keepalive=True,
|
||||
socket_connect_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
return _cache_pool
|
||||
"""Get or create a connection pool for cache operations (lazy, thread-safe)."""
|
||||
return ConnectionPool(
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
decode_responses=False, # Binary mode for pickle
|
||||
max_connections=50,
|
||||
socket_keepalive=True,
|
||||
socket_connect_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
|
||||
|
||||
redis = Redis(connection_pool=_get_cache_pool())
|
||||
@cache
|
||||
@conn_retry("Redis", "Acquiring cache connection")
|
||||
def _get_redis() -> Redis:
|
||||
"""
|
||||
Get the lazily-initialized Redis client for shared cache operations.
|
||||
Uses @cache for thread-safe singleton behavior - connection is only
|
||||
established when first accessed, allowing services that only use
|
||||
in-memory caching to work without Redis configuration.
|
||||
"""
|
||||
r = Redis(connection_pool=_get_cache_pool())
|
||||
r.ping() # Verify connection
|
||||
return r
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -179,9 +184,9 @@ def cached(
|
||||
try:
|
||||
if refresh_ttl_on_get:
|
||||
# Use GETEX to get value and refresh expiry atomically
|
||||
cached_bytes = redis.getex(redis_key, ex=ttl_seconds)
|
||||
cached_bytes = _get_redis().getex(redis_key, ex=ttl_seconds)
|
||||
else:
|
||||
cached_bytes = redis.get(redis_key)
|
||||
cached_bytes = _get_redis().get(redis_key)
|
||||
|
||||
if cached_bytes and isinstance(cached_bytes, bytes):
|
||||
return pickle.loads(cached_bytes)
|
||||
@@ -195,7 +200,7 @@ def cached(
|
||||
"""Set value in Redis with TTL."""
|
||||
try:
|
||||
pickled_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
redis.setex(redis_key, ttl_seconds, pickled_value)
|
||||
_get_redis().setex(redis_key, ttl_seconds, pickled_value)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Redis error storing cache for {target_func.__name__}: {e}"
|
||||
@@ -333,14 +338,18 @@ def cached(
|
||||
if pattern:
|
||||
# Clear entries matching pattern
|
||||
keys = list(
|
||||
redis.scan_iter(f"cache:{target_func.__name__}:{pattern}")
|
||||
_get_redis().scan_iter(
|
||||
f"cache:{target_func.__name__}:{pattern}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Clear all cache keys
|
||||
keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
|
||||
keys = list(
|
||||
_get_redis().scan_iter(f"cache:{target_func.__name__}:*")
|
||||
)
|
||||
|
||||
if keys:
|
||||
pipeline = redis.pipeline()
|
||||
pipeline = _get_redis().pipeline()
|
||||
for key in keys:
|
||||
pipeline.delete(key)
|
||||
pipeline.execute()
|
||||
@@ -355,7 +364,9 @@ def cached(
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
if shared_cache:
|
||||
cache_keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
|
||||
cache_keys = list(
|
||||
_get_redis().scan_iter(f"cache:{target_func.__name__}:*")
|
||||
)
|
||||
return {
|
||||
"size": len(cache_keys),
|
||||
"maxsize": None, # Redis manages its own size
|
||||
@@ -373,10 +384,8 @@ def cached(
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if shared_cache:
|
||||
redis_key = _make_redis_key(key, target_func.__name__)
|
||||
if redis.exists(redis_key):
|
||||
redis.delete(redis_key)
|
||||
return True
|
||||
return False
|
||||
deleted_count = cast(int, _get_redis().delete(redis_key))
|
||||
return deleted_count > 0
|
||||
else:
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
|
||||
@@ -43,4 +43,6 @@ CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" O
|
||||
-- CreateIndex
|
||||
-- HNSW index for fast vector similarity search on embeddings
|
||||
-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py
|
||||
-- Note: Drop first in case Prisma created a btree index (Prisma doesn't support HNSW)
|
||||
DROP INDEX IF EXISTS "UnifiedContentEmbedding_embedding_idx";
|
||||
CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops);
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
-- Add tsvector search column to UnifiedContentEmbedding for unified full-text search
|
||||
-- This enables hybrid search (semantic + lexical) across all content types
|
||||
|
||||
-- Add search column (IF NOT EXISTS for idempotency)
|
||||
ALTER TABLE "UnifiedContentEmbedding" ADD COLUMN IF NOT EXISTS "search" tsvector DEFAULT ''::tsvector;
|
||||
|
||||
-- Create GIN index for fast full-text search
|
||||
-- No @@index in schema.prisma - Prisma may generate DROP INDEX on migrate dev
|
||||
-- If that happens, just let it drop and this migration will recreate it, or manually re-run:
|
||||
-- CREATE INDEX IF NOT EXISTS "UnifiedContentEmbedding_search_idx" ON "UnifiedContentEmbedding" USING GIN ("search");
|
||||
DROP INDEX IF EXISTS "UnifiedContentEmbedding_search_idx";
|
||||
CREATE INDEX "UnifiedContentEmbedding_search_idx" ON "UnifiedContentEmbedding" USING GIN ("search");
|
||||
|
||||
-- Drop existing trigger/function if exists
|
||||
DROP TRIGGER IF EXISTS "update_unified_tsvector" ON "UnifiedContentEmbedding";
|
||||
DROP FUNCTION IF EXISTS update_unified_tsvector_column();
|
||||
|
||||
-- Create function to auto-update tsvector from searchableText
|
||||
CREATE OR REPLACE FUNCTION update_unified_tsvector_column() RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.search := to_tsvector('english', COALESCE(NEW."searchableText", ''));
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql SECURITY DEFINER SET search_path = platform, pg_temp;
|
||||
|
||||
-- Create trigger to auto-update search column on insert/update
|
||||
CREATE TRIGGER "update_unified_tsvector"
|
||||
BEFORE INSERT OR UPDATE ON "UnifiedContentEmbedding"
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_unified_tsvector_column();
|
||||
|
||||
-- Backfill existing rows
|
||||
UPDATE "UnifiedContentEmbedding"
|
||||
SET search = to_tsvector('english', COALESCE("searchableText", ''))
|
||||
WHERE search IS NULL OR search = ''::tsvector;
|
||||
@@ -0,0 +1,90 @@
|
||||
-- Remove the old search column from StoreListingVersion
|
||||
-- This column has been replaced by UnifiedContentEmbedding.search
|
||||
-- which provides unified hybrid search across all content types
|
||||
|
||||
-- First drop the dependent view
|
||||
DROP VIEW IF EXISTS "StoreAgent";
|
||||
|
||||
-- Drop the trigger and function for old search column
|
||||
-- The original trigger was created in 20251016093049_add_full_text_search
|
||||
DROP TRIGGER IF EXISTS "update_tsvector" ON "StoreListingVersion";
|
||||
DROP FUNCTION IF EXISTS update_tsvector_column();
|
||||
|
||||
-- Drop the index
|
||||
DROP INDEX IF EXISTS "StoreListingVersion_search_idx";
|
||||
|
||||
-- NOTE: Keeping search column for now to allow easy revert if needed
|
||||
-- Uncomment to fully remove once migration is verified in production:
|
||||
-- ALTER TABLE "StoreListingVersion" DROP COLUMN IF EXISTS "search";
|
||||
|
||||
-- Recreate the StoreAgent view WITHOUT the search column
|
||||
-- (Search now handled by UnifiedContentEmbedding)
|
||||
CREATE OR REPLACE VIEW "StoreAgent" AS
|
||||
WITH latest_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
MAX(version) AS max_version
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
),
|
||||
agent_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
),
|
||||
agent_graph_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
array_agg(DISTINCT "agentGraphVersion"::text ORDER BY "agentGraphVersion"::text) AS graph_versions
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
)
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
slv.id AS "storeListingVersionId",
|
||||
slv."createdAt" AS updated_at,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS agent_name,
|
||||
slv."videoUrl" AS agent_video,
|
||||
slv."agentOutputDemoUrl" AS agent_output_demo,
|
||||
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
|
||||
slv."isFeatured" AS featured,
|
||||
p.username AS creator_username,
|
||||
p."avatarUrl" AS creator_avatar,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.categories,
|
||||
COALESCE(ar.run_count, 0::bigint) AS runs,
|
||||
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
|
||||
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions,
|
||||
COALESCE(agv.graph_versions, ARRAY[slv."agentGraphVersion"::text]) AS "agentGraphVersions",
|
||||
slv."agentGraphId",
|
||||
slv."isAvailable" AS is_available,
|
||||
COALESCE(sl."useForOnboarding", false) AS "useForOnboarding"
|
||||
FROM "StoreListing" sl
|
||||
JOIN latest_versions lv
|
||||
ON sl.id = lv."storeListingId"
|
||||
JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = lv."storeListingId"
|
||||
AND slv.version = lv.max_version
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
JOIN "AgentGraph" a
|
||||
ON slv."agentGraphId" = a.id
|
||||
AND slv."agentGraphVersion" = a.version
|
||||
LEFT JOIN "Profile" p
|
||||
ON sl."owningUserId" = p."userId"
|
||||
LEFT JOIN "mv_review_stats" rs
|
||||
ON sl.id = rs."storeListingId"
|
||||
LEFT JOIN "mv_agent_run_counts" ar
|
||||
ON a.id = ar."agentGraphId"
|
||||
LEFT JOIN agent_versions av
|
||||
ON sl.id = av."storeListingId"
|
||||
LEFT JOIN agent_graph_versions agv
|
||||
ON sl.id = agv."storeListingId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true;
|
||||
20
autogpt_platform/backend/poetry.lock
generated
20
autogpt_platform/backend/poetry.lock
generated
@@ -5339,6 +5339,24 @@ urllib3 = ">=1.26.14,<3"
|
||||
fastembed = ["fastembed (>=0.7,<0.8)"]
|
||||
fastembed-gpu = ["fastembed-gpu (>=0.7,<0.8)"]
|
||||
|
||||
[[package]]
|
||||
name = "rank-bm25"
|
||||
version = "0.2.2"
|
||||
description = "Various BM25 algorithms for document ranking"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae"},
|
||||
{file = "rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = "*"
|
||||
|
||||
[package.extras]
|
||||
dev = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "rapidfuzz"
|
||||
version = "3.13.0"
|
||||
@@ -7494,4 +7512,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "86838b5ae40d606d6e01a14dad8a56c389d890d7a6a0c274a6602cca80f0df84"
|
||||
content-hash = "18b92e09596298c82432e4d0a85cb6d80a40b4229bee0a0c15f0529fd6cb21a4"
|
||||
|
||||
@@ -46,6 +46,7 @@ poetry = "2.1.1" # CHECK DEPENDABOT SUPPORT BEFORE UPGRADING
|
||||
postmarker = "^1.0"
|
||||
praw = "~7.8.1"
|
||||
prisma = "^0.15.0"
|
||||
rank-bm25 = "^0.2.2"
|
||||
prometheus-client = "^0.22.1"
|
||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||
psutil = "^7.0.0"
|
||||
|
||||
@@ -937,7 +937,7 @@ model StoreListingVersion {
|
||||
// Old versions can be made unavailable by the author if desired
|
||||
isAvailable Boolean @default(true)
|
||||
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
||||
// Note: search column removed - now using UnifiedContentEmbedding.search
|
||||
|
||||
// Version workflow state
|
||||
submissionStatus SubmissionStatus @default(DRAFT)
|
||||
@@ -1002,6 +1002,7 @@ model UnifiedContentEmbedding {
|
||||
// Search data
|
||||
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
|
||||
searchableText String // Combined text for search and fallback
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) // Full-text search (auto-populated by trigger)
|
||||
metadata Json @default("{}") // Content-specific metadata
|
||||
|
||||
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
|
||||
@@ -1009,6 +1010,8 @@ model UnifiedContentEmbedding {
|
||||
@@index([userId])
|
||||
@@index([contentType, userId])
|
||||
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
|
||||
// NO @@index for search - GIN index "UnifiedContentEmbedding_search_idx" created via SQL migration
|
||||
// Prisma may generate DROP INDEX on migrate dev - that's okay, migration recreates it
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput";
|
||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
|
||||
import { useState } from "react";
|
||||
import { getSchemaDefaultCredentials } from "../../helpers";
|
||||
import { areAllCredentialsSet, getCredentialFields } from "./helpers";
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentInputs/RunAgentInputs";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/__legacy__/ui/card";
|
||||
import { RunAgentInputs } from "@/components/contextual/RunAgentInputs/RunAgentInputs";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { CircleNotchIcon } from "@phosphor-icons/react/dist/ssr";
|
||||
import { Play } from "lucide-react";
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"use client";
|
||||
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput";
|
||||
import { useGetOauthGetOauthAppInfo } from "@/app/api/__generated__/endpoints/oauth/oauth";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { AuthCard } from "@/components/auth/AuthCard";
|
||||
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import type {
|
||||
BlockIOCredentialsSubSchema,
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
import { BlockUIType } from "@/app/(platform)/build/components/types";
|
||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputActions,
|
||||
OutputItem,
|
||||
} from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import { Label } from "@/components/__legacy__/ui/label";
|
||||
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
||||
import {
|
||||
@@ -23,6 +18,11 @@ import {
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputActions,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import { BookOpenIcon } from "@phosphor-icons/react";
|
||||
import { useMemo } from "react";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import { usePostV1ExecuteGraphAgent } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
|
||||
import {
|
||||
ApiError,
|
||||
CredentialsMetaInput,
|
||||
GraphExecutionMeta,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
@@ -9,6 +10,9 @@ import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
|
||||
import { useMemo, useState } from "react";
|
||||
import { uiSchema } from "../../../FlowEditor/nodes/uiSchema";
|
||||
import { isCredentialFieldSchema } from "@/components/renderers/InputRenderer/custom/CredentialField/helpers";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useReactFlow } from "@xyflow/react";
|
||||
|
||||
export const useRunInputDialog = ({
|
||||
setIsOpen,
|
||||
@@ -31,6 +35,7 @@ export const useRunInputDialog = ({
|
||||
flowVersion: parseAsInteger,
|
||||
});
|
||||
const { toast } = useToast();
|
||||
const { setViewport } = useReactFlow();
|
||||
|
||||
const { mutateAsync: executeGraph, isPending: isExecutingGraph } =
|
||||
usePostV1ExecuteGraphAgent({
|
||||
@@ -42,13 +47,75 @@ export const useRunInputDialog = ({
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
// Reset running state on error
|
||||
if (error instanceof ApiError && error.isGraphValidationError()) {
|
||||
const errorData = error.response?.detail || {
|
||||
node_errors: {},
|
||||
message: undefined,
|
||||
};
|
||||
const nodeErrors = errorData.node_errors || {};
|
||||
|
||||
if (Object.keys(nodeErrors).length > 0) {
|
||||
Object.entries(nodeErrors).forEach(
|
||||
([nodeId, nodeErrorsForNode]) => {
|
||||
useNodeStore
|
||||
.getState()
|
||||
.updateNodeErrors(
|
||||
nodeId,
|
||||
nodeErrorsForNode as { [key: string]: string },
|
||||
);
|
||||
},
|
||||
);
|
||||
} else {
|
||||
useNodeStore.getState().nodes.forEach((node) => {
|
||||
useNodeStore.getState().updateNodeErrors(node.id, {});
|
||||
});
|
||||
}
|
||||
|
||||
toast({
|
||||
title: errorData?.message || "Graph validation failed",
|
||||
description:
|
||||
"Please fix the validation errors on the highlighted nodes and try again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
setIsOpen(false);
|
||||
|
||||
const firstBackendId = Object.keys(nodeErrors)[0];
|
||||
|
||||
if (firstBackendId) {
|
||||
const firstErrorNode = useNodeStore
|
||||
.getState()
|
||||
.nodes.find(
|
||||
(n) =>
|
||||
n.data.metadata?.backend_id === firstBackendId ||
|
||||
n.id === firstBackendId,
|
||||
);
|
||||
|
||||
if (firstErrorNode) {
|
||||
setTimeout(() => {
|
||||
setViewport(
|
||||
{
|
||||
x:
|
||||
-firstErrorNode.position.x * 0.8 +
|
||||
window.innerWidth / 2 -
|
||||
150,
|
||||
y: -firstErrorNode.position.y * 0.8 + 50,
|
||||
zoom: 0.8,
|
||||
},
|
||||
{ duration: 500 },
|
||||
);
|
||||
}, 50);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
toast({
|
||||
title: "Error running graph",
|
||||
description:
|
||||
(error as Error).message || "An unexpected error occurred.",
|
||||
variant: "destructive",
|
||||
});
|
||||
setIsOpen(false);
|
||||
}
|
||||
setIsGraphRunning(false);
|
||||
toast({
|
||||
title: (error.detail as string) ?? "An unexpected error occurred.",
|
||||
description: "An unexpected error occurred.",
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -55,14 +55,16 @@ export const Flow = () => {
|
||||
const edgeTypes = useMemo(() => ({ custom: CustomEdge }), []);
|
||||
|
||||
const onNodeDragStop = useCallback(() => {
|
||||
const currentNodes = useNodeStore.getState().nodes;
|
||||
setNodes(
|
||||
resolveCollisions(nodes, {
|
||||
resolveCollisions(currentNodes, {
|
||||
maxIterations: Infinity,
|
||||
overlapThreshold: 0.5,
|
||||
margin: 15,
|
||||
}),
|
||||
);
|
||||
}, [setNodes, nodes]);
|
||||
}, [setNodes]);
|
||||
|
||||
const { edges, onConnect, onEdgesChange } = useCustomEdge();
|
||||
|
||||
// for loading purpose
|
||||
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { useCallback } from "react";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { useHistoryStore } from "../../../stores/historyStore";
|
||||
import { CustomEdge } from "./CustomEdge";
|
||||
|
||||
export const useCustomEdge = () => {
|
||||
@@ -51,7 +52,20 @@ export const useCustomEdge = () => {
|
||||
|
||||
const onEdgesChange = useCallback(
|
||||
(changes: EdgeChange<CustomEdge>[]) => {
|
||||
const hasRemoval = changes.some((change) => change.type === "remove");
|
||||
|
||||
const prevState = hasRemoval
|
||||
? {
|
||||
nodes: useNodeStore.getState().nodes,
|
||||
edges: edges,
|
||||
}
|
||||
: null;
|
||||
|
||||
setEdges(applyEdgeChanges(changes, edges));
|
||||
|
||||
if (prevState) {
|
||||
useHistoryStore.getState().pushState(prevState);
|
||||
}
|
||||
},
|
||||
[edges, setEdges],
|
||||
);
|
||||
|
||||
@@ -20,11 +20,13 @@ type Props = {
|
||||
|
||||
export const NodeHeader = ({ data, nodeId }: Props) => {
|
||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||
const title = (data.metadata?.customized_name as string) || data.title;
|
||||
const title =
|
||||
(data.metadata?.customized_name as string) ||
|
||||
data.hardcodedValues?.agent_name ||
|
||||
data.title;
|
||||
|
||||
const [isEditingTitle, setIsEditingTitle] = useState(false);
|
||||
const [editedTitle, setEditedTitle] = useState(
|
||||
beautifyString(title).replace("Block", "").trim(),
|
||||
);
|
||||
const [editedTitle, setEditedTitle] = useState(title);
|
||||
|
||||
const handleTitleEdit = () => {
|
||||
updateNodeData(nodeId, {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import type { OutputMetadata } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import { globalRegistry } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import { globalRegistry } from "@/components/contextual/OutputRenderers";
|
||||
|
||||
export const TextRenderer: React.FC<{
|
||||
value: any;
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import {
|
||||
OutputActions,
|
||||
OutputItem,
|
||||
} from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
@@ -11,6 +7,10 @@ import {
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import {
|
||||
OutputActions,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { OutputMetadata } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import { globalRegistry } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import { downloadOutputs } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers/utils/download";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import { globalRegistry } from "@/components/contextual/OutputRenderers";
|
||||
import { downloadOutputs } from "@/components/contextual/OutputRenderers/utils/download";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import React, { useMemo, useState } from "react";
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import Link from "next/link";
|
||||
import { useGetV2GetLibraryAgentByGraphId } from "@/app/api/__generated__/endpoints/library/library";
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { useQueryStates, parseAsString } from "nuqs";
|
||||
import { isValidUUID } from "@/app/(platform)/chat/helpers";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
import { isValidUUID } from "@/lib/utils";
|
||||
import Link from "next/link";
|
||||
import { parseAsString, useQueryStates } from "nuqs";
|
||||
|
||||
export const WebhookDisclaimer = ({ nodeId }: { nodeId: string }) => {
|
||||
const [{ flowID }] = useQueryStates({
|
||||
|
||||
@@ -31,8 +31,6 @@ export const OutputHandler = ({
|
||||
const [isOutputVisible, setIsOutputVisible] = useState(true);
|
||||
const brokenOutputs = useBrokenOutputs(nodeId);
|
||||
|
||||
console.log("brokenOutputs", brokenOutputs);
|
||||
|
||||
const showHandles = uiType !== BlockUIType.OUTPUT;
|
||||
|
||||
const renderOutputHandles = (
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import type { OutputMetadata } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputActions,
|
||||
OutputItem,
|
||||
} from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
|
||||
@@ -3,7 +3,6 @@ import {
|
||||
CustomNodeData,
|
||||
} from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
|
||||
import { NodeTableInput } from "@/app/(platform)/build/components/legacy-builder/NodeTableInput";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { Calendar } from "@/components/__legacy__/ui/calendar";
|
||||
import { LocalValuedInput } from "@/components/__legacy__/ui/input";
|
||||
@@ -28,6 +27,7 @@ import {
|
||||
SelectValue,
|
||||
} from "@/components/__legacy__/ui/select";
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
|
||||
import { GoogleDrivePickerInput } from "@/components/contextual/GoogleDrivePicker/GoogleDrivePickerInput";
|
||||
import {
|
||||
BlockIOArraySubSchema,
|
||||
|
||||
@@ -5,6 +5,8 @@ import { customEdgeToLink, linkToCustomEdge } from "../components/helper";
|
||||
import { MarkerType } from "@xyflow/react";
|
||||
import { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
|
||||
import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers";
|
||||
import { useHistoryStore } from "./historyStore";
|
||||
import { useNodeStore } from "./nodeStore";
|
||||
|
||||
type EdgeStore = {
|
||||
edges: CustomEdge[];
|
||||
@@ -53,25 +55,36 @@ export const useEdgeStore = create<EdgeStore>((set, get) => ({
|
||||
id,
|
||||
};
|
||||
|
||||
set((state) => {
|
||||
const exists = state.edges.some(
|
||||
(e) =>
|
||||
e.source === newEdge.source &&
|
||||
e.target === newEdge.target &&
|
||||
e.sourceHandle === newEdge.sourceHandle &&
|
||||
e.targetHandle === newEdge.targetHandle,
|
||||
);
|
||||
if (exists) return state;
|
||||
return { edges: [...state.edges, newEdge] };
|
||||
});
|
||||
const exists = get().edges.some(
|
||||
(e) =>
|
||||
e.source === newEdge.source &&
|
||||
e.target === newEdge.target &&
|
||||
e.sourceHandle === newEdge.sourceHandle &&
|
||||
e.targetHandle === newEdge.targetHandle,
|
||||
);
|
||||
if (exists) return newEdge;
|
||||
const prevState = {
|
||||
nodes: useNodeStore.getState().nodes,
|
||||
edges: get().edges,
|
||||
};
|
||||
|
||||
set((state) => ({ edges: [...state.edges, newEdge] }));
|
||||
useHistoryStore.getState().pushState(prevState);
|
||||
|
||||
return newEdge;
|
||||
},
|
||||
|
||||
removeEdge: (edgeId) =>
|
||||
removeEdge: (edgeId) => {
|
||||
const prevState = {
|
||||
nodes: useNodeStore.getState().nodes,
|
||||
edges: get().edges,
|
||||
};
|
||||
|
||||
set((state) => ({
|
||||
edges: state.edges.filter((e) => e.id !== edgeId),
|
||||
})),
|
||||
}));
|
||||
useHistoryStore.getState().pushState(prevState);
|
||||
},
|
||||
|
||||
upsertMany: (edges) =>
|
||||
set((state) => {
|
||||
|
||||
@@ -37,6 +37,15 @@ export const useHistoryStore = create<HistoryStore>((set, get) => ({
|
||||
return;
|
||||
}
|
||||
|
||||
const actualCurrentState = {
|
||||
nodes: useNodeStore.getState().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
|
||||
if (isEqual(state, actualCurrentState)) {
|
||||
return;
|
||||
}
|
||||
|
||||
set((prev) => ({
|
||||
past: [...prev.past.slice(-MAX_HISTORY + 1), state],
|
||||
future: [],
|
||||
@@ -55,18 +64,25 @@ export const useHistoryStore = create<HistoryStore>((set, get) => ({
|
||||
|
||||
undo: () => {
|
||||
const { past, future } = get();
|
||||
if (past.length <= 1) return;
|
||||
if (past.length === 0) return;
|
||||
|
||||
const currentState = past[past.length - 1];
|
||||
const actualCurrentState = {
|
||||
nodes: useNodeStore.getState().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
|
||||
const previousState = past[past.length - 2];
|
||||
const previousState = past[past.length - 1];
|
||||
|
||||
if (isEqual(actualCurrentState, previousState)) {
|
||||
return;
|
||||
}
|
||||
|
||||
useNodeStore.getState().setNodes(previousState.nodes);
|
||||
useEdgeStore.getState().setEdges(previousState.edges);
|
||||
|
||||
set({
|
||||
past: past.slice(0, -1),
|
||||
future: [currentState, ...future],
|
||||
past: past.length > 1 ? past.slice(0, -1) : past,
|
||||
future: [actualCurrentState, ...future],
|
||||
});
|
||||
},
|
||||
|
||||
@@ -74,18 +90,36 @@ export const useHistoryStore = create<HistoryStore>((set, get) => ({
|
||||
const { past, future } = get();
|
||||
if (future.length === 0) return;
|
||||
|
||||
const actualCurrentState = {
|
||||
nodes: useNodeStore.getState().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
|
||||
const nextState = future[0];
|
||||
|
||||
useNodeStore.getState().setNodes(nextState.nodes);
|
||||
useEdgeStore.getState().setEdges(nextState.edges);
|
||||
|
||||
const lastPast = past[past.length - 1];
|
||||
const shouldPushToPast =
|
||||
!lastPast || !isEqual(actualCurrentState, lastPast);
|
||||
|
||||
set({
|
||||
past: [...past, nextState],
|
||||
past: shouldPushToPast ? [...past, actualCurrentState] : past,
|
||||
future: future.slice(1),
|
||||
});
|
||||
},
|
||||
|
||||
canUndo: () => get().past.length > 1,
|
||||
canUndo: () => {
|
||||
const { past } = get();
|
||||
if (past.length === 0) return false;
|
||||
|
||||
const actualCurrentState = {
|
||||
nodes: useNodeStore.getState().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
return !isEqual(actualCurrentState, past[past.length - 1]);
|
||||
},
|
||||
canRedo: () => get().future.length > 0,
|
||||
|
||||
clear: () => set({ past: [{ nodes: [], edges: [] }], future: [] }),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { create } from "zustand";
|
||||
import { NodeChange, XYPosition, applyNodeChanges } from "@xyflow/react";
|
||||
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import { CustomEdge } from "../components/FlowEditor/edges/CustomEdge";
|
||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||
import {
|
||||
convertBlockInfoIntoCustomNodeData,
|
||||
@@ -44,6 +45,8 @@ const MINIMUM_MOVE_BEFORE_LOG = 50;
|
||||
// Track initial positions when drag starts (outside store to avoid re-renders)
|
||||
const dragStartPositions: Record<string, XYPosition> = {};
|
||||
|
||||
let dragStartState: { nodes: CustomNode[]; edges: CustomEdge[] } | null = null;
|
||||
|
||||
type NodeStore = {
|
||||
nodes: CustomNode[];
|
||||
nodeCounter: number;
|
||||
@@ -124,14 +127,20 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
nodeCounter: state.nodeCounter + 1,
|
||||
})),
|
||||
onNodesChange: (changes) => {
|
||||
const prevState = {
|
||||
nodes: get().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
|
||||
// Track initial positions when drag starts
|
||||
changes.forEach((change) => {
|
||||
if (change.type === "position" && change.dragging === true) {
|
||||
if (!dragStartState) {
|
||||
const currentNodes = get().nodes;
|
||||
const currentEdges = useEdgeStore.getState().edges;
|
||||
dragStartState = {
|
||||
nodes: currentNodes.map((n) => ({
|
||||
...n,
|
||||
position: { ...n.position },
|
||||
data: { ...n.data },
|
||||
})),
|
||||
edges: currentEdges.map((e) => ({ ...e })),
|
||||
};
|
||||
}
|
||||
if (!dragStartPositions[change.id]) {
|
||||
const node = get().nodes.find((n) => n.id === change.id);
|
||||
if (node) {
|
||||
@@ -141,12 +150,17 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
}
|
||||
});
|
||||
|
||||
// Check if we should track this change in history
|
||||
let shouldTrack = changes.some(
|
||||
(change) => change.type === "remove" || change.type === "add",
|
||||
);
|
||||
let shouldTrack = changes.some((change) => change.type === "remove");
|
||||
let stateToTrack: { nodes: CustomNode[]; edges: CustomEdge[] } | null =
|
||||
null;
|
||||
|
||||
if (shouldTrack) {
|
||||
stateToTrack = {
|
||||
nodes: get().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
}
|
||||
|
||||
// For position changes, only track if movement exceeds threshold
|
||||
if (!shouldTrack) {
|
||||
changes.forEach((change) => {
|
||||
if (change.type === "position" && change.dragging === false) {
|
||||
@@ -158,20 +172,23 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
);
|
||||
if (distanceMoved > MINIMUM_MOVE_BEFORE_LOG) {
|
||||
shouldTrack = true;
|
||||
stateToTrack = dragStartState;
|
||||
}
|
||||
}
|
||||
// Clean up tracked position after drag ends
|
||||
delete dragStartPositions[change.id];
|
||||
}
|
||||
});
|
||||
if (Object.keys(dragStartPositions).length === 0) {
|
||||
dragStartState = null;
|
||||
}
|
||||
}
|
||||
|
||||
set((state) => ({
|
||||
nodes: applyNodeChanges(changes, state.nodes),
|
||||
}));
|
||||
|
||||
if (shouldTrack) {
|
||||
useHistoryStore.getState().pushState(prevState);
|
||||
if (shouldTrack && stateToTrack) {
|
||||
useHistoryStore.getState().pushState(stateToTrack);
|
||||
}
|
||||
},
|
||||
|
||||
@@ -185,6 +202,11 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
hardcodedValues?: Record<string, any>,
|
||||
position?: XYPosition,
|
||||
) => {
|
||||
const prevState = {
|
||||
nodes: get().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
|
||||
const customNodeData = convertBlockInfoIntoCustomNodeData(
|
||||
block,
|
||||
hardcodedValues,
|
||||
@@ -218,21 +240,24 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
set((state) => ({
|
||||
nodes: [...state.nodes, customNode],
|
||||
}));
|
||||
|
||||
useHistoryStore.getState().pushState(prevState);
|
||||
|
||||
return customNode;
|
||||
},
|
||||
updateNodeData: (nodeId, data) => {
|
||||
const prevState = {
|
||||
nodes: get().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
|
||||
set((state) => ({
|
||||
nodes: state.nodes.map((n) =>
|
||||
n.id === nodeId ? { ...n, data: { ...n.data, ...data } } : n,
|
||||
),
|
||||
}));
|
||||
|
||||
const newState = {
|
||||
nodes: get().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
|
||||
useHistoryStore.getState().pushState(newState);
|
||||
useHistoryStore.getState().pushState(prevState);
|
||||
},
|
||||
toggleAdvanced: (nodeId: string) =>
|
||||
set((state) => ({
|
||||
@@ -391,6 +416,11 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
},
|
||||
|
||||
setCredentialsOptional: (nodeId: string, optional: boolean) => {
|
||||
const prevState = {
|
||||
nodes: get().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
|
||||
set((state) => ({
|
||||
nodes: state.nodes.map((n) =>
|
||||
n.id === nodeId
|
||||
@@ -408,12 +438,7 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
),
|
||||
}));
|
||||
|
||||
const newState = {
|
||||
nodes: get().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
|
||||
useHistoryStore.getState().pushState(newState);
|
||||
useHistoryStore.getState().pushState(prevState);
|
||||
},
|
||||
|
||||
// Sub-agent resolution mode state
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { List } from "@phosphor-icons/react";
|
||||
import React, { useState } from "react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
|
||||
import { ChatLoadingState } from "./components/ChatLoadingState/ChatLoadingState";
|
||||
import { SessionsDrawer } from "./components/SessionsDrawer/SessionsDrawer";
|
||||
import { useChat } from "./useChat";
|
||||
|
||||
export interface ChatProps {
|
||||
className?: string;
|
||||
headerTitle?: React.ReactNode;
|
||||
showHeader?: boolean;
|
||||
showSessionInfo?: boolean;
|
||||
showNewChatButton?: boolean;
|
||||
onNewChat?: () => void;
|
||||
headerActions?: React.ReactNode;
|
||||
}
|
||||
|
||||
export function Chat({
|
||||
className,
|
||||
headerTitle = "AutoGPT Copilot",
|
||||
showHeader = true,
|
||||
showSessionInfo = true,
|
||||
showNewChatButton = true,
|
||||
onNewChat,
|
||||
headerActions,
|
||||
}: ChatProps) {
|
||||
const {
|
||||
messages,
|
||||
isLoading,
|
||||
isCreating,
|
||||
error,
|
||||
sessionId,
|
||||
createSession,
|
||||
clearSession,
|
||||
loadSession,
|
||||
} = useChat();
|
||||
|
||||
const [isSessionsDrawerOpen, setIsSessionsDrawerOpen] = useState(false);
|
||||
|
||||
const handleNewChat = () => {
|
||||
clearSession();
|
||||
onNewChat?.();
|
||||
};
|
||||
|
||||
const handleSelectSession = async (sessionId: string) => {
|
||||
try {
|
||||
await loadSession(sessionId);
|
||||
} catch (err) {
|
||||
console.error("Failed to load session:", err);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
{/* Header */}
|
||||
{showHeader && (
|
||||
<header className="shrink-0 border-t border-zinc-200 bg-white p-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<button
|
||||
aria-label="View sessions"
|
||||
onClick={() => setIsSessionsDrawerOpen(true)}
|
||||
className="flex size-8 items-center justify-center rounded hover:bg-zinc-100"
|
||||
>
|
||||
<List width="1.25rem" height="1.25rem" />
|
||||
</button>
|
||||
{typeof headerTitle === "string" ? (
|
||||
<Text variant="h2" className="text-lg font-semibold">
|
||||
{headerTitle}
|
||||
</Text>
|
||||
) : (
|
||||
headerTitle
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-3">
|
||||
{showSessionInfo && sessionId && (
|
||||
<>
|
||||
{showNewChatButton && (
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={handleNewChat}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{headerActions}
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
)}
|
||||
|
||||
{/* Main Content */}
|
||||
<main className="flex min-h-0 flex-1 flex-col overflow-hidden">
|
||||
{/* Loading State - show when explicitly loading/creating OR when we don't have a session yet and no error */}
|
||||
{(isLoading || isCreating || (!sessionId && !error)) && (
|
||||
<ChatLoadingState
|
||||
message={isCreating ? "Creating session..." : "Loading..."}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Error State */}
|
||||
{error && !isLoading && (
|
||||
<ChatErrorState error={error} onRetry={createSession} />
|
||||
)}
|
||||
|
||||
{/* Session Content */}
|
||||
{sessionId && !isLoading && !error && (
|
||||
<ChatContainer
|
||||
sessionId={sessionId}
|
||||
initialMessages={messages}
|
||||
className="flex-1"
|
||||
/>
|
||||
)}
|
||||
</main>
|
||||
|
||||
{/* Sessions Drawer */}
|
||||
<SessionsDrawer
|
||||
isOpen={isSessionsDrawerOpen}
|
||||
onClose={() => setIsSessionsDrawerOpen(false)}
|
||||
onSelectSession={handleSelectSession}
|
||||
currentSessionId={sessionId}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,15 +1,16 @@
|
||||
import React from "react";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Card } from "@/components/atoms/Card/Card";
|
||||
import { List, Robot, ArrowRight } from "@phosphor-icons/react";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ArrowRight, List, Robot } from "@phosphor-icons/react";
|
||||
import Image from "next/image";
|
||||
|
||||
export interface Agent {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
version?: number;
|
||||
image_url?: string;
|
||||
}
|
||||
|
||||
export interface AgentCarouselMessageProps {
|
||||
@@ -30,7 +31,7 @@ export function AgentCarouselMessage({
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"mx-4 my-2 flex flex-col gap-4 rounded-lg border border-purple-200 bg-purple-50 p-6 dark:border-purple-900 dark:bg-purple-950",
|
||||
"mx-4 my-2 flex flex-col gap-4 rounded-lg border border-purple-200 bg-purple-50 p-6",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
@@ -40,13 +41,10 @@ export function AgentCarouselMessage({
|
||||
<List size={24} weight="bold" className="text-white" />
|
||||
</div>
|
||||
<div>
|
||||
<Text variant="h3" className="text-purple-900 dark:text-purple-100">
|
||||
<Text variant="h3" className="text-purple-900">
|
||||
Found {displayCount} {displayCount === 1 ? "Agent" : "Agents"}
|
||||
</Text>
|
||||
<Text
|
||||
variant="small"
|
||||
className="text-purple-700 dark:text-purple-300"
|
||||
>
|
||||
<Text variant="small" className="text-purple-700">
|
||||
Select an agent to view details or run it
|
||||
</Text>
|
||||
</div>
|
||||
@@ -57,40 +55,49 @@ export function AgentCarouselMessage({
|
||||
{agents.map((agent) => (
|
||||
<Card
|
||||
key={agent.id}
|
||||
className="border border-purple-200 bg-white p-4 dark:border-purple-800 dark:bg-purple-900"
|
||||
className="border border-purple-200 bg-white p-4"
|
||||
>
|
||||
<div className="flex gap-3">
|
||||
<div className="flex h-10 w-10 flex-shrink-0 items-center justify-center rounded-lg bg-purple-100 dark:bg-purple-800">
|
||||
<Robot size={20} weight="bold" className="text-purple-600" />
|
||||
<div className="relative h-10 w-10 flex-shrink-0 overflow-hidden rounded-lg bg-purple-100">
|
||||
{agent.image_url ? (
|
||||
<Image
|
||||
src={agent.image_url}
|
||||
alt={`${agent.name} preview image`}
|
||||
fill
|
||||
className="object-cover"
|
||||
/>
|
||||
) : (
|
||||
<div className="flex h-full w-full items-center justify-center">
|
||||
<Robot
|
||||
size={20}
|
||||
weight="bold"
|
||||
className="text-purple-600"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex-1 space-y-2">
|
||||
<div>
|
||||
<Text
|
||||
variant="body"
|
||||
className="font-semibold text-purple-900 dark:text-purple-100"
|
||||
className="font-semibold text-purple-900"
|
||||
>
|
||||
{agent.name}
|
||||
</Text>
|
||||
{agent.version && (
|
||||
<Text
|
||||
variant="small"
|
||||
className="text-purple-600 dark:text-purple-400"
|
||||
>
|
||||
<Text variant="small" className="text-purple-600">
|
||||
v{agent.version}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
<Text
|
||||
variant="small"
|
||||
className="line-clamp-2 text-purple-700 dark:text-purple-300"
|
||||
>
|
||||
<Text variant="small" className="line-clamp-2 text-purple-700">
|
||||
{agent.description}
|
||||
</Text>
|
||||
{onSelectAgent && (
|
||||
<Button
|
||||
onClick={() => onSelectAgent(agent.id)}
|
||||
variant="ghost"
|
||||
className="mt-2 flex items-center gap-1 p-0 text-sm text-purple-600 hover:text-purple-800 dark:text-purple-400 dark:hover:text-purple-200"
|
||||
className="mt-2 flex items-center gap-1 p-0 text-sm text-purple-600 hover:text-purple-800"
|
||||
>
|
||||
View details
|
||||
<ArrowRight size={16} weight="bold" />
|
||||
@@ -103,10 +110,7 @@ export function AgentCarouselMessage({
|
||||
</div>
|
||||
|
||||
{totalCount && totalCount > agents.length && (
|
||||
<Text
|
||||
variant="small"
|
||||
className="text-center text-purple-600 dark:text-purple-400"
|
||||
>
|
||||
<Text variant="small" className="text-center text-purple-600">
|
||||
Showing {agents.length} of {totalCount} results
|
||||
</Text>
|
||||
)}
|
||||
@@ -0,0 +1,246 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Card } from "@/components/atoms/Card/Card";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
|
||||
import { RunAgentInputs } from "@/components/contextual/RunAgentInputs/RunAgentInputs";
|
||||
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
BlockIOSubSchema,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { cn, isEmpty } from "@/lib/utils";
|
||||
import { PlayIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import { useMemo } from "react";
|
||||
import { useAgentInputsSetup } from "./useAgentInputsSetup";
|
||||
|
||||
type LibraryAgentInputSchemaProperties = LibraryAgent["input_schema"] extends {
|
||||
properties: infer P;
|
||||
}
|
||||
? P extends Record<string, BlockIOSubSchema>
|
||||
? P
|
||||
: Record<string, BlockIOSubSchema>
|
||||
: Record<string, BlockIOSubSchema>;
|
||||
|
||||
type LibraryAgentCredentialsInputSchemaProperties =
|
||||
LibraryAgent["credentials_input_schema"] extends {
|
||||
properties: infer P;
|
||||
}
|
||||
? P extends Record<string, BlockIOCredentialsSubSchema>
|
||||
? P
|
||||
: Record<string, BlockIOCredentialsSubSchema>
|
||||
: Record<string, BlockIOCredentialsSubSchema>;
|
||||
|
||||
interface Props {
|
||||
agentName?: string;
|
||||
inputSchema: LibraryAgentInputSchemaProperties | Record<string, any>;
|
||||
credentialsSchema?:
|
||||
| LibraryAgentCredentialsInputSchemaProperties
|
||||
| Record<string, any>;
|
||||
message: string;
|
||||
requiredFields?: string[];
|
||||
onRun: (
|
||||
inputs: Record<string, any>,
|
||||
credentials: Record<string, any>,
|
||||
) => void;
|
||||
onCancel?: () => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function AgentInputsSetup({
|
||||
agentName,
|
||||
inputSchema,
|
||||
credentialsSchema,
|
||||
message,
|
||||
requiredFields,
|
||||
onRun,
|
||||
onCancel,
|
||||
className,
|
||||
}: Props) {
|
||||
const { inputValues, setInputValue, credentialsValues, setCredentialsValue } =
|
||||
useAgentInputsSetup();
|
||||
|
||||
const inputSchemaObj = useMemo(() => {
|
||||
if (!inputSchema) return { properties: {}, required: [] };
|
||||
if ("properties" in inputSchema && "type" in inputSchema) {
|
||||
return inputSchema as {
|
||||
properties: Record<string, any>;
|
||||
required?: string[];
|
||||
};
|
||||
}
|
||||
return { properties: inputSchema as Record<string, any>, required: [] };
|
||||
}, [inputSchema]);
|
||||
|
||||
const credentialsSchemaObj = useMemo(() => {
|
||||
if (!credentialsSchema) return { properties: {}, required: [] };
|
||||
if ("properties" in credentialsSchema && "type" in credentialsSchema) {
|
||||
return credentialsSchema as {
|
||||
properties: Record<string, any>;
|
||||
required?: string[];
|
||||
};
|
||||
}
|
||||
return {
|
||||
properties: credentialsSchema as Record<string, any>,
|
||||
required: [],
|
||||
};
|
||||
}, [credentialsSchema]);
|
||||
|
||||
const agentInputFields = useMemo(() => {
|
||||
const properties = inputSchemaObj.properties || {};
|
||||
return Object.fromEntries(
|
||||
Object.entries(properties).filter(
|
||||
([_, subSchema]: [string, any]) => !subSchema.hidden,
|
||||
),
|
||||
);
|
||||
}, [inputSchemaObj]);
|
||||
|
||||
const agentCredentialsInputFields = useMemo(() => {
|
||||
return credentialsSchemaObj.properties || {};
|
||||
}, [credentialsSchemaObj]);
|
||||
|
||||
const inputFields = Object.entries(agentInputFields);
|
||||
const credentialFields = Object.entries(agentCredentialsInputFields);
|
||||
|
||||
const defaultsFromSchema = useMemo(() => {
|
||||
const defaults: Record<string, any> = {};
|
||||
Object.entries(agentInputFields).forEach(([key, schema]) => {
|
||||
if ("default" in schema && schema.default !== undefined) {
|
||||
defaults[key] = schema.default;
|
||||
}
|
||||
});
|
||||
return defaults;
|
||||
}, [agentInputFields]);
|
||||
|
||||
const defaultsFromCredentialsSchema = useMemo(() => {
|
||||
const defaults: Record<string, any> = {};
|
||||
Object.entries(agentCredentialsInputFields).forEach(([key, schema]) => {
|
||||
if ("default" in schema && schema.default !== undefined) {
|
||||
defaults[key] = schema.default;
|
||||
}
|
||||
});
|
||||
return defaults;
|
||||
}, [agentCredentialsInputFields]);
|
||||
|
||||
const mergedInputValues = useMemo(() => {
|
||||
return { ...defaultsFromSchema, ...inputValues };
|
||||
}, [defaultsFromSchema, inputValues]);
|
||||
|
||||
const mergedCredentialsValues = useMemo(() => {
|
||||
return { ...defaultsFromCredentialsSchema, ...credentialsValues };
|
||||
}, [defaultsFromCredentialsSchema, credentialsValues]);
|
||||
|
||||
const allRequiredInputsAreSet = useMemo(() => {
|
||||
const requiredInputs = new Set(
|
||||
requiredFields || (inputSchemaObj.required as string[]) || [],
|
||||
);
|
||||
const nonEmptyInputs = new Set(
|
||||
Object.keys(mergedInputValues).filter(
|
||||
(k) => !isEmpty(mergedInputValues[k]),
|
||||
),
|
||||
);
|
||||
const missing = [...requiredInputs].filter(
|
||||
(input) => !nonEmptyInputs.has(input),
|
||||
);
|
||||
return missing.length === 0;
|
||||
}, [inputSchemaObj.required, mergedInputValues, requiredFields]);
|
||||
|
||||
const allCredentialsAreSet = useMemo(() => {
|
||||
const requiredCredentials = new Set(
|
||||
(credentialsSchemaObj.required as string[]) || [],
|
||||
);
|
||||
if (requiredCredentials.size === 0) {
|
||||
return true;
|
||||
}
|
||||
const missing = [...requiredCredentials].filter((key) => {
|
||||
const cred = mergedCredentialsValues[key];
|
||||
return !cred || !cred.id;
|
||||
});
|
||||
return missing.length === 0;
|
||||
}, [credentialsSchemaObj.required, mergedCredentialsValues]);
|
||||
|
||||
const canRun = allRequiredInputsAreSet && allCredentialsAreSet;
|
||||
|
||||
function handleRun() {
|
||||
if (canRun) {
|
||||
onRun(mergedInputValues, mergedCredentialsValues);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Card
|
||||
className={cn(
|
||||
"mx-4 my-2 overflow-hidden border-blue-200 bg-blue-50",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex items-start gap-4 p-6">
|
||||
<div className="flex h-12 w-12 flex-shrink-0 items-center justify-center rounded-full bg-blue-500">
|
||||
<WarningIcon size={24} weight="bold" className="text-white" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<Text variant="h3" className="mb-2 text-blue-900">
|
||||
{agentName ? `Configure ${agentName}` : "Agent Configuration"}
|
||||
</Text>
|
||||
<Text variant="body" className="mb-4 text-blue-700">
|
||||
{message}
|
||||
</Text>
|
||||
|
||||
{inputFields.length > 0 && (
|
||||
<div className="mb-4 space-y-4">
|
||||
{inputFields.map(([key, inputSubSchema]) => (
|
||||
<RunAgentInputs
|
||||
key={key}
|
||||
schema={inputSubSchema}
|
||||
value={inputValues[key] ?? inputSubSchema.default}
|
||||
placeholder={inputSubSchema.description}
|
||||
onChange={(value) => setInputValue(key, value)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{credentialFields.length > 0 && (
|
||||
<div className="mb-4 space-y-4">
|
||||
{credentialFields.map(([key, schema]) => {
|
||||
const requiredCredentials = new Set(
|
||||
(credentialsSchemaObj.required as string[]) || [],
|
||||
);
|
||||
return (
|
||||
<CredentialsInput
|
||||
key={key}
|
||||
schema={schema}
|
||||
selectedCredentials={credentialsValues[key]}
|
||||
onSelectCredentials={(value) =>
|
||||
setCredentialsValue(key, value)
|
||||
}
|
||||
siblingInputs={mergedInputValues}
|
||||
isOptional={!requiredCredentials.has(key)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={handleRun}
|
||||
disabled={!canRun}
|
||||
>
|
||||
<PlayIcon className="mr-2 h-4 w-4" weight="bold" />
|
||||
Run Agent
|
||||
</Button>
|
||||
{onCancel && (
|
||||
<Button variant="outline" size="small" onClick={onCancel}>
|
||||
Cancel
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
|
||||
import { useState } from "react";
|
||||
|
||||
export function useAgentInputsSetup() {
|
||||
const [inputValues, setInputValues] = useState<Record<string, any>>({});
|
||||
const [credentialsValues, setCredentialsValues] = useState<
|
||||
Record<string, CredentialsMetaInput>
|
||||
>({});
|
||||
|
||||
function setInputValue(key: string, value: any) {
|
||||
setInputValues((prev) => ({
|
||||
...prev,
|
||||
[key]: value,
|
||||
}));
|
||||
}
|
||||
|
||||
function setCredentialsValue(key: string, value?: CredentialsMetaInput) {
|
||||
if (value) {
|
||||
setCredentialsValues((prev) => ({
|
||||
...prev,
|
||||
[key]: value,
|
||||
}));
|
||||
} else {
|
||||
setCredentialsValues((prev) => {
|
||||
const next = { ...prev };
|
||||
delete next[key];
|
||||
return next;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
inputValues,
|
||||
setInputValue,
|
||||
credentialsValues,
|
||||
setCredentialsValue,
|
||||
};
|
||||
}
|
||||
@@ -1,10 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { SignInIcon, UserPlusIcon, ShieldIcon } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ShieldIcon, SignInIcon, UserPlusIcon } from "@phosphor-icons/react";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export interface AuthPromptWidgetProps {
|
||||
message: string;
|
||||
@@ -54,8 +53,8 @@ export function AuthPromptWidget({
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"my-4 overflow-hidden rounded-lg border border-violet-200 dark:border-violet-800",
|
||||
"bg-gradient-to-br from-violet-50 to-purple-50 dark:from-violet-950/30 dark:to-purple-950/30",
|
||||
"my-4 overflow-hidden rounded-lg border border-violet-200",
|
||||
"bg-gradient-to-br from-violet-50 to-purple-50",
|
||||
"duration-500 animate-in fade-in-50 slide-in-from-bottom-2",
|
||||
className,
|
||||
)}
|
||||
@@ -66,21 +65,19 @@ export function AuthPromptWidget({
|
||||
<ShieldIcon size={20} weight="fill" className="text-white" />
|
||||
</div>
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-neutral-900 dark:text-neutral-100">
|
||||
<h3 className="text-lg font-semibold text-neutral-900">
|
||||
Authentication Required
|
||||
</h3>
|
||||
<p className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
<p className="text-sm text-neutral-600">
|
||||
Sign in to set up and manage agents
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mb-5 rounded-md bg-white/50 p-4 dark:bg-neutral-900/50">
|
||||
<p className="text-sm text-neutral-700 dark:text-neutral-300">
|
||||
{message}
|
||||
</p>
|
||||
<div className="mb-5 rounded-md bg-white/50 p-4">
|
||||
<p className="text-sm text-neutral-700">{message}</p>
|
||||
{agentInfo && (
|
||||
<div className="mt-3 text-xs text-neutral-600 dark:text-neutral-400">
|
||||
<div className="mt-3 text-xs text-neutral-600">
|
||||
<p>
|
||||
Ready to set up:{" "}
|
||||
<span className="font-medium">{agentInfo.name}</span>
|
||||
@@ -114,7 +111,7 @@ export function AuthPromptWidget({
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="mt-4 text-center text-xs text-neutral-500 dark:text-neutral-500">
|
||||
<div className="mt-4 text-center text-xs text-neutral-500">
|
||||
Your chat session will be preserved after signing in
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,88 @@
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useCallback } from "react";
|
||||
import { usePageContext } from "../../usePageContext";
|
||||
import { ChatInput } from "../ChatInput/ChatInput";
|
||||
import { MessageList } from "../MessageList/MessageList";
|
||||
import { QuickActionsWelcome } from "../QuickActionsWelcome/QuickActionsWelcome";
|
||||
import { useChatContainer } from "./useChatContainer";
|
||||
|
||||
export interface ChatContainerProps {
|
||||
sessionId: string | null;
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
className,
|
||||
}: ChatContainerProps) {
|
||||
const { messages, streamingChunks, isStreaming, sendMessage } =
|
||||
useChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
});
|
||||
const { capturePageContext } = usePageContext();
|
||||
|
||||
// Wrap sendMessage to automatically capture page context
|
||||
const sendMessageWithContext = useCallback(
|
||||
async (content: string, isUserMessage: boolean = true) => {
|
||||
const context = capturePageContext();
|
||||
await sendMessage(content, isUserMessage, context);
|
||||
},
|
||||
[sendMessage, capturePageContext],
|
||||
);
|
||||
|
||||
const quickActions = [
|
||||
"Find agents for social media management",
|
||||
"Show me agents for content creation",
|
||||
"Help me automate my business",
|
||||
"What can you help me with?",
|
||||
];
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn("flex h-full min-h-0 flex-col", className)}
|
||||
style={{
|
||||
backgroundColor: "#ffffff",
|
||||
backgroundImage:
|
||||
"radial-gradient(#e5e5e5 0.5px, transparent 0.5px), radial-gradient(#e5e5e5 0.5px, #ffffff 0.5px)",
|
||||
backgroundSize: "20px 20px",
|
||||
backgroundPosition: "0 0, 10px 10px",
|
||||
}}
|
||||
>
|
||||
{/* Messages or Welcome Screen */}
|
||||
<div className="flex min-h-0 flex-1 flex-col overflow-hidden pb-24">
|
||||
{messages.length === 0 ? (
|
||||
<QuickActionsWelcome
|
||||
title="Welcome to AutoGPT Copilot"
|
||||
description="Start a conversation to discover and run AI agents."
|
||||
actions={quickActions}
|
||||
onActionClick={sendMessageWithContext}
|
||||
disabled={isStreaming || !sessionId}
|
||||
/>
|
||||
) : (
|
||||
<MessageList
|
||||
messages={messages}
|
||||
streamingChunks={streamingChunks}
|
||||
isStreaming={isStreaming}
|
||||
onSendMessage={sendMessageWithContext}
|
||||
className="flex-1"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Input - Always visible */}
|
||||
<div className="fixed bottom-0 left-0 right-0 z-50 border-t border-zinc-200 bg-white p-4">
|
||||
<ChatInput
|
||||
onSend={sendMessageWithContext}
|
||||
disabled={isStreaming || !sessionId}
|
||||
placeholder={
|
||||
sessionId ? "Type your message..." : "Creating session..."
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,14 +1,14 @@
|
||||
import { toast } from "sonner";
|
||||
import type { StreamChunk } from "@/app/(platform)/chat/useChatStream";
|
||||
import { StreamChunk } from "../../useChatStream";
|
||||
import type { HandlerDependencies } from "./useChatContainer.handlers";
|
||||
import {
|
||||
handleError,
|
||||
handleLoginNeeded,
|
||||
handleStreamEnd,
|
||||
handleTextChunk,
|
||||
handleTextEnded,
|
||||
handleToolCallStart,
|
||||
handleToolResponse,
|
||||
handleLoginNeeded,
|
||||
handleStreamEnd,
|
||||
handleError,
|
||||
} from "./useChatContainer.handlers";
|
||||
|
||||
export function createStreamEventDispatcher(
|
||||
@@ -1,5 +1,24 @@
|
||||
import type { ChatMessageData } from "@/app/(platform)/chat/components/ChatMessage/useChatMessage";
|
||||
import type { ToolResult } from "@/types/chat";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
|
||||
export function removePageContext(content: string): string {
|
||||
// Remove "Page URL: ..." pattern at start of line (case insensitive, handles various formats)
|
||||
let cleaned = content.replace(/^\s*Page URL:\s*[^\n\r]*/gim, "");
|
||||
|
||||
// Find "User Message:" marker at start of line to preserve the actual user message
|
||||
const userMessageMatch = cleaned.match(/^\s*User Message:\s*([\s\S]*)$/im);
|
||||
if (userMessageMatch) {
|
||||
// If we found "User Message:", extract everything after it
|
||||
cleaned = userMessageMatch[1];
|
||||
} else {
|
||||
// If no "User Message:" marker, remove "Page Content:" and everything after it at start of line
|
||||
cleaned = cleaned.replace(/^\s*Page Content:[\s\S]*$/gim, "");
|
||||
}
|
||||
|
||||
// Clean up extra whitespace and newlines
|
||||
cleaned = cleaned.replace(/\n\s*\n\s*\n+/g, "\n\n").trim();
|
||||
return cleaned;
|
||||
}
|
||||
|
||||
export function createUserMessage(content: string): ChatMessageData {
|
||||
return {
|
||||
@@ -63,6 +82,7 @@ export function isAgentArray(value: unknown): value is Array<{
|
||||
name: string;
|
||||
description: string;
|
||||
version?: number;
|
||||
image_url?: string;
|
||||
}> {
|
||||
if (!Array.isArray(value)) {
|
||||
return false;
|
||||
@@ -77,7 +97,8 @@ export function isAgentArray(value: unknown): value is Array<{
|
||||
typeof item.name === "string" &&
|
||||
"description" in item &&
|
||||
typeof item.description === "string" &&
|
||||
(!("version" in item) || typeof item.version === "number"),
|
||||
(!("version" in item) || typeof item.version === "number") &&
|
||||
(!("image_url" in item) || typeof item.image_url === "string"),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -232,6 +253,7 @@ export function isSetupInfo(value: unknown): value is {
|
||||
|
||||
export function extractCredentialsNeeded(
|
||||
parsedResult: Record<string, unknown>,
|
||||
toolName: string = "run_agent",
|
||||
): ChatMessageData | null {
|
||||
try {
|
||||
const setupInfo = parsedResult?.setup_info as
|
||||
@@ -244,7 +266,7 @@ export function extractCredentialsNeeded(
|
||||
| Record<string, Record<string, unknown>>
|
||||
| undefined;
|
||||
if (missingCreds && Object.keys(missingCreds).length > 0) {
|
||||
const agentName = (setupInfo?.agent_name as string) || "this agent";
|
||||
const agentName = (setupInfo?.agent_name as string) || "this block";
|
||||
const credentials = Object.values(missingCreds).map((credInfo) => ({
|
||||
provider: (credInfo.provider as string) || "unknown",
|
||||
providerName:
|
||||
@@ -264,7 +286,7 @@ export function extractCredentialsNeeded(
|
||||
}));
|
||||
return {
|
||||
type: "credentials_needed",
|
||||
toolName: "run_agent",
|
||||
toolName,
|
||||
credentials,
|
||||
message: `To run ${agentName}, you need to add ${credentials.length === 1 ? "credentials" : `${credentials.length} credentials`}.`,
|
||||
agentName,
|
||||
@@ -277,3 +299,92 @@ export function extractCredentialsNeeded(
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function extractInputsNeeded(
|
||||
parsedResult: Record<string, unknown>,
|
||||
toolName: string = "run_agent",
|
||||
): ChatMessageData | null {
|
||||
try {
|
||||
const setupInfo = parsedResult?.setup_info as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
const requirements = setupInfo?.requirements as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
const inputs = requirements?.inputs as
|
||||
| Array<Record<string, unknown>>
|
||||
| undefined;
|
||||
const credentials = requirements?.credentials as
|
||||
| Array<Record<string, unknown>>
|
||||
| undefined;
|
||||
|
||||
if (!inputs || inputs.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const agentName = (setupInfo?.agent_name as string) || "this agent";
|
||||
const agentId = parsedResult?.graph_id as string | undefined;
|
||||
const graphVersion = parsedResult?.graph_version as number | undefined;
|
||||
|
||||
const properties: Record<string, any> = {};
|
||||
const requiredProps: string[] = [];
|
||||
inputs.forEach((input) => {
|
||||
const name = input.name as string;
|
||||
if (name) {
|
||||
properties[name] = {
|
||||
title: input.name as string,
|
||||
description: (input.description as string) || "",
|
||||
type: (input.type as string) || "string",
|
||||
default: input.default,
|
||||
enum: input.options,
|
||||
format: input.format,
|
||||
};
|
||||
if ((input.required as boolean) === true) {
|
||||
requiredProps.push(name);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
const inputSchema: Record<string, any> = {
|
||||
type: "object",
|
||||
properties,
|
||||
};
|
||||
if (requiredProps.length > 0) {
|
||||
inputSchema.required = requiredProps;
|
||||
}
|
||||
|
||||
const credentialsSchema: Record<string, any> = {};
|
||||
if (credentials && credentials.length > 0) {
|
||||
credentials.forEach((cred) => {
|
||||
const id = cred.id as string;
|
||||
if (id) {
|
||||
credentialsSchema[id] = {
|
||||
type: "object",
|
||||
properties: {},
|
||||
credentials_provider: [cred.provider as string],
|
||||
credentials_types: [(cred.type as string) || "api_key"],
|
||||
credentials_scopes: cred.scopes as string[] | undefined,
|
||||
};
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
type: "inputs_needed",
|
||||
toolName,
|
||||
agentName,
|
||||
agentId,
|
||||
graphVersion,
|
||||
inputSchema,
|
||||
credentialsSchema:
|
||||
Object.keys(credentialsSchema).length > 0
|
||||
? credentialsSchema
|
||||
: undefined,
|
||||
message: `Please provide the required inputs to run ${agentName}.`,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
} catch (err) {
|
||||
console.error("Failed to extract inputs from setup info:", err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,18 @@
|
||||
import type { Dispatch, SetStateAction, MutableRefObject } from "react";
|
||||
import type { StreamChunk } from "@/app/(platform)/chat/useChatStream";
|
||||
import type { ChatMessageData } from "@/app/(platform)/chat/components/ChatMessage/useChatMessage";
|
||||
import { parseToolResponse, extractCredentialsNeeded } from "./helpers";
|
||||
import type { Dispatch, MutableRefObject, SetStateAction } from "react";
|
||||
import { StreamChunk } from "../../useChatStream";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
import {
|
||||
extractCredentialsNeeded,
|
||||
extractInputsNeeded,
|
||||
parseToolResponse,
|
||||
} from "./helpers";
|
||||
|
||||
export interface HandlerDependencies {
|
||||
setHasTextChunks: Dispatch<SetStateAction<boolean>>;
|
||||
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
|
||||
streamingChunksRef: MutableRefObject<string[]>;
|
||||
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
|
||||
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
||||
sessionId: string;
|
||||
}
|
||||
|
||||
@@ -100,11 +105,18 @@ export function handleToolResponse(
|
||||
parsedResult = null;
|
||||
}
|
||||
if (
|
||||
chunk.tool_name === "run_agent" &&
|
||||
(chunk.tool_name === "run_agent" || chunk.tool_name === "run_block") &&
|
||||
chunk.success &&
|
||||
parsedResult?.type === "setup_requirements"
|
||||
) {
|
||||
const credentialsMessage = extractCredentialsNeeded(parsedResult);
|
||||
const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name);
|
||||
if (inputsMessage) {
|
||||
deps.setMessages((prev) => [...prev, inputsMessage]);
|
||||
}
|
||||
const credentialsMessage = extractCredentialsNeeded(
|
||||
parsedResult,
|
||||
chunk.tool_name,
|
||||
);
|
||||
if (credentialsMessage) {
|
||||
deps.setMessages((prev) => [...prev, credentialsMessage]);
|
||||
}
|
||||
@@ -197,10 +209,15 @@ export function handleStreamEnd(
|
||||
deps.setStreamingChunks([]);
|
||||
deps.streamingChunksRef.current = [];
|
||||
deps.setHasTextChunks(false);
|
||||
deps.setIsStreamingInitiated(false);
|
||||
console.log("[Stream End] Stream complete, messages in local state");
|
||||
}
|
||||
|
||||
export function handleError(chunk: StreamChunk, _deps: HandlerDependencies) {
|
||||
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||
const errorMessage = chunk.message || chunk.content || "An error occurred";
|
||||
console.error("Stream error:", errorMessage);
|
||||
deps.setIsStreamingInitiated(false);
|
||||
deps.setHasTextChunks(false);
|
||||
deps.setStreamingChunks([]);
|
||||
deps.streamingChunksRef.current = [];
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { useCallback, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { useChatStream } from "../../useChatStream";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
||||
import {
|
||||
createUserMessage,
|
||||
filterAuthMessages,
|
||||
isToolCallArray,
|
||||
isValidMessage,
|
||||
parseToolResponse,
|
||||
removePageContext,
|
||||
} from "./helpers";
|
||||
|
||||
interface Args {
|
||||
sessionId: string | null;
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
}
|
||||
|
||||
export function useChatContainer({ sessionId, initialMessages }: Args) {
|
||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||
const [hasTextChunks, setHasTextChunks] = useState(false);
|
||||
const [isStreamingInitiated, setIsStreamingInitiated] = useState(false);
|
||||
const streamingChunksRef = useRef<string[]>([]);
|
||||
const { error, sendMessage: sendStreamMessage } = useChatStream();
|
||||
const isStreaming = isStreamingInitiated || hasTextChunks;
|
||||
|
||||
const allMessages = useMemo(() => {
|
||||
const processedInitialMessages: ChatMessageData[] = [];
|
||||
// Map to track tool calls by their ID so we can look up tool names for tool responses
|
||||
const toolCallMap = new Map<string, string>();
|
||||
|
||||
for (const msg of initialMessages) {
|
||||
if (!isValidMessage(msg)) {
|
||||
console.warn("Invalid message structure from backend:", msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
let content = String(msg.content || "");
|
||||
const role = String(msg.role || "assistant").toLowerCase();
|
||||
const toolCalls = msg.tool_calls;
|
||||
const timestamp = msg.timestamp
|
||||
? new Date(msg.timestamp as string)
|
||||
: undefined;
|
||||
|
||||
// Remove page context from user messages when loading existing sessions
|
||||
if (role === "user") {
|
||||
content = removePageContext(content);
|
||||
// Skip user messages that become empty after removing page context
|
||||
if (!content.trim()) {
|
||||
continue;
|
||||
}
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
role: "user",
|
||||
content,
|
||||
timestamp,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle assistant messages first (before tool messages) to build tool call map
|
||||
if (role === "assistant") {
|
||||
// Strip <thinking> tags from content
|
||||
content = content
|
||||
.replace(/<thinking>[\s\S]*?<\/thinking>/gi, "")
|
||||
.trim();
|
||||
|
||||
// If assistant has tool calls, create tool_call messages for each
|
||||
if (toolCalls && isToolCallArray(toolCalls) && toolCalls.length > 0) {
|
||||
for (const toolCall of toolCalls) {
|
||||
const toolName = toolCall.function.name;
|
||||
const toolId = toolCall.id;
|
||||
// Store tool name for later lookup
|
||||
toolCallMap.set(toolId, toolName);
|
||||
|
||||
try {
|
||||
const args = JSON.parse(toolCall.function.arguments || "{}");
|
||||
processedInitialMessages.push({
|
||||
type: "tool_call",
|
||||
toolId,
|
||||
toolName,
|
||||
arguments: args,
|
||||
timestamp,
|
||||
});
|
||||
} catch (err) {
|
||||
console.warn("Failed to parse tool call arguments:", err);
|
||||
processedInitialMessages.push({
|
||||
type: "tool_call",
|
||||
toolId,
|
||||
toolName,
|
||||
arguments: {},
|
||||
timestamp,
|
||||
});
|
||||
}
|
||||
}
|
||||
// Only add assistant message if there's content after stripping thinking tags
|
||||
if (content.trim()) {
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content,
|
||||
timestamp,
|
||||
});
|
||||
}
|
||||
} else if (content.trim()) {
|
||||
// Assistant message without tool calls, but with content
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content,
|
||||
timestamp,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle tool messages - look up tool name from tool call map
|
||||
if (role === "tool") {
|
||||
const toolCallId = (msg.tool_call_id as string) || "";
|
||||
const toolName = toolCallMap.get(toolCallId) || "unknown";
|
||||
const toolResponse = parseToolResponse(
|
||||
content,
|
||||
toolCallId,
|
||||
toolName,
|
||||
timestamp,
|
||||
);
|
||||
if (toolResponse) {
|
||||
processedInitialMessages.push(toolResponse);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle other message types (system, etc.)
|
||||
if (content.trim()) {
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
role: role as "user" | "assistant" | "system",
|
||||
content,
|
||||
timestamp,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return [...processedInitialMessages, ...messages];
|
||||
}, [initialMessages, messages]);
|
||||
|
||||
const sendMessage = useCallback(
|
||||
async function sendMessage(
|
||||
content: string,
|
||||
isUserMessage: boolean = true,
|
||||
context?: { url: string; content: string },
|
||||
) {
|
||||
if (!sessionId) {
|
||||
console.error("Cannot send message: no session ID");
|
||||
return;
|
||||
}
|
||||
if (isUserMessage) {
|
||||
const userMessage = createUserMessage(content);
|
||||
setMessages((prev) => [...filterAuthMessages(prev), userMessage]);
|
||||
} else {
|
||||
setMessages((prev) => filterAuthMessages(prev));
|
||||
}
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(true);
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
setMessages,
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
});
|
||||
try {
|
||||
await sendStreamMessage(
|
||||
sessionId,
|
||||
content,
|
||||
dispatcher,
|
||||
isUserMessage,
|
||||
context,
|
||||
);
|
||||
} catch (err) {
|
||||
console.error("Failed to send message:", err);
|
||||
setIsStreamingInitiated(false);
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to send message";
|
||||
toast.error("Failed to send message", {
|
||||
description: errorMessage,
|
||||
});
|
||||
}
|
||||
},
|
||||
[sessionId, sendStreamMessage],
|
||||
);
|
||||
|
||||
return {
|
||||
messages: allMessages,
|
||||
streamingChunks,
|
||||
isStreaming,
|
||||
error,
|
||||
sendMessage,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
|
||||
import type { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CheckIcon, RobotIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { useChatCredentialsSetup } from "./useChatCredentialsSetup";
|
||||
|
||||
export interface CredentialInfo {
|
||||
provider: string;
|
||||
providerName: string;
|
||||
credentialType: "api_key" | "oauth2" | "user_password" | "host_scoped";
|
||||
title: string;
|
||||
scopes?: string[];
|
||||
}
|
||||
|
||||
interface Props {
|
||||
credentials: CredentialInfo[];
|
||||
agentName?: string;
|
||||
message: string;
|
||||
onAllCredentialsComplete: () => void;
|
||||
onCancel: () => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
function createSchemaFromCredentialInfo(
|
||||
credential: CredentialInfo,
|
||||
): BlockIOCredentialsSubSchema {
|
||||
return {
|
||||
type: "object",
|
||||
properties: {},
|
||||
credentials_provider: [credential.provider],
|
||||
credentials_types: [credential.credentialType],
|
||||
credentials_scopes: credential.scopes,
|
||||
discriminator: undefined,
|
||||
discriminator_mapping: undefined,
|
||||
discriminator_values: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
export function ChatCredentialsSetup({
|
||||
credentials,
|
||||
agentName: _agentName,
|
||||
message,
|
||||
onAllCredentialsComplete,
|
||||
onCancel: _onCancel,
|
||||
}: Props) {
|
||||
const { selectedCredentials, isAllComplete, handleCredentialSelect } =
|
||||
useChatCredentialsSetup(credentials);
|
||||
|
||||
// Track if we've already called completion to prevent double calls
|
||||
const hasCalledCompleteRef = useRef(false);
|
||||
|
||||
// Reset the completion flag when credentials change (new credential setup flow)
|
||||
useEffect(
|
||||
function resetCompletionFlag() {
|
||||
hasCalledCompleteRef.current = false;
|
||||
},
|
||||
[credentials],
|
||||
);
|
||||
|
||||
// Auto-call completion when all credentials are configured
|
||||
useEffect(
|
||||
function autoCompleteWhenReady() {
|
||||
if (isAllComplete && !hasCalledCompleteRef.current) {
|
||||
hasCalledCompleteRef.current = true;
|
||||
onAllCredentialsComplete();
|
||||
}
|
||||
},
|
||||
[isAllComplete, onAllCredentialsComplete],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="group relative flex w-full justify-start gap-3 px-4 py-3">
|
||||
<div className="flex w-full max-w-3xl gap-3">
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-500">
|
||||
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex min-w-0 flex-1 flex-col">
|
||||
<div className="group relative min-w-20 overflow-hidden rounded-xl border border-slate-100 bg-slate-50/20 px-6 py-2.5 text-sm leading-relaxed backdrop-blur-xl">
|
||||
<div className="absolute inset-0 bg-gradient-to-br from-slate-200/20 via-slate-300/10 to-transparent" />
|
||||
<div className="relative z-10 space-y-3 text-slate-900">
|
||||
<div>
|
||||
<Text variant="h4" className="mb-1 text-slate-900">
|
||||
Credentials Required
|
||||
</Text>
|
||||
<Text variant="small" className="text-slate-600">
|
||||
{message}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<div className="space-y-3">
|
||||
{credentials.map((cred, index) => {
|
||||
const schema = createSchemaFromCredentialInfo(cred);
|
||||
const isSelected = !!selectedCredentials[cred.provider];
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${cred.provider}-${index}`}
|
||||
className={cn(
|
||||
"relative rounded-lg border p-3",
|
||||
isSelected
|
||||
? "border-green-500 bg-green-50/50"
|
||||
: "border-slate-200 bg-white/50",
|
||||
)}
|
||||
>
|
||||
<div className="mb-2 flex items-center gap-2">
|
||||
{isSelected ? (
|
||||
<CheckIcon
|
||||
size={16}
|
||||
className="text-green-500"
|
||||
weight="bold"
|
||||
/>
|
||||
) : (
|
||||
<WarningIcon
|
||||
size={16}
|
||||
className="text-slate-500"
|
||||
weight="bold"
|
||||
/>
|
||||
)}
|
||||
<Text
|
||||
variant="small"
|
||||
className="font-semibold text-slate-900"
|
||||
>
|
||||
{cred.providerName}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<CredentialsInput
|
||||
schema={schema}
|
||||
selectedCredentials={selectedCredentials[cred.provider]}
|
||||
onSelectCredentials={(credMeta) =>
|
||||
handleCredentialSelect(cred.provider, credMeta)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ArrowUpIcon } from "@phosphor-icons/react";
|
||||
import { useChatInput } from "./useChatInput";
|
||||
|
||||
export interface ChatInputProps {
|
||||
onSend: (message: string) => void;
|
||||
disabled?: boolean;
|
||||
placeholder?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
placeholder = "Type your message...",
|
||||
className,
|
||||
}: ChatInputProps) {
|
||||
const inputId = "chat-input";
|
||||
const { value, setValue, handleKeyDown, handleSend } = useChatInput({
|
||||
onSend,
|
||||
disabled,
|
||||
maxRows: 5,
|
||||
inputId,
|
||||
});
|
||||
|
||||
return (
|
||||
<div className={cn("relative flex-1", className)}>
|
||||
<Input
|
||||
id={inputId}
|
||||
label="Chat message input"
|
||||
hideLabel
|
||||
type="textarea"
|
||||
value={value}
|
||||
onChange={(e) => setValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={placeholder}
|
||||
disabled={disabled}
|
||||
rows={1}
|
||||
wrapperClassName="mb-0 relative"
|
||||
className="pr-12"
|
||||
/>
|
||||
<span id="chat-input-hint" className="sr-only">
|
||||
Press Enter to send, Shift+Enter for new line
|
||||
</span>
|
||||
|
||||
<button
|
||||
onClick={handleSend}
|
||||
disabled={disabled || !value.trim()}
|
||||
className={cn(
|
||||
"absolute right-3 top-1/2 flex h-8 w-8 -translate-y-1/2 items-center justify-center rounded-full",
|
||||
"border border-zinc-800 bg-zinc-800 text-white",
|
||||
"hover:border-zinc-900 hover:bg-zinc-900",
|
||||
"disabled:border-zinc-200 disabled:bg-zinc-200 disabled:text-white disabled:opacity-50",
|
||||
"transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-neutral-950",
|
||||
"disabled:pointer-events-none",
|
||||
)}
|
||||
aria-label="Send message"
|
||||
>
|
||||
<ArrowUpIcon className="h-3 w-3" weight="bold" />
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,21 +1,22 @@
|
||||
import { KeyboardEvent, useCallback, useState, useRef, useEffect } from "react";
|
||||
import { KeyboardEvent, useCallback, useEffect, useState } from "react";
|
||||
|
||||
interface UseChatInputArgs {
|
||||
onSend: (message: string) => void;
|
||||
disabled?: boolean;
|
||||
maxRows?: number;
|
||||
inputId?: string;
|
||||
}
|
||||
|
||||
export function useChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
maxRows = 5,
|
||||
inputId = "chat-input",
|
||||
}: UseChatInputArgs) {
|
||||
const [value, setValue] = useState("");
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const textarea = textareaRef.current;
|
||||
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
||||
if (!textarea) return;
|
||||
textarea.style.height = "auto";
|
||||
const lineHeight = parseInt(
|
||||
@@ -27,23 +28,25 @@ export function useChatInput({
|
||||
textarea.style.height = `${newHeight}px`;
|
||||
textarea.style.overflowY =
|
||||
textarea.scrollHeight > maxHeight ? "auto" : "hidden";
|
||||
}, [value, maxRows]);
|
||||
}, [value, maxRows, inputId]);
|
||||
|
||||
const handleSend = useCallback(() => {
|
||||
if (disabled || !value.trim()) return;
|
||||
onSend(value.trim());
|
||||
setValue("");
|
||||
if (textareaRef.current) {
|
||||
textareaRef.current.style.height = "auto";
|
||||
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
||||
if (textarea) {
|
||||
textarea.style.height = "auto";
|
||||
}
|
||||
}, [value, onSend, disabled]);
|
||||
}, [value, onSend, disabled, inputId]);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(event: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
(event: KeyboardEvent<HTMLInputElement | HTMLTextAreaElement>) => {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
handleSend();
|
||||
}
|
||||
// Shift+Enter allows default behavior (new line) - no need to handle explicitly
|
||||
},
|
||||
[handleSend],
|
||||
);
|
||||
@@ -53,6 +56,5 @@ export function useChatInput({
|
||||
setValue,
|
||||
handleKeyDown,
|
||||
handleSend,
|
||||
textareaRef,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface ChatLoadingStateProps {
|
||||
message?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatLoadingState({ className }: ChatLoadingStateProps) {
|
||||
return (
|
||||
<div
|
||||
className={cn("flex flex-1 items-center justify-center p-6", className)}
|
||||
>
|
||||
<div className="flex flex-col items-center gap-4 text-center">
|
||||
<LoadingSpinner />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,341 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/store";
|
||||
import Avatar, {
|
||||
AvatarFallback,
|
||||
AvatarImage,
|
||||
} from "@/components/atoms/Avatar/Avatar";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
ArrowClockwise,
|
||||
CheckCircleIcon,
|
||||
CheckIcon,
|
||||
CopyIcon,
|
||||
RobotIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useCallback, useState } from "react";
|
||||
import { getToolActionPhrase } from "../../helpers";
|
||||
import { AgentCarouselMessage } from "../AgentCarouselMessage/AgentCarouselMessage";
|
||||
import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget";
|
||||
import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup";
|
||||
import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage";
|
||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||
import { MessageBubble } from "../MessageBubble/MessageBubble";
|
||||
import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage";
|
||||
import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage";
|
||||
import { ToolResponseMessage } from "../ToolResponseMessage/ToolResponseMessage";
|
||||
import { useChatMessage, type ChatMessageData } from "./useChatMessage";
|
||||
export interface ChatMessageProps {
|
||||
message: ChatMessageData;
|
||||
className?: string;
|
||||
onDismissLogin?: () => void;
|
||||
onDismissCredentials?: () => void;
|
||||
onSendMessage?: (content: string, isUserMessage?: boolean) => void;
|
||||
agentOutput?: ChatMessageData;
|
||||
}
|
||||
|
||||
export function ChatMessage({
|
||||
message,
|
||||
className,
|
||||
onDismissCredentials,
|
||||
onSendMessage,
|
||||
agentOutput,
|
||||
}: ChatMessageProps) {
|
||||
const { user } = useSupabase();
|
||||
const router = useRouter();
|
||||
const [copied, setCopied] = useState(false);
|
||||
const {
|
||||
isUser,
|
||||
isToolCall,
|
||||
isToolResponse,
|
||||
isLoginNeeded,
|
||||
isCredentialsNeeded,
|
||||
} = useChatMessage(message);
|
||||
|
||||
const { data: profile } = useGetV2GetUserProfile({
|
||||
query: {
|
||||
select: (res) => (res.status === 200 ? res.data : null),
|
||||
enabled: isUser && !!user,
|
||||
queryKey: ["/api/store/profile", user?.id],
|
||||
},
|
||||
});
|
||||
|
||||
const handleAllCredentialsComplete = useCallback(
|
||||
function handleAllCredentialsComplete() {
|
||||
// Send a user message that explicitly asks to retry the setup
|
||||
// This ensures the LLM calls get_required_setup_info again and proceeds with execution
|
||||
if (onSendMessage) {
|
||||
onSendMessage(
|
||||
"I've configured the required credentials. Please check if everything is ready and proceed with setting up the agent.",
|
||||
);
|
||||
}
|
||||
// Optionally dismiss the credentials prompt
|
||||
if (onDismissCredentials) {
|
||||
onDismissCredentials();
|
||||
}
|
||||
},
|
||||
[onSendMessage, onDismissCredentials],
|
||||
);
|
||||
|
||||
function handleCancelCredentials() {
|
||||
// Dismiss the credentials prompt
|
||||
if (onDismissCredentials) {
|
||||
onDismissCredentials();
|
||||
}
|
||||
}
|
||||
|
||||
const handleCopy = useCallback(async () => {
|
||||
if (message.type !== "message") return;
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(message.content);
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000);
|
||||
} catch (error) {
|
||||
console.error("Failed to copy:", error);
|
||||
}
|
||||
}, [message]);
|
||||
|
||||
const handleTryAgain = useCallback(() => {
|
||||
if (message.type !== "message" || !onSendMessage) return;
|
||||
onSendMessage(message.content, message.role === "user");
|
||||
}, [message, onSendMessage]);
|
||||
|
||||
const handleViewExecution = useCallback(() => {
|
||||
if (message.type === "execution_started" && message.libraryAgentLink) {
|
||||
router.push(message.libraryAgentLink);
|
||||
}
|
||||
}, [message, router]);
|
||||
|
||||
// Render credentials needed messages
|
||||
if (isCredentialsNeeded && message.type === "credentials_needed") {
|
||||
return (
|
||||
<ChatCredentialsSetup
|
||||
credentials={message.credentials}
|
||||
agentName={message.agentName}
|
||||
message={message.message}
|
||||
onAllCredentialsComplete={handleAllCredentialsComplete}
|
||||
onCancel={handleCancelCredentials}
|
||||
className={className}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Render login needed messages
|
||||
if (isLoginNeeded && message.type === "login_needed") {
|
||||
// If user is already logged in, show success message instead of auth prompt
|
||||
if (user) {
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<div className="my-4 overflow-hidden rounded-lg border border-green-200 bg-gradient-to-br from-green-50 to-emerald-50">
|
||||
<div className="px-6 py-4">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-full bg-green-600">
|
||||
<CheckCircleIcon
|
||||
size={20}
|
||||
weight="fill"
|
||||
className="text-white"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-neutral-900">
|
||||
Successfully Authenticated
|
||||
</h3>
|
||||
<p className="text-sm text-neutral-600">
|
||||
You're now signed in and ready to continue
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show auth prompt if not logged in
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<AuthPromptWidget
|
||||
message={message.message}
|
||||
sessionId={message.sessionId}
|
||||
agentInfo={message.agentInfo}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Render tool call messages
|
||||
if (isToolCall && message.type === "tool_call") {
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<ToolCallMessage toolName={message.toolName} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Render no_results messages - use dedicated component, not ToolResponseMessage
|
||||
if (message.type === "no_results") {
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<NoResultsMessage
|
||||
message={message.message}
|
||||
suggestions={message.suggestions}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Render agent_carousel messages - use dedicated component, not ToolResponseMessage
|
||||
if (message.type === "agent_carousel") {
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<AgentCarouselMessage
|
||||
agents={message.agents}
|
||||
totalCount={message.totalCount}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Render execution_started messages - use dedicated component, not ToolResponseMessage
|
||||
if (message.type === "execution_started") {
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<ExecutionStartedMessage
|
||||
executionId={message.executionId}
|
||||
agentName={message.agentName}
|
||||
message={message.message}
|
||||
onViewExecution={
|
||||
message.libraryAgentLink ? handleViewExecution : undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Render tool response messages (but skip agent_output if it's being rendered inside assistant message)
|
||||
if (isToolResponse && message.type === "tool_response") {
|
||||
// Check if this is an agent_output that should be rendered inside assistant message
|
||||
if (message.result) {
|
||||
let parsedResult: Record<string, unknown> | null = null;
|
||||
try {
|
||||
parsedResult =
|
||||
typeof message.result === "string"
|
||||
? JSON.parse(message.result)
|
||||
: (message.result as Record<string, unknown>);
|
||||
} catch {
|
||||
parsedResult = null;
|
||||
}
|
||||
if (parsedResult?.type === "agent_output") {
|
||||
// Skip rendering - this will be rendered inside the assistant message
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<ToolResponseMessage
|
||||
toolName={getToolActionPhrase(message.toolName)}
|
||||
result={message.result}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Render regular chat messages
|
||||
if (message.type === "message") {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group relative flex w-full gap-3 px-4 py-3",
|
||||
isUser ? "justify-end" : "justify-start",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full max-w-3xl gap-3">
|
||||
{!isUser && (
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-500">
|
||||
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
className={cn(
|
||||
"flex min-w-0 flex-1 flex-col",
|
||||
isUser && "items-end",
|
||||
)}
|
||||
>
|
||||
<MessageBubble variant={isUser ? "user" : "assistant"}>
|
||||
<MarkdownContent content={message.content} />
|
||||
{agentOutput &&
|
||||
agentOutput.type === "tool_response" &&
|
||||
!isUser && (
|
||||
<div className="mt-4">
|
||||
<ToolResponseMessage
|
||||
toolName={
|
||||
agentOutput.toolName
|
||||
? getToolActionPhrase(agentOutput.toolName)
|
||||
: "Agent Output"
|
||||
}
|
||||
result={agentOutput.result}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</MessageBubble>
|
||||
<div
|
||||
className={cn(
|
||||
"mt-1 flex gap-1",
|
||||
isUser ? "justify-end" : "justify-start",
|
||||
)}
|
||||
>
|
||||
{isUser && onSendMessage && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={handleTryAgain}
|
||||
aria-label="Try again"
|
||||
>
|
||||
<ArrowClockwise className="size-3 text-neutral-500" />
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={handleCopy}
|
||||
aria-label="Copy message"
|
||||
>
|
||||
{copied ? (
|
||||
<CheckIcon className="size-3 text-green-600" />
|
||||
) : (
|
||||
<CopyIcon className="size-3 text-neutral-500" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{isUser && (
|
||||
<div className="flex-shrink-0">
|
||||
<Avatar className="h-7 w-7">
|
||||
<AvatarImage
|
||||
src={profile?.avatar_url ?? ""}
|
||||
alt={profile?.username ?? "User"}
|
||||
/>
|
||||
<AvatarFallback className="rounded-lg bg-neutral-200 text-neutral-600">
|
||||
{profile?.username?.charAt(0)?.toUpperCase() || "U"}
|
||||
</AvatarFallback>
|
||||
</Avatar>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback for unknown message types
|
||||
return null;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { formatDistanceToNow } from "date-fns";
|
||||
import type { ToolArguments, ToolResult } from "@/types/chat";
|
||||
import { formatDistanceToNow } from "date-fns";
|
||||
|
||||
export type ChatMessageData =
|
||||
| {
|
||||
@@ -65,6 +65,7 @@ export type ChatMessageData =
|
||||
name: string;
|
||||
description: string;
|
||||
version?: number;
|
||||
image_url?: string;
|
||||
}>;
|
||||
totalCount?: number;
|
||||
timestamp?: string | Date;
|
||||
@@ -77,6 +78,17 @@ export type ChatMessageData =
|
||||
message?: string;
|
||||
libraryAgentLink?: string;
|
||||
timestamp?: string | Date;
|
||||
}
|
||||
| {
|
||||
type: "inputs_needed";
|
||||
toolName: string;
|
||||
agentName?: string;
|
||||
agentId?: string;
|
||||
graphVersion?: number;
|
||||
inputSchema: Record<string, any>;
|
||||
credentialsSchema?: Record<string, any>;
|
||||
message: string;
|
||||
timestamp?: string | Date;
|
||||
};
|
||||
|
||||
export function useChatMessage(message: ChatMessageData) {
|
||||
@@ -96,5 +108,6 @@ export function useChatMessage(message: ChatMessageData) {
|
||||
isNoResults: message.type === "no_results",
|
||||
isAgentCarousel: message.type === "agent_carousel",
|
||||
isExecutionStarted: message.type === "execution_started",
|
||||
isInputsNeeded: message.type === "inputs_needed",
|
||||
};
|
||||
}
|
||||
@@ -1,8 +1,7 @@
|
||||
import React from "react";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { CheckCircle, Play, ArrowSquareOut } from "@phosphor-icons/react";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ArrowSquareOut, CheckCircle, Play } from "@phosphor-icons/react";
|
||||
|
||||
export interface ExecutionStartedMessageProps {
|
||||
executionId: string;
|
||||
@@ -22,7 +21,7 @@ export function ExecutionStartedMessage({
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"mx-4 my-2 flex flex-col gap-4 rounded-lg border border-green-200 bg-green-50 p-6 dark:border-green-900 dark:bg-green-950",
|
||||
"mx-4 my-2 flex flex-col gap-4 rounded-lg border border-green-200 bg-green-50 p-6",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
@@ -32,48 +31,33 @@ export function ExecutionStartedMessage({
|
||||
<CheckCircle size={24} weight="bold" className="text-white" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<Text
|
||||
variant="h3"
|
||||
className="mb-1 text-green-900 dark:text-green-100"
|
||||
>
|
||||
<Text variant="h3" className="mb-1 text-green-900">
|
||||
Execution Started
|
||||
</Text>
|
||||
<Text variant="body" className="text-green-700 dark:text-green-300">
|
||||
<Text variant="body" className="text-green-700">
|
||||
{message}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Details */}
|
||||
<div className="rounded-md bg-green-100 p-4 dark:bg-green-900">
|
||||
<div className="rounded-md bg-green-100 p-4">
|
||||
<div className="space-y-2">
|
||||
{agentName && (
|
||||
<div className="flex items-center justify-between">
|
||||
<Text
|
||||
variant="small"
|
||||
className="font-semibold text-green-900 dark:text-green-100"
|
||||
>
|
||||
<Text variant="small" className="font-semibold text-green-900">
|
||||
Agent:
|
||||
</Text>
|
||||
<Text
|
||||
variant="body"
|
||||
className="text-green-800 dark:text-green-200"
|
||||
>
|
||||
<Text variant="body" className="text-green-800">
|
||||
{agentName}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex items-center justify-between">
|
||||
<Text
|
||||
variant="small"
|
||||
className="font-semibold text-green-900 dark:text-green-100"
|
||||
>
|
||||
<Text variant="small" className="font-semibold text-green-900">
|
||||
Execution ID:
|
||||
</Text>
|
||||
<Text
|
||||
variant="small"
|
||||
className="font-mono text-green-800 dark:text-green-200"
|
||||
>
|
||||
<Text variant="small" className="font-mono text-green-800">
|
||||
{executionId.slice(0, 16)}...
|
||||
</Text>
|
||||
</div>
|
||||
@@ -94,7 +78,7 @@ export function ExecutionStartedMessage({
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex items-center gap-2 text-green-600 dark:text-green-400">
|
||||
<div className="flex items-center gap-2 text-green-600">
|
||||
<Play size={16} weight="fill" />
|
||||
<Text variant="small">
|
||||
Your agent is now running. You can monitor its progress in the monitor
|
||||
@@ -1,9 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import React from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface MarkdownContentProps {
|
||||
content: string;
|
||||
@@ -41,7 +41,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||
if (isInline) {
|
||||
return (
|
||||
<code
|
||||
className="rounded bg-zinc-100 px-1.5 py-0.5 font-mono text-sm text-zinc-800 dark:bg-zinc-800 dark:text-zinc-200"
|
||||
className="rounded bg-zinc-100 px-1.5 py-0.5 font-mono text-sm text-zinc-800"
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
@@ -49,17 +49,14 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||
);
|
||||
}
|
||||
return (
|
||||
<code
|
||||
className="font-mono text-sm text-zinc-100 dark:text-zinc-200"
|
||||
{...props}
|
||||
>
|
||||
<code className="font-mono text-sm text-zinc-100" {...props}>
|
||||
{children}
|
||||
</code>
|
||||
);
|
||||
},
|
||||
pre: ({ children, ...props }) => (
|
||||
<pre
|
||||
className="my-2 overflow-x-auto rounded-md bg-zinc-900 p-3 dark:bg-zinc-950"
|
||||
className="my-2 overflow-x-auto rounded-md bg-zinc-900 p-3"
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
@@ -70,7 +67,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||
href={href}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-purple-600 underline decoration-1 underline-offset-2 hover:text-purple-700 dark:text-purple-400 dark:hover:text-purple-300"
|
||||
className="text-purple-600 underline decoration-1 underline-offset-2 hover:text-purple-700"
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
@@ -126,7 +123,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||
return (
|
||||
<input
|
||||
type="checkbox"
|
||||
className="mr-2 h-4 w-4 rounded border-zinc-300 text-purple-600 focus:ring-purple-500 disabled:cursor-not-allowed disabled:opacity-70 dark:border-zinc-600"
|
||||
className="mr-2 h-4 w-4 rounded border-zinc-300 text-purple-600 focus:ring-purple-500 disabled:cursor-not-allowed disabled:opacity-70"
|
||||
disabled
|
||||
{...props}
|
||||
/>
|
||||
@@ -136,57 +133,42 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||
},
|
||||
blockquote: ({ children, ...props }) => (
|
||||
<blockquote
|
||||
className="my-2 border-l-4 border-zinc-300 pl-3 italic text-zinc-700 dark:border-zinc-600 dark:text-zinc-300"
|
||||
className="my-2 border-l-4 border-zinc-300 pl-3 italic text-zinc-700"
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</blockquote>
|
||||
),
|
||||
h1: ({ children, ...props }) => (
|
||||
<h1
|
||||
className="my-2 text-xl font-bold text-zinc-900 dark:text-zinc-100"
|
||||
{...props}
|
||||
>
|
||||
<h1 className="my-2 text-xl font-bold text-zinc-900" {...props}>
|
||||
{children}
|
||||
</h1>
|
||||
),
|
||||
h2: ({ children, ...props }) => (
|
||||
<h2
|
||||
className="my-2 text-lg font-semibold text-zinc-800 dark:text-zinc-200"
|
||||
{...props}
|
||||
>
|
||||
<h2 className="my-2 text-lg font-semibold text-zinc-800" {...props}>
|
||||
{children}
|
||||
</h2>
|
||||
),
|
||||
h3: ({ children, ...props }) => (
|
||||
<h3
|
||||
className="my-1 text-base font-semibold text-zinc-800 dark:text-zinc-200"
|
||||
className="my-1 text-base font-semibold text-zinc-800"
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</h3>
|
||||
),
|
||||
h4: ({ children, ...props }) => (
|
||||
<h4
|
||||
className="my-1 text-sm font-medium text-zinc-700 dark:text-zinc-300"
|
||||
{...props}
|
||||
>
|
||||
<h4 className="my-1 text-sm font-medium text-zinc-700" {...props}>
|
||||
{children}
|
||||
</h4>
|
||||
),
|
||||
h5: ({ children, ...props }) => (
|
||||
<h5
|
||||
className="my-1 text-sm font-medium text-zinc-700 dark:text-zinc-300"
|
||||
{...props}
|
||||
>
|
||||
<h5 className="my-1 text-sm font-medium text-zinc-700" {...props}>
|
||||
{children}
|
||||
</h5>
|
||||
),
|
||||
h6: ({ children, ...props }) => (
|
||||
<h6
|
||||
className="my-1 text-xs font-medium text-zinc-600 dark:text-zinc-400"
|
||||
{...props}
|
||||
>
|
||||
<h6 className="my-1 text-xs font-medium text-zinc-600" {...props}>
|
||||
{children}
|
||||
</h6>
|
||||
),
|
||||
@@ -196,15 +178,12 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||
</p>
|
||||
),
|
||||
hr: ({ ...props }) => (
|
||||
<hr
|
||||
className="my-3 border-zinc-300 dark:border-zinc-700"
|
||||
{...props}
|
||||
/>
|
||||
<hr className="my-3 border-zinc-300" {...props} />
|
||||
),
|
||||
table: ({ children, ...props }) => (
|
||||
<div className="my-2 overflow-x-auto">
|
||||
<table
|
||||
className="min-w-full divide-y divide-zinc-200 rounded border border-zinc-200 dark:divide-zinc-700 dark:border-zinc-700"
|
||||
className="min-w-full divide-y divide-zinc-200 rounded border border-zinc-200"
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
@@ -213,7 +192,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||
),
|
||||
th: ({ children, ...props }) => (
|
||||
<th
|
||||
className="bg-zinc-50 px-3 py-2 text-left text-xs font-semibold text-zinc-700 dark:bg-zinc-800 dark:text-zinc-300"
|
||||
className="bg-zinc-50 px-3 py-2 text-left text-xs font-semibold text-zinc-700"
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
@@ -221,7 +200,7 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||
),
|
||||
td: ({ children, ...props }) => (
|
||||
<td
|
||||
className="border-t border-zinc-200 px-3 py-2 text-sm dark:border-zinc-700"
|
||||
className="border-t border-zinc-200 px-3 py-2 text-sm"
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
@@ -0,0 +1,56 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ReactNode } from "react";
|
||||
|
||||
export interface MessageBubbleProps {
|
||||
children: ReactNode;
|
||||
variant: "user" | "assistant";
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function MessageBubble({
|
||||
children,
|
||||
variant,
|
||||
className,
|
||||
}: MessageBubbleProps) {
|
||||
const userTheme = {
|
||||
bg: "bg-slate-900",
|
||||
border: "border-slate-800",
|
||||
gradient: "from-slate-900/30 via-slate-800/20 to-transparent",
|
||||
text: "text-slate-50",
|
||||
};
|
||||
|
||||
const assistantTheme = {
|
||||
bg: "bg-slate-50/20",
|
||||
border: "border-slate-100",
|
||||
gradient: "from-slate-200/20 via-slate-300/10 to-transparent",
|
||||
text: "text-slate-900",
|
||||
};
|
||||
|
||||
const theme = variant === "user" ? userTheme : assistantTheme;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group relative min-w-20 overflow-hidden rounded-xl border px-6 py-2.5 text-sm leading-relaxed backdrop-blur-xl transition-all duration-500 ease-in-out",
|
||||
theme.bg,
|
||||
theme.border,
|
||||
variant === "user" && "text-right",
|
||||
variant === "assistant" && "text-left",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{/* Gradient flare background */}
|
||||
<div
|
||||
className={cn("absolute inset-0 bg-gradient-to-br", theme.gradient)}
|
||||
/>
|
||||
<div
|
||||
className={cn(
|
||||
"relative z-10 transition-all duration-500 ease-in-out",
|
||||
theme.text,
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ChatMessage } from "../ChatMessage/ChatMessage";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
import { StreamingMessage } from "../StreamingMessage/StreamingMessage";
|
||||
import { ThinkingMessage } from "../ThinkingMessage/ThinkingMessage";
|
||||
import { useMessageList } from "./useMessageList";
|
||||
|
||||
export interface MessageListProps {
|
||||
messages: ChatMessageData[];
|
||||
streamingChunks?: string[];
|
||||
isStreaming?: boolean;
|
||||
className?: string;
|
||||
onStreamComplete?: () => void;
|
||||
onSendMessage?: (content: string) => void;
|
||||
}
|
||||
|
||||
export function MessageList({
|
||||
messages,
|
||||
streamingChunks = [],
|
||||
isStreaming = false,
|
||||
className,
|
||||
onStreamComplete,
|
||||
onSendMessage,
|
||||
}: MessageListProps) {
|
||||
const { messagesEndRef, messagesContainerRef } = useMessageList({
|
||||
messageCount: messages.length,
|
||||
isStreaming,
|
||||
});
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={messagesContainerRef}
|
||||
className={cn(
|
||||
"flex-1 overflow-y-auto",
|
||||
"scrollbar-thin scrollbar-track-transparent scrollbar-thumb-zinc-300",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="mx-auto flex max-w-3xl flex-col py-4">
|
||||
{/* Render all persisted messages */}
|
||||
{messages.map((message, index) => {
|
||||
// Check if current message is an agent_output tool_response
|
||||
// and if previous message is an assistant message
|
||||
let agentOutput: ChatMessageData | undefined;
|
||||
|
||||
if (message.type === "tool_response" && message.result) {
|
||||
let parsedResult: Record<string, unknown> | null = null;
|
||||
try {
|
||||
parsedResult =
|
||||
typeof message.result === "string"
|
||||
? JSON.parse(message.result)
|
||||
: (message.result as Record<string, unknown>);
|
||||
} catch {
|
||||
parsedResult = null;
|
||||
}
|
||||
if (parsedResult?.type === "agent_output") {
|
||||
const prevMessage = messages[index - 1];
|
||||
if (
|
||||
prevMessage &&
|
||||
prevMessage.type === "message" &&
|
||||
prevMessage.role === "assistant"
|
||||
) {
|
||||
// This agent output will be rendered inside the previous assistant message
|
||||
// Skip rendering this message separately
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if next message is an agent_output tool_response to include in current assistant message
|
||||
if (message.type === "message" && message.role === "assistant") {
|
||||
const nextMessage = messages[index + 1];
|
||||
if (
|
||||
nextMessage &&
|
||||
nextMessage.type === "tool_response" &&
|
||||
nextMessage.result
|
||||
) {
|
||||
let parsedResult: Record<string, unknown> | null = null;
|
||||
try {
|
||||
parsedResult =
|
||||
typeof nextMessage.result === "string"
|
||||
? JSON.parse(nextMessage.result)
|
||||
: (nextMessage.result as Record<string, unknown>);
|
||||
} catch {
|
||||
parsedResult = null;
|
||||
}
|
||||
if (parsedResult?.type === "agent_output") {
|
||||
agentOutput = nextMessage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<ChatMessage
|
||||
key={index}
|
||||
message={message}
|
||||
onSendMessage={onSendMessage}
|
||||
agentOutput={agentOutput}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Render thinking message when streaming but no chunks yet */}
|
||||
{isStreaming && streamingChunks.length === 0 && <ThinkingMessage />}
|
||||
|
||||
{/* Render streaming message if active */}
|
||||
{isStreaming && streamingChunks.length > 0 && (
|
||||
<StreamingMessage
|
||||
chunks={streamingChunks}
|
||||
onComplete={onStreamComplete}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Invisible div to scroll to */}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
import React from "react";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { MagnifyingGlass, X } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { MagnifyingGlass, X } from "@phosphor-icons/react";
|
||||
|
||||
export interface NoResultsMessageProps {
|
||||
message: string;
|
||||
@@ -17,26 +16,26 @@ export function NoResultsMessage({
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"mx-4 my-2 flex flex-col items-center gap-4 rounded-lg border border-gray-200 bg-gray-50 p-6 dark:border-gray-800 dark:bg-gray-900",
|
||||
"mx-4 my-2 flex flex-col items-center gap-4 rounded-lg border border-gray-200 bg-gray-50 p-6",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{/* Icon */}
|
||||
<div className="relative flex h-16 w-16 items-center justify-center">
|
||||
<div className="flex h-16 w-16 items-center justify-center rounded-full bg-gray-200 dark:bg-gray-700">
|
||||
<div className="flex h-16 w-16 items-center justify-center rounded-full bg-gray-200">
|
||||
<MagnifyingGlass size={32} weight="bold" className="text-gray-500" />
|
||||
</div>
|
||||
<div className="absolute -right-1 -top-1 flex h-8 w-8 items-center justify-center rounded-full bg-gray-400 dark:bg-gray-600">
|
||||
<div className="absolute -right-1 -top-1 flex h-8 w-8 items-center justify-center rounded-full bg-gray-400">
|
||||
<X size={20} weight="bold" className="text-white" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div className="text-center">
|
||||
<Text variant="h3" className="mb-2 text-gray-900 dark:text-gray-100">
|
||||
<Text variant="h3" className="mb-2 text-gray-900">
|
||||
No Results Found
|
||||
</Text>
|
||||
<Text variant="body" className="text-gray-700 dark:text-gray-300">
|
||||
<Text variant="body" className="text-gray-700">
|
||||
{message}
|
||||
</Text>
|
||||
</div>
|
||||
@@ -44,17 +43,14 @@ export function NoResultsMessage({
|
||||
{/* Suggestions */}
|
||||
{suggestions.length > 0 && (
|
||||
<div className="w-full space-y-2">
|
||||
<Text
|
||||
variant="small"
|
||||
className="font-semibold text-gray-900 dark:text-gray-100"
|
||||
>
|
||||
<Text variant="small" className="font-semibold text-gray-900">
|
||||
Try these suggestions:
|
||||
</Text>
|
||||
<ul className="space-y-1 rounded-md bg-gray-100 p-4 dark:bg-gray-800">
|
||||
<ul className="space-y-1 rounded-md bg-gray-100 p-4">
|
||||
{suggestions.map((suggestion, index) => (
|
||||
<li
|
||||
key={index}
|
||||
className="flex items-start gap-2 text-sm text-gray-700 dark:text-gray-300"
|
||||
className="flex items-start gap-2 text-sm text-gray-700"
|
||||
>
|
||||
<span className="mt-1 text-gray-500">•</span>
|
||||
<span>{suggestion}</span>
|
||||
@@ -0,0 +1,94 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface QuickActionsWelcomeProps {
|
||||
title: string;
|
||||
description: string;
|
||||
actions: string[];
|
||||
onActionClick: (action: string) => void;
|
||||
disabled?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function QuickActionsWelcome({
|
||||
title,
|
||||
description,
|
||||
actions,
|
||||
onActionClick,
|
||||
disabled = false,
|
||||
className,
|
||||
}: QuickActionsWelcomeProps) {
|
||||
return (
|
||||
<div
|
||||
className={cn("flex flex-1 items-center justify-center p-8", className)}
|
||||
>
|
||||
<div className="w-full max-w-3xl">
|
||||
<div className="mb-12 text-center">
|
||||
<Text
|
||||
variant="h2"
|
||||
className="mb-3 text-2xl font-semibold text-zinc-900"
|
||||
>
|
||||
{title}
|
||||
</Text>
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
{description}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="grid gap-3 sm:grid-cols-2">
|
||||
{actions.map((action) => {
|
||||
// Use slate theme for all cards
|
||||
const theme = {
|
||||
bg: "bg-slate-50/10",
|
||||
border: "border-slate-100",
|
||||
hoverBg: "hover:bg-slate-50/20",
|
||||
hoverBorder: "hover:border-slate-200",
|
||||
gradient: "from-slate-200/20 via-slate-300/10 to-transparent",
|
||||
text: "text-slate-900",
|
||||
hoverText: "group-hover:text-slate-900",
|
||||
};
|
||||
|
||||
return (
|
||||
<button
|
||||
key={action}
|
||||
onClick={() => onActionClick(action)}
|
||||
disabled={disabled}
|
||||
className={cn(
|
||||
"group relative overflow-hidden rounded-xl border p-5 text-left backdrop-blur-xl",
|
||||
"transition-all duration-200",
|
||||
theme.bg,
|
||||
theme.border,
|
||||
theme.hoverBg,
|
||||
theme.hoverBorder,
|
||||
"hover:shadow-sm",
|
||||
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-white/50 focus-visible:ring-offset-2",
|
||||
"disabled:cursor-not-allowed disabled:opacity-50 disabled:hover:shadow-none",
|
||||
)}
|
||||
>
|
||||
{/* Gradient flare background */}
|
||||
<div
|
||||
className={cn(
|
||||
"absolute inset-0 bg-gradient-to-br",
|
||||
theme.gradient,
|
||||
)}
|
||||
/>
|
||||
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"relative z-10 font-medium",
|
||||
theme.text,
|
||||
theme.hoverText,
|
||||
)}
|
||||
>
|
||||
{action}
|
||||
</Text>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { X } from "@phosphor-icons/react";
|
||||
import { formatDistanceToNow } from "date-fns";
|
||||
import { Drawer } from "vaul";
|
||||
|
||||
interface SessionsDrawerProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
onSelectSession: (sessionId: string) => void;
|
||||
currentSessionId?: string | null;
|
||||
}
|
||||
|
||||
export function SessionsDrawer({
|
||||
isOpen,
|
||||
onClose,
|
||||
onSelectSession,
|
||||
currentSessionId,
|
||||
}: SessionsDrawerProps) {
|
||||
const { data, isLoading } = useGetV2ListSessions(
|
||||
{ limit: 100 },
|
||||
{
|
||||
query: {
|
||||
enabled: isOpen,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const sessions =
|
||||
data?.status === 200
|
||||
? data.data.sessions.filter((session) => {
|
||||
// Filter out sessions without messages (sessions that were never updated)
|
||||
// If updated_at equals created_at, the session was created but never had messages
|
||||
return session.updated_at !== session.created_at;
|
||||
})
|
||||
: [];
|
||||
|
||||
function handleSelectSession(sessionId: string) {
|
||||
onSelectSession(sessionId);
|
||||
onClose();
|
||||
}
|
||||
|
||||
return (
|
||||
<Drawer.Root
|
||||
open={isOpen}
|
||||
onOpenChange={(open) => !open && onClose()}
|
||||
direction="right"
|
||||
>
|
||||
<Drawer.Portal>
|
||||
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
|
||||
<Drawer.Content
|
||||
className={cn(
|
||||
"fixed right-0 top-0 z-[70] flex h-full w-96 flex-col border-l border-zinc-200 bg-white",
|
||||
scrollbarStyles,
|
||||
)}
|
||||
>
|
||||
<div className="shrink-0 p-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<Drawer.Title className="text-lg font-semibold">
|
||||
Chat Sessions
|
||||
</Drawer.Title>
|
||||
<button
|
||||
aria-label="Close"
|
||||
onClick={onClose}
|
||||
className="flex size-8 items-center justify-center rounded hover:bg-zinc-100"
|
||||
>
|
||||
<X width="1.25rem" height="1.25rem" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex-1 overflow-y-auto p-4">
|
||||
{isLoading ? (
|
||||
<div className="flex items-center justify-center py-8">
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Loading sessions...
|
||||
</Text>
|
||||
</div>
|
||||
) : sessions.length === 0 ? (
|
||||
<div className="flex items-center justify-center py-8">
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
No sessions found
|
||||
</Text>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-2">
|
||||
{sessions.map((session) => {
|
||||
const isActive = session.id === currentSessionId;
|
||||
const updatedAt = session.updated_at
|
||||
? formatDistanceToNow(new Date(session.updated_at), {
|
||||
addSuffix: true,
|
||||
})
|
||||
: "";
|
||||
|
||||
return (
|
||||
<button
|
||||
key={session.id}
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className={cn(
|
||||
"w-full rounded-lg border p-3 text-left transition-colors",
|
||||
isActive
|
||||
? "border-indigo-500 bg-zinc-50"
|
||||
: "border-zinc-200 bg-zinc-100/50 hover:border-zinc-300 hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-col gap-1">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"font-medium",
|
||||
isActive ? "text-indigo-900" : "text-zinc-900",
|
||||
)}
|
||||
>
|
||||
{session.title || "Untitled Chat"}
|
||||
</Text>
|
||||
<div className="flex items-center gap-2 text-xs text-zinc-500">
|
||||
<span>{session.id.slice(0, 8)}...</span>
|
||||
{updatedAt && <span>•</span>}
|
||||
<span>{updatedAt}</span>
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Drawer.Content>
|
||||
</Drawer.Portal>
|
||||
</Drawer.Root>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { RobotIcon } from "@phosphor-icons/react";
|
||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||
import { MessageBubble } from "../MessageBubble/MessageBubble";
|
||||
import { useStreamingMessage } from "./useStreamingMessage";
|
||||
|
||||
export interface StreamingMessageProps {
|
||||
chunks: string[];
|
||||
className?: string;
|
||||
onComplete?: () => void;
|
||||
}
|
||||
|
||||
export function StreamingMessage({
|
||||
chunks,
|
||||
className,
|
||||
onComplete,
|
||||
}: StreamingMessageProps) {
|
||||
const { displayText } = useStreamingMessage({ chunks, onComplete });
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group relative flex w-full justify-start gap-3 px-4 py-3",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full max-w-3xl gap-3">
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-600">
|
||||
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex min-w-0 flex-1 flex-col">
|
||||
<MessageBubble variant="assistant">
|
||||
<MarkdownContent content={displayText} />
|
||||
</MessageBubble>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { RobotIcon } from "@phosphor-icons/react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { MessageBubble } from "../MessageBubble/MessageBubble";
|
||||
|
||||
export interface ThinkingMessageProps {
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
||||
const [showSlowLoader, setShowSlowLoader] = useState(false);
|
||||
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (timerRef.current === null) {
|
||||
timerRef.current = setTimeout(() => {
|
||||
setShowSlowLoader(true);
|
||||
}, 8000);
|
||||
}
|
||||
|
||||
return () => {
|
||||
if (timerRef.current) {
|
||||
clearTimeout(timerRef.current);
|
||||
timerRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group relative flex w-full justify-start gap-3 px-4 py-3",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full max-w-3xl gap-3">
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-500">
|
||||
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex min-w-0 flex-1 flex-col">
|
||||
<MessageBubble variant="assistant">
|
||||
<div className="transition-all duration-500 ease-in-out">
|
||||
{showSlowLoader ? (
|
||||
<div className="flex flex-col items-center gap-3 py-2">
|
||||
<div className="loader" style={{ flexShrink: 0 }} />
|
||||
<p className="text-sm text-slate-700">
|
||||
Taking a bit longer to think, wait a moment please
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<span
|
||||
className="inline-block bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-clip-text text-transparent"
|
||||
style={{
|
||||
backgroundSize: "200% 100%",
|
||||
animation: "shimmer 2s ease-in-out infinite",
|
||||
}}
|
||||
>
|
||||
Thinking...
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</MessageBubble>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { WrenchIcon } from "@phosphor-icons/react";
|
||||
import { getToolActionPhrase } from "../../helpers";
|
||||
|
||||
export interface ToolCallMessageProps {
|
||||
toolName: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ToolCallMessage({ toolName, className }: ToolCallMessageProps) {
|
||||
return (
|
||||
<div className={cn("flex items-center justify-center gap-2", className)}>
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}...
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import "@/components/contextual/OutputRenderers";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { ToolResult } from "@/types/chat";
|
||||
import { WrenchIcon } from "@phosphor-icons/react";
|
||||
import { getToolActionPhrase } from "../../helpers";
|
||||
|
||||
export interface ToolResponseMessageProps {
|
||||
toolName: string;
|
||||
result?: ToolResult;
|
||||
success?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ToolResponseMessage({
|
||||
toolName,
|
||||
result,
|
||||
success: _success = true,
|
||||
className,
|
||||
}: ToolResponseMessageProps) {
|
||||
if (!result) {
|
||||
return (
|
||||
<div className={cn("flex items-center justify-center gap-2", className)}>
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}...
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
let parsedResult: Record<string, unknown> | null = null;
|
||||
try {
|
||||
parsedResult =
|
||||
typeof result === "string"
|
||||
? JSON.parse(result)
|
||||
: (result as Record<string, unknown>);
|
||||
} catch {
|
||||
parsedResult = null;
|
||||
}
|
||||
|
||||
if (parsedResult && typeof parsedResult === "object") {
|
||||
const responseType = parsedResult.type as string | undefined;
|
||||
|
||||
if (responseType === "agent_output") {
|
||||
const execution = parsedResult.execution as
|
||||
| {
|
||||
outputs?: Record<string, unknown[]>;
|
||||
}
|
||||
| null
|
||||
| undefined;
|
||||
const outputs = execution?.outputs || {};
|
||||
const message = parsedResult.message as string | undefined;
|
||||
|
||||
return (
|
||||
<div className={cn("space-y-4 px-4 py-2", className)}>
|
||||
<div className="flex items-center gap-2">
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}
|
||||
</Text>
|
||||
</div>
|
||||
{message && (
|
||||
<div className="rounded border p-4">
|
||||
<Text variant="small" className="text-neutral-600">
|
||||
{message}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
{Object.keys(outputs).length > 0 && (
|
||||
<div className="space-y-4">
|
||||
{Object.entries(outputs).map(([outputName, values]) =>
|
||||
values.map((value, index) => {
|
||||
const renderer = globalRegistry.getRenderer(value);
|
||||
if (renderer) {
|
||||
return (
|
||||
<OutputItem
|
||||
key={`${outputName}-${index}`}
|
||||
value={value}
|
||||
renderer={renderer}
|
||||
label={outputName}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div
|
||||
key={`${outputName}-${index}`}
|
||||
className="rounded border p-4"
|
||||
>
|
||||
<Text variant="large-medium" className="mb-2 capitalize">
|
||||
{outputName}
|
||||
</Text>
|
||||
<pre className="overflow-auto text-sm">
|
||||
{JSON.stringify(value, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
);
|
||||
}),
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (responseType === "block_output" && parsedResult.outputs) {
|
||||
const outputs = parsedResult.outputs as Record<string, unknown[]>;
|
||||
|
||||
return (
|
||||
<div className={cn("space-y-4 px-4 py-2", className)}>
|
||||
<div className="flex items-center gap-2">
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="space-y-4">
|
||||
{Object.entries(outputs).map(([outputName, values]) =>
|
||||
values.map((value, index) => {
|
||||
const renderer = globalRegistry.getRenderer(value);
|
||||
if (renderer) {
|
||||
return (
|
||||
<OutputItem
|
||||
key={`${outputName}-${index}`}
|
||||
value={value}
|
||||
renderer={renderer}
|
||||
label={outputName}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div
|
||||
key={`${outputName}-${index}`}
|
||||
className="rounded border p-4"
|
||||
>
|
||||
<Text variant="large-medium" className="mb-2 capitalize">
|
||||
{outputName}
|
||||
</Text>
|
||||
<pre className="overflow-auto text-sm">
|
||||
{JSON.stringify(value, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
);
|
||||
}),
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Handle other response types with a message field (e.g., understanding_updated)
|
||||
if (parsedResult.message && typeof parsedResult.message === "string") {
|
||||
// Format tool name from snake_case to Title Case
|
||||
const formattedToolName = toolName
|
||||
.split("_")
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(" ");
|
||||
|
||||
// Clean up message - remove incomplete user_name references
|
||||
let cleanedMessage = parsedResult.message;
|
||||
// Remove "Updated understanding with: user_name" pattern if user_name is just a placeholder
|
||||
cleanedMessage = cleanedMessage.replace(
|
||||
/Updated understanding with:\s*user_name\.?\s*/gi,
|
||||
"",
|
||||
);
|
||||
// Remove standalone user_name references
|
||||
cleanedMessage = cleanedMessage.replace(/\buser_name\b\.?\s*/gi, "");
|
||||
cleanedMessage = cleanedMessage.trim();
|
||||
|
||||
// Only show message if it has content after cleaning
|
||||
if (!cleanedMessage) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center gap-2 px-4 py-2",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{formattedToolName}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("space-y-2 px-4 py-2", className)}>
|
||||
<div className="flex items-center justify-center gap-2">
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{formattedToolName}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="rounded border p-4">
|
||||
<Text variant="small" className="text-neutral-600">
|
||||
{cleanedMessage}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const renderer = globalRegistry.getRenderer(result);
|
||||
if (renderer) {
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<div className="mb-2 flex items-center gap-2">
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}
|
||||
</Text>
|
||||
</div>
|
||||
<OutputItem value={result} renderer={renderer} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("flex items-center justify-center gap-2", className)}>
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}...
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -64,10 +64,3 @@ export function getToolCompletionPhrase(toolName: string): string {
|
||||
`Finished ${toolName.replace(/_/g, " ").replace("...", "")}`
|
||||
);
|
||||
}
|
||||
|
||||
/** Validate UUID v4 format */
|
||||
export function isValidUUID(value: string): boolean {
|
||||
const uuidRegex =
|
||||
/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i;
|
||||
return uuidRegex.test(value);
|
||||
}
|
||||
@@ -1,17 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useRef } from "react";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { toast } from "sonner";
|
||||
import { useChatSession } from "@/app/(platform)/chat/useChatSession";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useChatStream } from "@/app/(platform)/chat/useChatStream";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
import { useChatStream } from "./useChatStream";
|
||||
|
||||
export function useChatPage() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const urlSessionId =
|
||||
searchParams.get("session_id") || searchParams.get("session");
|
||||
export function useChat() {
|
||||
const hasCreatedSessionRef = useRef(false);
|
||||
const hasClaimedSessionRef = useRef(false);
|
||||
const { user } = useSupabase();
|
||||
@@ -25,29 +20,24 @@ export function useChatPage() {
|
||||
isCreating,
|
||||
error,
|
||||
createSession,
|
||||
refreshSession,
|
||||
claimSession,
|
||||
clearSession: clearSessionBase,
|
||||
loadSession,
|
||||
} = useChatSession({
|
||||
urlSessionId,
|
||||
urlSessionId: null,
|
||||
autoCreate: false,
|
||||
});
|
||||
|
||||
useEffect(
|
||||
function autoCreateSession() {
|
||||
if (
|
||||
!urlSessionId &&
|
||||
!hasCreatedSessionRef.current &&
|
||||
!isCreating &&
|
||||
!sessionIdFromHook
|
||||
) {
|
||||
if (!hasCreatedSessionRef.current && !isCreating && !sessionIdFromHook) {
|
||||
hasCreatedSessionRef.current = true;
|
||||
createSession().catch((_err) => {
|
||||
hasCreatedSessionRef.current = false;
|
||||
});
|
||||
}
|
||||
},
|
||||
[urlSessionId, isCreating, sessionIdFromHook, createSession],
|
||||
[isCreating, sessionIdFromHook, createSession],
|
||||
);
|
||||
|
||||
useEffect(
|
||||
@@ -111,7 +101,6 @@ export function useChatPage() {
|
||||
clearSessionBase();
|
||||
hasCreatedSessionRef.current = false;
|
||||
hasClaimedSessionRef.current = false;
|
||||
router.push("/chat");
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -121,8 +110,8 @@ export function useChatPage() {
|
||||
isCreating,
|
||||
error,
|
||||
createSession,
|
||||
refreshSession,
|
||||
clearSession,
|
||||
loadSession,
|
||||
sessionId: sessionIdFromHook,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
"use client";
|
||||
|
||||
import { create } from "zustand";
|
||||
|
||||
interface ChatDrawerState {
|
||||
isOpen: boolean;
|
||||
open: () => void;
|
||||
close: () => void;
|
||||
toggle: () => void;
|
||||
}
|
||||
|
||||
export const useChatDrawer = create<ChatDrawerState>((set) => ({
|
||||
isOpen: false,
|
||||
open: () => set({ isOpen: true }),
|
||||
close: () => set({ isOpen: false }),
|
||||
toggle: () => set((state) => ({ isOpen: !state.isOpen })),
|
||||
}));
|
||||
@@ -1,17 +1,18 @@
|
||||
import { useCallback, useEffect, useState, useRef, useMemo } from "react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { toast } from "sonner";
|
||||
import {
|
||||
usePostV2CreateSession,
|
||||
getGetV2GetSessionQueryKey,
|
||||
getGetV2GetSessionQueryOptions,
|
||||
postV2CreateSession,
|
||||
useGetV2GetSession,
|
||||
usePatchV2SessionAssignUser,
|
||||
getGetV2GetSessionQueryKey,
|
||||
usePostV2CreateSession,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { storage, Key } from "@/services/storage/local-storage";
|
||||
import { isValidUUID } from "@/app/(platform)/chat/helpers";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { isValidUUID } from "@/lib/utils";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
interface UseChatSessionArgs {
|
||||
urlSessionId?: string | null;
|
||||
@@ -155,10 +156,22 @@ export function useChatSession({
|
||||
async function loadSession(id: string) {
|
||||
try {
|
||||
setError(null);
|
||||
// Invalidate the query cache for this session to force a fresh fetch
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(id),
|
||||
});
|
||||
// Set sessionId after invalidation to ensure the hook refetches
|
||||
setSessionId(id);
|
||||
storage.set(Key.CHAT_SESSION_ID, id);
|
||||
const result = await refetch();
|
||||
if (!result.data || result.isError) {
|
||||
// Force fetch with fresh data (bypass cache)
|
||||
const queryOptions = getGetV2GetSessionQueryOptions(id, {
|
||||
query: {
|
||||
staleTime: 0, // Force fresh fetch
|
||||
retry: 1,
|
||||
},
|
||||
});
|
||||
const result = await queryClient.fetchQuery(queryOptions);
|
||||
if (!result || ("status" in result && result.status !== 200)) {
|
||||
console.warn("Session not found on server, clearing local state");
|
||||
storage.clean(Key.CHAT_SESSION_ID);
|
||||
setSessionId(null);
|
||||
@@ -171,7 +184,7 @@ export function useChatSession({
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[refetch],
|
||||
[queryClient],
|
||||
);
|
||||
|
||||
const refreshSession = useCallback(
|
||||
@@ -0,0 +1,371 @@
|
||||
import type { ToolArguments, ToolResult } from "@/types/chat";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
const MAX_RETRIES = 3;
|
||||
const INITIAL_RETRY_DELAY = 1000;
|
||||
|
||||
export interface StreamChunk {
|
||||
type:
|
||||
| "text_chunk"
|
||||
| "text_ended"
|
||||
| "tool_call"
|
||||
| "tool_call_start"
|
||||
| "tool_response"
|
||||
| "login_needed"
|
||||
| "need_login"
|
||||
| "credentials_needed"
|
||||
| "error"
|
||||
| "usage"
|
||||
| "stream_end";
|
||||
timestamp?: string;
|
||||
content?: string;
|
||||
message?: string;
|
||||
tool_id?: string;
|
||||
tool_name?: string;
|
||||
arguments?: ToolArguments;
|
||||
result?: ToolResult;
|
||||
success?: boolean;
|
||||
idx?: number;
|
||||
session_id?: string;
|
||||
agent_info?: {
|
||||
graph_id: string;
|
||||
name: string;
|
||||
trigger_type: string;
|
||||
};
|
||||
provider?: string;
|
||||
provider_name?: string;
|
||||
credential_type?: string;
|
||||
scopes?: string[];
|
||||
title?: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
type VercelStreamChunk =
|
||||
| { type: "start"; messageId: string }
|
||||
| { type: "finish" }
|
||||
| { type: "text-start"; id: string }
|
||||
| { type: "text-delta"; id: string; delta: string }
|
||||
| { type: "text-end"; id: string }
|
||||
| { type: "tool-input-start"; toolCallId: string; toolName: string }
|
||||
| {
|
||||
type: "tool-input-available";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
input: ToolArguments;
|
||||
}
|
||||
| {
|
||||
type: "tool-output-available";
|
||||
toolCallId: string;
|
||||
toolName?: string;
|
||||
output: ToolResult;
|
||||
success?: boolean;
|
||||
}
|
||||
| {
|
||||
type: "usage";
|
||||
promptTokens: number;
|
||||
completionTokens: number;
|
||||
totalTokens: number;
|
||||
}
|
||||
| {
|
||||
type: "error";
|
||||
errorText: string;
|
||||
code?: string;
|
||||
details?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
const LEGACY_STREAM_TYPES = new Set<StreamChunk["type"]>([
|
||||
"text_chunk",
|
||||
"text_ended",
|
||||
"tool_call",
|
||||
"tool_call_start",
|
||||
"tool_response",
|
||||
"login_needed",
|
||||
"need_login",
|
||||
"credentials_needed",
|
||||
"error",
|
||||
"usage",
|
||||
"stream_end",
|
||||
]);
|
||||
|
||||
function isLegacyStreamChunk(
|
||||
chunk: StreamChunk | VercelStreamChunk,
|
||||
): chunk is StreamChunk {
|
||||
return LEGACY_STREAM_TYPES.has(chunk.type as StreamChunk["type"]);
|
||||
}
|
||||
|
||||
function normalizeStreamChunk(
|
||||
chunk: StreamChunk | VercelStreamChunk,
|
||||
): StreamChunk | null {
|
||||
if (isLegacyStreamChunk(chunk)) {
|
||||
return chunk;
|
||||
}
|
||||
switch (chunk.type) {
|
||||
case "text-delta":
|
||||
return { type: "text_chunk", content: chunk.delta };
|
||||
case "text-end":
|
||||
return { type: "text_ended" };
|
||||
case "tool-input-available":
|
||||
return {
|
||||
type: "tool_call_start",
|
||||
tool_id: chunk.toolCallId,
|
||||
tool_name: chunk.toolName,
|
||||
arguments: chunk.input,
|
||||
};
|
||||
case "tool-output-available":
|
||||
return {
|
||||
type: "tool_response",
|
||||
tool_id: chunk.toolCallId,
|
||||
tool_name: chunk.toolName,
|
||||
result: chunk.output,
|
||||
success: chunk.success ?? true,
|
||||
};
|
||||
case "usage":
|
||||
return {
|
||||
type: "usage",
|
||||
promptTokens: chunk.promptTokens,
|
||||
completionTokens: chunk.completionTokens,
|
||||
totalTokens: chunk.totalTokens,
|
||||
};
|
||||
case "error":
|
||||
return {
|
||||
type: "error",
|
||||
message: chunk.errorText,
|
||||
code: chunk.code,
|
||||
details: chunk.details,
|
||||
};
|
||||
case "finish":
|
||||
return { type: "stream_end" };
|
||||
case "start":
|
||||
case "text-start":
|
||||
case "tool-input-start":
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function useChatStream() {
|
||||
const [isStreaming, setIsStreaming] = useState(false);
|
||||
const [error, setError] = useState<Error | null>(null);
|
||||
const retryCountRef = useRef<number>(0);
|
||||
const retryTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
const stopStreaming = useCallback(() => {
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
if (retryTimeoutRef.current) {
|
||||
clearTimeout(retryTimeoutRef.current);
|
||||
retryTimeoutRef.current = null;
|
||||
}
|
||||
setIsStreaming(false);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
stopStreaming();
|
||||
};
|
||||
}, [stopStreaming]);
|
||||
|
||||
const sendMessage = useCallback(
|
||||
async (
|
||||
sessionId: string,
|
||||
message: string,
|
||||
onChunk: (chunk: StreamChunk) => void,
|
||||
isUserMessage: boolean = true,
|
||||
context?: { url: string; content: string },
|
||||
isRetry: boolean = false,
|
||||
) => {
|
||||
stopStreaming();
|
||||
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
if (abortController.signal.aborted) {
|
||||
return Promise.reject(new Error("Request aborted"));
|
||||
}
|
||||
|
||||
if (!isRetry) {
|
||||
retryCountRef.current = 0;
|
||||
}
|
||||
setIsStreaming(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const url = `/api/chat/sessions/${sessionId}/stream`;
|
||||
const body = JSON.stringify({
|
||||
message,
|
||||
is_user_message: isUserMessage,
|
||||
context: context || null,
|
||||
});
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
},
|
||||
body,
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(errorText || `HTTP ${response.status}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("Response body is null");
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
let didDispatchStreamEnd = false;
|
||||
|
||||
function dispatchStreamEnd() {
|
||||
if (didDispatchStreamEnd) return;
|
||||
didDispatchStreamEnd = true;
|
||||
onChunk({ type: "stream_end" });
|
||||
}
|
||||
|
||||
const cleanup = () => {
|
||||
reader.cancel().catch(() => {
|
||||
// Ignore cancel errors
|
||||
});
|
||||
};
|
||||
|
||||
async function readStream() {
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
cleanup();
|
||||
dispatchStreamEnd();
|
||||
retryCountRef.current = 0;
|
||||
stopStreaming();
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const data = line.slice(6);
|
||||
if (data === "[DONE]") {
|
||||
cleanup();
|
||||
dispatchStreamEnd();
|
||||
retryCountRef.current = 0;
|
||||
stopStreaming();
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const rawChunk = JSON.parse(data) as
|
||||
| StreamChunk
|
||||
| VercelStreamChunk;
|
||||
const chunk = normalizeStreamChunk(rawChunk);
|
||||
if (!chunk) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Call the chunk handler
|
||||
onChunk(chunk);
|
||||
|
||||
// Handle stream lifecycle
|
||||
if (chunk.type === "stream_end") {
|
||||
didDispatchStreamEnd = true;
|
||||
cleanup();
|
||||
retryCountRef.current = 0;
|
||||
stopStreaming();
|
||||
resolve();
|
||||
return;
|
||||
} else if (chunk.type === "error") {
|
||||
cleanup();
|
||||
reject(
|
||||
new Error(
|
||||
chunk.message || chunk.content || "Stream error",
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
} catch (err) {
|
||||
// Skip invalid JSON lines
|
||||
console.warn("Failed to parse SSE chunk:", err, data);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
cleanup();
|
||||
return;
|
||||
}
|
||||
|
||||
const streamError =
|
||||
err instanceof Error ? err : new Error("Failed to read stream");
|
||||
|
||||
if (retryCountRef.current < MAX_RETRIES) {
|
||||
retryCountRef.current += 1;
|
||||
const retryDelay =
|
||||
INITIAL_RETRY_DELAY * Math.pow(2, retryCountRef.current - 1);
|
||||
|
||||
toast.info("Connection interrupted", {
|
||||
description: `Retrying in ${retryDelay / 1000} seconds...`,
|
||||
});
|
||||
|
||||
retryTimeoutRef.current = setTimeout(() => {
|
||||
sendMessage(
|
||||
sessionId,
|
||||
message,
|
||||
onChunk,
|
||||
isUserMessage,
|
||||
context,
|
||||
true,
|
||||
).catch((_err) => {
|
||||
// Retry failed
|
||||
});
|
||||
}, retryDelay);
|
||||
} else {
|
||||
setError(streamError);
|
||||
toast.error("Connection Failed", {
|
||||
description:
|
||||
"Unable to connect to chat service. Please try again.",
|
||||
});
|
||||
cleanup();
|
||||
dispatchStreamEnd();
|
||||
retryCountRef.current = 0;
|
||||
stopStreaming();
|
||||
reject(streamError);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
readStream();
|
||||
});
|
||||
} catch (err) {
|
||||
const streamError =
|
||||
err instanceof Error ? err : new Error("Failed to start stream");
|
||||
setError(streamError);
|
||||
setIsStreaming(false);
|
||||
throw streamError;
|
||||
}
|
||||
},
|
||||
[stopStreaming],
|
||||
);
|
||||
|
||||
return {
|
||||
isStreaming,
|
||||
error,
|
||||
sendMessage,
|
||||
stopStreaming,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
import { useCallback } from "react";
|
||||
|
||||
export interface PageContext {
|
||||
url: string;
|
||||
content: string;
|
||||
}
|
||||
|
||||
const MAX_CONTENT_CHARS = 10000;
|
||||
|
||||
/**
|
||||
* Hook to capture the current page context (URL + full page content)
|
||||
* Privacy-hardened: removes sensitive inputs and enforces content size limits
|
||||
*/
|
||||
export function usePageContext() {
|
||||
const capturePageContext = useCallback((): PageContext => {
|
||||
if (typeof window === "undefined" || typeof document === "undefined") {
|
||||
return { url: "", content: "" };
|
||||
}
|
||||
|
||||
const url = window.location.href;
|
||||
|
||||
// Clone document to avoid modifying the original
|
||||
const clone = document.cloneNode(true) as Document;
|
||||
|
||||
// Remove script, style, and noscript elements
|
||||
const scripts = clone.querySelectorAll("script, style, noscript");
|
||||
scripts.forEach((el) => el.remove());
|
||||
|
||||
// Remove sensitive elements and their content
|
||||
const sensitiveSelectors = [
|
||||
"input",
|
||||
"textarea",
|
||||
"[contenteditable]",
|
||||
'input[type="password"]',
|
||||
'input[type="email"]',
|
||||
'input[type="tel"]',
|
||||
'input[type="search"]',
|
||||
'input[type="hidden"]',
|
||||
"form",
|
||||
"[data-sensitive]",
|
||||
"[data-sensitive='true']",
|
||||
];
|
||||
|
||||
sensitiveSelectors.forEach((selector) => {
|
||||
const elements = clone.querySelectorAll(selector);
|
||||
elements.forEach((el) => {
|
||||
// For form elements, remove the entire element
|
||||
if (el.tagName === "FORM") {
|
||||
el.remove();
|
||||
} else {
|
||||
// For inputs and textareas, clear their values but keep the element structure
|
||||
if (
|
||||
el instanceof HTMLInputElement ||
|
||||
el instanceof HTMLTextAreaElement
|
||||
) {
|
||||
el.value = "";
|
||||
el.textContent = "";
|
||||
} else {
|
||||
// For other sensitive elements, remove them entirely
|
||||
el.remove();
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Strip any remaining input values that might have been missed
|
||||
const allInputs = clone.querySelectorAll("input, textarea");
|
||||
allInputs.forEach((el) => {
|
||||
if (el instanceof HTMLInputElement || el instanceof HTMLTextAreaElement) {
|
||||
el.value = "";
|
||||
el.textContent = "";
|
||||
}
|
||||
});
|
||||
|
||||
// Get text content from body
|
||||
const body = clone.body;
|
||||
const content = body?.textContent || body?.innerText || "";
|
||||
|
||||
// Clean up whitespace
|
||||
let cleanedContent = content
|
||||
.replace(/\s+/g, " ")
|
||||
.replace(/\n\s*\n/g, "\n")
|
||||
.trim();
|
||||
|
||||
// Enforce maximum content size
|
||||
if (cleanedContent.length > MAX_CONTENT_CHARS) {
|
||||
cleanedContent =
|
||||
cleanedContent.substring(0, MAX_CONTENT_CHARS) + "... [truncated]";
|
||||
}
|
||||
|
||||
return {
|
||||
url,
|
||||
content: cleanedContent,
|
||||
};
|
||||
}, []);
|
||||
|
||||
return { capturePageContext };
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ChatInput } from "@/app/(platform)/chat/components/ChatInput/ChatInput";
|
||||
import { MessageList } from "@/app/(platform)/chat/components/MessageList/MessageList";
|
||||
import { QuickActionsWelcome } from "@/app/(platform)/chat/components/QuickActionsWelcome/QuickActionsWelcome";
|
||||
import { useChatContainer } from "./useChatContainer";
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
|
||||
export interface ChatContainerProps {
|
||||
sessionId: string | null;
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
onRefreshSession: () => Promise<void>;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
onRefreshSession,
|
||||
className,
|
||||
}: ChatContainerProps) {
|
||||
const { messages, streamingChunks, isStreaming, sendMessage } =
|
||||
useChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
onRefreshSession,
|
||||
});
|
||||
|
||||
const quickActions = [
|
||||
"Find agents for social media management",
|
||||
"Show me agents for content creation",
|
||||
"Help me automate my business",
|
||||
"What can you help me with?",
|
||||
];
|
||||
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
{/* Messages or Welcome Screen */}
|
||||
{messages.length === 0 ? (
|
||||
<QuickActionsWelcome
|
||||
title="Welcome to AutoGPT Chat"
|
||||
description="Start a conversation to discover and run AI agents."
|
||||
actions={quickActions}
|
||||
onActionClick={sendMessage}
|
||||
disabled={isStreaming || !sessionId}
|
||||
/>
|
||||
) : (
|
||||
<MessageList
|
||||
messages={messages}
|
||||
streamingChunks={streamingChunks}
|
||||
isStreaming={isStreaming}
|
||||
onSendMessage={sendMessage}
|
||||
className="flex-1"
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Input - Always visible */}
|
||||
<div className="border-t border-zinc-200 p-4 dark:border-zinc-800">
|
||||
<ChatInput
|
||||
onSend={sendMessage}
|
||||
disabled={isStreaming || !sessionId}
|
||||
placeholder={
|
||||
sessionId ? "Type your message..." : "Creating session..."
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
import { useState, useCallback, useRef, useMemo } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { useChatStream } from "@/app/(platform)/chat/useChatStream";
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import type { ChatMessageData } from "@/app/(platform)/chat/components/ChatMessage/useChatMessage";
|
||||
import {
|
||||
parseToolResponse,
|
||||
isValidMessage,
|
||||
isToolCallArray,
|
||||
createUserMessage,
|
||||
filterAuthMessages,
|
||||
} from "./helpers";
|
||||
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
||||
|
||||
interface UseChatContainerArgs {
|
||||
sessionId: string | null;
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
onRefreshSession: () => Promise<void>;
|
||||
}
|
||||
|
||||
export function useChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
}: UseChatContainerArgs) {
|
||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||
const [hasTextChunks, setHasTextChunks] = useState(false);
|
||||
const streamingChunksRef = useRef<string[]>([]);
|
||||
const { error, sendMessage: sendStreamMessage } = useChatStream();
|
||||
const isStreaming = hasTextChunks;
|
||||
|
||||
const allMessages = useMemo(() => {
|
||||
const processedInitialMessages = initialMessages
|
||||
.filter((msg: Record<string, unknown>) => {
|
||||
if (!isValidMessage(msg)) {
|
||||
console.warn("Invalid message structure from backend:", msg);
|
||||
return false;
|
||||
}
|
||||
const content = String(msg.content || "").trim();
|
||||
const toolCalls = msg.tool_calls;
|
||||
return (
|
||||
content.length > 0 ||
|
||||
(toolCalls && Array.isArray(toolCalls) && toolCalls.length > 0)
|
||||
);
|
||||
})
|
||||
.map((msg: Record<string, unknown>) => {
|
||||
const content = String(msg.content || "");
|
||||
const role = String(msg.role || "assistant").toLowerCase();
|
||||
const toolCalls = msg.tool_calls;
|
||||
if (
|
||||
role === "assistant" &&
|
||||
toolCalls &&
|
||||
isToolCallArray(toolCalls) &&
|
||||
toolCalls.length > 0
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
if (role === "tool") {
|
||||
const timestamp = msg.timestamp
|
||||
? new Date(msg.timestamp as string)
|
||||
: undefined;
|
||||
const toolResponse = parseToolResponse(
|
||||
content,
|
||||
(msg.tool_call_id as string) || "",
|
||||
"unknown",
|
||||
timestamp,
|
||||
);
|
||||
if (!toolResponse) {
|
||||
return null;
|
||||
}
|
||||
return toolResponse;
|
||||
}
|
||||
return {
|
||||
type: "message",
|
||||
role: role as "user" | "assistant" | "system",
|
||||
content,
|
||||
timestamp: msg.timestamp
|
||||
? new Date(msg.timestamp as string)
|
||||
: undefined,
|
||||
};
|
||||
})
|
||||
.filter((msg): msg is ChatMessageData => msg !== null);
|
||||
|
||||
return [...processedInitialMessages, ...messages];
|
||||
}, [initialMessages, messages]);
|
||||
|
||||
const sendMessage = useCallback(
|
||||
async function sendMessage(content: string, isUserMessage: boolean = true) {
|
||||
if (!sessionId) {
|
||||
console.error("Cannot send message: no session ID");
|
||||
return;
|
||||
}
|
||||
if (isUserMessage) {
|
||||
const userMessage = createUserMessage(content);
|
||||
setMessages((prev) => [...filterAuthMessages(prev), userMessage]);
|
||||
} else {
|
||||
setMessages((prev) => filterAuthMessages(prev));
|
||||
}
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
setMessages,
|
||||
sessionId,
|
||||
});
|
||||
try {
|
||||
await sendStreamMessage(sessionId, content, dispatcher, isUserMessage);
|
||||
} catch (err) {
|
||||
console.error("Failed to send message:", err);
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to send message";
|
||||
toast.error("Failed to send message", {
|
||||
description: errorMessage,
|
||||
});
|
||||
}
|
||||
},
|
||||
[sessionId, sendStreamMessage],
|
||||
);
|
||||
|
||||
return {
|
||||
messages: allMessages,
|
||||
streamingChunks,
|
||||
isStreaming,
|
||||
error,
|
||||
sendMessage,
|
||||
};
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput";
|
||||
import { Card } from "@/components/atoms/Card/Card";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import type { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CheckIcon, KeyIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { useChatCredentialsSetup } from "./useChatCredentialsSetup";
|
||||
|
||||
export interface CredentialInfo {
|
||||
provider: string;
|
||||
providerName: string;
|
||||
credentialType: "api_key" | "oauth2" | "user_password" | "host_scoped";
|
||||
title: string;
|
||||
scopes?: string[];
|
||||
}
|
||||
|
||||
interface Props {
|
||||
credentials: CredentialInfo[];
|
||||
agentName?: string;
|
||||
message: string;
|
||||
onAllCredentialsComplete: () => void;
|
||||
onCancel: () => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
function createSchemaFromCredentialInfo(
|
||||
credential: CredentialInfo,
|
||||
): BlockIOCredentialsSubSchema {
|
||||
return {
|
||||
type: "object",
|
||||
properties: {},
|
||||
credentials_provider: [credential.provider],
|
||||
credentials_types: [credential.credentialType],
|
||||
credentials_scopes: credential.scopes,
|
||||
discriminator: undefined,
|
||||
discriminator_mapping: undefined,
|
||||
discriminator_values: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
export function ChatCredentialsSetup({
|
||||
credentials,
|
||||
agentName: _agentName,
|
||||
message,
|
||||
onAllCredentialsComplete,
|
||||
onCancel: _onCancel,
|
||||
className,
|
||||
}: Props) {
|
||||
const { selectedCredentials, isAllComplete, handleCredentialSelect } =
|
||||
useChatCredentialsSetup(credentials);
|
||||
|
||||
// Track if we've already called completion to prevent double calls
|
||||
const hasCalledCompleteRef = useRef(false);
|
||||
|
||||
// Reset the completion flag when credentials change (new credential setup flow)
|
||||
useEffect(
|
||||
function resetCompletionFlag() {
|
||||
hasCalledCompleteRef.current = false;
|
||||
},
|
||||
[credentials],
|
||||
);
|
||||
|
||||
// Auto-call completion when all credentials are configured
|
||||
useEffect(
|
||||
function autoCompleteWhenReady() {
|
||||
if (isAllComplete && !hasCalledCompleteRef.current) {
|
||||
hasCalledCompleteRef.current = true;
|
||||
onAllCredentialsComplete();
|
||||
}
|
||||
},
|
||||
[isAllComplete, onAllCredentialsComplete],
|
||||
);
|
||||
|
||||
return (
|
||||
<Card
|
||||
className={cn(
|
||||
"mx-4 my-2 overflow-hidden border-orange-200 bg-orange-50 dark:border-orange-900 dark:bg-orange-950",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex items-start gap-4 p-6">
|
||||
<div className="flex h-12 w-12 flex-shrink-0 items-center justify-center rounded-full bg-orange-500">
|
||||
<KeyIcon size={24} weight="bold" className="text-white" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<Text
|
||||
variant="h3"
|
||||
className="mb-2 text-orange-900 dark:text-orange-100"
|
||||
>
|
||||
Credentials Required
|
||||
</Text>
|
||||
<Text
|
||||
variant="body"
|
||||
className="mb-4 text-orange-700 dark:text-orange-300"
|
||||
>
|
||||
{message}
|
||||
</Text>
|
||||
|
||||
<div className="space-y-3">
|
||||
{credentials.map((cred, index) => {
|
||||
const schema = createSchemaFromCredentialInfo(cred);
|
||||
const isSelected = !!selectedCredentials[cred.provider];
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${cred.provider}-${index}`}
|
||||
className={cn(
|
||||
"relative rounded-lg border border-orange-200 bg-white p-4 dark:border-orange-800 dark:bg-orange-900/20",
|
||||
isSelected &&
|
||||
"border-green-500 bg-green-50 dark:border-green-700 dark:bg-green-950/30",
|
||||
)}
|
||||
>
|
||||
<div className="mb-2 flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
{isSelected ? (
|
||||
<CheckIcon
|
||||
size={20}
|
||||
className="text-green-500"
|
||||
weight="bold"
|
||||
/>
|
||||
) : (
|
||||
<WarningIcon
|
||||
size={20}
|
||||
className="text-orange-500"
|
||||
weight="bold"
|
||||
/>
|
||||
)}
|
||||
<Text
|
||||
variant="body"
|
||||
className="font-semibold text-orange-900 dark:text-orange-100"
|
||||
>
|
||||
{cred.providerName}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<CredentialsInput
|
||||
schema={schema}
|
||||
selectedCredentials={selectedCredentials[cred.provider]}
|
||||
onSelectCredentials={(credMeta) =>
|
||||
handleCredentialSelect(cred.provider, credMeta)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { PaperPlaneRightIcon } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { useChatInput } from "./useChatInput";
|
||||
|
||||
export interface ChatInputProps {
|
||||
onSend: (message: string) => void;
|
||||
disabled?: boolean;
|
||||
placeholder?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
placeholder = "Type your message...",
|
||||
className,
|
||||
}: ChatInputProps) {
|
||||
const { value, setValue, handleKeyDown, handleSend, textareaRef } =
|
||||
useChatInput({
|
||||
onSend,
|
||||
disabled,
|
||||
maxRows: 5,
|
||||
});
|
||||
|
||||
return (
|
||||
<div className={cn("flex gap-2", className)}>
|
||||
<textarea
|
||||
ref={textareaRef}
|
||||
value={value}
|
||||
onChange={(e) => setValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={placeholder}
|
||||
disabled={disabled}
|
||||
rows={1}
|
||||
autoComplete="off"
|
||||
aria-label="Chat message input"
|
||||
aria-describedby="chat-input-hint"
|
||||
className={cn(
|
||||
"flex-1 resize-none rounded-lg border border-neutral-200 bg-white px-4 py-2 text-sm",
|
||||
"placeholder:text-neutral-400",
|
||||
"focus:border-violet-600 focus:outline-none focus:ring-2 focus:ring-violet-600/20",
|
||||
"dark:border-neutral-800 dark:bg-neutral-900 dark:text-neutral-100 dark:placeholder:text-neutral-500",
|
||||
"disabled:cursor-not-allowed disabled:opacity-50",
|
||||
)}
|
||||
/>
|
||||
<span id="chat-input-hint" className="sr-only">
|
||||
Press Enter to send, Shift+Enter for new line
|
||||
</span>
|
||||
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={handleSend}
|
||||
disabled={disabled || !value.trim()}
|
||||
className="self-end"
|
||||
aria-label="Send message"
|
||||
>
|
||||
<PaperPlaneRightIcon className="h-4 w-4" weight="fill" />
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
import React from "react";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { ArrowClockwiseIcon } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface ChatLoadingStateProps {
|
||||
message?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatLoadingState({
|
||||
message = "Loading...",
|
||||
className,
|
||||
}: ChatLoadingStateProps) {
|
||||
return (
|
||||
<div
|
||||
className={cn("flex flex-1 items-center justify-center p-6", className)}
|
||||
>
|
||||
<div className="flex flex-col items-center gap-4 text-center">
|
||||
<ArrowClockwiseIcon
|
||||
size={32}
|
||||
weight="bold"
|
||||
className="animate-spin text-purple-500"
|
||||
/>
|
||||
<Text variant="body" className="text-zinc-600 dark:text-zinc-400">
|
||||
{message}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { RobotIcon, UserIcon, CheckCircleIcon } from "@phosphor-icons/react";
|
||||
import { useCallback } from "react";
|
||||
import { MessageBubble } from "@/app/(platform)/chat/components/MessageBubble/MessageBubble";
|
||||
import { MarkdownContent } from "@/app/(platform)/chat/components/MarkdownContent/MarkdownContent";
|
||||
import { ToolCallMessage } from "@/app/(platform)/chat/components/ToolCallMessage/ToolCallMessage";
|
||||
import { ToolResponseMessage } from "@/app/(platform)/chat/components/ToolResponseMessage/ToolResponseMessage";
|
||||
import { AuthPromptWidget } from "@/app/(platform)/chat/components/AuthPromptWidget/AuthPromptWidget";
|
||||
import { ChatCredentialsSetup } from "@/app/(platform)/chat/components/ChatCredentialsSetup/ChatCredentialsSetup";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useChatMessage, type ChatMessageData } from "./useChatMessage";
|
||||
import { getToolActionPhrase } from "@/app/(platform)/chat/helpers";
|
||||
export interface ChatMessageProps {
|
||||
message: ChatMessageData;
|
||||
className?: string;
|
||||
onDismissLogin?: () => void;
|
||||
onDismissCredentials?: () => void;
|
||||
onSendMessage?: (content: string, isUserMessage?: boolean) => void;
|
||||
}
|
||||
|
||||
export function ChatMessage({
|
||||
message,
|
||||
className,
|
||||
onDismissCredentials,
|
||||
onSendMessage,
|
||||
}: ChatMessageProps) {
|
||||
const { user } = useSupabase();
|
||||
const {
|
||||
formattedTimestamp,
|
||||
isUser,
|
||||
isAssistant,
|
||||
isToolCall,
|
||||
isToolResponse,
|
||||
isLoginNeeded,
|
||||
isCredentialsNeeded,
|
||||
} = useChatMessage(message);
|
||||
|
||||
const handleAllCredentialsComplete = useCallback(
|
||||
function handleAllCredentialsComplete() {
|
||||
// Send a user message that explicitly asks to retry the setup
|
||||
// This ensures the LLM calls get_required_setup_info again and proceeds with execution
|
||||
if (onSendMessage) {
|
||||
onSendMessage(
|
||||
"I've configured the required credentials. Please check if everything is ready and proceed with setting up the agent.",
|
||||
);
|
||||
}
|
||||
// Optionally dismiss the credentials prompt
|
||||
if (onDismissCredentials) {
|
||||
onDismissCredentials();
|
||||
}
|
||||
},
|
||||
[onSendMessage, onDismissCredentials],
|
||||
);
|
||||
|
||||
function handleCancelCredentials() {
|
||||
// Dismiss the credentials prompt
|
||||
if (onDismissCredentials) {
|
||||
onDismissCredentials();
|
||||
}
|
||||
}
|
||||
|
||||
// Render credentials needed messages
|
||||
if (isCredentialsNeeded && message.type === "credentials_needed") {
|
||||
return (
|
||||
<ChatCredentialsSetup
|
||||
credentials={message.credentials}
|
||||
agentName={message.agentName}
|
||||
message={message.message}
|
||||
onAllCredentialsComplete={handleAllCredentialsComplete}
|
||||
onCancel={handleCancelCredentials}
|
||||
className={className}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Render login needed messages
|
||||
if (isLoginNeeded && message.type === "login_needed") {
|
||||
// If user is already logged in, show success message instead of auth prompt
|
||||
if (user) {
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<div className="my-4 overflow-hidden rounded-lg border border-green-200 bg-gradient-to-br from-green-50 to-emerald-50 dark:border-green-800 dark:from-green-950/30 dark:to-emerald-950/30">
|
||||
<div className="px-6 py-4">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-full bg-green-600">
|
||||
<CheckCircleIcon
|
||||
size={20}
|
||||
weight="fill"
|
||||
className="text-white"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-neutral-900 dark:text-neutral-100">
|
||||
Successfully Authenticated
|
||||
</h3>
|
||||
<p className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
You're now signed in and ready to continue
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show auth prompt if not logged in
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<AuthPromptWidget
|
||||
message={message.message}
|
||||
sessionId={message.sessionId}
|
||||
agentInfo={message.agentInfo}
|
||||
returnUrl="/chat"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Render tool call messages
|
||||
if (isToolCall && message.type === "tool_call") {
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<ToolCallMessage toolName={message.toolName} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Render tool response messages
|
||||
if (
|
||||
(isToolResponse && message.type === "tool_response") ||
|
||||
message.type === "no_results" ||
|
||||
message.type === "agent_carousel" ||
|
||||
message.type === "execution_started"
|
||||
) {
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<ToolResponseMessage toolName={getToolActionPhrase(message.toolName)} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Render regular chat messages
|
||||
if (message.type === "message") {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex gap-3 px-4 py-4",
|
||||
isUser && "flex-row-reverse",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{/* Avatar */}
|
||||
<div className="flex-shrink-0">
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-8 w-8 items-center justify-center rounded-full",
|
||||
isUser && "bg-zinc-200 dark:bg-zinc-700",
|
||||
isAssistant && "bg-purple-600 dark:bg-purple-500",
|
||||
)}
|
||||
>
|
||||
{isUser ? (
|
||||
<UserIcon className="h-5 w-5 text-zinc-700 dark:text-zinc-200" />
|
||||
) : (
|
||||
<RobotIcon className="h-5 w-5 text-white" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Message Content */}
|
||||
<div className={cn("flex max-w-[70%] flex-col", isUser && "items-end")}>
|
||||
<MessageBubble variant={isUser ? "user" : "assistant"}>
|
||||
<MarkdownContent content={message.content} />
|
||||
</MessageBubble>
|
||||
|
||||
{/* Timestamp */}
|
||||
<span
|
||||
className={cn(
|
||||
"mt-1 text-xs text-zinc-500 dark:text-zinc-400",
|
||||
isUser && "text-right",
|
||||
)}
|
||||
>
|
||||
{formattedTimestamp}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback for unknown message types
|
||||
return null;
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ReactNode } from "react";
|
||||
|
||||
export interface MessageBubbleProps {
|
||||
children: ReactNode;
|
||||
variant: "user" | "assistant";
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function MessageBubble({
|
||||
children,
|
||||
variant,
|
||||
className,
|
||||
}: MessageBubbleProps) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-lg px-4 py-3 text-sm",
|
||||
variant === "user" && "bg-violet-600 text-white dark:bg-violet-500",
|
||||
variant === "assistant" &&
|
||||
"border border-neutral-200 bg-white dark:border-neutral-700 dark:bg-neutral-900 dark:text-neutral-100",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ChatMessage } from "../ChatMessage/ChatMessage";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
import { StreamingMessage } from "../StreamingMessage/StreamingMessage";
|
||||
import { useMessageList } from "./useMessageList";
|
||||
|
||||
export interface MessageListProps {
|
||||
messages: ChatMessageData[];
|
||||
streamingChunks?: string[];
|
||||
isStreaming?: boolean;
|
||||
className?: string;
|
||||
onStreamComplete?: () => void;
|
||||
onSendMessage?: (content: string) => void;
|
||||
}
|
||||
|
||||
export function MessageList({
|
||||
messages,
|
||||
streamingChunks = [],
|
||||
isStreaming = false,
|
||||
className,
|
||||
onStreamComplete,
|
||||
onSendMessage,
|
||||
}: MessageListProps) {
|
||||
const { messagesEndRef, messagesContainerRef } = useMessageList({
|
||||
messageCount: messages.length,
|
||||
isStreaming,
|
||||
});
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={messagesContainerRef}
|
||||
className={cn(
|
||||
"flex-1 overflow-y-auto",
|
||||
"scrollbar-thin scrollbar-track-transparent scrollbar-thumb-zinc-300 dark:scrollbar-thumb-zinc-700",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="space-y-0">
|
||||
{/* Render all persisted messages */}
|
||||
{messages.map((message, index) => (
|
||||
<ChatMessage
|
||||
key={index}
|
||||
message={message}
|
||||
onSendMessage={onSendMessage}
|
||||
/>
|
||||
))}
|
||||
|
||||
{/* Render streaming message if active */}
|
||||
{isStreaming && streamingChunks.length > 0 && (
|
||||
<StreamingMessage
|
||||
chunks={streamingChunks}
|
||||
onComplete={onStreamComplete}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Invisible div to scroll to */}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
import React from "react";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface QuickActionsWelcomeProps {
|
||||
title: string;
|
||||
description: string;
|
||||
actions: string[];
|
||||
onActionClick: (action: string) => void;
|
||||
disabled?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function QuickActionsWelcome({
|
||||
title,
|
||||
description,
|
||||
actions,
|
||||
onActionClick,
|
||||
disabled = false,
|
||||
className,
|
||||
}: QuickActionsWelcomeProps) {
|
||||
return (
|
||||
<div
|
||||
className={cn("flex flex-1 items-center justify-center p-4", className)}
|
||||
>
|
||||
<div className="max-w-2xl text-center">
|
||||
<Text
|
||||
variant="h2"
|
||||
className="mb-4 text-3xl font-bold text-zinc-900 dark:text-zinc-100"
|
||||
>
|
||||
{title}
|
||||
</Text>
|
||||
<Text variant="body" className="mb-8 text-zinc-600 dark:text-zinc-400">
|
||||
{description}
|
||||
</Text>
|
||||
<div className="grid gap-2 sm:grid-cols-2">
|
||||
{actions.map((action) => (
|
||||
<button
|
||||
key={action}
|
||||
onClick={() => onActionClick(action)}
|
||||
disabled={disabled}
|
||||
className="rounded-lg border border-zinc-200 bg-white p-4 text-left text-sm hover:bg-zinc-50 disabled:cursor-not-allowed disabled:opacity-50 dark:border-zinc-800 dark:bg-zinc-900 dark:hover:bg-zinc-800"
|
||||
>
|
||||
{action}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Robot } from "@phosphor-icons/react";
|
||||
import { MessageBubble } from "@/app/(platform)/chat/components/MessageBubble/MessageBubble";
|
||||
import { MarkdownContent } from "@/app/(platform)/chat/components/MarkdownContent/MarkdownContent";
|
||||
import { useStreamingMessage } from "./useStreamingMessage";
|
||||
|
||||
export interface StreamingMessageProps {
|
||||
chunks: string[];
|
||||
className?: string;
|
||||
onComplete?: () => void;
|
||||
}
|
||||
|
||||
export function StreamingMessage({
|
||||
chunks,
|
||||
className,
|
||||
onComplete,
|
||||
}: StreamingMessageProps) {
|
||||
const { displayText } = useStreamingMessage({ chunks, onComplete });
|
||||
|
||||
return (
|
||||
<div className={cn("flex gap-3 px-4 py-4", className)}>
|
||||
{/* Avatar */}
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-8 w-8 items-center justify-center rounded-full bg-purple-600 dark:bg-purple-500">
|
||||
<Robot className="h-5 w-5 text-white" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Message Content */}
|
||||
<div className="flex max-w-[70%] flex-col">
|
||||
<MessageBubble variant="assistant">
|
||||
<MarkdownContent content={displayText} />
|
||||
</MessageBubble>
|
||||
|
||||
{/* Timestamp */}
|
||||
<span className="mt-1 text-xs text-neutral-500 dark:text-neutral-400">
|
||||
Typing...
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
import React from "react";
|
||||
import { WrenchIcon } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { getToolActionPhrase } from "@/app/(platform)/chat/helpers";
|
||||
|
||||
export interface ToolCallMessageProps {
|
||||
toolName: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ToolCallMessage({ toolName, className }: ToolCallMessageProps) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"mx-10 max-w-[70%] overflow-hidden rounded-lg border transition-all duration-200",
|
||||
"border-neutral-200 dark:border-neutral-700",
|
||||
"bg-white dark:bg-neutral-900",
|
||||
"animate-in fade-in-50 slide-in-from-top-1",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{/* Header */}
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-between px-3 py-2",
|
||||
"bg-gradient-to-r from-neutral-50 to-neutral-100 dark:from-neutral-800/20 dark:to-neutral-700/20",
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2 overflow-hidden">
|
||||
<WrenchIcon
|
||||
size={16}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500 dark:text-neutral-400"
|
||||
/>
|
||||
<span className="relative inline-block overflow-hidden text-sm font-medium text-neutral-700 dark:text-neutral-300">
|
||||
{getToolActionPhrase(toolName)}...
|
||||
<span
|
||||
className={cn(
|
||||
"absolute inset-0 bg-gradient-to-r from-transparent via-white/50 to-transparent",
|
||||
"dark:via-white/20",
|
||||
"animate-shimmer",
|
||||
)}
|
||||
/>
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
import React from "react";
|
||||
import { WrenchIcon } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { getToolActionPhrase } from "@/app/(platform)/chat/helpers";
|
||||
|
||||
export interface ToolResponseMessageProps {
|
||||
toolName: string;
|
||||
success?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ToolResponseMessage({
|
||||
toolName,
|
||||
success = true,
|
||||
className,
|
||||
}: ToolResponseMessageProps) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"mx-10 max-w-[70%] overflow-hidden rounded-lg border transition-all duration-200",
|
||||
success
|
||||
? "border-neutral-200 dark:border-neutral-700"
|
||||
: "border-red-200 dark:border-red-800",
|
||||
"bg-white dark:bg-neutral-900",
|
||||
"animate-in fade-in-50 slide-in-from-top-1",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{/* Header */}
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-between px-3 py-2",
|
||||
"bg-gradient-to-r",
|
||||
success
|
||||
? "from-neutral-50 to-neutral-100 dark:from-neutral-800/20 dark:to-neutral-700/20"
|
||||
: "from-red-50 to-red-100 dark:from-red-900/20 dark:to-red-800/20",
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<WrenchIcon
|
||||
size={16}
|
||||
weight="bold"
|
||||
className="text-neutral-500 dark:text-neutral-400"
|
||||
/>
|
||||
<span className="text-sm font-medium text-neutral-700 dark:text-neutral-300">
|
||||
{getToolActionPhrase(toolName)}...
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,30 +1,17 @@
|
||||
"use client";
|
||||
|
||||
import { useChatPage } from "./useChatPage";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
|
||||
import { ChatLoadingState } from "./components/ChatLoadingState/ChatLoadingState";
|
||||
import { useGetFlag, Flag } from "@/services/feature-flags/use-get-flag";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { Chat } from "./components/Chat/Chat";
|
||||
|
||||
export default function ChatPage() {
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const router = useRouter();
|
||||
const {
|
||||
messages,
|
||||
isLoading,
|
||||
isCreating,
|
||||
error,
|
||||
sessionId,
|
||||
createSession,
|
||||
clearSession,
|
||||
refreshSession,
|
||||
} = useChatPage();
|
||||
|
||||
useEffect(() => {
|
||||
if (isChatEnabled === false) {
|
||||
router.push("/404");
|
||||
router.push("/marketplace");
|
||||
}
|
||||
}, [isChatEnabled, router]);
|
||||
|
||||
@@ -34,50 +21,7 @@ export default function ChatPage() {
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
{/* Header */}
|
||||
<header className="border-b border-zinc-200 bg-white p-4 dark:border-zinc-800 dark:bg-zinc-900">
|
||||
<div className="container mx-auto flex items-center justify-between">
|
||||
<h1 className="text-xl font-semibold">Chat</h1>
|
||||
{sessionId && (
|
||||
<div className="flex items-center gap-4">
|
||||
<span className="text-sm text-zinc-600 dark:text-zinc-400">
|
||||
Session: {sessionId.slice(0, 8)}...
|
||||
</span>
|
||||
<button
|
||||
onClick={clearSession}
|
||||
className="text-sm text-zinc-600 hover:text-zinc-900 dark:text-zinc-400 dark:hover:text-zinc-100"
|
||||
>
|
||||
New Chat
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</header>
|
||||
|
||||
{/* Main Content */}
|
||||
<main className="container mx-auto flex flex-1 flex-col overflow-hidden">
|
||||
{/* Loading State - show when explicitly loading/creating OR when we don't have a session yet and no error */}
|
||||
{(isLoading || isCreating || (!sessionId && !error)) && (
|
||||
<ChatLoadingState
|
||||
message={isCreating ? "Creating session..." : "Loading..."}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Error State */}
|
||||
{error && !isLoading && (
|
||||
<ChatErrorState error={error} onRetry={createSession} />
|
||||
)}
|
||||
|
||||
{/* Session Content */}
|
||||
{sessionId && !isLoading && !error && (
|
||||
<ChatContainer
|
||||
sessionId={sessionId}
|
||||
initialMessages={messages}
|
||||
onRefreshSession={refreshSession}
|
||||
className="flex-1"
|
||||
/>
|
||||
)}
|
||||
</main>
|
||||
<Chat className="flex-1" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
import { useState, useCallback, useRef, useEffect } from "react";
|
||||
import { toast } from "sonner";
|
||||
import type { ToolArguments, ToolResult } from "@/types/chat";
|
||||
|
||||
const MAX_RETRIES = 3;
|
||||
const INITIAL_RETRY_DELAY = 1000;
|
||||
|
||||
export interface StreamChunk {
|
||||
type:
|
||||
| "text_chunk"
|
||||
| "text_ended"
|
||||
| "tool_call"
|
||||
| "tool_call_start"
|
||||
| "tool_response"
|
||||
| "login_needed"
|
||||
| "need_login"
|
||||
| "credentials_needed"
|
||||
| "error"
|
||||
| "usage"
|
||||
| "stream_end";
|
||||
timestamp?: string;
|
||||
content?: string;
|
||||
message?: string;
|
||||
tool_id?: string;
|
||||
tool_name?: string;
|
||||
arguments?: ToolArguments;
|
||||
result?: ToolResult;
|
||||
success?: boolean;
|
||||
idx?: number;
|
||||
session_id?: string;
|
||||
agent_info?: {
|
||||
graph_id: string;
|
||||
name: string;
|
||||
trigger_type: string;
|
||||
};
|
||||
provider?: string;
|
||||
provider_name?: string;
|
||||
credential_type?: string;
|
||||
scopes?: string[];
|
||||
title?: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export function useChatStream() {
|
||||
const [isStreaming, setIsStreaming] = useState(false);
|
||||
const [error, setError] = useState<Error | null>(null);
|
||||
const eventSourceRef = useRef<EventSource | null>(null);
|
||||
const retryCountRef = useRef<number>(0);
|
||||
const retryTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
const stopStreaming = useCallback(() => {
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
if (eventSourceRef.current) {
|
||||
eventSourceRef.current.close();
|
||||
eventSourceRef.current = null;
|
||||
}
|
||||
if (retryTimeoutRef.current) {
|
||||
clearTimeout(retryTimeoutRef.current);
|
||||
retryTimeoutRef.current = null;
|
||||
}
|
||||
retryCountRef.current = 0;
|
||||
setIsStreaming(false);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
stopStreaming();
|
||||
};
|
||||
}, [stopStreaming]);
|
||||
|
||||
const sendMessage = useCallback(
|
||||
async (
|
||||
sessionId: string,
|
||||
message: string,
|
||||
onChunk: (chunk: StreamChunk) => void,
|
||||
isUserMessage: boolean = true,
|
||||
) => {
|
||||
stopStreaming();
|
||||
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
if (abortController.signal.aborted) {
|
||||
return Promise.reject(new Error("Request aborted"));
|
||||
}
|
||||
|
||||
retryCountRef.current = 0;
|
||||
setIsStreaming(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const url = `/api/chat/sessions/${sessionId}/stream?message=${encodeURIComponent(
|
||||
message,
|
||||
)}&is_user_message=${isUserMessage}`;
|
||||
|
||||
const eventSource = new EventSource(url);
|
||||
eventSourceRef.current = eventSource;
|
||||
|
||||
abortController.signal.addEventListener("abort", () => {
|
||||
eventSource.close();
|
||||
eventSourceRef.current = null;
|
||||
});
|
||||
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
const cleanup = () => {
|
||||
eventSource.removeEventListener("message", messageHandler);
|
||||
eventSource.removeEventListener("error", errorHandler);
|
||||
};
|
||||
|
||||
const messageHandler = (event: MessageEvent) => {
|
||||
try {
|
||||
const chunk = JSON.parse(event.data) as StreamChunk;
|
||||
|
||||
if (retryCountRef.current > 0) {
|
||||
retryCountRef.current = 0;
|
||||
}
|
||||
|
||||
// Call the chunk handler
|
||||
onChunk(chunk);
|
||||
|
||||
// Handle stream lifecycle
|
||||
if (chunk.type === "stream_end") {
|
||||
cleanup();
|
||||
stopStreaming();
|
||||
resolve();
|
||||
} else if (chunk.type === "error") {
|
||||
cleanup();
|
||||
reject(
|
||||
new Error(chunk.message || chunk.content || "Stream error"),
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
const parseError =
|
||||
err instanceof Error
|
||||
? err
|
||||
: new Error("Failed to parse stream chunk");
|
||||
setError(parseError);
|
||||
cleanup();
|
||||
reject(parseError);
|
||||
}
|
||||
};
|
||||
|
||||
const errorHandler = () => {
|
||||
if (eventSourceRef.current) {
|
||||
eventSourceRef.current.close();
|
||||
eventSourceRef.current = null;
|
||||
}
|
||||
|
||||
if (retryCountRef.current < MAX_RETRIES) {
|
||||
retryCountRef.current += 1;
|
||||
const retryDelay =
|
||||
INITIAL_RETRY_DELAY * Math.pow(2, retryCountRef.current - 1);
|
||||
|
||||
toast.info("Connection interrupted", {
|
||||
description: `Retrying in ${retryDelay / 1000} seconds...`,
|
||||
});
|
||||
|
||||
retryTimeoutRef.current = setTimeout(() => {
|
||||
sendMessage(sessionId, message, onChunk, isUserMessage).catch(
|
||||
(_err) => {
|
||||
// Retry failed
|
||||
},
|
||||
);
|
||||
}, retryDelay);
|
||||
} else {
|
||||
const streamError = new Error(
|
||||
"Stream connection failed after multiple retries",
|
||||
);
|
||||
setError(streamError);
|
||||
toast.error("Connection Failed", {
|
||||
description:
|
||||
"Unable to connect to chat service. Please try again.",
|
||||
});
|
||||
cleanup();
|
||||
stopStreaming();
|
||||
reject(streamError);
|
||||
}
|
||||
};
|
||||
|
||||
eventSource.addEventListener("message", messageHandler);
|
||||
eventSource.addEventListener("error", errorHandler);
|
||||
});
|
||||
} catch (err) {
|
||||
const streamError =
|
||||
err instanceof Error ? err : new Error("Failed to start stream");
|
||||
setError(streamError);
|
||||
setIsStreaming(false);
|
||||
throw streamError;
|
||||
}
|
||||
},
|
||||
[stopStreaming],
|
||||
);
|
||||
|
||||
return {
|
||||
isStreaming,
|
||||
error,
|
||||
sendMessage,
|
||||
stopStreaming,
|
||||
};
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user