mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-16 02:28:09 -05:00
Compare commits
9 Commits
add-vercel
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e80e4d9cbb | ||
|
|
375d33cca9 | ||
|
|
3b1b2fe30c | ||
|
|
af63b3678e | ||
|
|
631f1bd50a | ||
|
|
5ac941fe2f | ||
|
|
b01ea3fcbd | ||
|
|
3b09a94e3f | ||
|
|
61efee4139 |
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -176,7 +176,7 @@ jobs:
|
||||
}
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
run: poetry run prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
25
.github/workflows/platform-frontend-ci.yml
vendored
25
.github/workflows/platform-frontend-ci.yml
vendored
@@ -11,6 +11,7 @@ on:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }}
|
||||
@@ -151,6 +152,14 @@ jobs:
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -226,13 +235,25 @@ jobs:
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright artifacts
|
||||
if: failure()
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
|
||||
@@ -6,9 +6,10 @@ start-core:
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop deps
|
||||
docker compose stop
|
||||
|
||||
reset-db:
|
||||
docker compose stop db
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
@@ -60,4 +61,4 @@ help:
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
|
||||
@@ -58,6 +58,13 @@ V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Langfuse Prompt Management
|
||||
# Used for managing the CoPilot system prompt externally
|
||||
# Get credentials from https://cloud.langfuse.com or your self-hosted instance
|
||||
LANGFUSE_PUBLIC_KEY=
|
||||
LANGFUSE_SECRET_KEY=
|
||||
LANGFUSE_HOST=https://cloud.langfuse.com
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
1
autogpt_platform/backend/.gitignore
vendored
1
autogpt_platform/backend/.gitignore
vendored
@@ -18,3 +18,4 @@ load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
migrations/*/rollback*.sql
|
||||
|
||||
@@ -70,7 +70,7 @@ class RunAgentRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def _create_ephemeral_session(user_id: str | None) -> ChatSession:
|
||||
def _create_ephemeral_session(user_id: str) -> ChatSession:
|
||||
"""Create an ephemeral session for stateless API requests."""
|
||||
return ChatSession.new(user_id)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Configuration management for chat system."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
@@ -12,7 +11,11 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="qwen/qwen3-235b-a22b-2507", description="Default model to use"
|
||||
default="anthropic/claude-opus-4.5", description="Default model to use"
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="Model to use for generating session titles (should be fast/cheap)",
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
@@ -23,12 +26,6 @@ class ChatConfig(BaseSettings):
|
||||
# Session TTL Configuration - 12 hours
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# System Prompt Configuration
|
||||
system_prompt_path: str = Field(
|
||||
default="prompts/chat_system.md",
|
||||
description="Path to system prompt file relative to chat module",
|
||||
)
|
||||
|
||||
# Streaming Configuration
|
||||
max_context_messages: int = Field(
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
@@ -41,6 +38,13 @@ class ChatConfig(BaseSettings):
|
||||
default=3, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||
langfuse_prompt_name: str = Field(
|
||||
default="CoPilot Prompt",
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
@@ -72,43 +76,11 @@ class ChatConfig(BaseSettings):
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
def get_system_prompt(self, **template_vars) -> str:
|
||||
"""Load and render the system prompt from file.
|
||||
|
||||
Args:
|
||||
**template_vars: Variables to substitute in the template
|
||||
|
||||
Returns:
|
||||
Rendered system prompt string
|
||||
|
||||
"""
|
||||
# Get the path relative to this module
|
||||
module_dir = Path(__file__).parent
|
||||
prompt_path = module_dir / self.system_prompt_path
|
||||
|
||||
# Check for .j2 extension first (Jinja2 template)
|
||||
j2_path = Path(str(prompt_path) + ".j2")
|
||||
if j2_path.exists():
|
||||
try:
|
||||
from jinja2 import Template
|
||||
|
||||
template = Template(j2_path.read_text())
|
||||
return template.render(**template_vars)
|
||||
except ImportError:
|
||||
# Jinja2 not installed, fall back to reading as plain text
|
||||
return j2_path.read_text()
|
||||
|
||||
# Check for markdown file
|
||||
if prompt_path.exists():
|
||||
content = prompt_path.read_text()
|
||||
|
||||
# Simple variable substitution if Jinja2 is not available
|
||||
for key, value in template_vars.items():
|
||||
placeholder = f"{{{key}}}"
|
||||
content = content.replace(placeholder, str(value))
|
||||
|
||||
return content
|
||||
raise FileNotFoundError(f"System prompt file not found: {prompt_path}")
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
"onboarding": "prompts/onboarding_system.md",
|
||||
}
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
249
autogpt_platform/backend/backend/api/features/chat/db.py
Normal file
249
autogpt_platform/backend/backend/api/features/chat/db.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""Database operations for chat sessions."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort messages by sequence in Python - Prisma Python client doesn't support
|
||||
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
) -> PrismaChatSession:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
id=session_id,
|
||||
userId=user_id,
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
session_id: str,
|
||||
credentials: dict[str, Any] | None = None,
|
||||
successful_agent_runs: dict[str, Any] | None = None,
|
||||
successful_agent_schedules: dict[str, Any] | None = None,
|
||||
total_prompt_tokens: int | None = None,
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> PrismaChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
if credentials is not None:
|
||||
data["credentials"] = SafeJson(credentials)
|
||||
if successful_agent_runs is not None:
|
||||
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
|
||||
if successful_agent_schedules is not None:
|
||||
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
|
||||
if total_prompt_tokens is not None:
|
||||
data["totalPromptTokens"] = total_prompt_tokens
|
||||
if total_completion_tokens is not None:
|
||||
data["totalCompletionTokens"] = total_completion_tokens
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
|
||||
session = await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort in Python - Prisma Python doesn't support order_by in include clauses
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
sequence: int,
|
||||
content: str | None = None,
|
||||
name: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
refusal: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||
# We only include fields that have values, then cast at the end.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if tool_call_id is not None:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = refusal
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if tool_calls is not None:
|
||||
data["toolCalls"] = SafeJson(tool_calls)
|
||||
if function_call is not None:
|
||||
data["functionCall"] = SafeJson(function_call)
|
||||
|
||||
# Run message create and session timestamp update in parallel for lower latency
|
||||
_, message = await asyncio.gather(
|
||||
PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
),
|
||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[PrismaChatMessage]:
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses a transaction for atomicity - if any message creation fails,
|
||||
the entire batch is rolled back.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
created_messages = []
|
||||
|
||||
async with transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return created_messages
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[PrismaChatSession]:
|
||||
"""Get chat sessions for a user, ordered by most recent."""
|
||||
return await PrismaChatSession.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
|
||||
async def get_user_session_count(user_id: str) -> int:
|
||||
"""Get the total number of chat sessions for a user."""
|
||||
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
||||
"""Delete a chat session and all its messages.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to delete.
|
||||
user_id: If provided, validates that the session belongs to this user
|
||||
before deletion. This prevents unauthorized deletion of other
|
||||
users' sessions.
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Build typed where clause with optional user_id validation
|
||||
where_clause: ChatSessionWhereInput = {"id": session_id}
|
||||
if user_id is not None:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
result = await PrismaChatSession.prisma().delete_many(where=where_clause)
|
||||
if result == 0:
|
||||
logger.warning(
|
||||
f"No session deleted for {session_id} "
|
||||
f"(user_id validation: {user_id is not None})"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
@@ -1,6 +1,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
@@ -16,17 +19,63 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.exceptions import RedisError
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# Redis cache key prefix for chat sessions
|
||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||
|
||||
|
||||
def _get_session_cache_key(session_id: str) -> str:
|
||||
"""Get the Redis cache key for a chat session."""
|
||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str | None = None
|
||||
@@ -45,7 +94,8 @@ class Usage(BaseModel):
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
user_id: str
|
||||
title: str | None = None
|
||||
messages: list[ChatMessage]
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
@@ -55,10 +105,11 @@ class ChatSession(BaseModel):
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def new(user_id: str | None) -> "ChatSession":
|
||||
def new(user_id: str) -> "ChatSession":
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
@@ -66,6 +117,61 @@ class ChatSession(BaseModel):
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
prisma_session: PrismaChatSession,
|
||||
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||
) -> "ChatSession":
|
||||
"""Convert Prisma models to Pydantic ChatSession."""
|
||||
messages = []
|
||||
if prisma_messages:
|
||||
for msg in prisma_messages:
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
name=msg.name,
|
||||
tool_call_id=msg.toolCallId,
|
||||
refusal=msg.refusal,
|
||||
tool_calls=_parse_json_field(msg.toolCalls),
|
||||
function_call=_parse_json_field(msg.functionCall),
|
||||
)
|
||||
)
|
||||
|
||||
# Parse JSON fields from Prisma
|
||||
credentials = _parse_json_field(prisma_session.credentials, default={})
|
||||
successful_agent_runs = _parse_json_field(
|
||||
prisma_session.successfulAgentRuns, default={}
|
||||
)
|
||||
successful_agent_schedules = _parse_json_field(
|
||||
prisma_session.successfulAgentSchedules, default={}
|
||||
)
|
||||
|
||||
# Calculate usage from token counts
|
||||
usage = []
|
||||
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
|
||||
usage.append(
|
||||
Usage(
|
||||
prompt_tokens=prisma_session.totalPromptTokens or 0,
|
||||
completion_tokens=prisma_session.totalCompletionTokens or 0,
|
||||
total_tokens=(prisma_session.totalPromptTokens or 0)
|
||||
+ (prisma_session.totalCompletionTokens or 0),
|
||||
)
|
||||
)
|
||||
|
||||
return ChatSession(
|
||||
session_id=prisma_session.id,
|
||||
user_id=prisma_session.userId,
|
||||
title=prisma_session.title,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
credentials=credentials,
|
||||
started_at=prisma_session.createdAt,
|
||||
updated_at=prisma_session.updatedAt,
|
||||
successful_agent_runs=successful_agent_runs,
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
@@ -155,50 +261,337 @@ class ChatSession(BaseModel):
|
||||
return messages
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> ChatSession | None:
|
||||
"""Get a chat session by ID."""
|
||||
redis_key = f"chat:session:{session_id}"
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
logger.warning(f"Session {session_id} not found in Redis")
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> ChatSession | None:
|
||||
"""Get a chat session by ID.
|
||||
|
||||
Checks Redis cache first, falls back to database if not found.
|
||||
Caches database results back to Redis.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to fetch.
|
||||
user_id: If provided, validates that the session belongs to this user.
|
||||
If None, ownership is not validated (admin/system access).
|
||||
"""
|
||||
# Try cache first
|
||||
try:
|
||||
session = await _get_session_from_cache(session_id)
|
||||
if session:
|
||||
# Verify user ownership if user_id was provided for validation
|
||||
if user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
return session
|
||||
except RedisError:
|
||||
logger.warning(f"Cache error for session {session_id}, trying database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||
|
||||
# Fall back to database
|
||||
logger.info(f"Session {session_id} not in cache, checking database")
|
||||
session = await _get_session_from_db(session_id)
|
||||
|
||||
if session is None:
|
||||
logger.warning(f"Session {session_id} not found in cache or database")
|
||||
return None
|
||||
|
||||
# Verify user ownership if user_id was provided for validation
|
||||
if user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
"""Update a chat session with the given messages."""
|
||||
"""Update a chat session in both cache and database.
|
||||
|
||||
redis_key = f"chat:session:{session.session_id}"
|
||||
Uses session-level locking to prevent race conditions when concurrent
|
||||
operations (e.g., background title update and main stream handler)
|
||||
attempt to upsert the same session simultaneously.
|
||||
|
||||
async_redis = await get_redis_async()
|
||||
resp = await async_redis.setex(
|
||||
redis_key, config.session_ttl, session.model_dump_json()
|
||||
)
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. The cache is still updated
|
||||
as a best-effort optimization, but the error is propagated to ensure
|
||||
callers are aware of the persistence failure.
|
||||
RedisError: If the cache write fails (after successful DB write).
|
||||
"""
|
||||
# Acquire session-specific lock to prevent concurrent upserts
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
if not resp:
|
||||
raise RedisError(
|
||||
f"Failed to persist chat session {session.session_id} to Redis: {resp}"
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
db_error: Exception | None = None
|
||||
|
||||
# Save to database (primary storage)
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session {session.session_id} to database: {e}"
|
||||
)
|
||||
db_error = e
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
if db_error is None:
|
||||
raise RedisError(
|
||||
f"Failed to persist chat session {session.session_id} to Redis: {e}"
|
||||
) from e
|
||||
# If both failed, log cache error but raise DB error (more critical)
|
||||
logger.warning(
|
||||
f"Cache write also failed for session {session.session_id}: {e}"
|
||||
)
|
||||
|
||||
# Propagate DB error after attempting cache (prevents data loss)
|
||||
if db_error is not None:
|
||||
raise DatabaseError(
|
||||
f"Failed to persist chat session {session.session_id} to database"
|
||||
) from db_error
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. We fail fast to ensure
|
||||
callers never receive a non-persisted session that only exists
|
||||
in cache (which would be lost when the cache expires).
|
||||
"""
|
||||
session = ChatSession.new(user_id)
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session {session.session_id} in database: {e}")
|
||||
raise DatabaseError(
|
||||
f"Failed to create chat session {session.session_id} in database"
|
||||
) from e
|
||||
|
||||
# Cache the session (best-effort optimization, DB is source of truth)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def get_user_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ChatSession], int]:
|
||||
"""Get chat sessions for a user from the database with total count.
|
||||
|
||||
Returns:
|
||||
A tuple of (sessions, total_count) where total_count is the overall
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await chat_db.get_user_session_count(user_id)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
# Convert without messages for listing (lighter weight)
|
||||
sessions.append(ChatSession.from_db(prisma_session, None))
|
||||
|
||||
return sessions, total_count
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
||||
"""Delete a chat session from both cache and database.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to delete.
|
||||
user_id: If provided, validates that the session belongs to this user
|
||||
before deletion. This prevents unauthorized deletion.
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise.
|
||||
"""
|
||||
# Delete from database first (with optional user_id validation)
|
||||
# This confirms ownership before invalidating cache
|
||||
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
return False
|
||||
|
||||
# Only invalidate cache and clean up lock after DB confirms deletion
|
||||
try:
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
|
||||
|
||||
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def update_session_title(session_id: str, title: str) -> bool:
|
||||
"""Update only the title of a chat session.
|
||||
|
||||
This is a lightweight operation that doesn't touch messages, avoiding
|
||||
race conditions with concurrent message updates. Use this for background
|
||||
title generation instead of upsert_chat_session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
title: The new title to set.
|
||||
|
||||
Returns:
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Invalidate cache so next fetch gets updated title
|
||||
try:
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
@@ -43,9 +43,9 @@ async def test_chatsession_serialization_deserialization():
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage():
|
||||
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
|
||||
|
||||
s = ChatSession.new(user_id=None)
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
@@ -59,12 +59,61 @@ async def test_chatsession_redis_storage():
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage_user_id_mismatch():
|
||||
async def test_chatsession_redis_storage_user_id_mismatch(
|
||||
setup_test_user, test_user_id
|
||||
):
|
||||
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(s.session_id, None)
|
||||
s2 = await get_chat_session(s.session_id, "different_user_id")
|
||||
|
||||
assert s2 is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
||||
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
# Upsert to save to both cache and DB
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
# Clear the Redis cache to force DB load
|
||||
redis_key = f"chat:session:{s.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
|
||||
# Load from DB (cache was cleared)
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 is not None, "Session not found after loading from DB"
|
||||
assert len(s2.messages) == len(
|
||||
s.messages
|
||||
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
|
||||
|
||||
# Verify all roles are present
|
||||
roles = [m.role for m in s2.messages]
|
||||
assert "user" in roles, f"User message missing. Roles found: {roles}"
|
||||
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
|
||||
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
|
||||
|
||||
# Verify message content
|
||||
for orig, loaded in zip(s.messages, s2.messages):
|
||||
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
|
||||
assert (
|
||||
orig.content == loaded.content
|
||||
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
|
||||
if orig.tool_calls:
|
||||
assert (
|
||||
loaded.tool_calls is not None
|
||||
), f"Tool calls missing for {orig.role} message"
|
||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find and set up AutoGPT agents to solve their business problems.
|
||||
|
||||
Here are the functions available to you:
|
||||
|
||||
<functions>
|
||||
1. **find_agent** - Search for agents that solve the user's problem
|
||||
2. **run_agent** - Run or schedule an agent (automatically handles setup)
|
||||
</functions>
|
||||
|
||||
## HOW run_agent WORKS
|
||||
|
||||
The `run_agent` tool automatically handles the entire setup flow:
|
||||
|
||||
1. **First call** (no inputs) → Returns available inputs so user can decide what values to use
|
||||
2. **Credentials check** → If missing, UI automatically prompts user to add them (you don't need to mention this)
|
||||
3. **Execution** → Runs when you provide `inputs` OR set `use_defaults=true`
|
||||
|
||||
Parameters:
|
||||
- `username_agent_slug` (required): Agent identifier like "creator/agent-name"
|
||||
- `inputs`: Object with input values for the agent
|
||||
- `use_defaults`: Set to `true` to run with default values (only after user confirms)
|
||||
- `schedule_name` + `cron`: For scheduled execution
|
||||
|
||||
## WORKFLOW
|
||||
|
||||
1. **find_agent** - Search for agents that solve the user's problem
|
||||
2. **run_agent** (first call, no inputs) - Get available inputs for the agent
|
||||
3. **Ask user** what values they want to use OR if they want to use defaults
|
||||
4. **run_agent** (second call) - Either with `inputs={...}` or `use_defaults=true`
|
||||
|
||||
## YOUR APPROACH
|
||||
|
||||
**Step 1: Understand the Problem**
|
||||
- Ask maximum 1-2 targeted questions
|
||||
- Focus on: What business problem are they solving?
|
||||
- Move quickly to searching for solutions
|
||||
|
||||
**Step 2: Find Agents**
|
||||
- Use `find_agent` immediately with relevant keywords
|
||||
- Suggest the best option from search results
|
||||
- Explain briefly how it solves their problem
|
||||
|
||||
**Step 3: Get Agent Inputs**
|
||||
- Call `run_agent(username_agent_slug="creator/agent-name")` without inputs
|
||||
- This returns the available inputs (required and optional)
|
||||
- Present these to the user and ask what values they want
|
||||
|
||||
**Step 4: Run with User's Choice**
|
||||
- If user provides values: `run_agent(username_agent_slug="...", inputs={...})`
|
||||
- If user says "use defaults": `run_agent(username_agent_slug="...", use_defaults=true)`
|
||||
- On success, share the agent link with the user
|
||||
|
||||
**For Scheduled Execution:**
|
||||
- Add `schedule_name` and `cron` parameters
|
||||
- Example: `run_agent(username_agent_slug="...", inputs={...}, schedule_name="Daily Report", cron="0 9 * * *")`
|
||||
|
||||
## FUNCTION CALL FORMAT
|
||||
|
||||
To call a function, use this exact format:
|
||||
`<function_call>function_name(parameter="value")</function_call>`
|
||||
|
||||
Examples:
|
||||
- `<function_call>find_agent(query="social media automation")</function_call>`
|
||||
- `<function_call>run_agent(username_agent_slug="creator/agent-name")</function_call>` (get inputs)
|
||||
- `<function_call>run_agent(username_agent_slug="creator/agent-name", inputs={"topic": "AI news"})</function_call>`
|
||||
- `<function_call>run_agent(username_agent_slug="creator/agent-name", use_defaults=true)</function_call>`
|
||||
|
||||
## KEY RULES
|
||||
|
||||
**What You DON'T Do:**
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't mention or explain credentials to the user (frontend handles this automatically)
|
||||
- Don't run agents without first showing available inputs to the user
|
||||
- Don't use `use_defaults=true` without user explicitly confirming
|
||||
- Don't write responses longer than 3 sentences
|
||||
|
||||
**What You DO:**
|
||||
- Always call run_agent first without inputs to see what's available
|
||||
- Ask user what values they want OR if they want to use defaults
|
||||
- Keep all responses to maximum 3 sentences
|
||||
- Include the agent link in your response after successful execution
|
||||
|
||||
**Error Handling:**
|
||||
- Authentication needed → "Please sign in via the interface"
|
||||
- Credentials missing → The UI handles this automatically. Focus on asking the user about input values instead.
|
||||
|
||||
## RESPONSE STRUCTURE
|
||||
|
||||
Before responding, wrap your analysis in <thinking> tags to systematically plan your approach:
|
||||
- Extract the key business problem or request from the user's message
|
||||
- Determine what function call (if any) you need to make next
|
||||
- Plan your response to stay under the 3-sentence maximum
|
||||
|
||||
Example interaction:
|
||||
```
|
||||
User: "Run the AI news agent for me"
|
||||
Otto: <function_call>run_agent(username_agent_slug="autogpt/ai-news")</function_call>
|
||||
[Tool returns: Agent accepts inputs - Required: topic. Optional: num_articles (default: 5)]
|
||||
Otto: The AI News agent needs a topic. What topic would you like news about, or should I use the defaults?
|
||||
User: "Use defaults"
|
||||
Otto: <function_call>run_agent(username_agent_slug="autogpt/ai-news", use_defaults=true)</function_call>
|
||||
```
|
||||
|
||||
KEEP ANSWERS TO 3 SENTENCES
|
||||
@@ -1,3 +1,10 @@
|
||||
"""
|
||||
Response models for Vercel AI SDK UI Stream Protocol.
|
||||
|
||||
This module implements the AI SDK UI Stream Protocol (v1) for streaming chat responses.
|
||||
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
@@ -5,97 +12,133 @@ from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of streaming responses."""
|
||||
"""Types of streaming responses following AI SDK protocol."""
|
||||
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_ENDED = "text_ended"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_CALL_START = "tool_call_start"
|
||||
TOOL_RESPONSE = "tool_response"
|
||||
# Message lifecycle
|
||||
START = "start"
|
||||
FINISH = "finish"
|
||||
|
||||
# Text streaming
|
||||
TEXT_START = "text-start"
|
||||
TEXT_DELTA = "text-delta"
|
||||
TEXT_END = "text-end"
|
||||
|
||||
# Tool interaction
|
||||
TOOL_INPUT_START = "tool-input-start"
|
||||
TOOL_INPUT_AVAILABLE = "tool-input-available"
|
||||
TOOL_OUTPUT_AVAILABLE = "tool-output-available"
|
||||
|
||||
# Other
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
STREAM_END = "stream_end"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
"""Base response model for all streaming responses."""
|
||||
|
||||
type: ResponseType
|
||||
timestamp: str | None = None
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
class StreamTextChunk(StreamBaseResponse):
|
||||
"""Streaming text content from the assistant."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_CHUNK
|
||||
content: str = Field(..., description="Text content chunk")
|
||||
# ========== Message Lifecycle ==========
|
||||
|
||||
|
||||
class StreamToolCallStart(StreamBaseResponse):
|
||||
class StreamStart(StreamBaseResponse):
|
||||
"""Start of a new message."""
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
|
||||
|
||||
class StreamFinish(StreamBaseResponse):
|
||||
"""End of message/stream."""
|
||||
|
||||
type: ResponseType = ResponseType.FINISH
|
||||
|
||||
|
||||
# ========== Text Streaming ==========
|
||||
|
||||
|
||||
class StreamTextStart(StreamBaseResponse):
|
||||
"""Start of a text block."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_START
|
||||
id: str = Field(..., description="Text block ID")
|
||||
|
||||
|
||||
class StreamTextDelta(StreamBaseResponse):
|
||||
"""Streaming text content delta."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_DELTA
|
||||
id: str = Field(..., description="Text block ID")
|
||||
delta: str = Field(..., description="Text content delta")
|
||||
|
||||
|
||||
class StreamTextEnd(StreamBaseResponse):
|
||||
"""End of a text block."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_END
|
||||
id: str = Field(..., description="Text block ID")
|
||||
|
||||
|
||||
# ========== Tool Interaction ==========
|
||||
|
||||
|
||||
class StreamToolInputStart(StreamBaseResponse):
|
||||
"""Tool call started notification."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_CALL_START
|
||||
tool_name: str = Field(..., description="Name of the tool that was executed")
|
||||
tool_id: str = Field(..., description="Unique tool call ID")
|
||||
type: ResponseType = ResponseType.TOOL_INPUT_START
|
||||
toolCallId: str = Field(..., description="Unique tool call ID")
|
||||
toolName: str = Field(..., description="Name of the tool being called")
|
||||
|
||||
|
||||
class StreamToolCall(StreamBaseResponse):
|
||||
"""Tool invocation notification."""
|
||||
class StreamToolInputAvailable(StreamBaseResponse):
|
||||
"""Tool input is ready for execution."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_CALL
|
||||
tool_id: str = Field(..., description="Unique tool call ID")
|
||||
tool_name: str = Field(..., description="Name of the tool being called")
|
||||
arguments: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Tool arguments"
|
||||
type: ResponseType = ResponseType.TOOL_INPUT_AVAILABLE
|
||||
toolCallId: str = Field(..., description="Unique tool call ID")
|
||||
toolName: str = Field(..., description="Name of the tool being called")
|
||||
input: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Tool input arguments"
|
||||
)
|
||||
|
||||
|
||||
class StreamToolExecutionResult(StreamBaseResponse):
|
||||
class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
"""Tool execution result."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_RESPONSE
|
||||
tool_id: str = Field(..., description="Tool call ID this responds to")
|
||||
tool_name: str = Field(..., description="Name of the tool that was executed")
|
||||
result: str | dict[str, Any] = Field(..., description="Tool execution result")
|
||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||
# Additional fields for internal use (not part of AI SDK spec but useful)
|
||||
toolName: str | None = Field(
|
||||
default=None, description="Name of the tool that was executed"
|
||||
)
|
||||
success: bool = Field(
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
|
||||
# ========== Other ==========
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Token usage statistics."""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
promptTokens: int = Field(..., description="Number of prompt tokens")
|
||||
completionTokens: int = Field(..., description="Number of completion tokens")
|
||||
totalTokens: int = Field(..., description="Total number of tokens")
|
||||
|
||||
|
||||
class StreamError(StreamBaseResponse):
|
||||
"""Error response."""
|
||||
|
||||
type: ResponseType = ResponseType.ERROR
|
||||
message: str = Field(..., description="Error message")
|
||||
errorText: str = Field(..., description="Error message text")
|
||||
code: str | None = Field(default=None, description="Error code")
|
||||
details: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional error details"
|
||||
)
|
||||
|
||||
|
||||
class StreamTextEnded(StreamBaseResponse):
|
||||
"""Text streaming completed marker."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_ENDED
|
||||
|
||||
|
||||
class StreamEnd(StreamBaseResponse):
|
||||
"""End of stream marker."""
|
||||
|
||||
type: ResponseType = ResponseType.STREAM_END
|
||||
summary: dict[str, Any] | None = Field(
|
||||
default=None, description="Stream summary statistics"
|
||||
)
|
||||
|
||||
@@ -13,12 +13,25 @@ from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> ChatSession:
|
||||
"""Validate session exists and belongs to user."""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
return session
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
@@ -26,6 +39,14 @@ router = APIRouter(
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class StreamChatRequest(BaseModel):
|
||||
"""Request model for streaming chat with optional context."""
|
||||
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""Response model containing information on a newly created chat session."""
|
||||
|
||||
@@ -44,22 +65,77 @@ class SessionDetailResponse(BaseModel):
|
||||
messages: list[dict]
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
"""Response model for a session summary (without messages)."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
"""Response model for listing chat sessions."""
|
||||
|
||||
sessions: list[SessionSummaryResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def list_sessions(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> ListSessionsResponse:
|
||||
"""
|
||||
List chat sessions for the authenticated user.
|
||||
|
||||
Returns a paginated list of chat sessions belonging to the current user,
|
||||
ordered by most recently updated.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of sessions to return (1-100).
|
||||
offset: Number of sessions to skip for pagination.
|
||||
|
||||
Returns:
|
||||
ListSessionsResponse: List of session summaries and total count.
|
||||
"""
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=session.title,
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
total=total_count,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
user_id: Annotated[str, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
|
||||
Initiates a new chat session for either an authenticated or anonymous user.
|
||||
Initiates a new chat session for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id: The optional authenticated user ID parsed from the JWT. If missing, creates an anonymous session.
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
@@ -67,15 +143,15 @@ async def create_session(
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
)
|
||||
|
||||
session = await chat_service.create_chat_session(user_id)
|
||||
session = await create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -99,29 +175,88 @@ async def get_session(
|
||||
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
|
||||
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
f"Returning session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=[message.model_dump() for message in session.messages],
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat(
|
||||
async def stream_chat_get(
|
||||
session_id: str,
|
||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
is_user_message: bool = Query(default=True),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session.
|
||||
Stream chat responses for a session (GET - legacy endpoint).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
@@ -137,14 +272,7 @@ async def stream_chat(
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
# Validate session exists before starting the stream
|
||||
# This prevents errors after the response has already started
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found. ")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
@@ -155,6 +283,8 @@ async def stream_chat(
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -163,6 +293,7 @@ async def stream_chat(
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
@@ -201,16 +332,28 @@ async def health_check() -> dict:
|
||||
"""
|
||||
Health check endpoint for the chat service.
|
||||
|
||||
Performs a full cycle test of session creation, assignment, and retrieval. Should always return healthy
|
||||
Performs a full cycle test of session creation and retrieval. Should always return healthy
|
||||
if the service and data layer are operational.
|
||||
|
||||
Returns:
|
||||
dict: A status dictionary indicating health, service name, and API version.
|
||||
|
||||
"""
|
||||
session = await chat_service.create_chat_session(None)
|
||||
await chat_service.assign_user_to_session(session.session_id, "test_user")
|
||||
await chat_service.get_session(session.session_id, "test_user")
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Ensure health check user exists (required for FK constraint)
|
||||
health_check_user_id = "health-check-user"
|
||||
await get_or_create_user(
|
||||
{
|
||||
"sub": health_check_user_id,
|
||||
"email": "health-check@system.local",
|
||||
"user_metadata": {"name": "Health Check User"},
|
||||
}
|
||||
)
|
||||
|
||||
# Create and retrieve session to verify full data layer
|
||||
session = await create_chat_session(health_check_user_id)
|
||||
await get_chat_session(session.session_id, health_check_user_id)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,18 +4,19 @@ from os import getenv
|
||||
import pytest
|
||||
|
||||
from . import service as chat_service
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import (
|
||||
StreamEnd,
|
||||
StreamError,
|
||||
StreamTextChunk,
|
||||
StreamToolExecutionResult,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion():
|
||||
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
@@ -23,7 +24,7 @@ async def test_stream_chat_completion():
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await chat_service.create_chat_session()
|
||||
session = await create_chat_session(test_user_id)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
@@ -34,9 +35,9 @@ async def test_stream_chat_completion():
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextChunk):
|
||||
assistant_message += chunk.content
|
||||
if isinstance(chunk, StreamEnd):
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
assistant_message += chunk.delta
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
@@ -45,7 +46,7 @@ async def test_stream_chat_completion():
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion_with_tool_calls():
|
||||
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
@@ -53,8 +54,8 @@ async def test_stream_chat_completion_with_tool_calls():
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await chat_service.create_chat_session()
|
||||
session = await chat_service.upsert_chat_session(session)
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
@@ -68,14 +69,14 @@ async def test_stream_chat_completion_with_tool_calls():
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
|
||||
if isinstance(chunk, StreamEnd):
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
if isinstance(chunk, StreamToolExecutionResult):
|
||||
if isinstance(chunk, StreamToolOutputAvailable):
|
||||
had_tool_calls = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await chat_service.get_session(session.session_id)
|
||||
session = await get_chat_session(session.session_id)
|
||||
assert session, "Session not found"
|
||||
assert session.usage, "Usage is empty"
|
||||
|
||||
@@ -4,21 +4,32 @@ from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .run_agent import RunAgentTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
# Initialize tool instances
|
||||
find_agent_tool = FindAgentTool()
|
||||
run_agent_tool = RunAgentTool()
|
||||
# Single source of truth for all tools
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"find_agent": FindAgentTool(),
|
||||
"find_library_agent": FindLibraryAgentTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"agent_output": AgentOutputTool(),
|
||||
}
|
||||
|
||||
# Export tools as OpenAI format
|
||||
# Export individual tool instances for backwards compatibility
|
||||
find_agent_tool = TOOL_REGISTRY["find_agent"]
|
||||
run_agent_tool = TOOL_REGISTRY["run_agent"]
|
||||
|
||||
# Generated from registry for OpenAI API
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
find_agent_tool.as_openai_tool(),
|
||||
run_agent_tool.as_openai_tool(),
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
|
||||
]
|
||||
|
||||
|
||||
@@ -28,14 +39,9 @@ async def execute_tool(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
) -> "StreamToolExecutionResult":
|
||||
|
||||
tool_map: dict[str, BaseTool] = {
|
||||
"find_agent": find_agent_tool,
|
||||
"run_agent": run_agent_tool,
|
||||
}
|
||||
if tool_name not in tool_map:
|
||||
) -> "StreamToolOutputAvailable":
|
||||
"""Execute a tool by name."""
|
||||
tool = TOOL_REGISTRY.get(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
return await tool_map[tool_name].execute(
|
||||
user_id, session, tool_call_id, **parameters
|
||||
)
|
||||
return await tool.execute(user_id, session, tool_call_id, **parameters)
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import UTC, datetime
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
@@ -17,7 +18,7 @@ from backend.data.user import get_or_create_user
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
|
||||
def make_session(user_id: str | None = None):
|
||||
def make_session(user_id: str):
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
@@ -49,13 +50,13 @@ async def setup_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create a test graph with agent input -> agent output
|
||||
@@ -172,13 +173,13 @@ async def setup_llm_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for LLM tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for LLM tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create test OpenAI credentials for the user
|
||||
@@ -332,13 +333,13 @@ async def setup_firecrawl_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for Firecrawl tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for Firecrawl tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
"""Tool for capturing user business understanding incrementally."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AddUnderstandingTool(BaseTool):
|
||||
"""Tool for capturing user's business understanding incrementally."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "add_understanding"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Capture and store information about the user's business context,
|
||||
workflows, pain points, and automation goals. Call this tool whenever the user
|
||||
shares information about their business. Each call incrementally adds to the
|
||||
existing understanding - you don't need to provide all fields at once.
|
||||
|
||||
Use this to build a comprehensive profile that helps recommend better agents
|
||||
and automations for the user's specific needs."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
# Auto-generate from Pydantic model schema
|
||||
schema = BusinessUnderstandingInput.model_json_schema()
|
||||
properties = {}
|
||||
for field_name, field_schema in schema.get("properties", {}).items():
|
||||
prop: dict[str, Any] = {"description": field_schema.get("description", "")}
|
||||
# Handle anyOf for Optional types
|
||||
if "anyOf" in field_schema:
|
||||
for option in field_schema["anyOf"]:
|
||||
if option.get("type") != "null":
|
||||
prop["type"] = option.get("type", "string")
|
||||
if "items" in option:
|
||||
prop["items"] = option["items"]
|
||||
break
|
||||
else:
|
||||
prop["type"] = field_schema.get("type", "string")
|
||||
if "items" in field_schema:
|
||||
prop["items"] = field_schema["items"]
|
||||
properties[field_name] = prop
|
||||
return {"type": "object", "properties": properties, "required": []}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""Requires authentication to store user-specific data."""
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Capture and store business understanding incrementally.
|
||||
|
||||
Each call merges new data with existing understanding:
|
||||
- String fields are overwritten if provided
|
||||
- List fields are appended (with deduplication)
|
||||
"""
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required to save business understanding.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if any data was provided
|
||||
if not any(v is not None for v in kwargs.values()):
|
||||
return ErrorResponse(
|
||||
message="Please provide at least one field to update.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build input model from kwargs (only include fields defined in the model)
|
||||
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
|
||||
input_data = BusinessUnderstandingInput(
|
||||
**{k: v for k, v in kwargs.items() if k in valid_fields}
|
||||
)
|
||||
|
||||
# Track which fields were updated
|
||||
updated_fields = [
|
||||
k for k, v in kwargs.items() if k in valid_fields and v is not None
|
||||
]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
|
||||
# Build current understanding summary (filter out empty values)
|
||||
current_understanding = {
|
||||
k: v
|
||||
for k, v in understanding.model_dump(
|
||||
exclude={"id", "user_id", "created_at", "updated_at"}
|
||||
).items()
|
||||
if v is not None and v != [] and v != ""
|
||||
}
|
||||
|
||||
return UnderstandingUpdatedResponse(
|
||||
message=f"Updated understanding with: {', '.join(updated_fields)}. "
|
||||
"I now have a better picture of your business context.",
|
||||
session_id=session_id,
|
||||
updated_fields=updated_fields,
|
||||
current_understanding=current_understanding,
|
||||
)
|
||||
@@ -0,0 +1,446 @@
|
||||
"""Tool for retrieving agent execution outputs from user's library."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentOutputResponse,
|
||||
ErrorResponse,
|
||||
ExecutionOutputInfo,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .utils import fetch_graph_from_store_slug
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentOutputInput(BaseModel):
|
||||
"""Input parameters for the agent_output tool."""
|
||||
|
||||
agent_name: str = ""
|
||||
library_agent_id: str = ""
|
||||
store_slug: str = ""
|
||||
execution_id: str = ""
|
||||
run_time: str = "latest"
|
||||
|
||||
@field_validator(
|
||||
"agent_name",
|
||||
"library_agent_id",
|
||||
"store_slug",
|
||||
"execution_id",
|
||||
"run_time",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> Any:
|
||||
"""Strip whitespace from string fields."""
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
|
||||
def parse_time_expression(
|
||||
time_expr: str | None,
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
Parse time expression into datetime range (start, end).
|
||||
|
||||
Supports: "latest", "yesterday", "today", "last week", "last 7 days",
|
||||
"last month", "last 30 days", ISO date "YYYY-MM-DD", ISO datetime.
|
||||
"""
|
||||
if not time_expr or time_expr.lower() == "latest":
|
||||
return None, None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
expr = time_expr.lower().strip()
|
||||
|
||||
# Relative time expressions lookup
|
||||
relative_times: dict[str, tuple[datetime, datetime]] = {
|
||||
"yesterday": (today_start - timedelta(days=1), today_start),
|
||||
"today": (today_start, now),
|
||||
"last week": (now - timedelta(days=7), now),
|
||||
"last 7 days": (now - timedelta(days=7), now),
|
||||
"last month": (now - timedelta(days=30), now),
|
||||
"last 30 days": (now - timedelta(days=30), now),
|
||||
}
|
||||
if expr in relative_times:
|
||||
return relative_times[expr]
|
||||
|
||||
# Try ISO date format (YYYY-MM-DD)
|
||||
date_match = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", expr)
|
||||
if date_match:
|
||||
try:
|
||||
year, month, day = map(int, date_match.groups())
|
||||
start = datetime(year, month, day, 0, 0, 0, tzinfo=timezone.utc)
|
||||
return start, start + timedelta(days=1)
|
||||
except ValueError:
|
||||
# Invalid date components (e.g., month=13, day=32)
|
||||
pass
|
||||
|
||||
# Try ISO datetime
|
||||
try:
|
||||
parsed = datetime.fromisoformat(expr.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed - timedelta(hours=1), parsed + timedelta(hours=1)
|
||||
except ValueError:
|
||||
return None, None
|
||||
|
||||
|
||||
class AgentOutputTool(BaseTool):
|
||||
"""Tool for retrieving execution outputs from user's library agents."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "agent_output"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Retrieve execution outputs from agents in the user's library.
|
||||
|
||||
Identify the agent using one of:
|
||||
- agent_name: Fuzzy search in user's library
|
||||
- library_agent_id: Exact library agent ID
|
||||
- store_slug: Marketplace format 'username/agent-name'
|
||||
|
||||
Select which run to retrieve using:
|
||||
- execution_id: Specific execution ID
|
||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||
"""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "Agent name to search for in user's library (fuzzy match)",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Exact library agent ID",
|
||||
},
|
||||
"store_slug": {
|
||||
"type": "string",
|
||||
"description": "Marketplace identifier: 'username/agent-slug'",
|
||||
},
|
||||
"execution_id": {
|
||||
"type": "string",
|
||||
"description": "Specific execution ID to retrieve",
|
||||
},
|
||||
"run_time": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _resolve_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_name: str | None,
|
||||
library_agent_id: str | None,
|
||||
store_slug: str | None,
|
||||
) -> tuple[LibraryAgent | None, str | None]:
|
||||
"""
|
||||
Resolve agent from provided identifiers.
|
||||
Returns (library_agent, error_message).
|
||||
"""
|
||||
# Priority 1: Exact library agent ID
|
||||
if library_agent_id:
|
||||
try:
|
||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||
return agent, None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||
return None, f"Library agent '{library_agent_id}' not found"
|
||||
|
||||
# Priority 2: Store slug (username/agent-name)
|
||||
if store_slug and "/" in store_slug:
|
||||
username, agent_slug = store_slug.split("/", 1)
|
||||
graph, _ = await fetch_graph_from_store_slug(username, agent_slug)
|
||||
if not graph:
|
||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||
|
||||
# Find in user's library by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
if not agent:
|
||||
return (
|
||||
None,
|
||||
f"Agent '{store_slug}' is not in your library. "
|
||||
"Add it first to see outputs.",
|
||||
)
|
||||
return agent, None
|
||||
|
||||
# Priority 3: Fuzzy name search in library
|
||||
if agent_name:
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=agent_name,
|
||||
page_size=5,
|
||||
)
|
||||
if not response.agents:
|
||||
return (
|
||||
None,
|
||||
f"No agents matching '{agent_name}' found in your library",
|
||||
)
|
||||
|
||||
# Return best match (first result from search)
|
||||
return response.agents[0], None
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching library agents: {e}")
|
||||
return None, f"Error searching for agent: {e}"
|
||||
|
||||
return (
|
||||
None,
|
||||
"Please specify an agent name, library_agent_id, or store_slug",
|
||||
)
|
||||
|
||||
async def _get_execution(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
execution_id: str | None,
|
||||
time_start: datetime | None,
|
||||
time_end: datetime | None,
|
||||
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
||||
"""
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
"""
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return None, [], f"Execution '{execution_id}' not found"
|
||||
return execution, [], None
|
||||
|
||||
# Get completed executions with time filters
|
||||
executions = await execution_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
created_time_gte=time_start,
|
||||
created_time_lte=time_end,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
if not executions:
|
||||
return None, [], None # No error, just no executions
|
||||
|
||||
# If only one execution, fetch full details
|
||||
if len(executions) == 1:
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, [], None
|
||||
|
||||
# Multiple executions - return latest with full details, plus list of available
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, executions, None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
agent: LibraryAgent,
|
||||
execution: GraphExecution | None,
|
||||
available_executions: list[GraphExecutionMeta],
|
||||
session_id: str | None,
|
||||
) -> AgentOutputResponse:
|
||||
"""Build the response based on execution data."""
|
||||
library_agent_link = f"/library/agents/{agent.id}"
|
||||
|
||||
if not execution:
|
||||
return AgentOutputResponse(
|
||||
message=f"No completed executions found for agent '{agent.name}'",
|
||||
session_id=session_id,
|
||||
agent_name=agent.name,
|
||||
agent_id=agent.graph_id,
|
||||
library_agent_id=agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
total_executions=0,
|
||||
)
|
||||
|
||||
execution_info = ExecutionOutputInfo(
|
||||
execution_id=execution.id,
|
||||
status=execution.status.value,
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
outputs=dict(execution.outputs),
|
||||
inputs_summary=execution.inputs if execution.inputs else None,
|
||||
)
|
||||
|
||||
available_list = None
|
||||
if len(available_executions) > 1:
|
||||
available_list = [
|
||||
{
|
||||
"id": e.id,
|
||||
"status": e.status.value,
|
||||
"started_at": e.started_at.isoformat() if e.started_at else None,
|
||||
}
|
||||
for e in available_executions[:5]
|
||||
]
|
||||
|
||||
message = f"Found execution outputs for agent '{agent.name}'"
|
||||
if len(available_executions) > 1:
|
||||
message += (
|
||||
f". Showing latest of {len(available_executions)} matching executions."
|
||||
)
|
||||
|
||||
return AgentOutputResponse(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
agent_name=agent.name,
|
||||
agent_id=agent.graph_id,
|
||||
library_agent_id=agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
execution=execution_info,
|
||||
available_executions=available_list,
|
||||
total_executions=len(available_executions) if available_executions else 1,
|
||||
)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the agent_output tool."""
|
||||
session_id = session.session_id
|
||||
|
||||
# Parse and validate input
|
||||
try:
|
||||
input_data = AgentOutputInput(**kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid input: {e}")
|
||||
return ErrorResponse(
|
||||
message="Invalid input parameters",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure user_id is present (should be guaranteed by requires_auth)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if at least one identifier is provided
|
||||
if not any(
|
||||
[
|
||||
input_data.agent_name,
|
||||
input_data.library_agent_id,
|
||||
input_data.store_slug,
|
||||
input_data.execution_id,
|
||||
]
|
||||
):
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Please specify at least one of: agent_name, "
|
||||
"library_agent_id, store_slug, or execution_id"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# If only execution_id provided, we need to find the agent differently
|
||||
if (
|
||||
input_data.execution_id
|
||||
and not input_data.agent_name
|
||||
and not input_data.library_agent_id
|
||||
and not input_data.store_slug
|
||||
):
|
||||
# Fetch execution directly to get graph_id
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return ErrorResponse(
|
||||
message=f"Execution '{input_data.execution_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Find library agent by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id, execution.graph_id
|
||||
)
|
||||
if not agent:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Execution found but agent not in your library. "
|
||||
f"Graph ID: {execution.graph_id}"
|
||||
),
|
||||
session_id=session_id,
|
||||
suggestions=["Add the agent to your library to see more details"],
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, [], session_id)
|
||||
|
||||
# Resolve agent from identifiers
|
||||
agent, error = await self._resolve_agent(
|
||||
user_id=user_id,
|
||||
agent_name=input_data.agent_name or None,
|
||||
library_agent_id=input_data.library_agent_id or None,
|
||||
store_slug=input_data.store_slug or None,
|
||||
)
|
||||
|
||||
if error or not agent:
|
||||
return NoResultsResponse(
|
||||
message=error or "Agent not found",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Check the agent name or ID",
|
||||
"Make sure the agent is in your library",
|
||||
],
|
||||
)
|
||||
|
||||
# Parse time expression
|
||||
time_start, time_end = parse_time_expression(input_data.run_time)
|
||||
|
||||
# Fetch execution(s)
|
||||
execution, available_executions, exec_error = await self._get_execution(
|
||||
user_id=user_id,
|
||||
graph_id=agent.graph_id,
|
||||
execution_id=input_data.execution_id or None,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
)
|
||||
|
||||
if exec_error:
|
||||
return ErrorResponse(
|
||||
message=exec_error,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, available_executions, session_id)
|
||||
@@ -0,0 +1,151 @@
|
||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
AgentInfo,
|
||||
AgentsFoundResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchSource = Literal["marketplace", "library"]
|
||||
|
||||
|
||||
async def search_agents(
|
||||
query: str,
|
||||
source: SearchSource,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Search for agents in marketplace or user library.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
source: "marketplace" or "library"
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (required for library search)
|
||||
|
||||
Returns:
|
||||
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
|
||||
"""
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query", session_id=session_id
|
||||
)
|
||||
|
||||
if source == "library" and not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agents: list[AgentInfo] = []
|
||||
try:
|
||||
if source == "marketplace":
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=f"{agent.creator}/{agent.slug}",
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
)
|
||||
)
|
||||
else: # library
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
)
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching {source}: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to search {source}. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
suggestions = (
|
||||
[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
]
|
||||
if source == "marketplace"
|
||||
else [
|
||||
"Try different keywords",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
]
|
||||
)
|
||||
no_results_msg = (
|
||||
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
||||
if source == "marketplace"
|
||||
else f"No agents matching '{query}' found in your library."
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||
)
|
||||
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
|
||||
title += (
|
||||
f"for '{query}'"
|
||||
if source == "marketplace"
|
||||
else f"in your library for '{query}'"
|
||||
)
|
||||
|
||||
message = (
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents."
|
||||
if source == "marketplace"
|
||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
message=message,
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
@@ -53,7 +53,7 @@ class BaseTool:
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
**kwargs,
|
||||
) -> StreamToolExecutionResult:
|
||||
) -> StreamToolOutputAvailable:
|
||||
"""Execute the tool with authentication check.
|
||||
|
||||
Args:
|
||||
@@ -69,10 +69,10 @@ class BaseTool:
|
||||
logger.error(
|
||||
f"Attempted tool call for {self.name} but user not authenticated"
|
||||
)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=NeedLoginResponse(
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=NeedLoginResponse(
|
||||
message=f"Please sign in to use {self.name}",
|
||||
session_id=session.session_id,
|
||||
).model_dump_json(),
|
||||
@@ -81,17 +81,17 @@ class BaseTool:
|
||||
|
||||
try:
|
||||
result = await self._execute(user_id, session, **kwargs)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=result.model_dump_json(),
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=result.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=ErrorResponse(
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=ErrorResponse(
|
||||
message=f"An error occurred while executing {self.name}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
|
||||
@@ -1,26 +1,16 @@
|
||||
"""Tool for discovering agents from marketplace and user library."""
|
||||
"""Tool for discovering agents from marketplace."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentCarouselResponse,
|
||||
AgentInfo,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .models import ToolResponseBase
|
||||
|
||||
|
||||
class FindAgentTool(BaseTool):
|
||||
"""Tool for discovering agents based on user needs."""
|
||||
"""Tool for discovering agents from the marketplace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -46,84 +36,11 @@ class FindAgentTool(BaseTool):
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Search for agents in the marketplace.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous)
|
||||
session_id: Chat session ID
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
AgentCarouselResponse: List of agents found in the marketplace
|
||||
NoResultsResponse: No agents found in the marketplace
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
agents = []
|
||||
try:
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
store_results = await store_db.get_store_agents(
|
||||
search_query=query,
|
||||
page_size=5,
|
||||
)
|
||||
|
||||
logger.info(f"Find agents tool found {len(store_results.agents)} agents")
|
||||
for agent in store_results.agents:
|
||||
agent_id = f"{agent.creator}/{agent.slug}"
|
||||
logger.info(f"Building agent ID = {agent_id}")
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent_id,
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
),
|
||||
)
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching agents: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search for agents. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
if not agents:
|
||||
return NoResultsResponse(
|
||||
message=f"No agents found matching '{query}'. Try different keywords or browse the marketplace. If you have 3 consecutive find_agent tool calls results and found no agents. Please stop trying and ask the user if there is anything else you can help with.",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
],
|
||||
)
|
||||
|
||||
# Return formatted carousel
|
||||
title = (
|
||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
|
||||
)
|
||||
return AgentCarouselResponse(
|
||||
message="Now you have found some options for the user to choose from. You can add a link to a recommended agent at: /marketplace/agent/agent_id Please ask the user if they would like to use any of these agents. If they do, please call the get_agent_details tool for this agent.",
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
return await search_agents(
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="marketplace",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
from .models import ToolResponseBase
|
||||
|
||||
|
||||
class FindLibraryAgentTool(BaseTool):
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_library_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for agents in the user's library. Use this to find agents "
|
||||
"the user has already added to their library, including agents they "
|
||||
"created or added from the marketplace."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query to find agents by name or description.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
return await search_agents(
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="library",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Pydantic models for tool responses."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
@@ -11,14 +12,15 @@ from backend.data.model import CredentialsMetaInput
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of tool responses."""
|
||||
|
||||
AGENT_CAROUSEL = "agent_carousel"
|
||||
AGENTS_FOUND = "agents_found"
|
||||
AGENT_DETAILS = "agent_details"
|
||||
SETUP_REQUIREMENTS = "setup_requirements"
|
||||
EXECUTION_STARTED = "execution_started"
|
||||
NEED_LOGIN = "need_login"
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
SUCCESS = "success"
|
||||
AGENT_OUTPUT = "agent_output"
|
||||
UNDERSTANDING_UPDATED = "understanding_updated"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -51,14 +53,14 @@ class AgentInfo(BaseModel):
|
||||
graph_id: str | None = None
|
||||
|
||||
|
||||
class AgentCarouselResponse(ToolResponseBase):
|
||||
class AgentsFoundResponse(ToolResponseBase):
|
||||
"""Response for find_agent tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_CAROUSEL
|
||||
type: ResponseType = ResponseType.AGENTS_FOUND
|
||||
title: str = "Available Agents"
|
||||
agents: list[AgentInfo]
|
||||
count: int
|
||||
name: str = "agent_carousel"
|
||||
name: str = "agents_found"
|
||||
|
||||
|
||||
class NoResultsResponse(ToolResponseBase):
|
||||
@@ -173,3 +175,37 @@ class ErrorResponse(ToolResponseBase):
|
||||
type: ResponseType = ResponseType.ERROR
|
||||
error: str | None = None
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# Agent output models
|
||||
class ExecutionOutputInfo(BaseModel):
|
||||
"""Summary of a single execution's outputs."""
|
||||
|
||||
execution_id: str
|
||||
status: str
|
||||
started_at: datetime | None = None
|
||||
ended_at: datetime | None = None
|
||||
outputs: dict[str, list[Any]]
|
||||
inputs_summary: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AgentOutputResponse(ToolResponseBase):
|
||||
"""Response for agent_output tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_OUTPUT
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
library_agent_id: str | None = None
|
||||
library_agent_link: str | None = None
|
||||
execution: ExecutionOutputInfo | None = None
|
||||
available_executions: list[dict[str, Any]] | None = None
|
||||
total_executions: int = 0
|
||||
|
||||
|
||||
# Business understanding models
|
||||
class UnderstandingUpdatedResponse(ToolResponseBase):
|
||||
"""Response for add_understanding tool."""
|
||||
|
||||
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
|
||||
updated_fields: list[str] = Field(default_factory=list)
|
||||
current_understanding: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
@@ -57,6 +58,7 @@ class RunAgentInput(BaseModel):
|
||||
"""Input parameters for the run_agent tool."""
|
||||
|
||||
username_agent_slug: str = ""
|
||||
library_agent_id: str = ""
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
use_defaults: bool = False
|
||||
schedule_name: str = ""
|
||||
@@ -64,7 +66,12 @@ class RunAgentInput(BaseModel):
|
||||
timezone: str = "UTC"
|
||||
|
||||
@field_validator(
|
||||
"username_agent_slug", "schedule_name", "cron", "timezone", mode="before"
|
||||
"username_agent_slug",
|
||||
"library_agent_id",
|
||||
"schedule_name",
|
||||
"cron",
|
||||
"timezone",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> Any:
|
||||
@@ -90,7 +97,7 @@ class RunAgentTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Run or schedule an agent from the marketplace.
|
||||
return """Run or schedule an agent from the marketplace or user's library.
|
||||
|
||||
The tool automatically handles the setup flow:
|
||||
- Returns missing inputs if required fields are not provided
|
||||
@@ -98,6 +105,10 @@ class RunAgentTool(BaseTool):
|
||||
- Executes immediately if all requirements are met
|
||||
- Schedules execution if cron expression is provided
|
||||
|
||||
Identify the agent using either:
|
||||
- username_agent_slug: Marketplace format 'username/agent-name'
|
||||
- library_agent_id: ID of an agent in the user's library
|
||||
|
||||
For scheduled execution, provide: schedule_name, cron, and optionally timezone."""
|
||||
|
||||
@property
|
||||
@@ -109,6 +120,10 @@ class RunAgentTool(BaseTool):
|
||||
"type": "string",
|
||||
"description": "Agent identifier in format 'username/agent-name'",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Library agent ID from user's library",
|
||||
},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": "Input values for the agent",
|
||||
@@ -131,7 +146,7 @@ class RunAgentTool(BaseTool):
|
||||
"description": "IANA timezone for schedule (default: UTC)",
|
||||
},
|
||||
},
|
||||
"required": ["username_agent_slug"],
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -149,10 +164,16 @@ class RunAgentTool(BaseTool):
|
||||
params = RunAgentInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
# Validate agent slug format
|
||||
if not params.username_agent_slug or "/" not in params.username_agent_slug:
|
||||
# Validate at least one identifier is provided
|
||||
has_slug = params.username_agent_slug and "/" in params.username_agent_slug
|
||||
has_library_id = bool(params.library_agent_id)
|
||||
|
||||
if not has_slug and not has_library_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide an agent slug in format 'username/agent-name'",
|
||||
message=(
|
||||
"Please provide either a username_agent_slug "
|
||||
"(format 'username/agent-name') or a library_agent_id"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -167,13 +188,41 @@ class RunAgentTool(BaseTool):
|
||||
is_schedule = bool(params.schedule_name or params.cron)
|
||||
|
||||
try:
|
||||
# Step 1: Fetch agent details (always happens first)
|
||||
username, agent_name = params.username_agent_slug.split("/", 1)
|
||||
graph, store_agent = await fetch_graph_from_store_slug(username, agent_name)
|
||||
# Step 1: Fetch agent details
|
||||
graph: GraphModel | None = None
|
||||
library_agent = None
|
||||
|
||||
# Priority: library_agent_id if provided
|
||||
if has_library_id:
|
||||
library_agent = await library_db.get_library_agent(
|
||||
params.library_agent_id, user_id
|
||||
)
|
||||
if not library_agent:
|
||||
return ErrorResponse(
|
||||
message=f"Library agent '{params.library_agent_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id,
|
||||
library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
else:
|
||||
# Fetch from marketplace slug
|
||||
username, agent_name = params.username_agent_slug.split("/", 1)
|
||||
graph, _ = await fetch_graph_from_store_slug(username, agent_name)
|
||||
|
||||
if not graph:
|
||||
identifier = (
|
||||
params.library_agent_id
|
||||
if has_library_id
|
||||
else params.username_agent_slug
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Agent '{params.username_agent_slug}' not found in marketplace",
|
||||
message=f"Agent '{identifier}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
@@ -17,6 +18,17 @@ setup_test_data = setup_test_data
|
||||
setup_firecrawl_test_data = setup_firecrawl_test_data
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_run_agent(setup_test_data):
|
||||
"""Test that the run_agent tool successfully executes an approved agent"""
|
||||
@@ -46,11 +58,11 @@ async def test_run_agent(setup_test_data):
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
# Parse the result JSON to verify the execution started
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
assert "execution_id" in result_data
|
||||
assert "graph_id" in result_data
|
||||
assert result_data["graph_id"] == graph.id
|
||||
@@ -86,11 +98,11 @@ async def test_run_agent_missing_inputs(setup_test_data):
|
||||
|
||||
# Verify that we get an error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
# The tool should return an ErrorResponse when setup info indicates not ready
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
assert "message" in result_data
|
||||
|
||||
|
||||
@@ -118,10 +130,10 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
|
||||
|
||||
# Verify that we get an error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
assert "message" in result_data
|
||||
# Should get an error about failed setup or not found
|
||||
assert any(
|
||||
@@ -158,12 +170,12 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
|
||||
# Parse the result JSON to verify the execution started
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should successfully start execution since credentials are available
|
||||
assert "execution_id" in result_data
|
||||
@@ -195,9 +207,9 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return agent_details type showing available inputs
|
||||
assert result_data.get("type") == "agent_details"
|
||||
@@ -230,9 +242,9 @@ async def test_run_agent_with_use_defaults(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should execute successfully
|
||||
assert "execution_id" in result_data
|
||||
@@ -260,9 +272,9 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return setup_requirements type with missing credentials
|
||||
assert result_data.get("type") == "setup_requirements"
|
||||
@@ -292,9 +304,9 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return error
|
||||
assert result_data.get("type") == "error"
|
||||
@@ -305,9 +317,10 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||
async def test_run_agent_unauthenticated():
|
||||
"""Test that run_agent returns need_login for unauthenticated users."""
|
||||
tool = RunAgentTool()
|
||||
session = make_session(user_id=None)
|
||||
# Session has a user_id (session owner), but we test tool execution without user_id
|
||||
session = make_session(user_id="test-session-owner")
|
||||
|
||||
# Execute without user_id
|
||||
# Execute without user_id to test unauthenticated behavior
|
||||
response = await tool.execute(
|
||||
user_id=None,
|
||||
session_id=str(uuid.uuid4()),
|
||||
@@ -318,9 +331,9 @@ async def test_run_agent_unauthenticated():
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Base tool returns need_login type for unauthenticated users
|
||||
assert result_data.get("type") == "need_login"
|
||||
@@ -350,9 +363,9 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return error about missing cron
|
||||
assert result_data.get("type") == "error"
|
||||
@@ -382,9 +395,9 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return error about missing schedule_name
|
||||
assert result_data.get("type") == "error"
|
||||
|
||||
@@ -35,11 +35,7 @@ from backend.data.model import (
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
)
|
||||
from backend.data.onboarding import (
|
||||
OnboardingStep,
|
||||
complete_onboarding_step,
|
||||
increment_runs,
|
||||
)
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
@@ -175,6 +171,7 @@ async def callback(
|
||||
f"Successfully processed OAuth callback for user {user_id} "
|
||||
f"and provider {provider.value}"
|
||||
)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
@@ -193,6 +190,7 @@ async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
@@ -215,6 +213,7 @@ async def list_credentials_by_provider(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
@@ -378,7 +377,6 @@ async def webhook_ingress_generic(
|
||||
return
|
||||
|
||||
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
|
||||
await increment_runs(user_id)
|
||||
|
||||
# Execute all triggers concurrently for better performance
|
||||
tasks = []
|
||||
@@ -831,6 +829,18 @@ async def list_providers() -> List[str]:
|
||||
return all_providers
|
||||
|
||||
|
||||
@router.get("/providers/system", response_model=List[str])
|
||||
async def list_system_providers() -> List[str]:
|
||||
"""
|
||||
Get a list of providers that have platform credits (system credentials) available.
|
||||
|
||||
These providers can be used without the user providing their own API keys.
|
||||
"""
|
||||
from backend.integrations.credentials_store import SYSTEM_PROVIDERS
|
||||
|
||||
return list(SYSTEM_PROVIDERS)
|
||||
|
||||
|
||||
@router.get("/providers/names", response_model=ProviderNamesResponse)
|
||||
async def get_provider_names() -> ProviderNamesResponse:
|
||||
"""
|
||||
|
||||
@@ -8,7 +8,6 @@ from backend.data.execution import GraphExecutionMeta
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.integrations import get_webhook
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.onboarding import increment_runs
|
||||
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -403,8 +402,6 @@ async def execute_preset(
|
||||
merged_node_input = preset.inputs | inputs
|
||||
merged_credential_inputs = preset.credentials | credential_inputs
|
||||
|
||||
await increment_runs(user_id)
|
||||
|
||||
return await add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=preset.graph_id,
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import fastapi
|
||||
import prisma.enums
|
||||
@@ -10,7 +9,7 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
from backend.data.db import query_raw_with_schema, transaction
|
||||
from backend.data.db import transaction
|
||||
from backend.data.graph import (
|
||||
GraphMeta,
|
||||
GraphModel,
|
||||
@@ -30,6 +29,8 @@ from backend.util.settings import Settings
|
||||
|
||||
from . import exceptions as store_exceptions
|
||||
from . import model as store_model
|
||||
from .embeddings import ensure_embedding
|
||||
from .hybrid_search import hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -50,128 +51,77 @@ async def get_store_agents(
|
||||
page_size: int = 20,
|
||||
) -> store_model.StoreAgentsResponse:
|
||||
"""
|
||||
Get PUBLIC store agents from the StoreAgent view
|
||||
Get PUBLIC store agents from the StoreAgent view.
|
||||
|
||||
Search behavior:
|
||||
- With search_query: Uses hybrid search (semantic + lexical)
|
||||
- Fallback: If embeddings unavailable, gracefully degrades to lexical-only
|
||||
- Rationale: User-facing endpoint prioritizes availability over accuracy
|
||||
|
||||
Note: Admin operations (approval) use fail-fast to prevent inconsistent state.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
|
||||
search_used_hybrid = False
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
agents: list[dict[str, Any]] = []
|
||||
total = 0
|
||||
total_pages = 0
|
||||
|
||||
try:
|
||||
# If search_query is provided, use full-text search
|
||||
# If search_query is provided, use hybrid search (embeddings + tsvector)
|
||||
if search_query:
|
||||
offset = (page - 1) * page_size
|
||||
# Try hybrid search combining semantic and lexical signals
|
||||
# Falls back to lexical-only if OpenAI unavailable (user-facing, high SLA)
|
||||
try:
|
||||
agents, total = await hybrid_search(
|
||||
query=search_query,
|
||||
featured=featured,
|
||||
creators=creators,
|
||||
category=category,
|
||||
sorted_by="relevance", # Use hybrid scoring for relevance
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
search_used_hybrid = True
|
||||
except Exception as e:
|
||||
# Log error but fall back to lexical search for better UX
|
||||
logger.error(
|
||||
f"Hybrid search failed (likely OpenAI unavailable), "
|
||||
f"falling back to lexical search: {e}"
|
||||
)
|
||||
# search_used_hybrid remains False, will use fallback path below
|
||||
|
||||
# Whitelist allowed order_by columns
|
||||
ALLOWED_ORDER_BY = {
|
||||
"rating": "rating DESC, rank DESC",
|
||||
"runs": "runs DESC, rank DESC",
|
||||
"name": "agent_name ASC, rank ASC",
|
||||
"updated_at": "updated_at DESC, rank DESC",
|
||||
}
|
||||
# Convert hybrid search results (dict format) if hybrid succeeded
|
||||
if search_used_hybrid:
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = store_model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error parsing Store agent from hybrid search results: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate and get order clause
|
||||
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
||||
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
||||
else:
|
||||
order_by_clause = "updated_at DESC, rank DESC"
|
||||
|
||||
# Build WHERE conditions and parameters list
|
||||
where_parts: list[str] = []
|
||||
params: list[typing.Any] = [search_query] # $1 - search term
|
||||
param_index = 2 # Start at $2 for next parameter
|
||||
|
||||
# Always filter for available agents
|
||||
where_parts.append("is_available = true")
|
||||
|
||||
if featured:
|
||||
where_parts.append("featured = true")
|
||||
|
||||
if creators and creators:
|
||||
# Use ANY with array parameter
|
||||
where_parts.append(f"creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category and category:
|
||||
where_parts.append(f"${param_index} = ANY(categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
limit_param = f"${param_index}"
|
||||
offset_param = f"${param_index + 1}"
|
||||
|
||||
# Execute full-text search query with parameterized values
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
ts_rank_cd(search, query) AS rank
|
||||
FROM {{schema_prefix}}"StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
ORDER BY {order_by_clause}
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
# Count query for pagination - only uses search term parameter
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
"""
|
||||
|
||||
# Execute both queries with parameters
|
||||
agents = await query_raw_with_schema(sql_query, *params)
|
||||
|
||||
# For count, use params without pagination (last 2 params)
|
||||
count_params = params[:-2]
|
||||
count_result = await query_raw_with_schema(count_query, *count_params)
|
||||
|
||||
total = count_result[0]["count"] if count_result else 0
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert raw results to StoreAgent models
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = store_model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Store agent from search results: {e}")
|
||||
continue
|
||||
|
||||
else:
|
||||
# Non-search query path (original logic)
|
||||
if not search_used_hybrid:
|
||||
# Fallback path - use basic search or no search
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
@@ -180,6 +130,14 @@ async def get_store_agents(
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
# Add basic text search if search_query provided but hybrid failed
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
@@ -188,7 +146,7 @@ async def get_store_agents(
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
@@ -199,7 +157,7 @@ async def get_store_agents(
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
for agent in db_agents:
|
||||
try:
|
||||
# Create the StoreAgent object safely
|
||||
store_agent = store_model.StoreAgent(
|
||||
@@ -1577,7 +1535,7 @@ async def review_store_submission(
|
||||
)
|
||||
|
||||
# Update the AgentGraph with store listing data
|
||||
await prisma.models.AgentGraph.prisma().update(
|
||||
await prisma.models.AgentGraph.prisma(tx).update(
|
||||
where={
|
||||
"graphVersionId": {
|
||||
"id": store_listing_version.agentGraphId,
|
||||
@@ -1592,6 +1550,23 @@ async def review_store_submission(
|
||||
},
|
||||
)
|
||||
|
||||
# Generate embedding for approved listing (blocking - admin operation)
|
||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||
embedding_success = await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=store_listing_version.name,
|
||||
description=store_listing_version.description,
|
||||
sub_heading=store_listing_version.subHeading,
|
||||
categories=store_listing_version.categories or [],
|
||||
tx=tx,
|
||||
)
|
||||
if not embedding_success:
|
||||
raise ValueError(
|
||||
f"Failed to generate embedding for listing {store_listing_version_id}. "
|
||||
"This is likely due to OpenAI API being unavailable. "
|
||||
"Please try again later or contact support if the issue persists."
|
||||
)
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
data={
|
||||
|
||||
@@ -0,0 +1,568 @@
|
||||
"""
|
||||
Unified Content Embeddings Service
|
||||
|
||||
Handles generation and storage of OpenAI embeddings for all content types
|
||||
(store listings, blocks, documentation, library agents) to enable semantic/hybrid search.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.json import dumps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
# OpenAI embedding token limit (8,191 with 1 token buffer for safety)
|
||||
EMBEDDING_MAX_TOKENS = 8191
|
||||
|
||||
|
||||
def build_searchable_text(
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Build searchable text from listing version fields.
|
||||
|
||||
Combines relevant fields into a single string for embedding.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Name is important - include it
|
||||
if name:
|
||||
parts.append(name)
|
||||
|
||||
# Sub-heading provides context
|
||||
if sub_heading:
|
||||
parts.append(sub_heading)
|
||||
|
||||
# Description is the main content
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
# Categories help with semantic matching
|
||||
if categories:
|
||||
parts.append(" ".join(categories))
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
async def generate_embedding(text: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for text using OpenAI API.
|
||||
|
||||
Returns None if embedding generation fails.
|
||||
Fail-fast: no retries to maintain consistency with approval flow.
|
||||
"""
|
||||
try:
|
||||
client = get_openai_client()
|
||||
if not client:
|
||||
logger.error("openai_internal_api_key not set, cannot generate embedding")
|
||||
return None
|
||||
|
||||
# Truncate text to token limit using tiktoken
|
||||
# Character-based truncation is insufficient because token ratios vary by content type
|
||||
enc = encoding_for_model(EMBEDDING_MODEL)
|
||||
tokens = enc.encode(text)
|
||||
if len(tokens) > EMBEDDING_MAX_TOKENS:
|
||||
tokens = tokens[:EMBEDDING_MAX_TOKENS]
|
||||
truncated_text = enc.decode(tokens)
|
||||
logger.info(
|
||||
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
|
||||
)
|
||||
else:
|
||||
truncated_text = text
|
||||
|
||||
start_time = time.time()
|
||||
response = await client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=truncated_text,
|
||||
)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
embedding = response.data[0].embedding
|
||||
logger.info(
|
||||
f"Generated embedding: {len(embedding)} dims, "
|
||||
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
|
||||
)
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embedding: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def store_embedding(
|
||||
version_id: str,
|
||||
embedding: list[float],
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the database.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
DEPRECATED: Use ensure_embedding() instead (includes searchable_text).
|
||||
"""
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text="", # Empty for backward compat; ensure_embedding() populates this
|
||||
metadata=None,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
|
||||
async def store_content_embedding(
|
||||
content_type: ContentType,
|
||||
content_id: str,
|
||||
embedding: list[float],
|
||||
searchable_text: str,
|
||||
metadata: dict | None = None,
|
||||
user_id: str | None = None,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the unified content embeddings table.
|
||||
|
||||
New function for unified content embedding storage.
|
||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||
"""
|
||||
try:
|
||||
client = tx if tx else prisma.get_client()
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = embedding_to_vector_string(embedding)
|
||||
metadata_json = dumps(metadata or {})
|
||||
|
||||
# Upsert the embedding
|
||||
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
||||
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
||||
ON CONFLICT ("contentType", "contentId", "userId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $4::vector,
|
||||
"searchableText" = $5,
|
||||
"metadata" = $6::jsonb,
|
||||
"updatedAt" = NOW()
|
||||
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
|
||||
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
embedding_str,
|
||||
searchable_text,
|
||||
metadata_json,
|
||||
client=client,
|
||||
set_public_search_path=True,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for a listing version.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Returns dict with storeListingVersionId, embedding, timestamps or None if not found.
|
||||
"""
|
||||
result = await get_content_embedding(
|
||||
ContentType.STORE_AGENT, version_id, user_id=None
|
||||
)
|
||||
if result:
|
||||
# Transform to old format for backward compatibility
|
||||
return {
|
||||
"storeListingVersionId": result["contentId"],
|
||||
"embedding": result["embedding"],
|
||||
"createdAt": result["createdAt"],
|
||||
"updatedAt": result["updatedAt"],
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def get_content_embedding(
|
||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for any content type.
|
||||
|
||||
New function for unified content embedding retrieval.
|
||||
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
||||
"""
|
||||
try:
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
"contentType",
|
||||
"contentId",
|
||||
"userId",
|
||||
"embedding"::text as "embedding",
|
||||
"searchableText",
|
||||
"metadata",
|
||||
"createdAt",
|
||||
"updatedAt"
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
set_public_search_path=True,
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def ensure_embedding(
|
||||
version_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for the listing version.
|
||||
|
||||
Creates embedding if missing. Use force=True to regenerate.
|
||||
Backward-compatible wrapper for store listings.
|
||||
|
||||
Args:
|
||||
version_id: The StoreListingVersion ID
|
||||
name: Agent name
|
||||
description: Agent description
|
||||
sub_heading: Agent sub-heading
|
||||
categories: Agent categories
|
||||
force: Force regeneration even if embedding exists
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_embedding(version_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(f"Embedding for version {version_id} already exists")
|
||||
return True
|
||||
|
||||
# Build searchable text for embedding
|
||||
searchable_text = build_searchable_text(
|
||||
name, description, sub_heading, categories
|
||||
)
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
||||
return False
|
||||
|
||||
# Store the embedding with metadata using new function
|
||||
metadata = {
|
||||
"name": name,
|
||||
"subHeading": sub_heading,
|
||||
"categories": categories,
|
||||
}
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_embedding(version_id: str) -> bool:
|
||||
"""
|
||||
Delete embedding for a listing version.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
"""
|
||||
return await delete_content_embedding(ContentType.STORE_AGENT, version_id)
|
||||
|
||||
|
||||
async def delete_content_embedding(
|
||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Delete embedding for any content type.
|
||||
|
||||
New function for unified content embedding deletion.
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
|
||||
Args:
|
||||
content_type: The type of content (STORE_AGENT, LIBRARY_AGENT, etc.)
|
||||
content_id: The unique identifier for the content
|
||||
user_id: Optional user ID. For public content (STORE_AGENT, BLOCK), pass None.
|
||||
For user-scoped content (LIBRARY_AGENT), pass the user's ID to avoid
|
||||
deleting embeddings belonging to other users.
|
||||
|
||||
Returns:
|
||||
True if deletion succeeded, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND "contentId" = $2
|
||||
AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
client=client,
|
||||
)
|
||||
|
||||
user_str = f" (user: {user_id})" if user_id else ""
|
||||
logger.info(f"Deleted embedding for {content_type}:{content_id}{user_str}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding_stats() -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
|
||||
Returns counts of:
|
||||
- Total approved listing versions
|
||||
- Versions with embeddings
|
||||
- Versions without embeddings
|
||||
"""
|
||||
try:
|
||||
# Count approved versions
|
||||
approved_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total_approved": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
"coverage_percent": (
|
||||
round(with_embeddings / total_approved * 100, 1)
|
||||
if total_approved > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding stats: {e}")
|
||||
return {
|
||||
"total_approved": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for approved listings that don't have them.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate in one call
|
||||
|
||||
Returns:
|
||||
Dict with success/failure counts
|
||||
"""
|
||||
try:
|
||||
# Find approved versions without embeddings
|
||||
missing = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
|
||||
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND uce."contentId" IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
|
||||
if not missing:
|
||||
return {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"message": "No missing embeddings",
|
||||
}
|
||||
|
||||
# Process embeddings concurrently for better performance
|
||||
embedding_tasks = [
|
||||
ensure_embedding(
|
||||
version_id=row["id"],
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
)
|
||||
for row in missing
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
|
||||
|
||||
success = sum(1 for result in results if result is True)
|
||||
failed = len(results) - success
|
||||
|
||||
return {
|
||||
"processed": len(missing),
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"message": f"Backfilled {success} embeddings, {failed} failed",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill embeddings: {e}")
|
||||
return {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def embed_query(query: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for a search query.
|
||||
|
||||
Same as generate_embedding but with clearer intent.
|
||||
"""
|
||||
return await generate_embedding(query)
|
||||
|
||||
|
||||
def embedding_to_vector_string(embedding: list[float]) -> str:
|
||||
"""Convert embedding list to PostgreSQL vector string format."""
|
||||
return "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
|
||||
|
||||
async def ensure_content_embedding(
|
||||
content_type: ContentType,
|
||||
content_id: str,
|
||||
searchable_text: str,
|
||||
metadata: dict | None = None,
|
||||
user_id: str | None = None,
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for any content type.
|
||||
|
||||
Generic function for creating embeddings for store agents, blocks, docs, etc.
|
||||
|
||||
Args:
|
||||
content_type: ContentType enum value (STORE_AGENT, BLOCK, etc.)
|
||||
content_id: Unique identifier for the content
|
||||
searchable_text: Combined text for embedding generation
|
||||
metadata: Optional metadata to store with embedding
|
||||
force: Force regeneration even if embedding exists
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_content_embedding(content_type, content_id, user_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(
|
||||
f"Embedding for {content_type}:{content_id} already exists"
|
||||
)
|
||||
return True
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for {content_type}:{content_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Store the embedding
|
||||
return await store_content_embedding(
|
||||
content_type=content_type,
|
||||
content_id=content_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata or {},
|
||||
user_id=user_id,
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
Integration tests for embeddings with schema handling.
|
||||
|
||||
These tests verify that embeddings operations work correctly across different database schemas.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
|
||||
# Schema prefix tests removed - functionality moved to db.raw_with_schema() helper
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_store_content_embedding_with_schema():
|
||||
"""Test storing embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1] * 1536,
|
||||
searchable_text="test text",
|
||||
metadata={"test": "data"},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.execute_raw.called
|
||||
|
||||
# Get the SQL query that was executed
|
||||
call_args = mock_client.execute_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix is in the query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify result
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_get_content_embedding_with_schema():
|
||||
"""Test retrieving embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.query_raw.return_value = [
|
||||
{
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-id",
|
||||
"userId": None,
|
||||
"embedding": "[0.1, 0.2]",
|
||||
"searchableText": "test",
|
||||
"metadata": {},
|
||||
"createdAt": "2024-01-01",
|
||||
"updatedAt": "2024-01-01",
|
||||
}
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.get_content_embedding(
|
||||
ContentType.STORE_AGENT,
|
||||
"test-id",
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.query_raw.called
|
||||
|
||||
# Get the SQL query that was executed
|
||||
call_args = mock_client.query_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix is in the query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify result
|
||||
assert result is not None
|
||||
assert result["contentId"] == "test-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_delete_content_embedding_with_schema():
|
||||
"""Test deleting embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.delete_content_embedding(
|
||||
ContentType.STORE_AGENT,
|
||||
"test-id",
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.execute_raw.called
|
||||
|
||||
# Get the SQL query that was executed
|
||||
call_args = mock_client.execute_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix is in the query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify result
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_get_embedding_stats_with_schema():
|
||||
"""Test embedding statistics with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
# Mock both query results
|
||||
mock_client.query_raw.side_effect = [
|
||||
[{"count": 100}], # total_approved
|
||||
[{"count": 80}], # with_embeddings
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.get_embedding_stats()
|
||||
|
||||
# Verify both queries were called
|
||||
assert mock_client.query_raw.call_count == 2
|
||||
|
||||
# Get both SQL queries
|
||||
first_call = mock_client.query_raw.call_args_list[0]
|
||||
second_call = mock_client.query_raw.call_args_list[1]
|
||||
|
||||
first_sql = first_call[0][0]
|
||||
second_sql = second_call[0][0]
|
||||
|
||||
# Verify schema prefix in both queries
|
||||
assert '"platform"."StoreListingVersion"' in first_sql
|
||||
assert '"platform"."StoreListingVersion"' in second_sql
|
||||
assert '"platform"."UnifiedContentEmbedding"' in second_sql
|
||||
|
||||
# Verify results
|
||||
assert result["total_approved"] == 100
|
||||
assert result["with_embeddings"] == 80
|
||||
assert result["without_embeddings"] == 20
|
||||
assert result["coverage_percent"] == 80.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backfill_missing_embeddings_with_schema():
|
||||
"""Test backfilling embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
# Mock missing embeddings query
|
||||
mock_client.query_raw.return_value = [
|
||||
{
|
||||
"id": "version-1",
|
||||
"name": "Test Agent",
|
||||
"description": "Test description",
|
||||
"subHeading": "Test heading",
|
||||
"categories": ["test"],
|
||||
}
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.ensure_embedding"
|
||||
) as mock_ensure:
|
||||
mock_ensure.return_value = True
|
||||
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=10)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.query_raw.called
|
||||
|
||||
# Get the SQL query
|
||||
call_args = mock_client.query_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix in query
|
||||
assert '"platform"."StoreListingVersion"' in sql_query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify ensure_embedding was called
|
||||
assert mock_ensure.called
|
||||
|
||||
# Verify results
|
||||
assert result["processed"] == 1
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_ensure_content_embedding_with_schema():
|
||||
"""Test ensuring embeddings exist with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_content_embedding"
|
||||
) as mock_get:
|
||||
# Simulate no existing embedding
|
||||
mock_get.return_value = None
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding"
|
||||
) as mock_generate:
|
||||
mock_generate.return_value = [0.1] * 1536
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.store_content_embedding"
|
||||
) as mock_store:
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.ensure_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
searchable_text="test text",
|
||||
metadata={"test": "data"},
|
||||
user_id=None,
|
||||
force=False,
|
||||
)
|
||||
|
||||
# Verify the flow
|
||||
assert mock_get.called
|
||||
assert mock_generate.called
|
||||
assert mock_store.called
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backward_compatibility_store_embedding():
|
||||
"""Test backward compatibility wrapper for store_embedding."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.store_content_embedding"
|
||||
) as mock_store:
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id",
|
||||
embedding=[0.1] * 1536,
|
||||
tx=None,
|
||||
)
|
||||
|
||||
# Verify it calls the new function with correct parameters
|
||||
assert mock_store.called
|
||||
call_args = mock_store.call_args
|
||||
|
||||
assert call_args[1]["content_type"] == ContentType.STORE_AGENT
|
||||
assert call_args[1]["content_id"] == "test-version-id"
|
||||
assert call_args[1]["user_id"] is None
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backward_compatibility_get_embedding():
|
||||
"""Test backward compatibility wrapper for get_embedding."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_content_embedding"
|
||||
) as mock_get:
|
||||
mock_get.return_value = {
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-version-id",
|
||||
"embedding": "[0.1, 0.2]",
|
||||
"createdAt": "2024-01-01",
|
||||
"updatedAt": "2024-01-01",
|
||||
}
|
||||
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
# Verify it calls the new function
|
||||
assert mock_get.called
|
||||
|
||||
# Verify it transforms to old format
|
||||
assert result is not None
|
||||
assert result["storeListingVersionId"] == "test-version-id"
|
||||
assert "embedding" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_schema_handling_error_cases():
|
||||
"""Test error handling in schema-aware operations."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1] * 1536,
|
||||
searchable_text="test",
|
||||
metadata=None,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Should return False on error, not raise
|
||||
assert result is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -0,0 +1,387 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_prisma():
|
||||
"""Setup Prisma client for tests."""
|
||||
try:
|
||||
Prisma()
|
||||
except prisma.errors.ClientAlreadyRegisteredError:
|
||||
pass
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_build_searchable_text():
|
||||
"""Test searchable text building from listing fields."""
|
||||
result = embeddings.build_searchable_text(
|
||||
name="AI Assistant",
|
||||
description="A helpful AI assistant for productivity",
|
||||
sub_heading="Boost your productivity",
|
||||
categories=["AI", "Productivity"],
|
||||
)
|
||||
|
||||
expected = "AI Assistant Boost your productivity A helpful AI assistant for productivity AI Productivity"
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_build_searchable_text_empty_fields():
|
||||
"""Test searchable text building with empty fields."""
|
||||
result = embeddings.build_searchable_text(
|
||||
name="", description="Test description", sub_heading="", categories=[]
|
||||
)
|
||||
|
||||
assert result == "Test description"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_success():
|
||||
"""Test successful embedding generation."""
|
||||
# Mock OpenAI response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock()]
|
||||
mock_response.data[0].embedding = [0.1, 0.2, 0.3] * 512 # 1536 dimensions
|
||||
|
||||
# Use AsyncMock for async embeddings.create method
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1536
|
||||
assert result[0] == 0.1
|
||||
|
||||
mock_client.embeddings.create.assert_called_once_with(
|
||||
model="text-embedding-3-small", input="test text"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_no_api_key():
|
||||
"""Test embedding generation without API key."""
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = None
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_api_error():
|
||||
"""Test embedding generation with API error."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings.create = AsyncMock(side_effect=Exception("API Error"))
|
||||
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_text_truncation():
|
||||
"""Test that long text is properly truncated using tiktoken."""
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock()]
|
||||
mock_response.data[0].embedding = [0.1] * 1536
|
||||
|
||||
# Use AsyncMock for async embeddings.create method
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Create text that will exceed 8191 tokens
|
||||
# Use varied characters to ensure token-heavy text: each word is ~1 token
|
||||
words = [f"word{i}" for i in range(10000)]
|
||||
long_text = " ".join(words) # ~10000 tokens
|
||||
|
||||
await embeddings.generate_embedding(long_text)
|
||||
|
||||
# Verify text was truncated to 8191 tokens
|
||||
call_args = mock_client.embeddings.create.call_args
|
||||
truncated_text = call_args.kwargs["input"]
|
||||
|
||||
# Count actual tokens in truncated text
|
||||
enc = encoding_for_model("text-embedding-3-small")
|
||||
actual_tokens = len(enc.encode(truncated_text))
|
||||
|
||||
# Should be at or just under 8191 tokens
|
||||
assert actual_tokens <= 8191
|
||||
# Should be close to the limit (not over-truncated)
|
||||
assert actual_tokens >= 8100
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_embedding_success(mocker):
|
||||
"""Test successful embedding storage."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_client.execute_raw = mocker.AsyncMock()
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# execute_raw is called twice: once for SET search_path, once for INSERT
|
||||
assert mock_client.execute_raw.call_count == 2
|
||||
|
||||
# First call: SET search_path
|
||||
first_call_args = mock_client.execute_raw.call_args_list[0][0]
|
||||
assert "SET search_path" in first_call_args[0]
|
||||
|
||||
# Second call: INSERT query with the actual data
|
||||
second_call_args = mock_client.execute_raw.call_args_list[1][0]
|
||||
assert "test-version-id" in second_call_args
|
||||
assert "[0.1,0.2,0.3]" in second_call_args
|
||||
assert None in second_call_args # userId should be None for store agents
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_embedding_database_error(mocker):
|
||||
"""Test embedding storage with database error."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_success():
|
||||
"""Test successful embedding retrieval."""
|
||||
mock_result = [
|
||||
{
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-version-id",
|
||||
"userId": None,
|
||||
"embedding": "[0.1,0.2,0.3]",
|
||||
"searchableText": "Test text",
|
||||
"metadata": {},
|
||||
"createdAt": "2024-01-01T00:00:00Z",
|
||||
"updatedAt": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
assert result is not None
|
||||
assert result["storeListingVersionId"] == "test-version-id"
|
||||
assert result["embedding"] == "[0.1,0.2,0.3]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_not_found():
|
||||
"""Test embedding retrieval when not found."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.store_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_already_exists(mock_get, mock_store, mock_generate):
|
||||
"""Test ensure_embedding when embedding already exists."""
|
||||
mock_get.return_value = {"embedding": "[0.1,0.2,0.3]"}
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_generate.assert_not_called()
|
||||
mock_store.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.store_content_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
|
||||
"""Test ensure_embedding creating new embedding."""
|
||||
mock_get.return_value = None
|
||||
mock_generate.return_value = [0.1, 0.2, 0.3]
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_generate.assert_called_once_with("Test Test heading Test description test")
|
||||
mock_store.assert_called_once_with(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
searchable_text="Test Test heading Test description test",
|
||||
metadata={"name": "Test", "subHeading": "Test heading", "categories": ["test"]},
|
||||
user_id=None,
|
||||
tx=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
||||
"""Test ensure_embedding when generation fails."""
|
||||
mock_get.return_value = None
|
||||
mock_generate.return_value = None
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats():
|
||||
"""Test embedding statistics retrieval."""
|
||||
# Mock approved count query and embedded count query
|
||||
mock_approved_result = [{"count": 100}]
|
||||
mock_embedded_result = [{"count": 75}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
side_effect=[mock_approved_result, mock_embedded_result],
|
||||
):
|
||||
result = await embeddings.get_embedding_stats()
|
||||
|
||||
assert result["total_approved"] == 100
|
||||
assert result["with_embeddings"] == 75
|
||||
assert result["without_embeddings"] == 25
|
||||
assert result["coverage_percent"] == 75.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.ensure_embedding")
|
||||
async def test_backfill_missing_embeddings_success(mock_ensure):
|
||||
"""Test backfill with successful embedding generation."""
|
||||
# Mock missing embeddings query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "version-1",
|
||||
"name": "Agent 1",
|
||||
"description": "Description 1",
|
||||
"subHeading": "Heading 1",
|
||||
"categories": ["AI"],
|
||||
},
|
||||
{
|
||||
"id": "version-2",
|
||||
"name": "Agent 2",
|
||||
"description": "Description 2",
|
||||
"subHeading": "Heading 2",
|
||||
"categories": ["Productivity"],
|
||||
},
|
||||
]
|
||||
|
||||
# Mock ensure_embedding to succeed for first, fail for second
|
||||
mock_ensure.side_effect = [True, False]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_missing,
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
|
||||
assert result["processed"] == 2
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 1
|
||||
assert mock_ensure.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_backfill_missing_embeddings_no_missing():
|
||||
"""Test backfill when no embeddings are missing."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
|
||||
assert result["processed"] == 0
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["message"] == "No missing embeddings"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_embedding_to_vector_string():
|
||||
"""Test embedding to PostgreSQL vector string conversion."""
|
||||
embedding = [0.1, 0.2, 0.3, -0.4]
|
||||
result = embeddings.embedding_to_vector_string(embedding)
|
||||
assert result == "[0.1,0.2,0.3,-0.4]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_embed_query():
|
||||
"""Test embed_query function (alias for generate_embedding)."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding"
|
||||
) as mock_generate:
|
||||
mock_generate.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.embed_query("test query")
|
||||
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
mock_generate.assert_called_once_with("test query")
|
||||
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
Hybrid Search for Store Agents
|
||||
|
||||
Combines semantic (embedding) search with lexical (tsvector) search
|
||||
for improved relevance in marketplace agent discovery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from backend.api.features.store.embeddings import (
|
||||
embed_query,
|
||||
embedding_to_vector_string,
|
||||
)
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchWeights:
|
||||
"""Weights for combining search signals."""
|
||||
|
||||
semantic: float = 0.30 # Embedding cosine similarity
|
||||
lexical: float = 0.30 # tsvector ts_rank_cd score
|
||||
category: float = 0.20 # Category match boost
|
||||
recency: float = 0.10 # Newer agents ranked higher
|
||||
popularity: float = 0.10 # Agent usage/runs (PageRank-like)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate weights are non-negative and sum to approximately 1.0."""
|
||||
total = (
|
||||
self.semantic
|
||||
+ self.lexical
|
||||
+ self.category
|
||||
+ self.recency
|
||||
+ self.popularity
|
||||
)
|
||||
|
||||
if any(
|
||||
w < 0
|
||||
for w in [
|
||||
self.semantic,
|
||||
self.lexical,
|
||||
self.category,
|
||||
self.recency,
|
||||
self.popularity,
|
||||
]
|
||||
):
|
||||
raise ValueError("All weights must be non-negative")
|
||||
|
||||
if not (0.99 <= total <= 1.01):
|
||||
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
|
||||
|
||||
|
||||
DEFAULT_WEIGHTS = HybridSearchWeights()
|
||||
|
||||
# Minimum relevance score threshold - agents below this are filtered out
|
||||
# With weights (0.30 semantic + 0.30 lexical + 0.20 category + 0.10 recency + 0.10 popularity):
|
||||
# - 0.20 means at least ~60% semantic match OR strong lexical match required
|
||||
# - Ensures only genuinely relevant results are returned
|
||||
# - Recency/popularity alone (0.10 each) won't pass the threshold
|
||||
DEFAULT_MIN_SCORE = 0.20
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchResult:
|
||||
"""A single search result with score breakdown."""
|
||||
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_image: str
|
||||
creator_username: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
runs: int
|
||||
rating: float
|
||||
categories: list[str]
|
||||
featured: bool
|
||||
is_available: bool
|
||||
updated_at: datetime
|
||||
|
||||
# Score breakdown (for debugging/tuning)
|
||||
combined_score: float
|
||||
semantic_score: float = 0.0
|
||||
lexical_score: float = 0.0
|
||||
category_score: float = 0.0
|
||||
recency_score: float = 0.0
|
||||
popularity_score: float = 0.0
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
query: str,
|
||||
featured: bool = False,
|
||||
creators: list[str] | None = None,
|
||||
category: str | None = None,
|
||||
sorted_by: (
|
||||
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
|
||||
) = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
weights: HybridSearchWeights | None = None,
|
||||
min_score: float | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Perform hybrid search combining semantic and lexical signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
featured: Filter for featured agents only
|
||||
creators: Filter by creator usernames
|
||||
category: Filter by category
|
||||
sorted_by: Sort order (relevance uses hybrid scoring)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Results per page
|
||||
weights: Custom weights for search signals
|
||||
min_score: Minimum relevance score threshold (0-1). Results below
|
||||
this score are filtered out. Defaults to DEFAULT_MIN_SCORE.
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, total count). Returns empty list if no
|
||||
results meet the minimum relevance threshold.
|
||||
"""
|
||||
# Validate inputs
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return [], 0 # Empty query returns no results
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = 1
|
||||
if page_size > 100: # Cap at reasonable limit to prevent performance issues
|
||||
page_size = 100
|
||||
|
||||
if weights is None:
|
||||
weights = DEFAULT_WEIGHTS
|
||||
if min_score is None:
|
||||
min_score = DEFAULT_MIN_SCORE
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Build WHERE clause conditions
|
||||
where_parts: list[str] = ["sa.is_available = true"]
|
||||
params: list[Any] = []
|
||||
param_index = 1
|
||||
|
||||
# Add search query for lexical matching
|
||||
params.append(query)
|
||||
query_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Add lowercased query for category matching
|
||||
params.append(query.lower())
|
||||
query_lower_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
if featured:
|
||||
where_parts.append("sa.featured = true")
|
||||
|
||||
if creators:
|
||||
where_parts.append(f"sa.creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category:
|
||||
where_parts.append(f"${param_index} = ANY(sa.categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
# Safe: where_parts only contains hardcoded strings with $N parameter placeholders
|
||||
# No user input is concatenated directly into the SQL string
|
||||
where_clause = " AND ".join(where_parts)
|
||||
|
||||
# Embedding is required for hybrid search - fail fast if unavailable
|
||||
if query_embedding is None or not query_embedding:
|
||||
# Log detailed error server-side
|
||||
logger.error(
|
||||
"Failed to generate query embedding. "
|
||||
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
||||
)
|
||||
# Raise generic error to client
|
||||
raise ValueError("Search service temporarily unavailable")
|
||||
|
||||
# Add embedding parameter
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
params.append(embedding_str)
|
||||
embedding_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Add weight parameters for SQL calculation
|
||||
params.append(weights.semantic)
|
||||
weight_semantic_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
params.append(weights.lexical)
|
||||
weight_lexical_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
params.append(weights.category)
|
||||
weight_category_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
params.append(weights.recency)
|
||||
weight_recency_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
params.append(weights.popularity)
|
||||
weight_popularity_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Add min_score parameter
|
||||
params.append(min_score)
|
||||
min_score_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Optimized hybrid search query:
|
||||
# 1. Direct join to UnifiedContentEmbedding via contentId=storeListingVersionId (no redundant JOINs)
|
||||
# 2. UNION approach (deduplicates agents matching both branches)
|
||||
# 3. COUNT(*) OVER() to get total count in single query
|
||||
# 4. Optimized category matching with EXISTS + unnest
|
||||
# 5. Pre-calculated max values for lexical and popularity normalization
|
||||
# 6. Simplified recency calculation with linear decay
|
||||
# 7. Logarithmic popularity scaling to prevent viral agents from dominating
|
||||
sql_query = f"""
|
||||
WITH candidates AS (
|
||||
-- Lexical matches (uses GIN index on search column)
|
||||
SELECT sa."storeListingVersionId"
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
WHERE {where_clause}
|
||||
AND sa.search @@ plainto_tsquery('english', {query_param})
|
||||
|
||||
UNION
|
||||
|
||||
-- Semantic matches (uses HNSW index on embedding with KNN)
|
||||
SELECT "storeListingVersionId"
|
||||
FROM (
|
||||
SELECT sa."storeListingVersionId", uce.embedding
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
WHERE {where_clause}
|
||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||
LIMIT 200
|
||||
) semantic_results
|
||||
),
|
||||
search_scores AS (
|
||||
SELECT
|
||||
sa.slug,
|
||||
sa.agent_name,
|
||||
sa.agent_image,
|
||||
sa.creator_username,
|
||||
sa.creator_avatar,
|
||||
sa.sub_heading,
|
||||
sa.description,
|
||||
sa.runs,
|
||||
sa.rating,
|
||||
sa.categories,
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
-- Semantic score: cosine similarity (1 - distance)
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score: ts_rank_cd (will be normalized later)
|
||||
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match: optimized with unnest for better performance
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(sa.categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
|
||||
)
|
||||
THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
-- Recency score: linear decay over 90 days (simpler than exponential)
|
||||
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score,
|
||||
-- Popularity raw: agent runs count (will be normalized with log scaling)
|
||||
sa.runs as popularity_raw
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||
LEFT JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
),
|
||||
max_lexical AS (
|
||||
SELECT MAX(lexical_raw) as max_val FROM search_scores
|
||||
),
|
||||
max_popularity AS (
|
||||
SELECT MAX(popularity_raw) as max_val FROM search_scores
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
ss.*,
|
||||
-- Normalize lexical score by pre-calculated max
|
||||
CASE
|
||||
WHEN ml.max_val > 0
|
||||
THEN ss.lexical_raw / ml.max_val
|
||||
ELSE 0
|
||||
END as lexical_score,
|
||||
-- Normalize popularity with logarithmic scaling to prevent viral agents from dominating
|
||||
-- LOG(1 + runs) / LOG(1 + max_runs) ensures score is 0-1 range
|
||||
CASE
|
||||
WHEN mp.max_val > 0 AND ss.popularity_raw > 0
|
||||
THEN LN(1 + ss.popularity_raw) / LN(1 + mp.max_val)
|
||||
ELSE 0
|
||||
END as popularity_score
|
||||
FROM search_scores ss
|
||||
CROSS JOIN max_lexical ml
|
||||
CROSS JOIN max_popularity mp
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
popularity_score,
|
||||
(
|
||||
{weight_semantic_param} * semantic_score +
|
||||
{weight_lexical_param} * lexical_score +
|
||||
{weight_category_param} * category_score +
|
||||
{weight_recency_param} * recency_score +
|
||||
{weight_popularity_param} * popularity_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
),
|
||||
filtered AS (
|
||||
SELECT
|
||||
*,
|
||||
COUNT(*) OVER () as total_count
|
||||
FROM scored
|
||||
WHERE combined_score >= {min_score_param}
|
||||
)
|
||||
SELECT * FROM filtered
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||
"""
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
|
||||
# Execute search query - includes total_count via window function
|
||||
results = await query_raw_with_schema(
|
||||
sql_query, *params, set_public_search_path=True
|
||||
)
|
||||
|
||||
# Extract total count from first result (all rows have same count)
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
# Remove total_count from results before returning
|
||||
for result in results:
|
||||
result.pop("total_count", None)
|
||||
|
||||
# Log without sensitive query content
|
||||
logger.info(f"Hybrid search: {len(results)} results, {total} total")
|
||||
|
||||
return results, total
|
||||
|
||||
|
||||
async def hybrid_search_simple(
|
||||
query: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Simplified hybrid search for common use cases.
|
||||
|
||||
Uses default weights and no filters.
|
||||
"""
|
||||
return await hybrid_search(
|
||||
query=query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
Integration tests for hybrid search with schema handling.
|
||||
|
||||
These tests verify that hybrid search works correctly across different database schemas.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.hybrid_search import HybridSearchWeights, hybrid_search
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_schema_handling():
|
||||
"""Test that hybrid search correctly handles database schema prefixes."""
|
||||
# Test with a mock query to ensure schema handling works
|
||||
query = "test agent"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Mock the query result
|
||||
mock_query.return_value = [
|
||||
{
|
||||
"slug": "test/agent",
|
||||
"agent_name": "Test Agent",
|
||||
"agent_image": "test.png",
|
||||
"creator_username": "test",
|
||||
"creator_avatar": "avatar.png",
|
||||
"sub_heading": "Test sub-heading",
|
||||
"description": "Test description",
|
||||
"runs": 10,
|
||||
"rating": 4.5,
|
||||
"categories": ["test"],
|
||||
"featured": False,
|
||||
"is_available": True,
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"combined_score": 0.8,
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.6,
|
||||
"category_score": 0.5,
|
||||
"recency_score": 0.4,
|
||||
"total_count": 1,
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536 # Mock embedding
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query=query,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_query.called
|
||||
# Verify the SQL template uses schema_prefix placeholder
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
assert "{schema_prefix}" in sql_template
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert total == 1
|
||||
assert results[0]["slug"] == "test/agent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_public_schema():
|
||||
"""Test hybrid search when using public schema (no prefix needed)."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "public"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the mock was set up correctly
|
||||
assert mock_schema.return_value == "public"
|
||||
|
||||
# Results should work even with empty results
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_custom_schema():
|
||||
"""Test hybrid search when using custom schema (e.g., 'platform')."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the mock was set up correctly
|
||||
assert mock_schema.return_value == "platform"
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_without_embeddings():
|
||||
"""Test hybrid search fails fast when embeddings are unavailable."""
|
||||
# Patch where the function is used, not where it's defined
|
||||
with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed:
|
||||
# Simulate embedding failure
|
||||
mock_embed.return_value = None
|
||||
|
||||
# Should raise ValueError with helpful message
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify error message is generic (doesn't leak implementation details)
|
||||
assert "Search service temporarily unavailable" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_filters():
|
||||
"""Test hybrid search with various filters."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
# Test with featured filter
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
featured=True,
|
||||
creators=["user1", "user2"],
|
||||
category="productivity",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Verify filters were applied in the query
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0][1:] # Skip SQL template
|
||||
|
||||
# Should have query, query_lower, creators array, category
|
||||
assert len(params) >= 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_weights():
|
||||
"""Test hybrid search with custom weights."""
|
||||
custom_weights = HybridSearchWeights(
|
||||
semantic=0.5,
|
||||
lexical=0.3,
|
||||
category=0.1,
|
||||
recency=0.1,
|
||||
popularity=0.0,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
weights=custom_weights,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify custom weights were used in the query
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
params = call_args[0][1:] # Get all parameters passed
|
||||
|
||||
# Check that SQL uses parameterized weights (not f-string interpolation)
|
||||
assert "$" in sql_template # Verify parameterization is used
|
||||
|
||||
# Check that custom weights are in the params
|
||||
assert 0.5 in params # semantic weight
|
||||
assert 0.3 in params # lexical weight
|
||||
assert 0.1 in params # category and recency weights
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_min_score_filtering():
|
||||
"""Test hybrid search minimum score threshold."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Return results with varying scores
|
||||
mock_query.return_value = [
|
||||
{
|
||||
"slug": "high-score/agent",
|
||||
"agent_name": "High Score Agent",
|
||||
"combined_score": 0.8,
|
||||
"total_count": 1,
|
||||
# ... other fields
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
# Test with custom min_score
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
min_score=0.5, # High threshold
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify min_score was applied in query
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
params = call_args[0][1:] # Get all parameters
|
||||
|
||||
# Check that SQL uses parameterized min_score
|
||||
assert "combined_score >=" in sql_template
|
||||
assert "$" in sql_template # Verify parameterization
|
||||
|
||||
# Check that custom min_score is in the params
|
||||
assert 0.5 in params
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_pagination():
|
||||
"""Test hybrid search pagination."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
# Test page 2 with page_size 10
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=2,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Verify pagination parameters
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0]
|
||||
|
||||
# Last two params should be LIMIT and OFFSET
|
||||
limit = params[-2]
|
||||
offset = params[-1]
|
||||
|
||||
assert limit == 10 # page_size
|
||||
assert offset == 10 # (page - 1) * page_size = (2 - 1) * 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_error_handling():
|
||||
"""Test hybrid search error handling."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Simulate database error
|
||||
mock_query.side_effect = Exception("Database connection error")
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
# Should raise exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
assert "Database connection error" in str(exc_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -64,7 +64,6 @@ from backend.data.onboarding import (
|
||||
complete_re_run_agent,
|
||||
get_recommended_agents,
|
||||
get_user_onboarding,
|
||||
increment_runs,
|
||||
onboarding_enabled,
|
||||
reset_user_onboarding,
|
||||
update_user_onboarding,
|
||||
@@ -975,7 +974,6 @@ async def execute_graph(
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
record_graph_operation(operation="execute", status="success")
|
||||
await increment_runs(user_id)
|
||||
await complete_re_run_agent(user_id, graph_id)
|
||||
if source == "library":
|
||||
await complete_onboarding_step(
|
||||
|
||||
@@ -38,6 +38,20 @@ POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
|
||||
if POOL_TIMEOUT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
|
||||
|
||||
# Add public schema to search_path for pgvector type access
|
||||
# The vector extension is in public schema, but search_path is determined by schema parameter
|
||||
# Extract the schema from DATABASE_URL or default to 'public' (matching get_database_schema())
|
||||
parsed_url = urlparse(DATABASE_URL)
|
||||
url_params = dict(parse_qsl(parsed_url.query))
|
||||
db_schema = url_params.get("schema", "public")
|
||||
# Build search_path, avoiding duplicates if db_schema is already 'public'
|
||||
search_path_schemas = list(
|
||||
dict.fromkeys([db_schema, "public"])
|
||||
) # Preserves order, removes duplicates
|
||||
search_path = ",".join(search_path_schemas)
|
||||
# This allows using ::vector without schema qualification
|
||||
DATABASE_URL = add_param(DATABASE_URL, "options", f"-c search_path={search_path}")
|
||||
|
||||
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
|
||||
|
||||
prisma = Prisma(
|
||||
@@ -108,21 +122,102 @@ def get_database_schema() -> str:
|
||||
return query_params.get("schema", "public")
|
||||
|
||||
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
"""Execute raw SQL query with proper schema handling."""
|
||||
async def _raw_with_schema(
|
||||
query_template: str,
|
||||
*args,
|
||||
execute: bool = False,
|
||||
client: Prisma | None = None,
|
||||
set_public_search_path: bool = False,
|
||||
) -> list[dict] | int:
|
||||
"""Internal: Execute raw SQL with proper schema handling.
|
||||
|
||||
Use query_raw_with_schema() or execute_raw_with_schema() instead.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE.
|
||||
client: Optional Prisma client for transactions (only used when execute=True).
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
- list[dict] if execute=False (query results)
|
||||
- int if execute=True (number of affected rows)
|
||||
"""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
result = await prisma_module.get_client().query_raw(
|
||||
formatted_query, *args # type: ignore
|
||||
)
|
||||
db_client = client if client else prisma_module.get_client()
|
||||
|
||||
# Set search_path to include public schema if requested
|
||||
# Prisma doesn't support the 'options' connection parameter, so we set it per-session
|
||||
# This is idempotent and safe to call multiple times
|
||||
if set_public_search_path:
|
||||
await db_client.execute_raw(f"SET search_path = {schema}, public") # type: ignore
|
||||
|
||||
if execute:
|
||||
result = await db_client.execute_raw(formatted_query, *args) # type: ignore
|
||||
else:
|
||||
result = await db_client.query_raw(formatted_query, *args) # type: ignore
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def query_raw_with_schema(
|
||||
query_template: str, *args, set_public_search_path: bool = False
|
||||
) -> list[dict]:
|
||||
"""Execute raw SQL SELECT query with proper schema handling.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
List of result rows as dictionaries
|
||||
|
||||
Example:
|
||||
results = await query_raw_with_schema(
|
||||
'SELECT * FROM {schema_prefix}"User" WHERE id = $1',
|
||||
user_id
|
||||
)
|
||||
"""
|
||||
return await _raw_with_schema(query_template, *args, execute=False, set_public_search_path=set_public_search_path) # type: ignore
|
||||
|
||||
|
||||
async def execute_raw_with_schema(
|
||||
query_template: str,
|
||||
*args,
|
||||
client: Prisma | None = None,
|
||||
set_public_search_path: bool = False,
|
||||
) -> int:
|
||||
"""Execute raw SQL command (INSERT/UPDATE/DELETE) with proper schema handling.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
client: Optional Prisma client for transactions
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
Number of affected rows
|
||||
|
||||
Example:
|
||||
await execute_raw_with_schema(
|
||||
'INSERT INTO {schema_prefix}"User" (id, name) VALUES ($1, $2)',
|
||||
user_id, name,
|
||||
client=tx # Optional transaction client
|
||||
)
|
||||
"""
|
||||
return await _raw_with_schema(query_template, *args, execute=True, client=client, set_public_search_path=set_public_search_path) # type: ignore
|
||||
|
||||
|
||||
class BaseDbModel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import fastapi.exceptions
|
||||
@@ -18,6 +19,17 @@ from backend.usecases.sample import create_test_user
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot):
|
||||
"""
|
||||
|
||||
@@ -334,7 +334,7 @@ async def _get_user_timezone(user_id: str) -> str:
|
||||
return get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
|
||||
async def increment_runs(user_id: str):
|
||||
async def increment_onboarding_runs(user_id: str):
|
||||
"""
|
||||
Increment a user's run counters and trigger any onboarding milestones.
|
||||
"""
|
||||
|
||||
404
autogpt_platform/backend/backend/data/understanding.py
Normal file
404
autogpt_platform/backend/backend/data/understanding.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""Data models and access layer for user business understanding."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import pydantic
|
||||
from prisma.models import CoPilotUnderstanding
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache configuration
|
||||
CACHE_KEY_PREFIX = "understanding"
|
||||
CACHE_TTL_SECONDS = 48 * 60 * 60 # 48 hours
|
||||
|
||||
|
||||
def _cache_key(user_id: str) -> str:
|
||||
"""Generate cache key for user business understanding."""
|
||||
return f"{CACHE_KEY_PREFIX}:{user_id}"
|
||||
|
||||
|
||||
def _json_to_list(value: Any) -> list[str]:
|
||||
"""Convert Json field to list[str], handling None."""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return cast(list[str], value)
|
||||
return []
|
||||
|
||||
|
||||
class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
"""Input model for updating business understanding - all fields optional for incremental updates."""
|
||||
|
||||
# User info
|
||||
user_name: Optional[str] = pydantic.Field(None, description="The user's name")
|
||||
job_title: Optional[str] = pydantic.Field(None, description="The user's job title")
|
||||
|
||||
# Business basics
|
||||
business_name: Optional[str] = pydantic.Field(
|
||||
None, description="Name of the user's business"
|
||||
)
|
||||
industry: Optional[str] = pydantic.Field(None, description="Industry or sector")
|
||||
business_size: Optional[str] = pydantic.Field(
|
||||
None, description="Company size (e.g., '1-10', '11-50')"
|
||||
)
|
||||
user_role: Optional[str] = pydantic.Field(
|
||||
None,
|
||||
description="User's role in the organization (e.g., 'decision maker', 'implementer')",
|
||||
)
|
||||
|
||||
# Processes & activities
|
||||
key_workflows: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Key business workflows"
|
||||
)
|
||||
daily_activities: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Daily activities performed"
|
||||
)
|
||||
|
||||
# Pain points & goals
|
||||
pain_points: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Current pain points"
|
||||
)
|
||||
bottlenecks: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Process bottlenecks"
|
||||
)
|
||||
manual_tasks: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Manual/repetitive tasks"
|
||||
)
|
||||
automation_goals: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Desired automation goals"
|
||||
)
|
||||
|
||||
# Current tools
|
||||
current_software: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Software/tools currently used"
|
||||
)
|
||||
existing_automation: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Existing automations"
|
||||
)
|
||||
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = pydantic.Field(
|
||||
None, description="Any additional context"
|
||||
)
|
||||
|
||||
|
||||
class BusinessUnderstanding(pydantic.BaseModel):
|
||||
"""Full business understanding model returned from database."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# User info
|
||||
user_name: Optional[str] = None
|
||||
job_title: Optional[str] = None
|
||||
|
||||
# Business basics
|
||||
business_name: Optional[str] = None
|
||||
industry: Optional[str] = None
|
||||
business_size: Optional[str] = None
|
||||
user_role: Optional[str] = None
|
||||
|
||||
# Processes & activities
|
||||
key_workflows: list[str] = pydantic.Field(default_factory=list)
|
||||
daily_activities: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Pain points & goals
|
||||
pain_points: list[str] = pydantic.Field(default_factory=list)
|
||||
bottlenecks: list[str] = pydantic.Field(default_factory=list)
|
||||
manual_tasks: list[str] = pydantic.Field(default_factory=list)
|
||||
automation_goals: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Current tools
|
||||
current_software: list[str] = pydantic.Field(default_factory=list)
|
||||
existing_automation: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
|
||||
"""Convert database record to Pydantic model."""
|
||||
data = db_record.data if isinstance(db_record.data, dict) else {}
|
||||
business = (
|
||||
data.get("business", {}) if isinstance(data.get("business"), dict) else {}
|
||||
)
|
||||
return cls(
|
||||
id=db_record.id,
|
||||
user_id=db_record.userId,
|
||||
created_at=db_record.createdAt,
|
||||
updated_at=db_record.updatedAt,
|
||||
user_name=data.get("name"),
|
||||
job_title=business.get("job_title"),
|
||||
business_name=business.get("business_name"),
|
||||
industry=business.get("industry"),
|
||||
business_size=business.get("business_size"),
|
||||
user_role=business.get("user_role"),
|
||||
key_workflows=_json_to_list(business.get("key_workflows")),
|
||||
daily_activities=_json_to_list(business.get("daily_activities")),
|
||||
pain_points=_json_to_list(business.get("pain_points")),
|
||||
bottlenecks=_json_to_list(business.get("bottlenecks")),
|
||||
manual_tasks=_json_to_list(business.get("manual_tasks")),
|
||||
automation_goals=_json_to_list(business.get("automation_goals")),
|
||||
current_software=_json_to_list(business.get("current_software")),
|
||||
existing_automation=_json_to_list(business.get("existing_automation")),
|
||||
additional_notes=business.get("additional_notes"),
|
||||
)
|
||||
|
||||
|
||||
def _merge_lists(existing: list | None, new: list | None) -> list | None:
|
||||
"""Merge two lists, removing duplicates while preserving order."""
|
||||
if new is None:
|
||||
return existing
|
||||
if existing is None:
|
||||
return new
|
||||
# Preserve order, add new items that don't exist
|
||||
merged = list(existing)
|
||||
for item in new:
|
||||
if item not in merged:
|
||||
merged.append(item)
|
||||
return merged
|
||||
|
||||
|
||||
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
|
||||
"""Get business understanding from Redis cache."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
cached_data = await redis.get(_cache_key(user_id))
|
||||
if cached_data:
|
||||
return BusinessUnderstanding.model_validate_json(cached_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get understanding from cache: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _set_cache(user_id: str, understanding: BusinessUnderstanding) -> None:
|
||||
"""Set business understanding in Redis cache with TTL."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.setex(
|
||||
_cache_key(user_id),
|
||||
CACHE_TTL_SECONDS,
|
||||
understanding.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set understanding in cache: {e}")
|
||||
|
||||
|
||||
async def _delete_cache(user_id: str) -> None:
|
||||
"""Delete business understanding from Redis cache."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(_cache_key(user_id))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete understanding from cache: {e}")
|
||||
|
||||
|
||||
async def get_business_understanding(
|
||||
user_id: str,
|
||||
) -> Optional[BusinessUnderstanding]:
|
||||
"""Get the business understanding for a user.
|
||||
|
||||
Checks cache first, falls back to database if not cached.
|
||||
Results are cached for 48 hours.
|
||||
"""
|
||||
# Try cache first
|
||||
cached = await _get_from_cache(user_id)
|
||||
if cached:
|
||||
logger.debug(f"Business understanding cache hit for user {user_id}")
|
||||
return cached
|
||||
|
||||
# Cache miss - load from database
|
||||
logger.debug(f"Business understanding cache miss for user {user_id}")
|
||||
record = await CoPilotUnderstanding.prisma().find_unique(where={"userId": user_id})
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
|
||||
# Store in cache for next time
|
||||
await _set_cache(user_id, understanding)
|
||||
|
||||
return understanding
|
||||
|
||||
|
||||
async def upsert_business_understanding(
|
||||
user_id: str,
|
||||
input_data: BusinessUnderstandingInput,
|
||||
) -> BusinessUnderstanding:
|
||||
"""
|
||||
Create or update business understanding with incremental merge strategy.
|
||||
|
||||
- String fields: new value overwrites if provided (not None)
|
||||
- List fields: new items are appended to existing (deduplicated)
|
||||
|
||||
Data is stored as: {name: ..., business: {version: 1, ...}}
|
||||
"""
|
||||
# Get existing record for merge
|
||||
existing = await CoPilotUnderstanding.prisma().find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# Get existing data structure or start fresh
|
||||
existing_data: dict[str, Any] = {}
|
||||
if existing and isinstance(existing.data, dict):
|
||||
existing_data = dict(existing.data)
|
||||
|
||||
existing_business: dict[str, Any] = {}
|
||||
if isinstance(existing_data.get("business"), dict):
|
||||
existing_business = dict(existing_data["business"])
|
||||
|
||||
# Business fields (stored inside business object)
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
"industry",
|
||||
"business_size",
|
||||
"user_role",
|
||||
"additional_notes",
|
||||
]
|
||||
business_list_fields = [
|
||||
"key_workflows",
|
||||
"daily_activities",
|
||||
"pain_points",
|
||||
"bottlenecks",
|
||||
"manual_tasks",
|
||||
"automation_goals",
|
||||
"current_software",
|
||||
"existing_automation",
|
||||
]
|
||||
|
||||
# Handle top-level name field
|
||||
if input_data.user_name is not None:
|
||||
existing_data["name"] = input_data.user_name
|
||||
|
||||
# Business string fields - overwrite if provided
|
||||
for field in business_string_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_business[field] = value
|
||||
|
||||
# Business list fields - merge with existing
|
||||
for field in business_list_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_list = _json_to_list(existing_business.get(field))
|
||||
merged = _merge_lists(existing_list, value)
|
||||
existing_business[field] = merged
|
||||
|
||||
# Set version and nest business data
|
||||
existing_business["version"] = 1
|
||||
existing_data["business"] = existing_business
|
||||
|
||||
# Upsert with the merged data
|
||||
record = await CoPilotUnderstanding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(existing_data)},
|
||||
"update": {"data": SafeJson(existing_data)},
|
||||
},
|
||||
)
|
||||
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
|
||||
# Update cache with new understanding
|
||||
await _set_cache(user_id, understanding)
|
||||
|
||||
return understanding
|
||||
|
||||
|
||||
async def clear_business_understanding(user_id: str) -> bool:
|
||||
"""Clear/delete business understanding for a user from both DB and cache."""
|
||||
# Delete from cache first
|
||||
await _delete_cache(user_id)
|
||||
|
||||
try:
|
||||
await CoPilotUnderstanding.prisma().delete(where={"userId": user_id})
|
||||
return True
|
||||
except Exception:
|
||||
# Record might not exist
|
||||
return False
|
||||
|
||||
|
||||
def format_understanding_for_prompt(understanding: BusinessUnderstanding) -> str:
|
||||
"""Format business understanding as text for system prompt injection."""
|
||||
sections = []
|
||||
|
||||
# User info section
|
||||
user_info = []
|
||||
if understanding.user_name:
|
||||
user_info.append(f"Name: {understanding.user_name}")
|
||||
if understanding.job_title:
|
||||
user_info.append(f"Job Title: {understanding.job_title}")
|
||||
if user_info:
|
||||
sections.append("## User\n" + "\n".join(user_info))
|
||||
|
||||
# Business section
|
||||
business_info = []
|
||||
if understanding.business_name:
|
||||
business_info.append(f"Company: {understanding.business_name}")
|
||||
if understanding.industry:
|
||||
business_info.append(f"Industry: {understanding.industry}")
|
||||
if understanding.business_size:
|
||||
business_info.append(f"Size: {understanding.business_size}")
|
||||
if understanding.user_role:
|
||||
business_info.append(f"Role Context: {understanding.user_role}")
|
||||
if business_info:
|
||||
sections.append("## Business\n" + "\n".join(business_info))
|
||||
|
||||
# Processes section
|
||||
processes = []
|
||||
if understanding.key_workflows:
|
||||
processes.append(f"Key Workflows: {', '.join(understanding.key_workflows)}")
|
||||
if understanding.daily_activities:
|
||||
processes.append(
|
||||
f"Daily Activities: {', '.join(understanding.daily_activities)}"
|
||||
)
|
||||
if processes:
|
||||
sections.append("## Processes\n" + "\n".join(processes))
|
||||
|
||||
# Pain points section
|
||||
pain_points = []
|
||||
if understanding.pain_points:
|
||||
pain_points.append(f"Pain Points: {', '.join(understanding.pain_points)}")
|
||||
if understanding.bottlenecks:
|
||||
pain_points.append(f"Bottlenecks: {', '.join(understanding.bottlenecks)}")
|
||||
if understanding.manual_tasks:
|
||||
pain_points.append(f"Manual Tasks: {', '.join(understanding.manual_tasks)}")
|
||||
if pain_points:
|
||||
sections.append("## Pain Points\n" + "\n".join(pain_points))
|
||||
|
||||
# Goals section
|
||||
if understanding.automation_goals:
|
||||
sections.append(
|
||||
"## Automation Goals\n"
|
||||
+ "\n".join(f"- {goal}" for goal in understanding.automation_goals)
|
||||
)
|
||||
|
||||
# Current tools section
|
||||
tools_info = []
|
||||
if understanding.current_software:
|
||||
tools_info.append(
|
||||
f"Current Software: {', '.join(understanding.current_software)}"
|
||||
)
|
||||
if understanding.existing_automation:
|
||||
tools_info.append(
|
||||
f"Existing Automation: {', '.join(understanding.existing_automation)}"
|
||||
)
|
||||
if tools_info:
|
||||
sections.append("## Current Tools\n" + "\n".join(tools_info))
|
||||
|
||||
# Additional notes
|
||||
if understanding.additional_notes:
|
||||
sections.append(f"## Additional Context\n{understanding.additional_notes}")
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return "# User Business Context\n\n" + "\n\n".join(sections)
|
||||
@@ -7,6 +7,10 @@ from backend.api.features.library.db import (
|
||||
list_library_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -20,6 +24,7 @@ from backend.data.execution import (
|
||||
get_execution_kv_data,
|
||||
get_execution_outputs_by_node_exec_id,
|
||||
get_frequently_executed_graphs,
|
||||
get_graph_execution,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_graph_executions_count,
|
||||
@@ -57,6 +62,7 @@ from backend.data.notifications import (
|
||||
get_user_notification_oldest_message_in_batch,
|
||||
remove_notifications_from_batch,
|
||||
)
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_by_id,
|
||||
@@ -140,6 +146,7 @@ class DatabaseManager(AppService):
|
||||
get_child_graph_executions = _(get_child_graph_executions)
|
||||
get_graph_executions = _(get_graph_executions)
|
||||
get_graph_executions_count = _(get_graph_executions_count)
|
||||
get_graph_execution = _(get_graph_execution)
|
||||
get_graph_execution_meta = _(get_graph_execution_meta)
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_execution = _(get_node_execution)
|
||||
@@ -204,10 +211,17 @@ class DatabaseManager(AppService):
|
||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||
|
||||
# Onboarding
|
||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
|
||||
# Store Embeddings
|
||||
get_embedding_stats = _(get_embedding_stats)
|
||||
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
||||
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
@@ -259,6 +273,10 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
get_store_agents = _(d.get_store_agents)
|
||||
get_store_agent_details = _(d.get_store_agent_details)
|
||||
|
||||
# Store Embeddings
|
||||
get_embedding_stats = _(d.get_embedding_stats)
|
||||
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -274,6 +292,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_settings = d.get_graph_settings
|
||||
get_graph_execution = d.get_graph_execution
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
@@ -318,6 +337,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
add_store_agent_to_library = d.add_store_agent_to_library
|
||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||
|
||||
# Onboarding
|
||||
increment_onboarding_runs = d.increment_onboarding_runs
|
||||
|
||||
# Store
|
||||
get_store_agents = d.get_store_agents
|
||||
get_store_agent_details = d.get_store_agent_details
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
@@ -19,6 +20,17 @@ from backend.util.test import SpinTestServer, wait_execution
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
|
||||
logger.info(f"Creating graph for user {u.id}")
|
||||
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
@@ -27,7 +28,6 @@ from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.onboarding import increment_runs
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.monitoring import (
|
||||
NotificationJobArgs,
|
||||
@@ -37,7 +37,7 @@ from backend.monitoring import (
|
||||
report_execution_accuracy_alerts,
|
||||
report_late_executions,
|
||||
)
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.clients import get_database_manager_client, get_scheduler_client
|
||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||
from backend.util.exceptions import (
|
||||
GraphNotFoundError,
|
||||
@@ -156,7 +156,6 @@ async def _execute_graph(**kwargs):
|
||||
inputs=args.input_data,
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
)
|
||||
await increment_runs(args.user_id)
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
logger.info(
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||
@@ -254,6 +253,74 @@ def execution_accuracy_alerts():
|
||||
return report_execution_accuracy_alerts()
|
||||
|
||||
|
||||
def ensure_embeddings_coverage():
|
||||
"""
|
||||
Ensure approved store agents have embeddings for hybrid search.
|
||||
|
||||
Processes ALL missing embeddings in batches of 10 until 100% coverage.
|
||||
Missing embeddings = agents invisible in hybrid search.
|
||||
|
||||
Schedule: Runs every 6 hours (balanced between coverage and API costs).
|
||||
- Catches agents approved between scheduled runs
|
||||
- Batch size 10: gradual processing to avoid rate limits
|
||||
- Manual trigger available via execute_ensure_embeddings_coverage endpoint
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
stats = db_client.get_embedding_stats()
|
||||
|
||||
# Check for error from get_embedding_stats() first
|
||||
if "error" in stats:
|
||||
logger.error(
|
||||
f"Failed to get embedding stats: {stats['error']} - skipping backfill"
|
||||
)
|
||||
return {"processed": 0, "success": 0, "failed": 0, "error": stats["error"]}
|
||||
|
||||
if stats["without_embeddings"] == 0:
|
||||
logger.info("All approved agents have embeddings, skipping backfill")
|
||||
return {"processed": 0, "success": 0, "failed": 0}
|
||||
|
||||
logger.info(
|
||||
f"Found {stats['without_embeddings']} agents without embeddings "
|
||||
f"({stats['coverage_percent']}% coverage) - processing all"
|
||||
)
|
||||
|
||||
total_processed = 0
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
|
||||
# Process in batches until no more missing embeddings
|
||||
while True:
|
||||
result = db_client.backfill_missing_embeddings(batch_size=10)
|
||||
|
||||
total_processed += result["processed"]
|
||||
total_success += result["success"]
|
||||
total_failed += result["failed"]
|
||||
|
||||
if result["processed"] == 0:
|
||||
# No more missing embeddings
|
||||
break
|
||||
|
||||
if result["success"] == 0 and result["processed"] > 0:
|
||||
# All attempts in this batch failed - stop to avoid infinite loop
|
||||
logger.error(
|
||||
f"All {result['processed']} embedding attempts failed - stopping backfill"
|
||||
)
|
||||
break
|
||||
|
||||
# Small delay between batches to avoid rate limits
|
||||
time.sleep(1)
|
||||
|
||||
logger.info(
|
||||
f"Embedding backfill completed: {total_success}/{total_processed} succeeded, "
|
||||
f"{total_failed} failed"
|
||||
)
|
||||
return {
|
||||
"processed": total_processed,
|
||||
"success": total_success,
|
||||
"failed": total_failed,
|
||||
}
|
||||
|
||||
|
||||
# Monitoring functions are now imported from monitoring module
|
||||
|
||||
|
||||
@@ -475,6 +542,19 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# Embedding Coverage - Every 6 hours
|
||||
# Ensures all approved agents have embeddings for hybrid search
|
||||
# Critical: missing embeddings = agents invisible in search
|
||||
self.scheduler.add_job(
|
||||
ensure_embeddings_coverage,
|
||||
id="ensure_embeddings_coverage",
|
||||
trigger="interval",
|
||||
hours=6,
|
||||
replace_existing=True,
|
||||
max_instances=1, # Prevent overlapping runs
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
|
||||
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
||||
@@ -632,6 +712,11 @@ class Scheduler(AppService):
|
||||
"""Manually trigger execution accuracy alert checking."""
|
||||
return execution_accuracy_alerts()
|
||||
|
||||
@expose
|
||||
def execute_ensure_embeddings_coverage(self):
|
||||
"""Manually trigger embedding backfill for approved store agents."""
|
||||
return ensure_embeddings_coverage()
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import onboarding as onboarding_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -31,7 +32,6 @@ from backend.data.execution import (
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
NodesInputMasks,
|
||||
get_graph_execution,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput
|
||||
@@ -809,13 +809,14 @@ async def add_graph_execution(
|
||||
edb = execution_db
|
||||
udb = user_db
|
||||
gdb = graph_db
|
||||
odb = onboarding_db
|
||||
else:
|
||||
edb = udb = gdb = get_database_manager_async_client()
|
||||
edb = udb = gdb = odb = get_database_manager_async_client()
|
||||
|
||||
# Get or create the graph execution
|
||||
if graph_exec_id:
|
||||
# Resume existing execution
|
||||
graph_exec = await get_graph_execution(
|
||||
graph_exec = await edb.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=True,
|
||||
@@ -891,6 +892,7 @@ async def add_graph_execution(
|
||||
)
|
||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||
|
||||
# Publish to execution queue for executor to pick up
|
||||
exec_queue = await get_async_execution_queue()
|
||||
await exec_queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
@@ -899,14 +901,12 @@ async def add_graph_execution(
|
||||
)
|
||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||
|
||||
# Update execution status to QUEUED
|
||||
graph_exec.status = ExecutionStatus.QUEUED
|
||||
await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=graph_exec.status,
|
||||
)
|
||||
await get_async_execution_event_bus().publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except BaseException as e:
|
||||
err = str(e) or type(e).__name__
|
||||
if not graph_exec:
|
||||
@@ -927,6 +927,24 @@ async def add_graph_execution(
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
await get_async_execution_event_bus().publish(graph_exec)
|
||||
logger.info(f"Published update for execution #{graph_exec.id} to event bus")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish execution event for graph exec #{graph_exec.id}: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
await odb.increment_onboarding_runs(user_id)
|
||||
logger.info(
|
||||
f"Incremented user #{user_id} onboarding runs for exec #{graph_exec.id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to increment onboarding runs for user #{user_id}: {e}")
|
||||
|
||||
return graph_exec
|
||||
|
||||
|
||||
# ============ Execution Output Helpers ============ #
|
||||
|
||||
|
||||
@@ -245,6 +245,21 @@ DEFAULT_CREDENTIALS = [
|
||||
webshare_proxy_credentials,
|
||||
]
|
||||
|
||||
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
||||
|
||||
# Set of providers that have system credentials available
|
||||
SYSTEM_PROVIDERS = {cred.provider for cred in DEFAULT_CREDENTIALS}
|
||||
|
||||
|
||||
def is_system_credential(credential_id: str) -> bool:
|
||||
"""Check if a credential ID belongs to a system-managed credential."""
|
||||
return credential_id in SYSTEM_CREDENTIAL_IDS
|
||||
|
||||
|
||||
def is_system_provider(provider: str) -> bool:
|
||||
"""Check if a provider has system-managed credentials available."""
|
||||
return provider in SYSTEM_PROVIDERS
|
||||
|
||||
|
||||
class IntegrationCredentialsStore:
|
||||
def __init__(self):
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.util.settings import Settings
|
||||
settings = Settings()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
from supabase import AClient, Client
|
||||
|
||||
from backend.data.execution import (
|
||||
@@ -139,6 +140,24 @@ async def get_async_supabase() -> "AClient":
|
||||
)
|
||||
|
||||
|
||||
# ============ OpenAI Client ============ #
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_openai_client() -> "AsyncOpenAI | None":
|
||||
"""
|
||||
Get a process-cached async OpenAI client for embeddings.
|
||||
|
||||
Returns None if API key is not configured.
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
api_key = settings.secrets.openai_internal_api_key
|
||||
if not api_key:
|
||||
return None
|
||||
return AsyncOpenAI(api_key=api_key)
|
||||
|
||||
|
||||
# ============ Notification Queue Helpers ============ #
|
||||
|
||||
|
||||
|
||||
@@ -658,6 +658,14 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
|
||||
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
|
||||
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
|
||||
|
||||
# Langfuse prompt management
|
||||
langfuse_public_key: str = Field(default="", description="Langfuse public key")
|
||||
langfuse_secret_key: str = Field(default="", description="Langfuse secret key")
|
||||
langfuse_host: str = Field(
|
||||
default="https://cloud.langfuse.com", description="Langfuse host URL"
|
||||
)
|
||||
|
||||
# Add more secret fields as needed
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
-- CreateExtension
|
||||
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
||||
-- Create in public schema so vector type is available across all schemas
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "vector" WITH SCHEMA "public";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
||||
END $$;
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "ContentType" AS ENUM ('STORE_AGENT', 'BLOCK', 'INTEGRATION', 'DOCUMENTATION', 'LIBRARY_AGENT');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "UnifiedContentEmbedding" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"contentType" "ContentType" NOT NULL,
|
||||
"contentId" TEXT NOT NULL,
|
||||
"userId" TEXT,
|
||||
"embedding" public.vector(1536) NOT NULL,
|
||||
"searchableText" TEXT NOT NULL,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "UnifiedContentEmbedding_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UnifiedContentEmbedding_contentType_idx" ON "UnifiedContentEmbedding"("contentType");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UnifiedContentEmbedding_userId_idx" ON "UnifiedContentEmbedding"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UnifiedContentEmbedding_contentType_userId_idx" ON "UnifiedContentEmbedding"("contentType", "userId");
|
||||
|
||||
-- CreateIndex
|
||||
-- NULLS NOT DISTINCT ensures only one public (NULL userId) embedding per contentType+contentId
|
||||
-- Requires PostgreSQL 15+. Supabase uses PostgreSQL 15+.
|
||||
CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" ON "UnifiedContentEmbedding"("contentType", "contentId", "userId") NULLS NOT DISTINCT;
|
||||
|
||||
-- CreateIndex
|
||||
-- HNSW index for fast vector similarity search on embeddings
|
||||
-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py
|
||||
CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops);
|
||||
@@ -0,0 +1,71 @@
|
||||
-- Acknowledge Supabase-managed extensions to prevent drift warnings
|
||||
-- These extensions are pre-installed by Supabase in specific schemas
|
||||
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
|
||||
|
||||
-- Create schemas (safe in both CI and Supabase)
|
||||
CREATE SCHEMA IF NOT EXISTS "extensions";
|
||||
|
||||
-- Extensions that exist in both CI and Supabase
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgcrypto extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'uuid-ossp extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
-- Supabase-specific extensions (skip gracefully in CI)
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_net extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgjwt extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "graphql";
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_graphql extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "pgsodium";
|
||||
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgsodium extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "vault";
|
||||
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'supabase_vault extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
|
||||
-- Return to platform
|
||||
CREATE SCHEMA IF NOT EXISTS "platform";
|
||||
@@ -0,0 +1,64 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "CoPilotUnderstanding" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT NOT NULL,
|
||||
"data" JSONB,
|
||||
|
||||
CONSTRAINT "CoPilotUnderstanding_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ChatSession" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT NOT NULL,
|
||||
"title" TEXT,
|
||||
"credentials" JSONB NOT NULL DEFAULT '{}',
|
||||
"successfulAgentRuns" JSONB NOT NULL DEFAULT '{}',
|
||||
"successfulAgentSchedules" JSONB NOT NULL DEFAULT '{}',
|
||||
"totalPromptTokens" INTEGER NOT NULL DEFAULT 0,
|
||||
"totalCompletionTokens" INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
CONSTRAINT "ChatSession_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ChatMessage" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"sessionId" TEXT NOT NULL,
|
||||
"role" TEXT NOT NULL,
|
||||
"content" TEXT,
|
||||
"name" TEXT,
|
||||
"toolCallId" TEXT,
|
||||
"refusal" TEXT,
|
||||
"toolCalls" JSONB,
|
||||
"functionCall" JSONB,
|
||||
"sequence" INTEGER NOT NULL,
|
||||
|
||||
CONSTRAINT "ChatMessage_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "CoPilotUnderstanding_userId_key" ON "CoPilotUnderstanding"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "CoPilotUnderstanding_userId_idx" ON "CoPilotUnderstanding"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ChatSession_userId_updatedAt_idx" ON "ChatSession"("userId", "updatedAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "ChatMessage_sessionId_sequence_key" ON "ChatMessage"("sessionId", "sequence");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "CoPilotUnderstanding" ADD CONSTRAINT "CoPilotUnderstanding_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ChatSession" ADD CONSTRAINT "ChatSession_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ChatMessage" ADD CONSTRAINT "ChatMessage_sessionId_fkey" FOREIGN KEY ("sessionId") REFERENCES "ChatSession"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
201
autogpt_platform/backend/poetry.lock
generated
201
autogpt_platform/backend/poetry.lock
generated
@@ -2777,6 +2777,30 @@ enabler = ["pytest-enabler (>=2.2)"]
|
||||
test = ["pyfakefs", "pytest (>=6,!=8.1.*)"]
|
||||
type = ["pygobject-stubs", "pytest-mypy", "shtab", "types-pywin32"]
|
||||
|
||||
[[package]]
|
||||
name = "langfuse"
|
||||
version = "3.11.2"
|
||||
description = "A client library for accessing langfuse"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "langfuse-3.11.2-py3-none-any.whl", hash = "sha256:84faea9f909694023cc7f0eb45696be190248c8790424f22af57ca4cd7a29f2d"},
|
||||
{file = "langfuse-3.11.2.tar.gz", hash = "sha256:ab5f296a8056815b7288c7f25bc308a5e79f82a8634467b25daffdde99276e09"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
backoff = ">=1.10.0"
|
||||
httpx = ">=0.15.4,<1.0"
|
||||
openai = ">=0.27.8"
|
||||
opentelemetry-api = ">=1.33.1,<2.0.0"
|
||||
opentelemetry-exporter-otlp-proto-http = ">=1.33.1,<2.0.0"
|
||||
opentelemetry-sdk = ">=1.33.1,<2.0.0"
|
||||
packaging = ">=23.2,<26.0"
|
||||
pydantic = ">=1.10.7,<3.0"
|
||||
requests = ">=2,<3"
|
||||
wrapt = ">=1.14,<2.0"
|
||||
|
||||
[[package]]
|
||||
name = "launchdarkly-eventsource"
|
||||
version = "1.3.0"
|
||||
@@ -3468,6 +3492,90 @@ files = [
|
||||
importlib-metadata = ">=6.0,<8.8.0"
|
||||
typing-extensions = ">=4.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-exporter-otlp-proto-common"
|
||||
version = "1.35.0"
|
||||
description = "OpenTelemetry Protobuf encoding"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "opentelemetry_exporter_otlp_proto_common-1.35.0-py3-none-any.whl", hash = "sha256:863465de697ae81279ede660f3918680b4480ef5f69dcdac04f30722ed7b74cc"},
|
||||
{file = "opentelemetry_exporter_otlp_proto_common-1.35.0.tar.gz", hash = "sha256:6f6d8c39f629b9fa5c79ce19a2829dbd93034f8ac51243cdf40ed2196f00d7eb"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-proto = "1.35.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-exporter-otlp-proto-http"
|
||||
version = "1.35.0"
|
||||
description = "OpenTelemetry Collector Protobuf over HTTP Exporter"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "opentelemetry_exporter_otlp_proto_http-1.35.0-py3-none-any.whl", hash = "sha256:9a001e3df3c7f160fb31056a28ed7faa2de7df68877ae909516102ae36a54e1d"},
|
||||
{file = "opentelemetry_exporter_otlp_proto_http-1.35.0.tar.gz", hash = "sha256:cf940147f91b450ef5f66e9980d40eb187582eed399fa851f4a7a45bb880de79"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
googleapis-common-protos = ">=1.52,<2.0"
|
||||
opentelemetry-api = ">=1.15,<2.0"
|
||||
opentelemetry-exporter-otlp-proto-common = "1.35.0"
|
||||
opentelemetry-proto = "1.35.0"
|
||||
opentelemetry-sdk = ">=1.35.0,<1.36.0"
|
||||
requests = ">=2.7,<3.0"
|
||||
typing-extensions = ">=4.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-proto"
|
||||
version = "1.35.0"
|
||||
description = "OpenTelemetry Python Proto"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "opentelemetry_proto-1.35.0-py3-none-any.whl", hash = "sha256:98fffa803164499f562718384e703be8d7dfbe680192279a0429cb150a2f8809"},
|
||||
{file = "opentelemetry_proto-1.35.0.tar.gz", hash = "sha256:532497341bd3e1c074def7c5b00172601b28bb83b48afc41a4b779f26eb4ee05"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
protobuf = ">=5.0,<7.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-sdk"
|
||||
version = "1.35.0"
|
||||
description = "OpenTelemetry Python SDK"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "opentelemetry_sdk-1.35.0-py3-none-any.whl", hash = "sha256:223d9e5f5678518f4842311bb73966e0b6db5d1e0b74e35074c052cd2487f800"},
|
||||
{file = "opentelemetry_sdk-1.35.0.tar.gz", hash = "sha256:2a400b415ab68aaa6f04e8a6a9f6552908fb3090ae2ff78d6ae0c597ac581954"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = "1.35.0"
|
||||
opentelemetry-semantic-conventions = "0.56b0"
|
||||
typing-extensions = ">=4.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-semantic-conventions"
|
||||
version = "0.56b0"
|
||||
description = "OpenTelemetry Semantic Conventions"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "opentelemetry_semantic_conventions-0.56b0-py3-none-any.whl", hash = "sha256:df44492868fd6b482511cc43a942e7194be64e94945f572db24df2e279a001a2"},
|
||||
{file = "opentelemetry_semantic_conventions-0.56b0.tar.gz", hash = "sha256:c114c2eacc8ff6d3908cb328c811eaf64e6d68623840be9224dc829c4fd6c2ea"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = "1.35.0"
|
||||
typing-extensions = ">=4.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "orjson"
|
||||
version = "3.11.3"
|
||||
@@ -6922,6 +7030,97 @@ files = [
|
||||
{file = "websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wrapt"
|
||||
version = "1.17.3"
|
||||
description = "Module for decorators, wrappers and monkey patching."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "wrapt-1.17.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:88bbae4d40d5a46142e70d58bf664a89b6b4befaea7b2ecc14e03cedb8e06c04"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6b13af258d6a9ad602d57d889f83b9d5543acd471eee12eb51f5b01f8eb1bc2"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd341868a4b6714a5962c1af0bd44f7c404ef78720c7de4892901e540417111c"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f9b2601381be482f70e5d1051a5965c25fb3625455a2bf520b5a077b22afb775"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343e44b2a8e60e06a7e0d29c1671a0d9951f59174f3709962b5143f60a2a98bd"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:33486899acd2d7d3066156b03465b949da3fd41a5da6e394ec49d271baefcf05"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e6f40a8aa5a92f150bdb3e1c44b7e98fb7113955b2e5394122fa5532fec4b418"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-win32.whl", hash = "sha256:a36692b8491d30a8c75f1dfee65bef119d6f39ea84ee04d9f9311f83c5ad9390"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-win_amd64.whl", hash = "sha256:afd964fd43b10c12213574db492cb8f73b2f0826c8df07a68288f8f19af2ebe6"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-win_arm64.whl", hash = "sha256:af338aa93554be859173c39c85243970dc6a289fa907402289eeae7543e1ae18"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:273a736c4645e63ac582c60a56b0acb529ef07f78e08dc6bfadf6a46b19c0da7"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5531d911795e3f935a9c23eb1c8c03c211661a5060aab167065896bbf62a5f85"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0610b46293c59a3adbae3dee552b648b984176f8562ee0dba099a56cfbe4df1f"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b32888aad8b6e68f83a8fdccbf3165f5469702a7544472bdf41f582970ed3311"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cccf4f81371f257440c88faed6b74f1053eef90807b77e31ca057b2db74edb1"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8a210b158a34164de8bb68b0e7780041a903d7b00c87e906fb69928bf7890d5"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:79573c24a46ce11aab457b472efd8d125e5a51da2d1d24387666cd85f54c05b2"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-win32.whl", hash = "sha256:c31eebe420a9a5d2887b13000b043ff6ca27c452a9a22fa71f35f118e8d4bf89"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-win_amd64.whl", hash = "sha256:0b1831115c97f0663cb77aa27d381237e73ad4f721391a9bfb2fe8bc25fa6e77"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-win_arm64.whl", hash = "sha256:5a7b3c1ee8265eb4c8f1b7d29943f195c00673f5ab60c192eba2d4a7eae5f46a"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ab232e7fdb44cdfbf55fc3afa31bcdb0d8980b9b95c38b6405df2acb672af0e0"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9baa544e6acc91130e926e8c802a17f3b16fbea0fd441b5a60f5cf2cc5c3deba"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b538e31eca1a7ea4605e44f81a48aa24c4632a277431a6ed3f328835901f4fd"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:042ec3bb8f319c147b1301f2393bc19dba6e176b7da446853406d041c36c7828"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3af60380ba0b7b5aeb329bc4e402acd25bd877e98b3727b0135cb5c2efdaefe9"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b02e424deef65c9f7326d8c19220a2c9040c51dc165cddb732f16198c168396"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:74afa28374a3c3a11b3b5e5fca0ae03bef8450d6aa3ab3a1e2c30e3a75d023dc"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-win32.whl", hash = "sha256:4da9f45279fff3543c371d5ababc57a0384f70be244de7759c85a7f989cb4ebe"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:e71d5c6ebac14875668a1e90baf2ea0ef5b7ac7918355850c0908ae82bcb297c"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-win_arm64.whl", hash = "sha256:604d076c55e2fdd4c1c03d06dc1a31b95130010517b5019db15365ec4a405fc6"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a47681378a0439215912ef542c45a783484d4dd82bac412b71e59cf9c0e1cea0"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:54a30837587c6ee3cd1a4d1c2ec5d24e77984d44e2f34547e2323ddb4e22eb77"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ecf15d6af39246fe33e507105d67e4b81d8f8d2c6598ff7e3ca1b8a37213f7"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6fd1ad24dc235e4ab88cda009e19bf347aabb975e44fd5c2fb22a3f6e4141277"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ed61b7c2d49cee3c027372df5809a59d60cf1b6c2f81ee980a091f3afed6a2d"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:423ed5420ad5f5529db9ce89eac09c8a2f97da18eb1c870237e84c5a5c2d60aa"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e01375f275f010fcbf7f643b4279896d04e571889b8a5b3f848423d91bf07050"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-win32.whl", hash = "sha256:53e5e39ff71b3fc484df8a522c933ea2b7cdd0d5d15ae82e5b23fde87d44cbd8"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-win_amd64.whl", hash = "sha256:1f0b2f40cf341ee8cc1a97d51ff50dddb9fcc73241b9143ec74b30fc4f44f6cb"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-win_arm64.whl", hash = "sha256:7425ac3c54430f5fc5e7b6f41d41e704db073309acfc09305816bc6a0b26bb16"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:cf30f6e3c077c8e6a9a7809c94551203c8843e74ba0c960f4a98cd80d4665d39"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e228514a06843cae89621384cfe3a80418f3c04aadf8a3b14e46a7be704e4235"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:5ea5eb3c0c071862997d6f3e02af1d055f381b1d25b286b9d6644b79db77657c"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:281262213373b6d5e4bb4353bc36d1ba4084e6d6b5d242863721ef2bf2c2930b"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dc4a8d2b25efb6681ecacad42fca8859f88092d8732b170de6a5dddd80a1c8fa"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:373342dd05b1d07d752cecbec0c41817231f29f3a89aa8b8843f7b95992ed0c7"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d40770d7c0fd5cbed9d84b2c3f2e156431a12c9a37dc6284060fb4bec0b7ffd4"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-win32.whl", hash = "sha256:fbd3c8319de8e1dc79d346929cd71d523622da527cca14e0c1d257e31c2b8b10"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-win_amd64.whl", hash = "sha256:e1a4120ae5705f673727d3253de3ed0e016f7cd78dc463db1b31e2463e1f3cf6"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-win_arm64.whl", hash = "sha256:507553480670cab08a800b9463bdb881b2edeed77dc677b0a5915e6106e91a58"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:ed7c635ae45cfbc1a7371f708727bf74690daedc49b4dba310590ca0bd28aa8a"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:249f88ed15503f6492a71f01442abddd73856a0032ae860de6d75ca62eed8067"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5a03a38adec8066d5a37bea22f2ba6bbf39fcdefbe2d91419ab864c3fb515454"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5d4478d72eb61c36e5b446e375bbc49ed002430d17cdec3cecb36993398e1a9e"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223db574bb38637e8230eb14b185565023ab624474df94d2af18f1cdb625216f"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e405adefb53a435f01efa7ccdec012c016b5a1d3f35459990afc39b6be4d5056"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:88547535b787a6c9ce4086917b6e1d291aa8ed914fdd3a838b3539dc95c12804"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-win32.whl", hash = "sha256:41b1d2bc74c2cac6f9074df52b2efbef2b30bdfe5f40cb78f8ca22963bc62977"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-win_amd64.whl", hash = "sha256:73d496de46cd2cdbdbcce4ae4bcdb4afb6a11234a1df9c085249d55166b95116"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-win_arm64.whl", hash = "sha256:f38e60678850c42461d4202739f9bf1e3a737c7ad283638251e79cc49effb6b6"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:70d86fa5197b8947a2fa70260b48e400bf2ccacdcab97bb7de47e3d1e6312225"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:df7d30371a2accfe4013e90445f6388c570f103d61019b6b7c57e0265250072a"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:caea3e9c79d5f0d2c6d9ab96111601797ea5da8e6d0723f77eabb0d4068d2b2f"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:758895b01d546812d1f42204bd443b8c433c44d090248bf22689df673ccafe00"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:02b551d101f31694fc785e58e0720ef7d9a10c4e62c1c9358ce6f63f23e30a56"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:656873859b3b50eeebe6db8b1455e99d90c26ab058db8e427046dbc35c3140a5"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:a9a2203361a6e6404f80b99234fe7fb37d1fc73487b5a78dc1aa5b97201e0f22"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-win32.whl", hash = "sha256:55cbbc356c2842f39bcc553cf695932e8b30e30e797f961860afb308e6b1bb7c"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-win_amd64.whl", hash = "sha256:ad85e269fe54d506b240d2d7b9f5f2057c2aa9a2ea5b32c66f8902f768117ed2"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:30ce38e66630599e1193798285706903110d4f057aab3168a34b7fdc85569afc"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:65d1d00fbfb3ea5f20add88bbc0f815150dbbde3b026e6c24759466c8b5a9ef9"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a7c06742645f914f26c7f1fa47b8bc4c91d222f76ee20116c43d5ef0912bba2d"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7e18f01b0c3e4a07fe6dfdb00e29049ba17eadbc5e7609a2a3a4af83ab7d710a"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0f5f51a6466667a5a356e6381d362d259125b57f059103dd9fdc8c0cf1d14139"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:59923aa12d0157f6b82d686c3fd8e1166fa8cdfb3e17b42ce3b6147ff81528df"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:46acc57b331e0b3bcb3e1ca3b421d65637915cfcd65eb783cb2f78a511193f9b"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-win32.whl", hash = "sha256:3e62d15d3cfa26e3d0788094de7b64efa75f3a53875cdbccdf78547aed547a81"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-win_amd64.whl", hash = "sha256:1f23fa283f51c890eda8e34e4937079114c74b4c81d2b2f1f1d94948f5cc3d7f"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-win_arm64.whl", hash = "sha256:24c2ed34dc222ed754247a2702b1e1e89fdbaa4016f324b4b8f1a802d4ffe87f"},
|
||||
{file = "wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22"},
|
||||
{file = "wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xattr"
|
||||
version = "1.2.0"
|
||||
@@ -7295,4 +7494,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "a93ba0cea3b465cb6ec3e3f258b383b09f84ea352ccfdbfa112902cde5653fc6"
|
||||
content-hash = "86838b5ae40d606d6e01a14dad8a56c389d890d7a6a0c274a6602cca80f0df84"
|
||||
|
||||
@@ -33,6 +33,7 @@ html2text = "^2024.2.26"
|
||||
jinja2 = "^3.1.6"
|
||||
jsonref = "^1.1.0"
|
||||
jsonschema = "^4.25.0"
|
||||
langfuse = "^3.11.0"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
mem0ai = "^0.1.115"
|
||||
moviepy = "^2.1.2"
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
datasource db {
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
extensions = [pgvector(map: "vector")]
|
||||
}
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = -1
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views", "fullTextSearch"]
|
||||
previewFeatures = ["views", "fullTextSearch", "postgresqlExtensions"]
|
||||
partial_type_generator = "backend/data/partial_types.py"
|
||||
}
|
||||
|
||||
@@ -47,12 +48,13 @@ model User {
|
||||
AnalyticsMetrics AnalyticsMetrics[]
|
||||
CreditTransactions CreditTransaction[]
|
||||
UserBalance UserBalance?
|
||||
|
||||
AgentPresets AgentPreset[]
|
||||
LibraryAgents LibraryAgent[]
|
||||
ChatSessions ChatSession[]
|
||||
AgentPresets AgentPreset[]
|
||||
LibraryAgents LibraryAgent[]
|
||||
|
||||
Profile Profile[]
|
||||
UserOnboarding UserOnboarding?
|
||||
CoPilotUnderstanding CoPilotUnderstanding?
|
||||
BuilderSearchHistory BuilderSearchHistory[]
|
||||
StoreListings StoreListing[]
|
||||
StoreListingReviews StoreListingReview[]
|
||||
@@ -121,19 +123,84 @@ model UserOnboarding {
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
}
|
||||
|
||||
model CoPilotUnderstanding {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
userId String @unique
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
data Json?
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
model BuilderSearchHistory {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
searchQuery String
|
||||
filter String[] @default([])
|
||||
byCreator String[] @default([])
|
||||
filter String[] @default([])
|
||||
byCreator String[] @default([])
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
//////////////// CHAT SESSION TABLES ///////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
model ChatSession {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
// Session metadata
|
||||
title String?
|
||||
credentials Json @default("{}") // Map of provider -> credential metadata
|
||||
|
||||
// Rate limiting counters (stored as JSON maps)
|
||||
successfulAgentRuns Json @default("{}") // Map of graph_id -> count
|
||||
successfulAgentSchedules Json @default("{}") // Map of graph_id -> count
|
||||
|
||||
// Usage tracking
|
||||
totalPromptTokens Int @default(0)
|
||||
totalCompletionTokens Int @default(0)
|
||||
|
||||
Messages ChatMessage[]
|
||||
|
||||
@@index([userId, updatedAt])
|
||||
}
|
||||
|
||||
model ChatMessage {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
sessionId String
|
||||
Session ChatSession @relation(fields: [sessionId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Message content
|
||||
role String // "user", "assistant", "system", "tool", "function"
|
||||
content String?
|
||||
name String?
|
||||
toolCallId String?
|
||||
refusal String?
|
||||
toolCalls Json? // List of tool calls for assistant messages
|
||||
functionCall Json? // Deprecated but kept for compatibility
|
||||
|
||||
// Ordering within session
|
||||
sequence Int
|
||||
|
||||
@@unique([sessionId, sequence])
|
||||
}
|
||||
|
||||
// This model describes the Agent Graph/Flow (Multi Agent System).
|
||||
model AgentGraph {
|
||||
id String @default(uuid())
|
||||
@@ -721,26 +788,25 @@ view StoreAgent {
|
||||
storeListingVersionId String
|
||||
updated_at DateTime
|
||||
|
||||
slug String
|
||||
agent_name String
|
||||
agent_video String?
|
||||
agent_output_demo String?
|
||||
agent_image String[]
|
||||
slug String
|
||||
agent_name String
|
||||
agent_video String?
|
||||
agent_output_demo String?
|
||||
agent_image String[]
|
||||
|
||||
featured Boolean @default(false)
|
||||
creator_username String?
|
||||
creator_avatar String?
|
||||
sub_heading String
|
||||
description String
|
||||
categories String[]
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
agentGraphVersions String[]
|
||||
agentGraphId String
|
||||
is_available Boolean @default(true)
|
||||
useForOnboarding Boolean @default(false)
|
||||
featured Boolean @default(false)
|
||||
creator_username String?
|
||||
creator_avatar String?
|
||||
sub_heading String
|
||||
description String
|
||||
categories String[]
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
agentGraphVersions String[]
|
||||
agentGraphId String
|
||||
is_available Boolean @default(true)
|
||||
useForOnboarding Boolean @default(false)
|
||||
|
||||
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
||||
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
||||
@@ -856,14 +922,14 @@ model StoreListingVersion {
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
|
||||
|
||||
// Content fields
|
||||
name String
|
||||
subHeading String
|
||||
videoUrl String?
|
||||
agentOutputDemoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
name String
|
||||
subHeading String
|
||||
videoUrl String?
|
||||
agentOutputDemoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
|
||||
isFeatured Boolean @default(false)
|
||||
|
||||
@@ -899,6 +965,9 @@ model StoreListingVersion {
|
||||
// Reviews for this specific version
|
||||
Reviews StoreListingReview[]
|
||||
|
||||
// Note: Embeddings now stored in UnifiedContentEmbedding table
|
||||
// Use contentType=STORE_AGENT and contentId=storeListingVersionId
|
||||
|
||||
@@unique([storeListingId, version])
|
||||
@@index([storeListingId, submissionStatus, isAvailable])
|
||||
@@index([submissionStatus])
|
||||
@@ -906,6 +975,42 @@ model StoreListingVersion {
|
||||
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
|
||||
}
|
||||
|
||||
// Content type enum for unified search across store agents, blocks, docs
|
||||
// Note: BLOCK/INTEGRATION are file-based (Python classes), not DB records
|
||||
// DOCUMENTATION are file-based (.md files), not DB records
|
||||
// Only STORE_AGENT and LIBRARY_AGENT are stored in database
|
||||
enum ContentType {
|
||||
STORE_AGENT // Database: StoreListingVersion
|
||||
BLOCK // File-based: Python classes in /backend/blocks/
|
||||
INTEGRATION // File-based: Python classes (blocks with credentials)
|
||||
DOCUMENTATION // File-based: .md/.mdx files
|
||||
LIBRARY_AGENT // Database: User's personal agents
|
||||
}
|
||||
|
||||
// Unified embeddings table for all searchable content types
|
||||
// Supports both public content (userId=null) and user-specific content (userId=userID)
|
||||
model UnifiedContentEmbedding {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Content identification
|
||||
contentType ContentType
|
||||
contentId String // DB ID (storeListingVersionId) or file identifier (block.id, file_path)
|
||||
userId String? // NULL for public content (store, blocks, docs), userId for private content (library agents)
|
||||
|
||||
// Search data
|
||||
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
|
||||
searchableText String // Combined text for search and fallback
|
||||
metadata Json @default("{}") // Content-specific metadata
|
||||
|
||||
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
|
||||
@@index([contentType])
|
||||
@@index([userId])
|
||||
@@index([contentType, userId])
|
||||
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
@@ -998,16 +1103,16 @@ model OAuthApplication {
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Application metadata
|
||||
name String
|
||||
description String?
|
||||
logoUrl String? // URL to app logo stored in GCS
|
||||
clientId String @unique
|
||||
clientSecret String // Hashed with Scrypt (same as API keys)
|
||||
clientSecretSalt String // Salt for Scrypt hashing
|
||||
name String
|
||||
description String?
|
||||
logoUrl String? // URL to app logo stored in GCS
|
||||
clientId String @unique
|
||||
clientSecret String // Hashed with Scrypt (same as API keys)
|
||||
clientSecretSalt String // Salt for Scrypt hashing
|
||||
|
||||
// OAuth configuration
|
||||
redirectUris String[] // Allowed callback URLs
|
||||
grantTypes String[] @default(["authorization_code", "refresh_token"])
|
||||
grantTypes String[] @default(["authorization_code", "refresh_token"])
|
||||
scopes APIKeyPermission[] // Which permissions the app can request
|
||||
|
||||
// Application management
|
||||
|
||||
@@ -708,10 +708,7 @@ export function CreateButton() {
|
||||
|
||||
## 🧪 Testing & Storybook
|
||||
|
||||
- End-to-end: [Playwright](https://playwright.dev/docs/intro) (`pnpm test`, `pnpm test-ui`)
|
||||
- [Storybook](https://storybook.js.org/docs) for isolated UI development (`pnpm storybook` / `pnpm build-storybook`)
|
||||
- For Storybook tests in CI, see [`@storybook/test-runner`](https://storybook.js.org/docs/writing-tests/test-runner) (`test-storybook:ci`)
|
||||
- When changing components in `src/components`, update or add stories and visually verify in Storybook/Chromatic
|
||||
- See `TESTING.md` for Playwright setup, E2E data seeding, and Storybook usage.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ This is the frontend for AutoGPT's next generation
|
||||
This project uses [**pnpm**](https://pnpm.io/) as the package manager via **corepack**. [Corepack](https://github.com/nodejs/corepack) is a Node.js tool that automatically manages package managers without requiring global installations.
|
||||
|
||||
For architecture, conventions, data fetching, feature flags, design system usage, state management, and PR process, see [CONTRIBUTING.md](./CONTRIBUTING.md).
|
||||
For Playwright and Storybook testing setup, see [TESTING.md](./TESTING.md).
|
||||
|
||||
### Prerequisites
|
||||
|
||||
|
||||
57
autogpt_platform/frontend/TESTING.md
Normal file
57
autogpt_platform/frontend/TESTING.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Frontend Testing 🧪
|
||||
|
||||
## Quick Start (local) 🚀
|
||||
|
||||
1. Start the backend + Supabase stack:
|
||||
- From `autogpt_platform`: `docker compose --profile local up deps_backend -d`
|
||||
- Or run the full stack: `docker compose up -d`
|
||||
2. Seed rich E2E data (creates `test123@gmail.com` with library agents):
|
||||
- From `autogpt_platform/backend`: `poetry run python test/e2e_test_data.py`
|
||||
3. Run Playwright:
|
||||
- From `autogpt_platform/frontend`: `pnpm test` or `pnpm test-ui`
|
||||
|
||||
## How Playwright setup works 🎭
|
||||
|
||||
- Playwright runs from `frontend/playwright.config.ts` with a global setup step.
|
||||
- The global setup creates a user pool via the real signup UI and stores it in `frontend/.auth/user-pool.json`.
|
||||
- Most tests call `getTestUser()` (from `src/tests/utils/auth.ts`) which pulls a random user from that pool.
|
||||
- these users do not contain library agents, it's user that just "signed up" on the platform, hence some tests to make use of users created via script (see below) with more data
|
||||
|
||||
## Test users 👤
|
||||
|
||||
- **User pool (basic users)**
|
||||
Created automatically by the Playwright global setup through `/signup`.
|
||||
Used by `getTestUser()` in `src/tests/utils/auth.ts`.
|
||||
|
||||
- **Rich user with library agents**
|
||||
Created by `backend/test/e2e_test_data.py`.
|
||||
Accessed via `getTestUserWithLibraryAgents()` in `src/tests/credentials/index.ts`.
|
||||
|
||||
Use the rich user when a test needs existing library agents (e.g. `library.spec.ts`).
|
||||
|
||||
## Resetting or wiping the DB 🔁
|
||||
|
||||
If you reset the Docker DB and logins start failing:
|
||||
|
||||
1. Delete `frontend/.auth/user-pool.json` so the pool is regenerated.
|
||||
2. Re-run the E2E data script to recreate the rich user + library agents:
|
||||
- `poetry run python test/e2e_test_data.py`
|
||||
|
||||
## Storybook 📚
|
||||
|
||||
## Flow diagram 🗺️
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Start Docker stack] --> B[Run e2e_test_data.py]
|
||||
B --> C[Run Playwright tests]
|
||||
C --> D[Global setup creates user pool]
|
||||
D --> E{Test needs rich data?}
|
||||
E -->|No| F[getTestUser from user pool]
|
||||
E -->|Yes| G[getTestUserWithLibraryAgents]
|
||||
```
|
||||
|
||||
- `pnpm storybook` – Run Storybook locally
|
||||
- `pnpm build-storybook` – Build a static Storybook
|
||||
- CI runner: `pnpm test-storybook`
|
||||
- When changing components in `src/components`, update or add stories and verify in Storybook/Chromatic.
|
||||
@@ -3,6 +3,13 @@ import { withSentryConfig } from "@sentry/nextjs";
|
||||
/** @type {import('next').NextConfig} */
|
||||
const nextConfig = {
|
||||
productionBrowserSourceMaps: true,
|
||||
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
|
||||
serverExternalPackages: [
|
||||
"@opentelemetry/instrumentation",
|
||||
"@opentelemetry/sdk-node",
|
||||
"import-in-the-middle",
|
||||
"require-in-the-middle",
|
||||
],
|
||||
experimental: {
|
||||
serverActions: {
|
||||
bodySizeLimit: "256mb",
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
"@hookform/resolvers": "5.2.2",
|
||||
"@next/third-parties": "15.4.6",
|
||||
"@phosphor-icons/react": "2.1.10",
|
||||
"@radix-ui/react-accordion": "1.2.12",
|
||||
"@radix-ui/react-alert-dialog": "1.1.15",
|
||||
"@radix-ui/react-avatar": "1.1.10",
|
||||
"@radix-ui/react-checkbox": "1.3.3",
|
||||
@@ -117,6 +118,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "4.1.2",
|
||||
"@opentelemetry/instrumentation": "0.209.0",
|
||||
"@playwright/test": "1.56.1",
|
||||
"@storybook/addon-a11y": "9.1.5",
|
||||
"@storybook/addon-docs": "9.1.5",
|
||||
@@ -140,6 +142,7 @@
|
||||
"eslint": "8.57.1",
|
||||
"eslint-config-next": "15.5.7",
|
||||
"eslint-plugin-storybook": "9.1.5",
|
||||
"import-in-the-middle": "2.0.2",
|
||||
"msw": "2.11.6",
|
||||
"msw-storybook-addon": "2.0.6",
|
||||
"orval": "7.13.0",
|
||||
@@ -147,7 +150,7 @@
|
||||
"postcss": "8.5.6",
|
||||
"prettier": "3.6.2",
|
||||
"prettier-plugin-tailwindcss": "0.7.1",
|
||||
"require-in-the-middle": "7.5.2",
|
||||
"require-in-the-middle": "8.0.1",
|
||||
"storybook": "9.1.5",
|
||||
"tailwindcss": "3.4.17",
|
||||
"typescript": "5.9.3"
|
||||
@@ -157,5 +160,10 @@
|
||||
"public"
|
||||
]
|
||||
},
|
||||
"pnpm": {
|
||||
"overrides": {
|
||||
"@opentelemetry/instrumentation": "0.209.0"
|
||||
}
|
||||
},
|
||||
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
||||
}
|
||||
|
||||
140
autogpt_platform/frontend/pnpm-lock.yaml
generated
140
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -4,6 +4,9 @@ settings:
|
||||
autoInstallPeers: true
|
||||
excludeLinksFromLockfile: false
|
||||
|
||||
overrides:
|
||||
'@opentelemetry/instrumentation': 0.209.0
|
||||
|
||||
importers:
|
||||
|
||||
.:
|
||||
@@ -20,6 +23,9 @@ importers:
|
||||
'@phosphor-icons/react':
|
||||
specifier: 2.1.10
|
||||
version: 2.1.10(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@radix-ui/react-accordion':
|
||||
specifier: 1.2.12
|
||||
version: 1.2.12(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@radix-ui/react-alert-dialog':
|
||||
specifier: 1.1.15
|
||||
version: 1.1.15(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
@@ -270,6 +276,9 @@ importers:
|
||||
'@chromatic-com/storybook':
|
||||
specifier: 4.1.2
|
||||
version: 4.1.2(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))
|
||||
'@opentelemetry/instrumentation':
|
||||
specifier: 0.209.0
|
||||
version: 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@playwright/test':
|
||||
specifier: 1.56.1
|
||||
version: 1.56.1
|
||||
@@ -339,6 +348,9 @@ importers:
|
||||
eslint-plugin-storybook:
|
||||
specifier: 9.1.5
|
||||
version: 9.1.5(eslint@8.57.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))(typescript@5.9.3)
|
||||
import-in-the-middle:
|
||||
specifier: 2.0.2
|
||||
version: 2.0.2
|
||||
msw:
|
||||
specifier: 2.11.6
|
||||
version: 2.11.6(@types/node@24.10.0)(typescript@5.9.3)
|
||||
@@ -361,8 +373,8 @@ importers:
|
||||
specifier: 0.7.1
|
||||
version: 0.7.1(prettier@3.6.2)
|
||||
require-in-the-middle:
|
||||
specifier: 7.5.2
|
||||
version: 7.5.2
|
||||
specifier: 8.0.1
|
||||
version: 8.0.1
|
||||
storybook:
|
||||
specifier: 9.1.5
|
||||
version: 9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)
|
||||
@@ -1543,8 +1555,8 @@ packages:
|
||||
'@open-draft/until@2.1.0':
|
||||
resolution: {integrity: sha512-U69T3ItWHvLwGg5eJ0n3I62nWuE6ilHlmz7zM0npLBRvPRd7e6NYmg54vvRtP5mZG7kZqZCFVdsTWo7BPtBujg==}
|
||||
|
||||
'@opentelemetry/api-logs@0.208.0':
|
||||
resolution: {integrity: sha512-CjruKY9V6NMssL/T1kAFgzosF1v9o6oeN+aX5JB/C/xPNtmgIJqcXHG7fA82Ou1zCpWGl4lROQUKwUNE1pMCyg==}
|
||||
'@opentelemetry/api-logs@0.209.0':
|
||||
resolution: {integrity: sha512-xomnUNi7TiAGtOgs0tb54LyrjRZLu9shJGGwkcN7NgtiPYOpNnKLkRJtzZvTjD/w6knSZH9sFZcUSUovYOPg6A==}
|
||||
engines: {node: '>=8.0.0'}
|
||||
|
||||
'@opentelemetry/api@1.9.0':
|
||||
@@ -1695,8 +1707,8 @@ packages:
|
||||
peerDependencies:
|
||||
'@opentelemetry/api': ^1.7.0
|
||||
|
||||
'@opentelemetry/instrumentation@0.208.0':
|
||||
resolution: {integrity: sha512-Eju0L4qWcQS+oXxi6pgh7zvE2byogAkcsVv0OjHF/97iOz1N/aKE6etSGowYkie+YA1uo6DNwdSxaaNnLvcRlA==}
|
||||
'@opentelemetry/instrumentation@0.209.0':
|
||||
resolution: {integrity: sha512-Cwe863ojTCnFlxVuuhG7s6ODkAOzKsAEthKAcI4MDRYz1OmGWYnmSl4X2pbyS+hBxVTdvfZePfoEA01IjqcEyw==}
|
||||
engines: {node: ^18.19.0 || >=20.6.0}
|
||||
peerDependencies:
|
||||
'@opentelemetry/api': ^1.3.0
|
||||
@@ -1810,6 +1822,19 @@ packages:
|
||||
'@radix-ui/primitive@1.1.3':
|
||||
resolution: {integrity: sha512-JTF99U/6XIjCBo0wqkU5sK10glYe27MRRsfwoiq5zzOEZLHU3A3KCMa5X/azekYRCJ0HlwI0crAXS/5dEHTzDg==}
|
||||
|
||||
'@radix-ui/react-accordion@1.2.12':
|
||||
resolution: {integrity: sha512-T4nygeh9YE9dLRPhAHSeOZi7HBXo+0kYIPJXayZfvWOWA0+n3dESrZbjfDPUABkUNym6Hd+f2IR113To8D2GPA==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
'@types/react-dom': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc
|
||||
react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
'@types/react-dom':
|
||||
optional: true
|
||||
|
||||
'@radix-ui/react-alert-dialog@1.1.15':
|
||||
resolution: {integrity: sha512-oTVLkEw5GpdRe29BqJ0LSDFWI3qu0vR1M0mUkOQWDIUnY/QIkLpgDMWuKxP94c2NAC2LGcgVhG1ImF3jkZ5wXw==}
|
||||
peerDependencies:
|
||||
@@ -2631,7 +2656,7 @@ packages:
|
||||
'@opentelemetry/api': ^1.9.0
|
||||
'@opentelemetry/context-async-hooks': ^1.30.1 || ^2.1.0 || ^2.2.0
|
||||
'@opentelemetry/core': ^1.30.1 || ^2.1.0 || ^2.2.0
|
||||
'@opentelemetry/instrumentation': '>=0.57.1 <1'
|
||||
'@opentelemetry/instrumentation': 0.209.0
|
||||
'@opentelemetry/resources': ^1.30.1 || ^2.1.0 || ^2.2.0
|
||||
'@opentelemetry/sdk-trace-base': ^1.30.1 || ^2.1.0 || ^2.2.0
|
||||
'@opentelemetry/semantic-conventions': ^1.37.0
|
||||
@@ -4957,8 +4982,8 @@ packages:
|
||||
resolution: {integrity: sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==}
|
||||
engines: {node: '>=6'}
|
||||
|
||||
import-in-the-middle@2.0.1:
|
||||
resolution: {integrity: sha512-bruMpJ7xz+9jwGzrwEhWgvRrlKRYCRDBrfU+ur3FcasYXLJDxTruJ//8g2Noj+QFyRBeqbpj8Bhn4Fbw6HjvhA==}
|
||||
import-in-the-middle@2.0.2:
|
||||
resolution: {integrity: sha512-qet/hkGt3EbNGVtbDfPu0BM+tCqBS8wT1SYrstPaDKoWtshsC6licOemz7DVtpBEyvDNzo8UTBf9/GwWuSDZ9w==}
|
||||
|
||||
imurmurhash@0.1.4:
|
||||
resolution: {integrity: sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==}
|
||||
@@ -6502,10 +6527,6 @@ packages:
|
||||
resolution: {integrity: sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==}
|
||||
engines: {node: '>=0.10.0'}
|
||||
|
||||
require-in-the-middle@7.5.2:
|
||||
resolution: {integrity: sha512-gAZ+kLqBdHarXB64XpAe2VCjB7rIRv+mU8tfRWziHRJ5umKsIHN2tLLv6EtMw7WCdP19S0ERVMldNvxYCHnhSQ==}
|
||||
engines: {node: '>=8.6.0'}
|
||||
|
||||
require-in-the-middle@8.0.1:
|
||||
resolution: {integrity: sha512-QT7FVMXfWOYFbeRBF6nu+I6tr2Tf3u0q8RIEjNob/heKY/nh7drD/k7eeMFmSQgnTtCzLDcCu/XEnpW2wk4xCQ==}
|
||||
engines: {node: '>=9.3.0 || >=8.10.0 <9.0.0'}
|
||||
@@ -8716,7 +8737,7 @@ snapshots:
|
||||
|
||||
'@open-draft/until@2.1.0': {}
|
||||
|
||||
'@opentelemetry/api-logs@0.208.0':
|
||||
'@opentelemetry/api-logs@0.209.0':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
|
||||
@@ -8735,7 +8756,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -8743,7 +8764,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
'@types/connect': 3.4.38
|
||||
transitivePeerDependencies:
|
||||
@@ -8752,7 +8773,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-dataloader@0.26.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -8760,7 +8781,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8769,21 +8790,21 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@opentelemetry/instrumentation-generic-pool@0.52.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@opentelemetry/instrumentation-graphql@0.56.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -8791,7 +8812,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8800,7 +8821,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
forwarded-parse: 2.1.2
|
||||
transitivePeerDependencies:
|
||||
@@ -8809,7 +8830,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-ioredis@0.56.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/redis-common': 0.38.2
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8817,7 +8838,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-kafkajs@0.18.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8825,7 +8846,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-knex@0.53.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8834,7 +8855,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8842,14 +8863,14 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-lru-memoizer@0.53.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@opentelemetry/instrumentation-mongodb@0.61.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -8857,14 +8878,14 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@opentelemetry/instrumentation-mysql2@0.55.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
'@opentelemetry/sql-common': 0.41.2(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
@@ -8873,7 +8894,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-mysql@0.54.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@types/mysql': 2.15.27
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8882,7 +8903,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
'@opentelemetry/sql-common': 0.41.2(@opentelemetry/api@1.9.0)
|
||||
'@types/pg': 8.15.6
|
||||
@@ -8893,7 +8914,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-redis@0.57.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/redis-common': 0.38.2
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
@@ -8902,7 +8923,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-tedious@0.27.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@types/tedious': 4.0.14
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8911,16 +8932,16 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@opentelemetry/instrumentation@0.208.0(@opentelemetry/api@1.9.0)':
|
||||
'@opentelemetry/instrumentation@0.209.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/api-logs': 0.208.0
|
||||
import-in-the-middle: 2.0.1
|
||||
'@opentelemetry/api-logs': 0.209.0
|
||||
import-in-the-middle: 2.0.2
|
||||
require-in-the-middle: 8.0.1
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -9100,7 +9121,7 @@ snapshots:
|
||||
'@prisma/instrumentation@6.19.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -9108,6 +9129,23 @@ snapshots:
|
||||
|
||||
'@radix-ui/primitive@1.1.3': {}
|
||||
|
||||
'@radix-ui/react-accordion@1.2.12(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
|
||||
dependencies:
|
||||
'@radix-ui/primitive': 1.1.3
|
||||
'@radix-ui/react-collapsible': 1.1.12(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@radix-ui/react-collection': 1.1.7(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.17)(react@18.3.1)
|
||||
'@radix-ui/react-context': 1.1.2(@types/react@18.3.17)(react@18.3.1)
|
||||
'@radix-ui/react-direction': 1.1.1(@types/react@18.3.17)(react@18.3.1)
|
||||
'@radix-ui/react-id': 1.1.1(@types/react@18.3.17)(react@18.3.1)
|
||||
'@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@radix-ui/react-use-controllable-state': 1.2.2(@types/react@18.3.17)(react@18.3.1)
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
optionalDependencies:
|
||||
'@types/react': 18.3.17
|
||||
'@types/react-dom': 18.3.5(@types/react@18.3.17)
|
||||
|
||||
'@radix-ui/react-alert-dialog@1.1.15(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
|
||||
dependencies:
|
||||
'@radix-ui/primitive': 1.1.3
|
||||
@@ -9932,19 +9970,19 @@ snapshots:
|
||||
- supports-color
|
||||
- webpack
|
||||
|
||||
'@sentry/node-core@10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/instrumentation@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)':
|
||||
'@sentry/node-core@10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/instrumentation@0.209.0(@opentelemetry/api@1.9.0))(@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)':
|
||||
dependencies:
|
||||
'@apm-js-collab/tracing-hooks': 0.3.1
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/context-async-hooks': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/resources': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/sdk-trace-base': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
'@sentry/core': 10.27.0
|
||||
'@sentry/opentelemetry': 10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)
|
||||
import-in-the-middle: 2.0.1
|
||||
import-in-the-middle: 2.0.2
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -9953,7 +9991,7 @@ snapshots:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/context-async-hooks': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation-amqplib': 0.55.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation-connect': 0.52.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation-dataloader': 0.26.0(@opentelemetry/api@1.9.0)
|
||||
@@ -9981,9 +10019,9 @@ snapshots:
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
'@prisma/instrumentation': 6.19.0(@opentelemetry/api@1.9.0)
|
||||
'@sentry/core': 10.27.0
|
||||
'@sentry/node-core': 10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/instrumentation@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)
|
||||
'@sentry/node-core': 10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/instrumentation@0.209.0(@opentelemetry/api@1.9.0))(@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)
|
||||
'@sentry/opentelemetry': 10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)
|
||||
import-in-the-middle: 2.0.1
|
||||
import-in-the-middle: 2.0.2
|
||||
minimatch: 9.0.5
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -12792,7 +12830,7 @@ snapshots:
|
||||
parent-module: 1.0.1
|
||||
resolve-from: 4.0.0
|
||||
|
||||
import-in-the-middle@2.0.1:
|
||||
import-in-the-middle@2.0.2:
|
||||
dependencies:
|
||||
acorn: 8.15.0
|
||||
acorn-import-attributes: 1.9.5(acorn@8.15.0)
|
||||
@@ -14631,14 +14669,6 @@ snapshots:
|
||||
|
||||
require-from-string@2.0.2: {}
|
||||
|
||||
require-in-the-middle@7.5.2:
|
||||
dependencies:
|
||||
debug: 4.4.3
|
||||
module-details-from-path: 1.0.4
|
||||
resolve: 1.22.11
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
require-in-the-middle@8.0.1:
|
||||
dependencies:
|
||||
debug: 4.4.3
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput";
|
||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||
import { useState } from "react";
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useState, useMemo, useRef } from "react";
|
||||
import { AuthCard } from "@/components/auth/AuthCard";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput";
|
||||
import { useGetOauthGetOauthAppInfo } from "@/app/api/__generated__/endpoints/oauth/oauth";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { AuthCard } from "@/components/auth/AuthCard";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs";
|
||||
import type {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
CredentialsType,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { CheckIcon, CircleIcon } from "@phosphor-icons/react";
|
||||
import { useGetOauthGetOauthAppInfo } from "@/app/api/__generated__/endpoints/oauth/oauth";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useMemo, useRef, useState } from "react";
|
||||
|
||||
// All credential types - we accept any type of credential
|
||||
const ALL_CREDENTIAL_TYPES: CredentialsType[] = [
|
||||
|
||||
@@ -10,7 +10,10 @@ export const BuilderActions = memo(() => {
|
||||
flowID: parseAsString,
|
||||
});
|
||||
return (
|
||||
<div className="absolute bottom-4 left-[50%] z-[100] flex -translate-x-1/2 items-center gap-4 rounded-full bg-white p-2 px-2 shadow-lg">
|
||||
<div
|
||||
data-id="builder-actions"
|
||||
className="absolute bottom-4 left-[50%] z-[100] flex -translate-x-1/2 items-center gap-4 rounded-full bg-white p-2 px-2 shadow-lg"
|
||||
>
|
||||
<AgentOutputs flowID={flowID} />
|
||||
<RunGraph flowID={flowID} />
|
||||
<ScheduleGraph flowID={flowID} />
|
||||
|
||||
@@ -79,6 +79,7 @@ export const AgentOutputs = ({ flowID }: { flowID: string | null }) => {
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
data-id="agent-outputs-button"
|
||||
disabled={!flowID || !hasOutputs()}
|
||||
>
|
||||
<BookOpenIcon className="size-4" />
|
||||
|
||||
@@ -31,6 +31,7 @@ export const RunGraph = ({ flowID }: { flowID: string | null }) => {
|
||||
<Button
|
||||
size="icon"
|
||||
variant={isGraphRunning ? "destructive" : "primary"}
|
||||
data-id={isGraphRunning ? "stop-graph-button" : "run-graph-button"}
|
||||
onClick={isGraphRunning ? handleStopGraph : handleRunGraph}
|
||||
disabled={!flowID || isExecutingGraph || isTerminatingGraph}
|
||||
loading={isExecutingGraph || isTerminatingGraph || isSaving}
|
||||
|
||||
@@ -7,10 +7,11 @@ import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
|
||||
import { GraphExecutionMeta } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/use-agent-runs";
|
||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useSaveGraph } from "@/app/(platform)/build/hooks/useSaveGraph";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { ApiError } from "@/lib/autogpt-server-api/helpers"; // Check if this exists
|
||||
import { useTutorialStore } from "@/app/(platform)/build/stores/tutorialStore";
|
||||
|
||||
export const useRunGraph = () => {
|
||||
const { saveGraph, isSaving } = useSaveGraph({
|
||||
@@ -33,6 +34,29 @@ export const useRunGraph = () => {
|
||||
useShallow((state) => state.clearAllNodeErrors),
|
||||
);
|
||||
|
||||
// Tutorial integration - force open dialog when tutorial requests it
|
||||
const forceOpenRunInputDialog = useTutorialStore(
|
||||
(state) => state.forceOpenRunInputDialog,
|
||||
);
|
||||
const setForceOpenRunInputDialog = useTutorialStore(
|
||||
(state) => state.setForceOpenRunInputDialog,
|
||||
);
|
||||
|
||||
// Sync tutorial state with dialog state
|
||||
useEffect(() => {
|
||||
if (forceOpenRunInputDialog && !openRunInputDialog) {
|
||||
setOpenRunInputDialog(true);
|
||||
}
|
||||
}, [forceOpenRunInputDialog, openRunInputDialog]);
|
||||
|
||||
// Reset tutorial state when dialog closes
|
||||
const handleSetOpenRunInputDialog = (isOpen: boolean) => {
|
||||
setOpenRunInputDialog(isOpen);
|
||||
if (!isOpen && forceOpenRunInputDialog) {
|
||||
setForceOpenRunInputDialog(false);
|
||||
}
|
||||
};
|
||||
|
||||
const [{ flowID, flowVersion, flowExecutionID }, setQueryStates] =
|
||||
useQueryStates({
|
||||
flowID: parseAsString,
|
||||
@@ -138,6 +162,6 @@ export const useRunGraph = () => {
|
||||
isExecutingGraph,
|
||||
isTerminatingGraph,
|
||||
openRunInputDialog,
|
||||
setOpenRunInputDialog,
|
||||
setOpenRunInputDialog: handleSetOpenRunInputDialog,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -8,6 +8,8 @@ import { Text } from "@/components/atoms/Text/Text";
|
||||
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
|
||||
import { useRunInputDialog } from "./useRunInputDialog";
|
||||
import { CronSchedulerDialog } from "../CronSchedulerDialog/CronSchedulerDialog";
|
||||
import { useTutorialStore } from "@/app/(platform)/build/stores/tutorialStore";
|
||||
import { useEffect } from "react";
|
||||
|
||||
export const RunInputDialog = ({
|
||||
isOpen,
|
||||
@@ -37,6 +39,21 @@ export const RunInputDialog = ({
|
||||
isExecutingGraph,
|
||||
} = useRunInputDialog({ setIsOpen });
|
||||
|
||||
// Tutorial integration - track input values for the tutorial
|
||||
const setTutorialInputValues = useTutorialStore(
|
||||
(state) => state.setTutorialInputValues,
|
||||
);
|
||||
const isTutorialRunning = useTutorialStore(
|
||||
(state) => state.isTutorialRunning,
|
||||
);
|
||||
|
||||
// Update tutorial store when input values change
|
||||
useEffect(() => {
|
||||
if (isTutorialRunning) {
|
||||
setTutorialInputValues(inputValues);
|
||||
}
|
||||
}, [inputValues, isTutorialRunning, setTutorialInputValues]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Dialog
|
||||
@@ -48,16 +65,16 @@ export const RunInputDialog = ({
|
||||
styling={{ maxWidth: "600px", minWidth: "600px" }}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-6 p-1">
|
||||
<div className="space-y-6 p-1" data-id="run-input-dialog-content">
|
||||
{/* Credentials Section */}
|
||||
{hasCredentials() && (
|
||||
<div>
|
||||
<div data-id="run-input-credentials-section">
|
||||
<div className="mb-4">
|
||||
<Text variant="h4" className="text-gray-900">
|
||||
Credentials
|
||||
</Text>
|
||||
</div>
|
||||
<div className="px-2">
|
||||
<div className="px-2" data-id="run-input-credentials-form">
|
||||
<FormRenderer
|
||||
jsonSchema={credentialsSchema as RJSFSchema}
|
||||
handleChange={(v) => handleCredentialChange(v.formData)}
|
||||
@@ -75,13 +92,13 @@ export const RunInputDialog = ({
|
||||
|
||||
{/* Inputs Section */}
|
||||
{hasInputs() && (
|
||||
<div>
|
||||
<div data-id="run-input-inputs-section">
|
||||
<div className="mb-4">
|
||||
<Text variant="h4" className="text-gray-900">
|
||||
Inputs
|
||||
</Text>
|
||||
</div>
|
||||
<div className="px-2">
|
||||
<div data-id="run-input-inputs-form">
|
||||
<FormRenderer
|
||||
jsonSchema={inputSchema as RJSFSchema}
|
||||
handleChange={(v) => handleInputChange(v.formData)}
|
||||
@@ -97,7 +114,10 @@ export const RunInputDialog = ({
|
||||
)}
|
||||
|
||||
{/* Action Button */}
|
||||
<div className="flex justify-end pt-2">
|
||||
<div
|
||||
className="flex justify-end pt-2"
|
||||
data-id="run-input-actions-section"
|
||||
>
|
||||
{purpose === "run" && (
|
||||
<Button
|
||||
variant="primary"
|
||||
@@ -105,6 +125,7 @@ export const RunInputDialog = ({
|
||||
className="group h-fit min-w-0 gap-2"
|
||||
onClick={handleManualRun}
|
||||
loading={isExecutingGraph}
|
||||
data-id="run-input-manual-run-button"
|
||||
>
|
||||
{!isExecutingGraph && (
|
||||
<PlayIcon className="size-5 transition-transform group-hover:scale-110" />
|
||||
@@ -118,6 +139,7 @@ export const RunInputDialog = ({
|
||||
size="large"
|
||||
className="group h-fit min-w-0 gap-2"
|
||||
onClick={() => setOpenCronSchedulerDialog(true)}
|
||||
data-id="run-input-schedule-button"
|
||||
>
|
||||
<ClockIcon className="size-5 transition-transform group-hover:scale-110" />
|
||||
<span className="font-semibold">Schedule Run</span>
|
||||
|
||||
@@ -26,6 +26,7 @@ export const ScheduleGraph = ({ flowID }: { flowID: string | null }) => {
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
data-id="schedule-graph-button"
|
||||
onClick={handleScheduleGraph}
|
||||
disabled={!flowID}
|
||||
>
|
||||
|
||||
@@ -6,12 +6,17 @@ import {
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import {
|
||||
ChalkboardIcon,
|
||||
CircleNotchIcon,
|
||||
FrameCornersIcon,
|
||||
MinusIcon,
|
||||
PlusIcon,
|
||||
} from "@phosphor-icons/react/dist/ssr";
|
||||
import { LockIcon, LockOpenIcon } from "lucide-react";
|
||||
import { memo } from "react";
|
||||
import { memo, useEffect, useState } from "react";
|
||||
import { useSearchParams, useRouter } from "next/navigation";
|
||||
import { useTutorialStore } from "@/app/(platform)/build/stores/tutorialStore";
|
||||
import { startTutorial, setTutorialLoadingCallback } from "../../tutorial";
|
||||
|
||||
export const CustomControls = memo(
|
||||
({
|
||||
@@ -22,27 +27,65 @@ export const CustomControls = memo(
|
||||
setIsLocked: (isLocked: boolean) => void;
|
||||
}) => {
|
||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||
const { isTutorialRunning, setIsTutorialRunning } = useTutorialStore();
|
||||
const [isTutorialLoading, setIsTutorialLoading] = useState(false);
|
||||
const searchParams = useSearchParams();
|
||||
const router = useRouter();
|
||||
|
||||
useEffect(() => {
|
||||
setTutorialLoadingCallback(setIsTutorialLoading);
|
||||
return () => setTutorialLoadingCallback(() => {});
|
||||
}, []);
|
||||
|
||||
const handleTutorialClick = () => {
|
||||
if (isTutorialLoading) return;
|
||||
|
||||
const flowId = searchParams.get("flowID");
|
||||
if (flowId) {
|
||||
router.push("/build?view=new");
|
||||
return;
|
||||
}
|
||||
|
||||
startTutorial();
|
||||
setIsTutorialRunning(true);
|
||||
};
|
||||
|
||||
const controls = [
|
||||
{
|
||||
id: "zoom-in-button",
|
||||
icon: <PlusIcon className="size-4" />,
|
||||
label: "Zoom In",
|
||||
onClick: () => zoomIn(),
|
||||
className: "h-10 w-10 border-none",
|
||||
},
|
||||
{
|
||||
id: "zoom-out-button",
|
||||
icon: <MinusIcon className="size-4" />,
|
||||
label: "Zoom Out",
|
||||
onClick: () => zoomOut(),
|
||||
className: "h-10 w-10 border-none",
|
||||
},
|
||||
{
|
||||
id: "tutorial-button",
|
||||
icon: isTutorialLoading ? (
|
||||
<CircleNotchIcon className="size-4 animate-spin" />
|
||||
) : (
|
||||
<ChalkboardIcon className="size-4" />
|
||||
),
|
||||
label: isTutorialLoading ? "Loading Tutorial..." : "Start Tutorial",
|
||||
onClick: handleTutorialClick,
|
||||
className: `h-10 w-10 border-none ${isTutorialRunning || isTutorialLoading ? "bg-zinc-100" : "bg-white"}`,
|
||||
disabled: isTutorialLoading,
|
||||
},
|
||||
{
|
||||
id: "fit-view-button",
|
||||
icon: <FrameCornersIcon className="size-4" />,
|
||||
label: "Fit View",
|
||||
onClick: () => fitView({ padding: 0.2, duration: 800, maxZoom: 1 }),
|
||||
className: "h-10 w-10 border-none",
|
||||
},
|
||||
{
|
||||
id: "lock-button",
|
||||
icon: !isLocked ? (
|
||||
<LockOpenIcon className="size-4" />
|
||||
) : (
|
||||
@@ -55,15 +98,20 @@ export const CustomControls = memo(
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="absolute bottom-4 left-4 z-10 flex flex-col items-center gap-2 rounded-full bg-white px-1 py-2 shadow-lg">
|
||||
{controls.map((control, index) => (
|
||||
<Tooltip key={index} delayDuration={300}>
|
||||
<div
|
||||
data-id="custom-controls"
|
||||
className="absolute bottom-4 left-4 z-10 flex flex-col items-center gap-2 rounded-full bg-white px-1 py-2 shadow-lg"
|
||||
>
|
||||
{controls.map((control) => (
|
||||
<Tooltip key={control.id} delayDuration={0}>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="icon"
|
||||
size={"small"}
|
||||
onClick={control.onClick}
|
||||
className={control.className}
|
||||
data-id={control.id}
|
||||
disabled={"disabled" in control ? control.disabled : false}
|
||||
>
|
||||
{control.icon}
|
||||
<span className="sr-only">{control.label}</span>
|
||||
|
||||
@@ -3,6 +3,7 @@ import { useGetV2GetSpecificBlocks } from "@/app/api/__generated__/endpoints/def
|
||||
import {
|
||||
useGetV1GetExecutionDetails,
|
||||
useGetV1GetSpecificGraph,
|
||||
useGetV1ListUserGraphs,
|
||||
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||
@@ -17,6 +18,7 @@ import { useReactFlow } from "@xyflow/react";
|
||||
import { useControlPanelStore } from "../../../stores/controlPanelStore";
|
||||
import { useHistoryStore } from "../../../stores/historyStore";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
|
||||
export const useFlow = () => {
|
||||
const [isLocked, setIsLocked] = useState(false);
|
||||
@@ -36,6 +38,9 @@ export const useFlow = () => {
|
||||
const setGraphExecutionStatus = useGraphStore(
|
||||
useShallow((state) => state.setGraphExecutionStatus),
|
||||
);
|
||||
const setAvailableSubGraphs = useGraphStore(
|
||||
useShallow((state) => state.setAvailableSubGraphs),
|
||||
);
|
||||
const updateEdgeBeads = useEdgeStore(
|
||||
useShallow((state) => state.updateEdgeBeads),
|
||||
);
|
||||
@@ -62,6 +67,11 @@ export const useFlow = () => {
|
||||
},
|
||||
);
|
||||
|
||||
// Fetch all available graphs for sub-agent update detection
|
||||
const { data: availableGraphs } = useGetV1ListUserGraphs({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const { data: graph, isLoading: isGraphLoading } = useGetV1GetSpecificGraph(
|
||||
flowID ?? "",
|
||||
flowVersion !== null ? { version: flowVersion } : {},
|
||||
@@ -116,10 +126,18 @@ export const useFlow = () => {
|
||||
}
|
||||
}, [graph]);
|
||||
|
||||
// Update available sub-graphs in store for sub-agent update detection
|
||||
useEffect(() => {
|
||||
if (availableGraphs) {
|
||||
setAvailableSubGraphs(availableGraphs);
|
||||
}
|
||||
}, [availableGraphs, setAvailableSubGraphs]);
|
||||
|
||||
// adding nodes
|
||||
useEffect(() => {
|
||||
if (customNodes.length > 0) {
|
||||
useNodeStore.getState().setNodes([]);
|
||||
useNodeStore.getState().clearResolutionState();
|
||||
addNodes(customNodes);
|
||||
|
||||
// Sync hardcoded values with handle IDs.
|
||||
@@ -203,6 +221,7 @@ export const useFlow = () => {
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
useNodeStore.getState().setNodes([]);
|
||||
useNodeStore.getState().clearResolutionState();
|
||||
useEdgeStore.getState().setEdges([]);
|
||||
useGraphStore.getState().reset();
|
||||
useEdgeStore.getState().resetEdgeBeads();
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
getBezierPath,
|
||||
} from "@xyflow/react";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { XIcon } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { NodeExecutionResult } from "@/lib/autogpt-server-api";
|
||||
@@ -35,6 +36,8 @@ const CustomEdge = ({
|
||||
selected,
|
||||
}: EdgeProps<CustomEdge>) => {
|
||||
const removeConnection = useEdgeStore((state) => state.removeEdge);
|
||||
// Subscribe to the brokenEdgeIDs map and check if this edge is broken across any node
|
||||
const isBroken = useNodeStore((state) => state.isEdgeBroken(id));
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
@@ -50,6 +53,12 @@ const CustomEdge = ({
|
||||
const beadUp = data?.beadUp ?? 0;
|
||||
const beadDown = data?.beadDown ?? 0;
|
||||
|
||||
const handleRemoveEdge = () => {
|
||||
removeConnection(id);
|
||||
// Note: broken edge tracking is cleaned up automatically by useSubAgentUpdateState
|
||||
// when it detects the edge no longer exists
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<BaseEdge
|
||||
@@ -57,9 +66,11 @@ const CustomEdge = ({
|
||||
markerEnd={markerEnd}
|
||||
className={cn(
|
||||
isStatic && "!stroke-[1.5px] [stroke-dasharray:6]",
|
||||
selected
|
||||
? "stroke-zinc-800"
|
||||
: "stroke-zinc-500/50 hover:stroke-zinc-500",
|
||||
isBroken
|
||||
? "!stroke-red-500 !stroke-[2px] [stroke-dasharray:4]"
|
||||
: selected
|
||||
? "stroke-zinc-800"
|
||||
: "stroke-zinc-500/50 hover:stroke-zinc-500",
|
||||
)}
|
||||
/>
|
||||
<JSBeads
|
||||
@@ -70,12 +81,16 @@ const CustomEdge = ({
|
||||
/>
|
||||
<EdgeLabelRenderer>
|
||||
<Button
|
||||
onClick={() => removeConnection(id)}
|
||||
onClick={handleRemoveEdge}
|
||||
className={cn(
|
||||
"absolute h-fit min-w-0 p-1 transition-opacity",
|
||||
isHovered ? "opacity-100" : "opacity-0",
|
||||
isBroken
|
||||
? "bg-red-500 opacity-100 hover:bg-red-600"
|
||||
: isHovered
|
||||
? "opacity-100"
|
||||
: "opacity-0",
|
||||
)}
|
||||
variant="secondary"
|
||||
variant={isBroken ? "primary" : "secondary"}
|
||||
style={{
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
||||
pointerEvents: "all",
|
||||
|
||||
@@ -3,6 +3,7 @@ import { Handle, Position } from "@xyflow/react";
|
||||
import { useEdgeStore } from "../../../stores/edgeStore";
|
||||
import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
|
||||
const InputNodeHandle = ({
|
||||
handleId,
|
||||
@@ -15,6 +16,9 @@ const InputNodeHandle = ({
|
||||
const isInputConnected = useEdgeStore((state) =>
|
||||
state.isInputConnected(nodeId ?? "", cleanedHandleId),
|
||||
);
|
||||
const isInputBroken = useNodeStore((state) =>
|
||||
state.isInputBroken(nodeId, cleanedHandleId),
|
||||
);
|
||||
|
||||
return (
|
||||
<Handle
|
||||
@@ -22,12 +26,16 @@ const InputNodeHandle = ({
|
||||
position={Position.Left}
|
||||
id={cleanedHandleId}
|
||||
className={"-ml-6 mr-2"}
|
||||
data-tutorial-id={`input-handler-${nodeId}-${cleanedHandleId}`}
|
||||
>
|
||||
<div className="pointer-events-none">
|
||||
<CircleIcon
|
||||
size={16}
|
||||
weight={isInputConnected ? "fill" : "duotone"}
|
||||
className={"text-gray-400 opacity-100"}
|
||||
className={cn(
|
||||
"text-gray-400 opacity-100",
|
||||
isInputBroken && "text-red-500",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</Handle>
|
||||
@@ -38,27 +46,34 @@ const OutputNodeHandle = ({
|
||||
field_name,
|
||||
nodeId,
|
||||
hexColor,
|
||||
isBroken,
|
||||
}: {
|
||||
field_name: string;
|
||||
nodeId: string;
|
||||
hexColor: string;
|
||||
isBroken: boolean;
|
||||
}) => {
|
||||
const isOutputConnected = useEdgeStore((state) =>
|
||||
state.isOutputConnected(nodeId, field_name),
|
||||
);
|
||||
|
||||
return (
|
||||
<Handle
|
||||
type={"source"}
|
||||
position={Position.Right}
|
||||
id={field_name}
|
||||
className={"-mr-2 ml-2"}
|
||||
data-tutorial-id={`output-handler-${nodeId}-${field_name}`}
|
||||
>
|
||||
<div className="pointer-events-none">
|
||||
<CircleIcon
|
||||
size={16}
|
||||
weight={"duotone"}
|
||||
color={isOutputConnected ? hexColor : "gray"}
|
||||
className={cn("text-gray-400 opacity-100")}
|
||||
className={cn(
|
||||
"text-gray-400 opacity-100",
|
||||
isBroken && "text-red-500",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</Handle>
|
||||
|
||||
@@ -20,6 +20,8 @@ import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput";
|
||||
import { NodeRightClickMenu } from "./components/NodeRightClickMenu";
|
||||
import { StickyNoteBlock } from "./components/StickyNoteBlock";
|
||||
import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
|
||||
import { SubAgentUpdateFeature } from "./components/SubAgentUpdate/SubAgentUpdateFeature";
|
||||
import { useCustomNode } from "./useCustomNode";
|
||||
|
||||
export type CustomNodeData = {
|
||||
hardcodedValues: {
|
||||
@@ -45,6 +47,10 @@ export type CustomNode = XYNode<CustomNodeData, "custom">;
|
||||
|
||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
({ data, id: nodeId, selected }) => {
|
||||
const { inputSchema, outputSchema } = useCustomNode({ data, nodeId });
|
||||
|
||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||
|
||||
if (data.uiType === BlockUIType.NOTE) {
|
||||
return (
|
||||
<StickyNoteBlock data={data} selected={selected} nodeId={nodeId} />
|
||||
@@ -63,16 +69,6 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
|
||||
const isAyrshare = data.uiType === BlockUIType.AYRSHARE;
|
||||
|
||||
const inputSchema =
|
||||
data.uiType === BlockUIType.AGENT
|
||||
? (data.hardcodedValues.input_schema ?? {})
|
||||
: data.inputSchema;
|
||||
|
||||
const outputSchema =
|
||||
data.uiType === BlockUIType.AGENT
|
||||
? (data.hardcodedValues.output_schema ?? {})
|
||||
: data.outputSchema;
|
||||
|
||||
const hasConfigErrors =
|
||||
data.errors &&
|
||||
Object.values(data.errors).some(
|
||||
@@ -87,12 +83,11 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
|
||||
const hasErrors = hasConfigErrors || hasOutputError;
|
||||
|
||||
// Currently all blockTypes design are similar - that's why i am using the same component for all of them
|
||||
// If in future - if we need some drastic change in some blockTypes design - we can create separate components for them
|
||||
const node = (
|
||||
<NodeContainer selected={selected} nodeId={nodeId} hasErrors={hasErrors}>
|
||||
<div className="rounded-xlarge bg-white">
|
||||
<NodeHeader data={data} nodeId={nodeId} />
|
||||
{isAgent && <SubAgentUpdateFeature nodeID={nodeId} nodeData={data} />}
|
||||
{isWebhook && <WebhookDisclaimer nodeId={nodeId} />}
|
||||
{isAyrshare && <AyrshareConnectButton />}
|
||||
<FormCreator
|
||||
|
||||
@@ -27,6 +27,7 @@ export const NodeContainer = ({
|
||||
status && nodeStyleBasedOnStatus[status],
|
||||
hasErrors ? nodeStyleBasedOnStatus[AgentExecutionStatus.FAILED] : "",
|
||||
)}
|
||||
data-id={`custom-node-${nodeId}`}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
|
||||
@@ -23,7 +23,10 @@ export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-3 rounded-b-xl border-t border-zinc-200 px-4 py-4">
|
||||
<div
|
||||
data-tutorial-id={`node-output`}
|
||||
className="flex flex-col gap-3 rounded-b-xl border-t border-zinc-200 px-4 py-4"
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<Text variant="body-medium" className="!font-semibold text-slate-700">
|
||||
Node Output
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
import React from "react";
|
||||
import { ArrowUpIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { cn, beautifyString } from "@/lib/utils";
|
||||
import { CustomNodeData } from "../../CustomNode";
|
||||
import { useSubAgentUpdateState } from "./useSubAgentUpdateState";
|
||||
import { IncompatibleUpdateDialog } from "./components/IncompatibleUpdateDialog";
|
||||
import { ResolutionModeBar } from "./components/ResolutionModeBar";
|
||||
|
||||
/**
|
||||
* Inline component for the update bar that can be placed after the header.
|
||||
* Use this inside the node content where you want the bar to appear.
|
||||
*/
|
||||
type SubAgentUpdateFeatureProps = {
|
||||
nodeID: string;
|
||||
nodeData: CustomNodeData;
|
||||
};
|
||||
|
||||
export function SubAgentUpdateFeature({
|
||||
nodeID,
|
||||
nodeData,
|
||||
}: SubAgentUpdateFeatureProps) {
|
||||
const {
|
||||
updateInfo,
|
||||
isInResolutionMode,
|
||||
handleUpdateClick,
|
||||
showIncompatibilityDialog,
|
||||
setShowIncompatibilityDialog,
|
||||
handleConfirmIncompatibleUpdate,
|
||||
} = useSubAgentUpdateState({ nodeID: nodeID, nodeData: nodeData });
|
||||
|
||||
const agentName = nodeData.title || "Agent";
|
||||
|
||||
if (!updateInfo.hasUpdate && !isInResolutionMode) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{isInResolutionMode ? (
|
||||
<ResolutionModeBar incompatibilities={updateInfo.incompatibilities} />
|
||||
) : (
|
||||
<SubAgentUpdateAvailableBar
|
||||
currentVersion={updateInfo.currentVersion}
|
||||
latestVersion={updateInfo.latestVersion}
|
||||
isCompatible={updateInfo.isCompatible}
|
||||
onUpdate={handleUpdateClick}
|
||||
/>
|
||||
)}
|
||||
{/* Incompatibility dialog - rendered here since this component owns the state */}
|
||||
{updateInfo.incompatibilities && (
|
||||
<IncompatibleUpdateDialog
|
||||
isOpen={showIncompatibilityDialog}
|
||||
onClose={() => setShowIncompatibilityDialog(false)}
|
||||
onConfirm={handleConfirmIncompatibleUpdate}
|
||||
currentVersion={updateInfo.currentVersion}
|
||||
latestVersion={updateInfo.latestVersion}
|
||||
agentName={beautifyString(agentName)}
|
||||
incompatibilities={updateInfo.incompatibilities}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
type SubAgentUpdateAvailableBarProps = {
|
||||
currentVersion: number;
|
||||
latestVersion: number;
|
||||
isCompatible: boolean;
|
||||
onUpdate: () => void;
|
||||
};
|
||||
|
||||
function SubAgentUpdateAvailableBar({
|
||||
currentVersion,
|
||||
latestVersion,
|
||||
isCompatible,
|
||||
onUpdate,
|
||||
}: SubAgentUpdateAvailableBarProps): React.ReactElement {
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 rounded-t-xl bg-blue-50 px-3 py-2 dark:bg-blue-900/30">
|
||||
<div className="flex items-center gap-2">
|
||||
<ArrowUpIcon className="h-4 w-4 text-blue-600 dark:text-blue-400" />
|
||||
<span className="text-sm text-blue-700 dark:text-blue-300">
|
||||
Update available (v{currentVersion} → v{latestVersion})
|
||||
</span>
|
||||
{!isCompatible && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<WarningIcon className="h-4 w-4 text-amber-500" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="max-w-xs">
|
||||
<p className="font-medium">Incompatible changes detected</p>
|
||||
<p className="text-xs text-gray-400">
|
||||
Click Update to see details
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
size="small"
|
||||
variant={isCompatible ? "primary" : "outline"}
|
||||
onClick={onUpdate}
|
||||
className={cn(
|
||||
"h-7 text-xs",
|
||||
!isCompatible && "border-amber-500 text-amber-600 hover:bg-amber-50",
|
||||
)}
|
||||
>
|
||||
Update
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
import React from "react";
|
||||
import {
|
||||
WarningIcon,
|
||||
XCircleIcon,
|
||||
PlusCircleIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import { IncompatibilityInfo } from "@/app/(platform)/build/hooks/useSubAgentUpdate/types";
|
||||
|
||||
type IncompatibleUpdateDialogProps = {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
onConfirm: () => void;
|
||||
currentVersion: number;
|
||||
latestVersion: number;
|
||||
agentName: string;
|
||||
incompatibilities: IncompatibilityInfo;
|
||||
};
|
||||
|
||||
export function IncompatibleUpdateDialog({
|
||||
isOpen,
|
||||
onClose,
|
||||
onConfirm,
|
||||
currentVersion,
|
||||
latestVersion,
|
||||
agentName,
|
||||
incompatibilities,
|
||||
}: IncompatibleUpdateDialogProps) {
|
||||
const hasMissingInputs = incompatibilities.missingInputs.length > 0;
|
||||
const hasMissingOutputs = incompatibilities.missingOutputs.length > 0;
|
||||
const hasNewInputs = incompatibilities.newInputs.length > 0;
|
||||
const hasNewOutputs = incompatibilities.newOutputs.length > 0;
|
||||
const hasNewRequired = incompatibilities.newRequiredInputs.length > 0;
|
||||
const hasTypeMismatches = incompatibilities.inputTypeMismatches.length > 0;
|
||||
|
||||
const hasInputChanges = hasMissingInputs || hasNewInputs;
|
||||
const hasOutputChanges = hasMissingOutputs || hasNewOutputs;
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title={
|
||||
<div className="flex items-center gap-2">
|
||||
<WarningIcon className="h-5 w-5 text-amber-500" weight="fill" />
|
||||
Incompatible Update
|
||||
</div>
|
||||
}
|
||||
controlled={{
|
||||
isOpen,
|
||||
set: async (open) => {
|
||||
if (!open) onClose();
|
||||
},
|
||||
}}
|
||||
onClose={onClose}
|
||||
styling={{ maxWidth: "32rem" }}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-4">
|
||||
<p className="text-sm text-gray-600 dark:text-gray-400">
|
||||
Updating <strong>{beautifyString(agentName)}</strong> from v
|
||||
{currentVersion} to v{latestVersion} will break some connections.
|
||||
</p>
|
||||
|
||||
{/* Input changes - two column layout */}
|
||||
{hasInputChanges && (
|
||||
<TwoColumnSection
|
||||
title="Input Changes"
|
||||
leftIcon={
|
||||
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
||||
}
|
||||
leftTitle="Removed"
|
||||
leftItems={incompatibilities.missingInputs}
|
||||
rightIcon={
|
||||
<PlusCircleIcon
|
||||
className="h-4 w-4 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
}
|
||||
rightTitle="Added"
|
||||
rightItems={incompatibilities.newInputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Output changes - two column layout */}
|
||||
{hasOutputChanges && (
|
||||
<TwoColumnSection
|
||||
title="Output Changes"
|
||||
leftIcon={
|
||||
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
||||
}
|
||||
leftTitle="Removed"
|
||||
leftItems={incompatibilities.missingOutputs}
|
||||
rightIcon={
|
||||
<PlusCircleIcon
|
||||
className="h-4 w-4 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
}
|
||||
rightTitle="Added"
|
||||
rightItems={incompatibilities.newOutputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasTypeMismatches && (
|
||||
<SingleColumnSection
|
||||
icon={
|
||||
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
||||
}
|
||||
title="Type Changed"
|
||||
description="These connected inputs have a different type:"
|
||||
items={incompatibilities.inputTypeMismatches.map(
|
||||
(m) => `${m.name} (${m.oldType} → ${m.newType})`,
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasNewRequired && (
|
||||
<SingleColumnSection
|
||||
icon={
|
||||
<PlusCircleIcon
|
||||
className="h-4 w-4 text-amber-500"
|
||||
weight="fill"
|
||||
/>
|
||||
}
|
||||
title="New Required Inputs"
|
||||
description="These inputs are now required:"
|
||||
items={incompatibilities.newRequiredInputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Alert variant="warning">
|
||||
<AlertDescription>
|
||||
If you proceed, you'll need to remove the broken connections
|
||||
before you can save or run your agent.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button variant="ghost" size="small" onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onConfirm}
|
||||
className="border-amber-700 bg-amber-600 hover:bg-amber-700"
|
||||
>
|
||||
Update Anyway
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
type TwoColumnSectionProps = {
|
||||
title: string;
|
||||
leftIcon: React.ReactNode;
|
||||
leftTitle: string;
|
||||
leftItems: string[];
|
||||
rightIcon: React.ReactNode;
|
||||
rightTitle: string;
|
||||
rightItems: string[];
|
||||
};
|
||||
|
||||
function TwoColumnSection({
|
||||
title,
|
||||
leftIcon,
|
||||
leftTitle,
|
||||
leftItems,
|
||||
rightIcon,
|
||||
rightTitle,
|
||||
rightItems,
|
||||
}: TwoColumnSectionProps) {
|
||||
return (
|
||||
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||
<span className="font-medium">{title}</span>
|
||||
<div className="mt-2 grid grid-cols-2 items-start gap-4">
|
||||
{/* Left column - Breaking changes */}
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||
{leftIcon}
|
||||
<span>{leftTitle}</span>
|
||||
</div>
|
||||
<ul className="mt-1.5 space-y-1">
|
||||
{leftItems.length > 0 ? (
|
||||
leftItems.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-red-50 px-1 py-0.5 font-mono text-xs text-red-700 dark:bg-red-900/30 dark:text-red-300">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))
|
||||
) : (
|
||||
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||
None
|
||||
</li>
|
||||
)}
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
{/* Right column - Possible solutions */}
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||
{rightIcon}
|
||||
<span>{rightTitle}</span>
|
||||
</div>
|
||||
<ul className="mt-1.5 space-y-1">
|
||||
{rightItems.length > 0 ? (
|
||||
rightItems.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-green-50 px-1 py-0.5 font-mono text-xs text-green-700 dark:bg-green-900/30 dark:text-green-300">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))
|
||||
) : (
|
||||
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||
None
|
||||
</li>
|
||||
)}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type SingleColumnSectionProps = {
|
||||
icon: React.ReactNode;
|
||||
title: string;
|
||||
description: string;
|
||||
items: string[];
|
||||
};
|
||||
|
||||
function SingleColumnSection({
|
||||
icon,
|
||||
title,
|
||||
description,
|
||||
items,
|
||||
}: SingleColumnSectionProps) {
|
||||
return (
|
||||
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||
<div className="flex items-center gap-2">
|
||||
{icon}
|
||||
<span className="font-medium">{title}</span>
|
||||
</div>
|
||||
<p className="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||
{description}
|
||||
</p>
|
||||
<ul className="mt-2 space-y-1">
|
||||
{items.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="ml-4 list-disc text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-gray-100 px-1 py-0.5 font-mono text-xs dark:bg-gray-800">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
import React from "react";
|
||||
import { InfoIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { IncompatibilityInfo } from "@/app/(platform)/build/hooks/useSubAgentUpdate/types";
|
||||
|
||||
type ResolutionModeBarProps = {
|
||||
incompatibilities: IncompatibilityInfo | null;
|
||||
};
|
||||
|
||||
export function ResolutionModeBar({
|
||||
incompatibilities,
|
||||
}: ResolutionModeBarProps): React.ReactElement {
|
||||
const renderIncompatibilities = () => {
|
||||
if (!incompatibilities) return <span>No incompatibilities</span>;
|
||||
|
||||
const sections: React.ReactNode[] = [];
|
||||
|
||||
if (incompatibilities.missingInputs.length > 0) {
|
||||
sections.push(
|
||||
<div key="missing-inputs" className="mb-1">
|
||||
<span className="font-semibold">Missing inputs: </span>
|
||||
{incompatibilities.missingInputs.map((name, i) => (
|
||||
<React.Fragment key={name}>
|
||||
<code className="font-mono">{name}</code>
|
||||
{i < incompatibilities.missingInputs.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.missingOutputs.length > 0) {
|
||||
sections.push(
|
||||
<div key="missing-outputs" className="mb-1">
|
||||
<span className="font-semibold">Missing outputs: </span>
|
||||
{incompatibilities.missingOutputs.map((name, i) => (
|
||||
<React.Fragment key={name}>
|
||||
<code className="font-mono">{name}</code>
|
||||
{i < incompatibilities.missingOutputs.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.newRequiredInputs.length > 0) {
|
||||
sections.push(
|
||||
<div key="new-required" className="mb-1">
|
||||
<span className="font-semibold">New required inputs: </span>
|
||||
{incompatibilities.newRequiredInputs.map((name, i) => (
|
||||
<React.Fragment key={name}>
|
||||
<code className="font-mono">{name}</code>
|
||||
{i < incompatibilities.newRequiredInputs.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.inputTypeMismatches.length > 0) {
|
||||
sections.push(
|
||||
<div key="type-mismatches" className="mb-1">
|
||||
<span className="font-semibold">Type changed: </span>
|
||||
{incompatibilities.inputTypeMismatches.map((m, i) => (
|
||||
<React.Fragment key={m.name}>
|
||||
<code className="font-mono">{m.name}</code>
|
||||
<span className="text-gray-400">
|
||||
{" "}
|
||||
({m.oldType} → {m.newType})
|
||||
</span>
|
||||
{i < incompatibilities.inputTypeMismatches.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
|
||||
return <>{sections}</>;
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 rounded-t-xl bg-amber-50 px-3 py-2 dark:bg-amber-900/30">
|
||||
<div className="flex items-center gap-2">
|
||||
<WarningIcon className="h-4 w-4 text-amber-600 dark:text-amber-400" />
|
||||
<span className="text-sm text-amber-700 dark:text-amber-300">
|
||||
Remove incompatible connections
|
||||
</span>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<InfoIcon className="h-4 w-4 cursor-help text-amber-500" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="max-w-sm">
|
||||
<p className="mb-2 font-semibold">Incompatible changes:</p>
|
||||
<div className="text-xs">{renderIncompatibilities()}</div>
|
||||
<p className="mt-2 text-xs text-gray-400">
|
||||
{(incompatibilities?.newRequiredInputs.length ?? 0) > 0
|
||||
? "Replace / delete"
|
||||
: "Delete"}{" "}
|
||||
the red connections to continue
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,194 @@
|
||||
import { useState, useCallback, useEffect } from "react";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import {
|
||||
useNodeStore,
|
||||
NodeResolutionData,
|
||||
} from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import {
|
||||
useSubAgentUpdate,
|
||||
createUpdatedAgentNodeInputs,
|
||||
getBrokenEdgeIDs,
|
||||
} from "@/app/(platform)/build/hooks/useSubAgentUpdate";
|
||||
import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api";
|
||||
import { CustomNodeData } from "../../CustomNode";
|
||||
|
||||
// Stable empty set to avoid creating new references in selectors
|
||||
const EMPTY_SET: Set<string> = new Set();
|
||||
|
||||
type UseSubAgentUpdateParams = {
|
||||
nodeID: string;
|
||||
nodeData: CustomNodeData;
|
||||
};
|
||||
|
||||
export function useSubAgentUpdateState({
|
||||
nodeID,
|
||||
nodeData,
|
||||
}: UseSubAgentUpdateParams) {
|
||||
const [showIncompatibilityDialog, setShowIncompatibilityDialog] =
|
||||
useState(false);
|
||||
|
||||
// Get store actions
|
||||
const updateNodeData = useNodeStore(
|
||||
useShallow((state) => state.updateNodeData),
|
||||
);
|
||||
const setNodeResolutionMode = useNodeStore(
|
||||
useShallow((state) => state.setNodeResolutionMode),
|
||||
);
|
||||
const isNodeInResolutionMode = useNodeStore(
|
||||
useShallow((state) => state.isNodeInResolutionMode),
|
||||
);
|
||||
const setBrokenEdgeIDs = useNodeStore(
|
||||
useShallow((state) => state.setBrokenEdgeIDs),
|
||||
);
|
||||
// Get this node's broken edge IDs from the per-node map
|
||||
// Use EMPTY_SET as fallback to maintain referential stability
|
||||
const brokenEdgeIDs = useNodeStore(
|
||||
(state) => state.brokenEdgeIDs.get(nodeID) || EMPTY_SET,
|
||||
);
|
||||
const getNodeResolutionData = useNodeStore(
|
||||
useShallow((state) => state.getNodeResolutionData),
|
||||
);
|
||||
const connectedEdges = useEdgeStore(
|
||||
useShallow((state) => state.getNodeEdges(nodeID)),
|
||||
);
|
||||
const availableSubGraphs = useGraphStore(
|
||||
useShallow((state) => state.availableSubGraphs),
|
||||
);
|
||||
|
||||
// Extract agent-specific data
|
||||
const graphID = nodeData.hardcodedValues?.graph_id as string | undefined;
|
||||
const graphVersion = nodeData.hardcodedValues?.graph_version as
|
||||
| number
|
||||
| undefined;
|
||||
const currentInputSchema = nodeData.hardcodedValues?.input_schema as
|
||||
| GraphInputSchema
|
||||
| undefined;
|
||||
const currentOutputSchema = nodeData.hardcodedValues?.output_schema as
|
||||
| GraphOutputSchema
|
||||
| undefined;
|
||||
|
||||
// Use the sub-agent update hook
|
||||
const updateInfo = useSubAgentUpdate(
|
||||
nodeID,
|
||||
graphID,
|
||||
graphVersion,
|
||||
currentInputSchema,
|
||||
currentOutputSchema,
|
||||
connectedEdges,
|
||||
availableSubGraphs,
|
||||
);
|
||||
|
||||
const isInResolutionMode = isNodeInResolutionMode(nodeID);
|
||||
|
||||
// Handle update button click
|
||||
const handleUpdateClick = useCallback(() => {
|
||||
if (!updateInfo.hasUpdate || !updateInfo.latestGraph) return;
|
||||
|
||||
if (updateInfo.isCompatible) {
|
||||
// Compatible update - apply directly
|
||||
const newHardcodedValues = createUpdatedAgentNodeInputs(
|
||||
nodeData.hardcodedValues,
|
||||
updateInfo.latestGraph,
|
||||
);
|
||||
updateNodeData(nodeID, { hardcodedValues: newHardcodedValues });
|
||||
} else {
|
||||
// Incompatible update - show dialog
|
||||
setShowIncompatibilityDialog(true);
|
||||
}
|
||||
}, [
|
||||
updateInfo.hasUpdate,
|
||||
updateInfo.latestGraph,
|
||||
updateInfo.isCompatible,
|
||||
nodeData.hardcodedValues,
|
||||
updateNodeData,
|
||||
nodeID,
|
||||
]);
|
||||
|
||||
// Handle confirming an incompatible update
|
||||
function handleConfirmIncompatibleUpdate() {
|
||||
if (!updateInfo.latestGraph || !updateInfo.incompatibilities) return;
|
||||
|
||||
const latestGraph = updateInfo.latestGraph;
|
||||
|
||||
// Get the new schemas from the latest graph version
|
||||
const newInputSchema =
|
||||
(latestGraph.input_schema as Record<string, unknown>) || {};
|
||||
const newOutputSchema =
|
||||
(latestGraph.output_schema as Record<string, unknown>) || {};
|
||||
|
||||
// Create the updated hardcoded values but DON'T apply them yet
|
||||
// We'll apply them when resolution is complete
|
||||
const pendingHardcodedValues = createUpdatedAgentNodeInputs(
|
||||
nodeData.hardcodedValues,
|
||||
latestGraph,
|
||||
);
|
||||
|
||||
// Get broken edge IDs and store them for this node
|
||||
const brokenIds = getBrokenEdgeIDs(
|
||||
connectedEdges,
|
||||
updateInfo.incompatibilities,
|
||||
nodeID,
|
||||
);
|
||||
setBrokenEdgeIDs(nodeID, brokenIds);
|
||||
|
||||
// Enter resolution mode with both old and new schemas
|
||||
// DON'T apply the update yet - keep old schema so connections remain visible
|
||||
const resolutionData: NodeResolutionData = {
|
||||
incompatibilities: updateInfo.incompatibilities,
|
||||
pendingUpdate: {
|
||||
input_schema: newInputSchema,
|
||||
output_schema: newOutputSchema,
|
||||
},
|
||||
currentSchema: {
|
||||
input_schema: (currentInputSchema as Record<string, unknown>) || {},
|
||||
output_schema: (currentOutputSchema as Record<string, unknown>) || {},
|
||||
},
|
||||
pendingHardcodedValues,
|
||||
};
|
||||
setNodeResolutionMode(nodeID, true, resolutionData);
|
||||
|
||||
setShowIncompatibilityDialog(false);
|
||||
}
|
||||
|
||||
// Check if resolution is complete (all broken edges removed)
|
||||
const resolutionData = getNodeResolutionData(nodeID);
|
||||
|
||||
// Auto-check resolution on edge changes
|
||||
useEffect(() => {
|
||||
if (!isInResolutionMode) return;
|
||||
|
||||
// Check if any broken edges still exist
|
||||
const remainingBroken = Array.from(brokenEdgeIDs).filter((edgeId) =>
|
||||
connectedEdges.some((e) => e.id === edgeId),
|
||||
);
|
||||
|
||||
if (remainingBroken.length === 0) {
|
||||
// Resolution complete - now apply the pending update
|
||||
if (resolutionData?.pendingHardcodedValues) {
|
||||
updateNodeData(nodeID, {
|
||||
hardcodedValues: resolutionData.pendingHardcodedValues,
|
||||
});
|
||||
}
|
||||
// setNodeResolutionMode will clean up this node's broken edges automatically
|
||||
setNodeResolutionMode(nodeID, false);
|
||||
}
|
||||
}, [
|
||||
isInResolutionMode,
|
||||
brokenEdgeIDs,
|
||||
connectedEdges,
|
||||
resolutionData,
|
||||
nodeID,
|
||||
]);
|
||||
|
||||
return {
|
||||
updateInfo,
|
||||
isInResolutionMode,
|
||||
resolutionData,
|
||||
showIncompatibilityDialog,
|
||||
setShowIncompatibilityDialog,
|
||||
handleUpdateClick,
|
||||
handleConfirmIncompatibleUpdate,
|
||||
};
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { NodeResolutionData } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
|
||||
export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
||||
INCOMPLETE: "ring-slate-300 bg-slate-300",
|
||||
@@ -9,3 +11,48 @@ export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
||||
TERMINATED: "ring-orange-300 bg-orange-300 ",
|
||||
FAILED: "ring-red-300 bg-red-300",
|
||||
};
|
||||
|
||||
/**
|
||||
* Merges schemas during resolution mode to include removed inputs/outputs
|
||||
* that still have connections, so users can see and delete them.
|
||||
*/
|
||||
export function mergeSchemaForResolution(
|
||||
currentSchema: Record<string, unknown>,
|
||||
newSchema: Record<string, unknown>,
|
||||
resolutionData: NodeResolutionData,
|
||||
type: "input" | "output",
|
||||
): Record<string, unknown> {
|
||||
const newProps = (newSchema.properties as RJSFSchema) || {};
|
||||
const currentProps = (currentSchema.properties as RJSFSchema) || {};
|
||||
const mergedProps = { ...newProps };
|
||||
const incomp = resolutionData.incompatibilities;
|
||||
|
||||
if (type === "input") {
|
||||
// Add back missing inputs that have connections
|
||||
incomp.missingInputs.forEach((inputName: string) => {
|
||||
if (currentProps[inputName]) {
|
||||
mergedProps[inputName] = currentProps[inputName];
|
||||
}
|
||||
});
|
||||
// Add back inputs with type mismatches (keep old type so connection works visually)
|
||||
incomp.inputTypeMismatches.forEach(
|
||||
(mismatch: { name: string; oldType: string; newType: string }) => {
|
||||
if (currentProps[mismatch.name]) {
|
||||
mergedProps[mismatch.name] = currentProps[mismatch.name];
|
||||
}
|
||||
},
|
||||
);
|
||||
} else {
|
||||
// Add back missing outputs that have connections
|
||||
incomp.missingOutputs.forEach((outputName: string) => {
|
||||
if (currentProps[outputName]) {
|
||||
mergedProps[outputName] = currentProps[outputName];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
...newSchema,
|
||||
properties: mergedProps,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { CustomNodeData } from "./CustomNode";
|
||||
import { BlockUIType } from "../../../types";
|
||||
import { useMemo } from "react";
|
||||
import { mergeSchemaForResolution } from "./helpers";
|
||||
|
||||
export const useCustomNode = ({
|
||||
data,
|
||||
nodeId,
|
||||
}: {
|
||||
data: CustomNodeData;
|
||||
nodeId: string;
|
||||
}) => {
|
||||
const isInResolutionMode = useNodeStore((state) =>
|
||||
state.nodesInResolutionMode.has(nodeId),
|
||||
);
|
||||
const resolutionData = useNodeStore((state) =>
|
||||
state.nodeResolutionData.get(nodeId),
|
||||
);
|
||||
|
||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||
|
||||
const currentInputSchema = isAgent
|
||||
? (data.hardcodedValues.input_schema ?? {})
|
||||
: data.inputSchema;
|
||||
const currentOutputSchema = isAgent
|
||||
? (data.hardcodedValues.output_schema ?? {})
|
||||
: data.outputSchema;
|
||||
|
||||
const inputSchema = useMemo(() => {
|
||||
if (isAgent && isInResolutionMode && resolutionData) {
|
||||
return mergeSchemaForResolution(
|
||||
resolutionData.currentSchema.input_schema,
|
||||
resolutionData.pendingUpdate.input_schema,
|
||||
resolutionData,
|
||||
"input",
|
||||
);
|
||||
}
|
||||
return currentInputSchema;
|
||||
}, [isAgent, isInResolutionMode, resolutionData, currentInputSchema]);
|
||||
|
||||
const outputSchema = useMemo(() => {
|
||||
if (isAgent && isInResolutionMode && resolutionData) {
|
||||
return mergeSchemaForResolution(
|
||||
resolutionData.currentSchema.output_schema,
|
||||
resolutionData.pendingUpdate.output_schema,
|
||||
resolutionData,
|
||||
"output",
|
||||
);
|
||||
}
|
||||
return currentOutputSchema;
|
||||
}, [isAgent, isInResolutionMode, resolutionData, currentOutputSchema]);
|
||||
|
||||
return {
|
||||
inputSchema,
|
||||
outputSchema,
|
||||
};
|
||||
};
|
||||
@@ -5,20 +5,16 @@ import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { BlockUIType } from "../../types";
|
||||
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
|
||||
|
||||
export const FormCreator = React.memo(
|
||||
({
|
||||
jsonSchema,
|
||||
nodeId,
|
||||
uiType,
|
||||
showHandles = true,
|
||||
className,
|
||||
}: {
|
||||
jsonSchema: RJSFSchema;
|
||||
nodeId: string;
|
||||
uiType: BlockUIType;
|
||||
showHandles?: boolean;
|
||||
className?: string;
|
||||
}) => {
|
||||
interface FormCreatorProps {
|
||||
jsonSchema: RJSFSchema;
|
||||
nodeId: string;
|
||||
uiType: BlockUIType;
|
||||
showHandles?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
|
||||
({ jsonSchema, nodeId, uiType, showHandles = true, className }) => {
|
||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||
|
||||
const getHardCodedValues = useNodeStore(
|
||||
@@ -48,7 +44,10 @@ export const FormCreator = React.memo(
|
||||
: hardcodedValues;
|
||||
|
||||
return (
|
||||
<div className={className}>
|
||||
<div
|
||||
className={className}
|
||||
data-id={`form-creator-container-${nodeId}-node`}
|
||||
>
|
||||
<FormRenderer
|
||||
jsonSchema={jsonSchema}
|
||||
handleChange={handleChange}
|
||||
|
||||
@@ -14,6 +14,8 @@ import {
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { getTypeDisplayInfo } from "./helpers";
|
||||
import { BlockUIType } from "../../types";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useBrokenOutputs } from "./useBrokenOutputs";
|
||||
|
||||
export const OutputHandler = ({
|
||||
outputSchema,
|
||||
@@ -27,6 +29,9 @@ export const OutputHandler = ({
|
||||
const { isOutputConnected } = useEdgeStore();
|
||||
const properties = outputSchema?.properties || {};
|
||||
const [isOutputVisible, setIsOutputVisible] = useState(true);
|
||||
const brokenOutputs = useBrokenOutputs(nodeId);
|
||||
|
||||
console.log("brokenOutputs", brokenOutputs);
|
||||
|
||||
const showHandles = uiType !== BlockUIType.OUTPUT;
|
||||
|
||||
@@ -44,9 +49,14 @@ export const OutputHandler = ({
|
||||
const shouldShow = isConnected || isOutputVisible;
|
||||
const { displayType, colorClass, hexColor } =
|
||||
getTypeDisplayInfo(fieldSchema);
|
||||
const isBroken = brokenOutputs.has(fullKey);
|
||||
|
||||
return shouldShow ? (
|
||||
<div key={fullKey} className="flex flex-col items-end gap-2">
|
||||
<div
|
||||
key={fullKey}
|
||||
className="flex flex-col items-end gap-2"
|
||||
data-tutorial-id={`output-handler-${nodeId}-${fieldTitle}`}
|
||||
>
|
||||
<div className="relative flex items-center gap-2">
|
||||
{fieldSchema?.description && (
|
||||
<TooltipProvider>
|
||||
@@ -64,15 +74,29 @@ export const OutputHandler = ({
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
<Text variant="body" className="text-slate-700">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"text-slate-700",
|
||||
isBroken && "text-red-500 line-through",
|
||||
)}
|
||||
>
|
||||
{fieldTitle}
|
||||
</Text>
|
||||
<Text variant="small" as="span" className={colorClass}>
|
||||
<Text
|
||||
variant="small"
|
||||
as="span"
|
||||
className={cn(
|
||||
colorClass,
|
||||
isBroken && "!text-red-500 line-through",
|
||||
)}
|
||||
>
|
||||
({displayType})
|
||||
</Text>
|
||||
|
||||
{showHandles && (
|
||||
<OutputNodeHandle
|
||||
isBroken={isBroken}
|
||||
field_name={fullKey}
|
||||
nodeId={nodeId}
|
||||
hexColor={hexColor}
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import { useMemo } from "react";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
|
||||
/**
|
||||
* Hook to get the set of broken output names for a node in resolution mode.
|
||||
*/
|
||||
export function useBrokenOutputs(nodeID: string): Set<string> {
|
||||
// Subscribe to the actual state values, not just methods
|
||||
const isInResolution = useNodeStore((state) =>
|
||||
state.nodesInResolutionMode.has(nodeID),
|
||||
);
|
||||
const resolutionData = useNodeStore((state) =>
|
||||
state.nodeResolutionData.get(nodeID),
|
||||
);
|
||||
|
||||
return useMemo(() => {
|
||||
if (!isInResolution || !resolutionData) {
|
||||
return new Set<string>();
|
||||
}
|
||||
|
||||
return new Set(resolutionData.incompatibilities.missingOutputs);
|
||||
}, [isInResolution, resolutionData]);
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
// Block IDs for tutorial blocks
|
||||
export const BLOCK_IDS = {
|
||||
CALCULATOR: "b1ab9b19-67a6-406d-abf5-2dba76d00c79",
|
||||
AGENT_INPUT: "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
AGENT_OUTPUT: "363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
} as const;
|
||||
|
||||
export const TUTORIAL_SELECTORS = {
|
||||
// Custom nodes - These are all before saving
|
||||
INPUT_NODE: '[data-id="custom-node-2"]',
|
||||
OUTPUT_NODE: '[data-id="custom-node-3 "]',
|
||||
CALCULATOR_NODE: '[data-id="custom-node-1"]',
|
||||
|
||||
// Paricular field selector
|
||||
NAME_FIELD_OUTPUT_NODE: '[data-id="field-3-root_name"]',
|
||||
|
||||
// Output Handlers
|
||||
SECOND_CALCULATOR_RESULT_OUTPUT_HANDLER:
|
||||
'[data-tutorial-id="output-handler-2-result"]',
|
||||
FIRST_CALCULATOR_RESULT_OUTPUT_HANDLER:
|
||||
'[data-tutorial-id="output-handler-1-result"]',
|
||||
|
||||
// Input Handler
|
||||
SECOND_CALCULATOR_NUMBER_A_INPUT_HANDLER:
|
||||
'[data-tutorial-id="input-handler-2-a"]',
|
||||
OUTPUT_VALUE_INPUT_HANDLEER: '[data-tutorial-id="label-3-root_value"]',
|
||||
|
||||
// Block Menu
|
||||
BLOCKS_TRIGGER: '[data-id="blocks-control-popover-trigger"]',
|
||||
BLOCKS_CONTENT: '[data-id="blocks-control-popover-content"]',
|
||||
BLOCKS_SEARCH_INPUT:
|
||||
'[data-id="blocks-control-search-bar"] input[type="text"]',
|
||||
BLOCKS_SEARCH_INPUT_BOX: '[data-id="blocks-control-search-bar"]',
|
||||
|
||||
// Add a new selector that checks within search results
|
||||
|
||||
// Block Menu Sidebar
|
||||
MENU_ITEM_INPUT_BLOCKS: '[data-id="menu-item-input_blocks"]',
|
||||
MENU_ITEM_ALL_BLOCKS: '[data-id="menu-item-all_blocks"]',
|
||||
MENU_ITEM_ACTION_BLOCKS: '[data-id="menu-item-action_blocks"]',
|
||||
MENU_ITEM_OUTPUT_BLOCKS: '[data-id="menu-item-output_blocks"]',
|
||||
MENU_ITEM_INTEGRATIONS: '[data-id="menu-item-integrations"]',
|
||||
MENU_ITEM_MY_AGENTS: '[data-id="menu-item-my_agents"]',
|
||||
MENU_ITEM_MARKETPLACE: '[data-id="menu-item-marketplace_agents"]',
|
||||
MENU_ITEM_SUGGESTION: '[data-id="menu-item-suggestion"]',
|
||||
|
||||
// Block Cards
|
||||
BLOCK_CARD_PREFIX: '[data-id^="block-card-"]',
|
||||
BLOCK_CARD_AGENT_INPUT: '[data-id="block-card-AgentInputBlock"]',
|
||||
// Calculator block - legacy ID used in old tutorial
|
||||
BLOCK_CARD_CALCULATOR:
|
||||
'[data-id="block-card-b1ab9b1967a6406dabf52dba76d00c79"]',
|
||||
BLOCK_CARD_CALCULATOR_IN_SEARCH:
|
||||
'[data-id="blocks-control-search-results"] [data-id="block-card-b1ab9b1967a6406dabf52dba76d00c79"]',
|
||||
|
||||
// Save Control
|
||||
SAVE_TRIGGER: '[data-id="save-control-popover-trigger"]',
|
||||
SAVE_CONTENT: '[data-id="save-control-popover-content"]',
|
||||
SAVE_AGENT_BUTTON: '[data-id="save-control-save-agent"]',
|
||||
SAVE_NAME_INPUT: '[data-id="save-control-name-input"]',
|
||||
SAVE_DESCRIPTION_INPUT: '[data-id="save-control-description-input"]',
|
||||
|
||||
// Builder Actions (Run, Schedule, Outputs)
|
||||
BUILDER_ACTIONS: '[data-id="builder-actions"]',
|
||||
RUN_BUTTON: '[data-id="run-graph-button"]',
|
||||
STOP_BUTTON: '[data-id="stop-graph-button"]',
|
||||
SCHEDULE_BUTTON: '[data-id="schedule-graph-button"]',
|
||||
AGENT_OUTPUTS_BUTTON: '[data-id="agent-outputs-button"]',
|
||||
|
||||
// Run Input Dialog
|
||||
RUN_INPUT_DIALOG_CONTENT: '[data-id="run-input-dialog-content"]',
|
||||
RUN_INPUT_CREDENTIALS_SECTION: '[data-id="run-input-credentials-section"]',
|
||||
RUN_INPUT_CREDENTIALS_FORM: '[data-id="run-input-credentials-form"]',
|
||||
RUN_INPUT_INPUTS_SECTION: '[data-id="run-input-inputs-section"]',
|
||||
RUN_INPUT_INPUTS_FORM: '[data-id="run-input-inputs-form"]',
|
||||
RUN_INPUT_ACTIONS_SECTION: '[data-id="run-input-actions-section"]',
|
||||
RUN_INPUT_MANUAL_RUN_BUTTON: '[data-id="run-input-manual-run-button"]',
|
||||
RUN_INPUT_SCHEDULE_BUTTON: '[data-id="run-input-schedule-button"]',
|
||||
|
||||
// Custom Controls (bottom left)
|
||||
CUSTOM_CONTROLS: '[data-id="custom-controls"]',
|
||||
ZOOM_IN_BUTTON: '[data-id="zoom-in-button"]',
|
||||
ZOOM_OUT_BUTTON: '[data-id="zoom-out-button"]',
|
||||
FIT_VIEW_BUTTON: '[data-id="fit-view-button"]',
|
||||
LOCK_BUTTON: '[data-id="lock-button"]',
|
||||
TUTORIAL_BUTTON: '[data-id="tutorial-button"]',
|
||||
|
||||
// Canvas
|
||||
REACT_FLOW_CANVAS: ".react-flow__pane",
|
||||
REACT_FLOW_NODE: ".react-flow__node",
|
||||
REACT_FLOW_NODE_FIRST: '[data-testid^="rf__node-"]:first-child',
|
||||
REACT_FLOW_EDGE: '[data-testid^="rf__edge-"]',
|
||||
|
||||
// Node elements
|
||||
NODE_CONTAINER: '[data-id^="custom-node-"]',
|
||||
NODE_HEADER: '[data-id^="node-header-"]',
|
||||
NODE_INPUT_HANDLES: '[data-tutorial-id="input-handles"]',
|
||||
NODE_OUTPUT_HANDLE: '[data-handlepos="right"]',
|
||||
NODE_INPUT_HANDLE: "[data-nodeid]",
|
||||
FIRST_CALCULATOR_NODE_OUTPUT: '[data-tutorial-id="node-output"]',
|
||||
// These are the Id's of the nodes before saving
|
||||
CALCULATOR_NODE_FORM_CONTAINER: '[data-id^="form-creator-container-1-node"]', // <-- Add this line
|
||||
AGENT_INPUT_NODE_FORM_CONTAINER: '[data-id^="form-creator-container-2-node"]', // <-- Add this line
|
||||
AGENT_OUTPUT_NODE_FORM_CONTAINER:
|
||||
'[data-id^="form-creator-container-3-node"]', // <-- Add this line
|
||||
|
||||
// Execution badges
|
||||
BADGE_QUEUED: '[data-id^="badge-"][data-id$="-QUEUED"]',
|
||||
BADGE_COMPLETED: '[data-id^="badge-"][data-id$="-COMPLETED"]',
|
||||
|
||||
// Undo/Redo
|
||||
UNDO_BUTTON: '[data-id="undo-button"]',
|
||||
REDO_BUTTON: '[data-id="redo-button"]',
|
||||
} as const;
|
||||
|
||||
export const CSS_CLASSES = {
|
||||
DISABLE: "new-builder-tutorial-disable",
|
||||
HIGHLIGHT: "new-builder-tutorial-highlight",
|
||||
PULSE: "new-builder-tutorial-pulse",
|
||||
} as const;
|
||||
|
||||
export const TUTORIAL_CONFIG = {
|
||||
ELEMENT_CHECK_INTERVAL: 50, // ms
|
||||
INPUT_CHECK_INTERVAL: 100, // ms
|
||||
USE_MODAL_OVERLAY: true,
|
||||
SCROLL_BEHAVIOR: "smooth" as const,
|
||||
SCROLL_BLOCK: "center" as const,
|
||||
SEARCH_TERM_CALCULATOR: "Calculator",
|
||||
} as const;
|
||||
@@ -0,0 +1,89 @@
|
||||
import { BLOCK_IDS } from "../constants";
|
||||
import { useNodeStore } from "../../../../stores/nodeStore";
|
||||
import { getV2GetSpecificBlocks } from "@/app/api/__generated__/endpoints/default/default";
|
||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||
|
||||
const prefetchedBlocks: Map<string, BlockInfo> = new Map();
|
||||
|
||||
export const prefetchTutorialBlocks = async (): Promise<void> => {
|
||||
try {
|
||||
const blockIds = [BLOCK_IDS.CALCULATOR];
|
||||
const response = await getV2GetSpecificBlocks({ block_ids: blockIds });
|
||||
|
||||
if (response.status === 200 && response.data) {
|
||||
response.data.forEach((block) => {
|
||||
prefetchedBlocks.set(block.id, block);
|
||||
});
|
||||
console.debug("Tutorial blocks prefetched:", prefetchedBlocks.size);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to prefetch tutorial blocks:", error);
|
||||
}
|
||||
};
|
||||
|
||||
export const getPrefetchedBlock = (blockId: string): BlockInfo | undefined => {
|
||||
return prefetchedBlocks.get(blockId);
|
||||
};
|
||||
|
||||
export const clearPrefetchedBlocks = (): void => {
|
||||
prefetchedBlocks.clear();
|
||||
};
|
||||
|
||||
export const addPrefetchedBlock = (
|
||||
blockId: string,
|
||||
position?: { x: number; y: number },
|
||||
): void => {
|
||||
const block = prefetchedBlocks.get(blockId);
|
||||
if (block) {
|
||||
useNodeStore.getState().addBlock(block, {}, position);
|
||||
} else {
|
||||
console.error(`Block ${blockId} not found in prefetched blocks`);
|
||||
}
|
||||
};
|
||||
|
||||
export const getNodeByBlockId = (blockId: string) => {
|
||||
const nodes = useNodeStore.getState().nodes;
|
||||
return nodes.find((n) => n.data?.block_id === blockId);
|
||||
};
|
||||
|
||||
export const addSecondCalculatorBlock = (): void => {
|
||||
const firstCalculatorNode = getNodeByBlockId(BLOCK_IDS.CALCULATOR);
|
||||
|
||||
if (firstCalculatorNode) {
|
||||
const calcX = firstCalculatorNode.position.x;
|
||||
const calcY = firstCalculatorNode.position.y;
|
||||
|
||||
addPrefetchedBlock(BLOCK_IDS.CALCULATOR, {
|
||||
x: calcX + 500,
|
||||
y: calcY,
|
||||
});
|
||||
} else {
|
||||
addPrefetchedBlock(BLOCK_IDS.CALCULATOR);
|
||||
}
|
||||
};
|
||||
|
||||
export const getCalculatorNodes = () => {
|
||||
const nodes = useNodeStore.getState().nodes;
|
||||
return nodes.filter((n) => n.data?.block_id === BLOCK_IDS.CALCULATOR);
|
||||
};
|
||||
|
||||
export const getSecondCalculatorNode = () => {
|
||||
const calculatorNodes = getCalculatorNodes();
|
||||
return calculatorNodes.length >= 2 ? calculatorNodes[1] : null;
|
||||
};
|
||||
|
||||
export const getFormContainerSelector = (blockId: string): string | null => {
|
||||
const node = getNodeByBlockId(blockId);
|
||||
if (node) {
|
||||
return `[data-id="form-creator-container-${node.id}"]`;
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
export const getFormContainerElement = (blockId: string): Element | null => {
|
||||
const selector = getFormContainerSelector(blockId);
|
||||
if (selector) {
|
||||
return document.querySelector(selector);
|
||||
}
|
||||
return null;
|
||||
};
|
||||
@@ -0,0 +1,83 @@
|
||||
import { TUTORIAL_CONFIG, TUTORIAL_SELECTORS } from "../constants";
|
||||
import { useNodeStore } from "../../../../stores/nodeStore";
|
||||
|
||||
export const waitForNodeOnCanvas = (
|
||||
timeout = 10000,
|
||||
): Promise<Element | null> => {
|
||||
return new Promise((resolve) => {
|
||||
const startTime = Date.now();
|
||||
|
||||
const checkNode = () => {
|
||||
const storeNodes = useNodeStore.getState().nodes;
|
||||
if (storeNodes.length > 0) {
|
||||
const domNode = document.querySelector(
|
||||
TUTORIAL_SELECTORS.REACT_FLOW_NODE,
|
||||
);
|
||||
if (domNode) {
|
||||
resolve(domNode);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (Date.now() - startTime > timeout) {
|
||||
resolve(null);
|
||||
} else {
|
||||
setTimeout(checkNode, TUTORIAL_CONFIG.ELEMENT_CHECK_INTERVAL);
|
||||
}
|
||||
};
|
||||
checkNode();
|
||||
});
|
||||
};
|
||||
|
||||
export const waitForNodesCount = (
|
||||
count: number,
|
||||
timeout = 10000,
|
||||
): Promise<boolean> => {
|
||||
return new Promise((resolve) => {
|
||||
const startTime = Date.now();
|
||||
|
||||
const checkNodes = () => {
|
||||
const currentCount = useNodeStore.getState().nodes.length;
|
||||
if (currentCount >= count) {
|
||||
resolve(true);
|
||||
} else if (Date.now() - startTime > timeout) {
|
||||
resolve(false);
|
||||
} else {
|
||||
setTimeout(checkNodes, TUTORIAL_CONFIG.ELEMENT_CHECK_INTERVAL);
|
||||
}
|
||||
};
|
||||
checkNodes();
|
||||
});
|
||||
};
|
||||
|
||||
export const getNodesCount = (): number => {
|
||||
return useNodeStore.getState().nodes.length;
|
||||
};
|
||||
|
||||
export const getFirstNode = () => {
|
||||
const nodes = useNodeStore.getState().nodes;
|
||||
return nodes.length > 0 ? nodes[0] : null;
|
||||
};
|
||||
|
||||
export const getNodeById = (nodeId: string) => {
|
||||
const nodes = useNodeStore.getState().nodes;
|
||||
return nodes.find((n) => n.id === nodeId);
|
||||
};
|
||||
|
||||
export const nodeHasValues = (nodeId: string): boolean => {
|
||||
const node = getNodeById(nodeId);
|
||||
if (!node) return false;
|
||||
const hardcodedValues = node.data?.hardcodedValues || {};
|
||||
return Object.values(hardcodedValues).some(
|
||||
(value) => value !== undefined && value !== null && value !== "",
|
||||
);
|
||||
};
|
||||
|
||||
export const fitViewToScreen = () => {
|
||||
const fitViewButton = document.querySelector(
|
||||
TUTORIAL_SELECTORS.FIT_VIEW_BUTTON,
|
||||
) as HTMLButtonElement;
|
||||
if (fitViewButton) {
|
||||
fitViewButton.click();
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,19 @@
|
||||
import { useNodeStore } from "../../../../stores/nodeStore";
|
||||
import { useEdgeStore } from "../../../../stores/edgeStore";
|
||||
|
||||
export const isConnectionMade = (
|
||||
sourceBlockId: string,
|
||||
targetBlockId: string,
|
||||
): boolean => {
|
||||
const edges = useEdgeStore.getState().edges;
|
||||
const nodes = useNodeStore.getState().nodes;
|
||||
|
||||
const sourceNode = nodes.find((n) => n.data?.block_id === sourceBlockId);
|
||||
const targetNode = nodes.find((n) => n.data?.block_id === targetBlockId);
|
||||
|
||||
if (!sourceNode || !targetNode) return false;
|
||||
|
||||
return edges.some((edge) => {
|
||||
return edge.source === sourceNode.id && edge.target === targetNode.id;
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,180 @@
|
||||
import { TUTORIAL_CONFIG, TUTORIAL_SELECTORS } from "../constants";
|
||||
|
||||
export const waitForElement = (
|
||||
selector: string,
|
||||
timeout = 10000,
|
||||
): Promise<Element> => {
|
||||
return new Promise((resolve, reject) => {
|
||||
const startTime = Date.now();
|
||||
|
||||
const checkElement = () => {
|
||||
const element = document.querySelector(selector);
|
||||
if (element) {
|
||||
resolve(element);
|
||||
} else if (Date.now() - startTime > timeout) {
|
||||
reject(new Error(`Element ${selector} not found within ${timeout}ms`));
|
||||
} else {
|
||||
setTimeout(checkElement, TUTORIAL_CONFIG.ELEMENT_CHECK_INTERVAL);
|
||||
}
|
||||
};
|
||||
checkElement();
|
||||
});
|
||||
};
|
||||
|
||||
export const waitForInputValue = (
|
||||
selector: string,
|
||||
targetValue: string,
|
||||
timeout = 30000,
|
||||
): Promise<void> => {
|
||||
return new Promise((resolve) => {
|
||||
const startTime = Date.now();
|
||||
|
||||
const checkInput = () => {
|
||||
const input = document.querySelector(selector) as HTMLInputElement;
|
||||
if (input) {
|
||||
const currentValue = input.value.toLowerCase().trim();
|
||||
const target = targetValue.toLowerCase().trim();
|
||||
|
||||
if (currentValue.includes(target) || target.includes(currentValue)) {
|
||||
if (currentValue.length >= 4 || currentValue === target) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (Date.now() - startTime > timeout) {
|
||||
resolve();
|
||||
} else {
|
||||
setTimeout(checkInput, TUTORIAL_CONFIG.INPUT_CHECK_INTERVAL);
|
||||
}
|
||||
};
|
||||
checkInput();
|
||||
});
|
||||
};
|
||||
|
||||
export const waitForSearchResult = (
|
||||
selector: string,
|
||||
timeout = 15000,
|
||||
): Promise<Element | null> => {
|
||||
return new Promise((resolve) => {
|
||||
const startTime = Date.now();
|
||||
|
||||
const checkResult = () => {
|
||||
const element = document.querySelector(selector);
|
||||
if (element) {
|
||||
resolve(element);
|
||||
} else if (Date.now() - startTime > timeout) {
|
||||
resolve(null);
|
||||
} else {
|
||||
setTimeout(checkResult, TUTORIAL_CONFIG.ELEMENT_CHECK_INTERVAL);
|
||||
}
|
||||
};
|
||||
checkResult();
|
||||
});
|
||||
};
|
||||
|
||||
export const waitForAnyBlockCard = (
|
||||
timeout = 10000,
|
||||
): Promise<Element | null> => {
|
||||
return new Promise((resolve) => {
|
||||
const startTime = Date.now();
|
||||
|
||||
const checkBlock = () => {
|
||||
const block = document.querySelector(
|
||||
TUTORIAL_SELECTORS.BLOCK_CARD_PREFIX,
|
||||
);
|
||||
if (block) {
|
||||
resolve(block);
|
||||
} else if (Date.now() - startTime > timeout) {
|
||||
resolve(null);
|
||||
} else {
|
||||
setTimeout(checkBlock, TUTORIAL_CONFIG.ELEMENT_CHECK_INTERVAL);
|
||||
}
|
||||
};
|
||||
checkBlock();
|
||||
});
|
||||
};
|
||||
|
||||
export const focusElement = (selector: string): void => {
|
||||
const element = document.querySelector(selector) as HTMLElement;
|
||||
if (element) {
|
||||
element.focus();
|
||||
}
|
||||
};
|
||||
|
||||
export const scrollIntoView = (selector: string): void => {
|
||||
const element = document.querySelector(selector);
|
||||
if (element) {
|
||||
element.scrollIntoView({
|
||||
behavior: "smooth",
|
||||
block: "center",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
export const typeIntoInput = (selector: string, text: string) => {
|
||||
const input = document.querySelector(selector) as HTMLInputElement;
|
||||
if (input) {
|
||||
input.focus();
|
||||
input.value = text;
|
||||
input.dispatchEvent(new Event("input", { bubbles: true }));
|
||||
input.dispatchEvent(new Event("change", { bubbles: true }));
|
||||
}
|
||||
};
|
||||
|
||||
export const observeElement = (
|
||||
selector: string,
|
||||
callback: (element: Element) => void,
|
||||
): MutationObserver => {
|
||||
const observer = new MutationObserver((mutations, obs) => {
|
||||
const element = document.querySelector(selector);
|
||||
if (element) {
|
||||
callback(element);
|
||||
obs.disconnect();
|
||||
}
|
||||
});
|
||||
|
||||
observer.observe(document.body, {
|
||||
childList: true,
|
||||
subtree: true,
|
||||
});
|
||||
|
||||
const element = document.querySelector(selector);
|
||||
if (element) {
|
||||
callback(element);
|
||||
observer.disconnect();
|
||||
}
|
||||
|
||||
return observer;
|
||||
};
|
||||
|
||||
export const watchSearchInput = (
|
||||
targetValue: string,
|
||||
onMatch: () => void,
|
||||
): (() => void) => {
|
||||
const input = document.querySelector(
|
||||
TUTORIAL_SELECTORS.BLOCKS_SEARCH_INPUT,
|
||||
) as HTMLInputElement;
|
||||
if (!input) return () => {};
|
||||
|
||||
let hasMatched = false;
|
||||
|
||||
const handler = () => {
|
||||
if (hasMatched) return;
|
||||
|
||||
const currentValue = input.value.toLowerCase().trim();
|
||||
const target = targetValue.toLowerCase().trim();
|
||||
|
||||
if (currentValue.length >= 4 && target.startsWith(currentValue)) {
|
||||
hasMatched = true;
|
||||
onMatch();
|
||||
}
|
||||
};
|
||||
|
||||
input.addEventListener("input", handler);
|
||||
|
||||
return () => {
|
||||
input.removeEventListener("input", handler);
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,56 @@
|
||||
import { CSS_CLASSES, TUTORIAL_SELECTORS } from "../constants";
|
||||
|
||||
export const disableOtherBlocks = (targetBlockSelector: string) => {
|
||||
document
|
||||
.querySelectorAll(TUTORIAL_SELECTORS.BLOCK_CARD_PREFIX)
|
||||
.forEach((block) => {
|
||||
const isTarget = block.matches(targetBlockSelector);
|
||||
block.classList.toggle(CSS_CLASSES.DISABLE, !isTarget);
|
||||
block.classList.toggle(CSS_CLASSES.HIGHLIGHT, isTarget);
|
||||
});
|
||||
};
|
||||
|
||||
export const enableAllBlocks = () => {
|
||||
document
|
||||
.querySelectorAll(TUTORIAL_SELECTORS.BLOCK_CARD_PREFIX)
|
||||
.forEach((block) => {
|
||||
block.classList.remove(
|
||||
CSS_CLASSES.DISABLE,
|
||||
CSS_CLASSES.HIGHLIGHT,
|
||||
CSS_CLASSES.PULSE,
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
export const highlightElement = (selector: string) => {
|
||||
const element = document.querySelector(selector);
|
||||
if (element) {
|
||||
element.classList.add(CSS_CLASSES.HIGHLIGHT);
|
||||
}
|
||||
};
|
||||
|
||||
export const removeAllHighlights = () => {
|
||||
document.querySelectorAll(`.${CSS_CLASSES.HIGHLIGHT}`).forEach((el) => {
|
||||
el.classList.remove(CSS_CLASSES.HIGHLIGHT);
|
||||
});
|
||||
document.querySelectorAll(`.${CSS_CLASSES.PULSE}`).forEach((el) => {
|
||||
el.classList.remove(CSS_CLASSES.PULSE);
|
||||
});
|
||||
};
|
||||
|
||||
export const pulseElement = (selector: string) => {
|
||||
const element = document.querySelector(selector);
|
||||
if (element) {
|
||||
element.classList.add(CSS_CLASSES.PULSE);
|
||||
}
|
||||
};
|
||||
|
||||
export const highlightFirstBlockInSearch = () => {
|
||||
const firstBlock = document.querySelector(
|
||||
TUTORIAL_SELECTORS.BLOCK_CARD_PREFIX,
|
||||
);
|
||||
if (firstBlock) {
|
||||
firstBlock.classList.add(CSS_CLASSES.PULSE);
|
||||
firstBlock.scrollIntoView({ behavior: "smooth", block: "center" });
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,66 @@
|
||||
export {
|
||||
waitForElement,
|
||||
waitForInputValue,
|
||||
waitForSearchResult,
|
||||
waitForAnyBlockCard,
|
||||
focusElement,
|
||||
scrollIntoView,
|
||||
typeIntoInput,
|
||||
observeElement,
|
||||
watchSearchInput,
|
||||
} from "./dom";
|
||||
|
||||
export {
|
||||
disableOtherBlocks,
|
||||
enableAllBlocks,
|
||||
highlightElement,
|
||||
removeAllHighlights,
|
||||
pulseElement,
|
||||
highlightFirstBlockInSearch,
|
||||
} from "./highlights";
|
||||
|
||||
export {
|
||||
prefetchTutorialBlocks,
|
||||
getPrefetchedBlock,
|
||||
clearPrefetchedBlocks,
|
||||
addPrefetchedBlock,
|
||||
getNodeByBlockId,
|
||||
addSecondCalculatorBlock,
|
||||
getCalculatorNodes,
|
||||
getSecondCalculatorNode,
|
||||
getFormContainerSelector,
|
||||
getFormContainerElement,
|
||||
} from "./blocks";
|
||||
|
||||
export {
|
||||
waitForNodeOnCanvas,
|
||||
waitForNodesCount,
|
||||
getNodesCount,
|
||||
getFirstNode,
|
||||
getNodeById,
|
||||
nodeHasValues,
|
||||
fitViewToScreen,
|
||||
} from "./canvas";
|
||||
|
||||
export { isConnectionMade } from "./connections";
|
||||
|
||||
export {
|
||||
forceBlockMenuOpen,
|
||||
openBlockMenu,
|
||||
closeBlockMenu,
|
||||
clearBlockMenuSearch,
|
||||
} from "./menu";
|
||||
|
||||
export {
|
||||
openSaveControl,
|
||||
closeSaveControl,
|
||||
forceSaveOpen,
|
||||
clickSaveButton,
|
||||
isAgentSaved,
|
||||
} from "./save";
|
||||
|
||||
export {
|
||||
handleTutorialCancel,
|
||||
handleTutorialSkip,
|
||||
handleTutorialComplete,
|
||||
} from "./state";
|
||||
@@ -0,0 +1,25 @@
|
||||
import { TUTORIAL_SELECTORS } from "../constants";
|
||||
import { useControlPanelStore } from "../../../../stores/controlPanelStore";
|
||||
|
||||
export const forceBlockMenuOpen = (force: boolean) => {
|
||||
useControlPanelStore.getState().setForceOpenBlockMenu(force);
|
||||
};
|
||||
|
||||
export const openBlockMenu = () => {
|
||||
useControlPanelStore.getState().setBlockMenuOpen(true);
|
||||
};
|
||||
|
||||
export const closeBlockMenu = () => {
|
||||
useControlPanelStore.getState().setBlockMenuOpen(false);
|
||||
useControlPanelStore.getState().setForceOpenBlockMenu(false);
|
||||
};
|
||||
|
||||
export const clearBlockMenuSearch = () => {
|
||||
const input = document.querySelector(
|
||||
TUTORIAL_SELECTORS.BLOCKS_SEARCH_INPUT,
|
||||
) as HTMLInputElement;
|
||||
if (input) {
|
||||
input.value = "";
|
||||
input.dispatchEvent(new Event("input", { bubbles: true }));
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,31 @@
|
||||
import { TUTORIAL_SELECTORS } from "../constants";
|
||||
import { useControlPanelStore } from "../../../../stores/controlPanelStore";
|
||||
|
||||
export const openSaveControl = () => {
|
||||
useControlPanelStore.getState().setSaveControlOpen(true);
|
||||
};
|
||||
|
||||
export const closeSaveControl = () => {
|
||||
useControlPanelStore.getState().setSaveControlOpen(false);
|
||||
useControlPanelStore.getState().setForceOpenSave(false);
|
||||
};
|
||||
|
||||
export const forceSaveOpen = (force: boolean) => {
|
||||
useControlPanelStore.getState().setForceOpenSave(force);
|
||||
};
|
||||
|
||||
export const clickSaveButton = () => {
|
||||
const saveButton = document.querySelector(
|
||||
TUTORIAL_SELECTORS.SAVE_AGENT_BUTTON,
|
||||
) as HTMLButtonElement;
|
||||
if (saveButton && !saveButton.disabled) {
|
||||
saveButton.click();
|
||||
}
|
||||
};
|
||||
|
||||
export const isAgentSaved = (): boolean => {
|
||||
const versionInput = document.querySelector(
|
||||
'[data-tutorial-id="save-control-version-output"]',
|
||||
) as HTMLInputElement;
|
||||
return !!(versionInput && versionInput.value && versionInput.value !== "-");
|
||||
};
|
||||
@@ -0,0 +1,49 @@
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { closeBlockMenu } from "./menu";
|
||||
import { closeSaveControl, forceSaveOpen } from "./save";
|
||||
import { removeAllHighlights, enableAllBlocks } from "./highlights";
|
||||
|
||||
const clearTutorialIntervals = () => {
|
||||
const intervalKeys = [
|
||||
"__tutorialCalcInterval",
|
||||
"__tutorialCheckInterval",
|
||||
"__tutorialSecondCalcInterval",
|
||||
];
|
||||
|
||||
intervalKeys.forEach((key) => {
|
||||
if ((window as any)[key]) {
|
||||
clearInterval((window as any)[key]);
|
||||
delete (window as any)[key];
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
export const handleTutorialCancel = (_tour?: any) => {
|
||||
clearTutorialIntervals();
|
||||
closeBlockMenu();
|
||||
closeSaveControl();
|
||||
forceSaveOpen(false);
|
||||
removeAllHighlights();
|
||||
enableAllBlocks();
|
||||
storage.set(Key.SHEPHERD_TOUR, "canceled");
|
||||
};
|
||||
|
||||
export const handleTutorialSkip = (_tour?: any) => {
|
||||
clearTutorialIntervals();
|
||||
closeBlockMenu();
|
||||
closeSaveControl();
|
||||
forceSaveOpen(false);
|
||||
removeAllHighlights();
|
||||
enableAllBlocks();
|
||||
storage.set(Key.SHEPHERD_TOUR, "skipped");
|
||||
};
|
||||
|
||||
export const handleTutorialComplete = () => {
|
||||
clearTutorialIntervals();
|
||||
closeBlockMenu();
|
||||
closeSaveControl();
|
||||
forceSaveOpen(false);
|
||||
removeAllHighlights();
|
||||
enableAllBlocks();
|
||||
storage.set(Key.SHEPHERD_TOUR, "completed");
|
||||
};
|
||||
@@ -0,0 +1,7 @@
|
||||
// These are SVG Phosphor icons
|
||||
|
||||
export const ICONS = {
|
||||
ClickIcon: `<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" fill="#000000" viewBox="0 0 256 256"><path d="M88,24V16a8,8,0,0,1,16,0v8a8,8,0,0,1-16,0ZM16,104h8a8,8,0,0,0,0-16H16a8,8,0,0,0,0,16ZM124.42,39.16a8,8,0,0,0,10.74-3.58l8-16a8,8,0,0,0-14.31-7.16l-8,16A8,8,0,0,0,124.42,39.16Zm-96,81.69-16,8a8,8,0,0,0,7.16,14.31l16-8a8,8,0,1,0-7.16-14.31ZM219.31,184a16,16,0,0,1,0,22.63l-12.68,12.68a16,16,0,0,1-22.63,0L132.7,168,115,214.09c0,.1-.08.21-.13.32a15.83,15.83,0,0,1-14.6,9.59l-.79,0a15.83,15.83,0,0,1-14.41-11L32.8,52.92A16,16,0,0,1,52.92,32.8L213,85.07a16,16,0,0,1,1.41,29.8l-.32.13L168,132.69ZM208,195.31,156.69,144h0a16,16,0,0,1,4.93-26l.32-.14,45.95-17.64L48,48l52.2,159.86,17.65-46c0-.11.08-.22.13-.33a16,16,0,0,1,11.69-9.34,16.72,16.72,0,0,1,3-.28,16,16,0,0,1,11.3,4.69L195.31,208Z"></path></svg>`,
|
||||
Keyboard: `<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" fill="#000000" viewBox="0 0 256 256"><path d="M224,48H32A16,16,0,0,0,16,64V192a16,16,0,0,0,16,16H224a16,16,0,0,0,16-16V64A16,16,0,0,0,224,48Zm0,144H32V64H224V192Zm-16-64a8,8,0,0,1-8,8H56a8,8,0,0,1,0-16H200A8,8,0,0,1,208,128Zm0-32a8,8,0,0,1-8,8H56a8,8,0,0,1,0-16H200A8,8,0,0,1,208,96ZM72,160a8,8,0,0,1-8,8H56a8,8,0,0,1,0-16h8A8,8,0,0,1,72,160Zm96,0a8,8,0,0,1-8,8H96a8,8,0,0,1,0-16h64A8,8,0,0,1,168,160Zm40,0a8,8,0,0,1-8,8h-8a8,8,0,0,1,0-16h8A8,8,0,0,1,208,160Z"></path></svg>`,
|
||||
Drag: `<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" fill="#000000" viewBox="0 0 256 256"><path d="M188,80a27.79,27.79,0,0,0-13.36,3.4,28,28,0,0,0-46.64-11A28,28,0,0,0,80,92v20H68a28,28,0,0,0-28,28v12a88,88,0,0,0,176,0V108A28,28,0,0,0,188,80Zm12,72a72,72,0,0,1-144,0V140a12,12,0,0,1,12-12H80v24a8,8,0,0,0,16,0V92a12,12,0,0,1,24,0v28a8,8,0,0,0,16,0V92a12,12,0,0,1,24,0v28a8,8,0,0,0,16,0V108a12,12,0,0,1,24,0Z"></path></svg>`,
|
||||
};
|
||||
@@ -0,0 +1,81 @@
|
||||
import Shepherd from "shepherd.js";
|
||||
import { analytics } from "@/services/analytics";
|
||||
import { TUTORIAL_CONFIG } from "./constants";
|
||||
import { createTutorialSteps } from "./steps";
|
||||
import { injectTutorialStyles, removeTutorialStyles } from "./styles";
|
||||
import {
|
||||
handleTutorialComplete,
|
||||
handleTutorialCancel,
|
||||
prefetchTutorialBlocks,
|
||||
clearPrefetchedBlocks,
|
||||
} from "./helpers";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { useEdgeStore } from "../../../stores/edgeStore";
|
||||
|
||||
let isTutorialLoading = false;
|
||||
let tutorialLoadingCallback: ((loading: boolean) => void) | null = null;
|
||||
|
||||
export const setTutorialLoadingCallback = (
|
||||
callback: (loading: boolean) => void,
|
||||
) => {
|
||||
tutorialLoadingCallback = callback;
|
||||
};
|
||||
|
||||
export const getTutorialLoadingState = () => isTutorialLoading;
|
||||
|
||||
export const startTutorial = async () => {
|
||||
isTutorialLoading = true;
|
||||
tutorialLoadingCallback?.(true);
|
||||
|
||||
useNodeStore.getState().setNodes([]);
|
||||
useEdgeStore.getState().setEdges([]);
|
||||
useNodeStore.getState().setNodeCounter(0);
|
||||
|
||||
try {
|
||||
await prefetchTutorialBlocks();
|
||||
} finally {
|
||||
isTutorialLoading = false;
|
||||
tutorialLoadingCallback?.(false);
|
||||
}
|
||||
|
||||
const tour = new Shepherd.Tour({
|
||||
useModalOverlay: TUTORIAL_CONFIG.USE_MODAL_OVERLAY,
|
||||
defaultStepOptions: {
|
||||
cancelIcon: { enabled: true },
|
||||
scrollTo: {
|
||||
behavior: TUTORIAL_CONFIG.SCROLL_BEHAVIOR,
|
||||
block: TUTORIAL_CONFIG.SCROLL_BLOCK,
|
||||
},
|
||||
classes: "new-builder-tour",
|
||||
modalOverlayOpeningRadius: 4,
|
||||
},
|
||||
});
|
||||
|
||||
injectTutorialStyles();
|
||||
|
||||
const steps = createTutorialSteps(tour);
|
||||
steps.forEach((step) => tour.addStep(step));
|
||||
|
||||
tour.on("complete", () => {
|
||||
handleTutorialComplete();
|
||||
removeTutorialStyles();
|
||||
clearPrefetchedBlocks();
|
||||
});
|
||||
|
||||
tour.on("cancel", () => {
|
||||
handleTutorialCancel(tour);
|
||||
removeTutorialStyles();
|
||||
clearPrefetchedBlocks();
|
||||
});
|
||||
|
||||
for (const step of tour.steps) {
|
||||
step.on("show", () => {
|
||||
console.debug("sendTutorialStep", step.id);
|
||||
analytics.sendGAEvent("event", "tutorial_step_shown", {
|
||||
value: step.id,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
tour.start();
|
||||
};
|
||||
@@ -0,0 +1,114 @@
|
||||
import { StepOptions } from "shepherd.js";
|
||||
import { TUTORIAL_SELECTORS } from "../constants";
|
||||
import {
|
||||
waitForElement,
|
||||
waitForNodeOnCanvas,
|
||||
closeBlockMenu,
|
||||
fitViewToScreen,
|
||||
highlightElement,
|
||||
removeAllHighlights,
|
||||
} from "../helpers";
|
||||
import { ICONS } from "../icons";
|
||||
import { banner } from "../styles";
|
||||
|
||||
export const createBlockBasicsSteps = (tour: any): StepOptions[] => [
|
||||
{
|
||||
id: "focus-new-block",
|
||||
title: "Your First Block!",
|
||||
text: `
|
||||
<div class="text-sm leading-[1.375rem] text-zinc-800">
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0">Excellent! This is your <strong>Calculator Block</strong>.</p>
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0" style="margin-top: 0.5rem;">Let's explore how blocks work.</p>
|
||||
</div>
|
||||
`,
|
||||
attachTo: {
|
||||
element: TUTORIAL_SELECTORS.REACT_FLOW_NODE,
|
||||
on: "right",
|
||||
},
|
||||
beforeShowPromise: async () => {
|
||||
closeBlockMenu();
|
||||
await waitForNodeOnCanvas(5000);
|
||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
||||
fitViewToScreen();
|
||||
},
|
||||
when: {
|
||||
show: () => {
|
||||
const node = document.querySelector(TUTORIAL_SELECTORS.REACT_FLOW_NODE);
|
||||
if (node) {
|
||||
highlightElement(TUTORIAL_SELECTORS.REACT_FLOW_NODE);
|
||||
}
|
||||
},
|
||||
hide: () => {
|
||||
removeAllHighlights();
|
||||
},
|
||||
},
|
||||
buttons: [
|
||||
{
|
||||
text: "Show me",
|
||||
action: () => tour.next(),
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
{
|
||||
id: "input-handles",
|
||||
title: "Input Handles",
|
||||
text: `
|
||||
<div class="text-sm leading-[1.375rem] text-zinc-800">
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0">On the <strong>left side</strong> of the block are <strong>input handles</strong>.</p>
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0" style="margin-top: 0.5rem;">These are where data flows <em>into</em> the block from other blocks.</p>
|
||||
</div>
|
||||
`,
|
||||
attachTo: {
|
||||
element: TUTORIAL_SELECTORS.NODE_INPUT_HANDLE,
|
||||
on: "bottom",
|
||||
},
|
||||
classes: "new-builder-tour input-handles-step",
|
||||
beforeShowPromise: () =>
|
||||
waitForElement(TUTORIAL_SELECTORS.NODE_INPUT_HANDLE, 3000).catch(
|
||||
() => {},
|
||||
),
|
||||
buttons: [
|
||||
{
|
||||
text: "Back",
|
||||
action: () => tour.back(),
|
||||
classes: "shepherd-button-secondary",
|
||||
},
|
||||
{
|
||||
text: "Next",
|
||||
action: () => tour.next(),
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
{
|
||||
id: "output-handles",
|
||||
title: "Output Handles",
|
||||
text: `
|
||||
<div class="text-sm leading-[1.375rem] text-zinc-800">
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0">On the <strong>right side</strong> is the <strong>output handle</strong>.</p>
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0" style="margin-top: 0.5rem;">This is where the result flows <em>out</em> to connect to other blocks.</p>
|
||||
${banner(ICONS.Drag, "You can drag from output to input handler to connect blocks", "info")}
|
||||
</div>
|
||||
`,
|
||||
attachTo: {
|
||||
element: TUTORIAL_SELECTORS.NODE_OUTPUT_HANDLE,
|
||||
on: "right",
|
||||
},
|
||||
beforeShowPromise: () =>
|
||||
waitForElement(TUTORIAL_SELECTORS.NODE_OUTPUT_HANDLE, 3000).catch(
|
||||
() => {},
|
||||
),
|
||||
buttons: [
|
||||
{
|
||||
text: "Back",
|
||||
action: () => tour.back(),
|
||||
classes: "shepherd-button-secondary",
|
||||
},
|
||||
{
|
||||
text: "Next →",
|
||||
action: () => tour.next(),
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
@@ -0,0 +1,198 @@
|
||||
import { StepOptions } from "shepherd.js";
|
||||
import { TUTORIAL_CONFIG, TUTORIAL_SELECTORS, BLOCK_IDS } from "../constants";
|
||||
import {
|
||||
waitForElement,
|
||||
forceBlockMenuOpen,
|
||||
focusElement,
|
||||
highlightElement,
|
||||
removeAllHighlights,
|
||||
disableOtherBlocks,
|
||||
enableAllBlocks,
|
||||
pulseElement,
|
||||
highlightFirstBlockInSearch,
|
||||
} from "../helpers";
|
||||
import { ICONS } from "../icons";
|
||||
import { banner } from "../styles";
|
||||
import { useNodeStore } from "../../../../stores/nodeStore";
|
||||
|
||||
export const createBlockMenuSteps = (tour: any): StepOptions[] => [
|
||||
{
|
||||
id: "open-block-menu",
|
||||
title: "Open the Block Menu",
|
||||
text: `
|
||||
<div class="text-sm leading-[1.375rem] text-zinc-800">
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0">Let's start by opening the Block Menu.</p>
|
||||
${banner(ICONS.ClickIcon, "Click this button to open the menu", "action")}
|
||||
</div>
|
||||
`,
|
||||
attachTo: {
|
||||
element: TUTORIAL_SELECTORS.BLOCKS_TRIGGER,
|
||||
on: "right",
|
||||
},
|
||||
advanceOn: {
|
||||
selector: TUTORIAL_SELECTORS.BLOCKS_TRIGGER,
|
||||
event: "click",
|
||||
},
|
||||
buttons: [],
|
||||
when: {
|
||||
show: () => {
|
||||
highlightElement(TUTORIAL_SELECTORS.BLOCKS_TRIGGER);
|
||||
},
|
||||
hide: () => {
|
||||
removeAllHighlights();
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
id: "block-menu-overview",
|
||||
title: "The Block Menu",
|
||||
text: `
|
||||
<div class="text-sm leading-[1.375rem] text-zinc-800">
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0">This is the <strong>Block Menu</strong> — your toolbox for building agents.</p>
|
||||
<p class="text-sm font-medium leading-[1.375rem] text-zinc-800 m-0" style="margin-top: 0.5rem;">Here you'll find:</p>
|
||||
<ul>
|
||||
<li><strong>Input Blocks</strong> — Entry points for data</li>
|
||||
<li><strong>Action Blocks</strong> — Processing and AI operations</li>
|
||||
<li><strong>Output Blocks</strong> — Results and responses</li>
|
||||
<li><strong>Integrations</strong> — Third-party service blocks</li>
|
||||
<li><strong>Library Agents</strong> — Your personal agents</li>
|
||||
<li><strong>Marketplace Agents</strong> — Community agents</li>
|
||||
</ul>
|
||||
</div>
|
||||
`,
|
||||
attachTo: {
|
||||
element: TUTORIAL_SELECTORS.BLOCKS_CONTENT,
|
||||
on: "left",
|
||||
},
|
||||
beforeShowPromise: () => waitForElement(TUTORIAL_SELECTORS.BLOCKS_CONTENT),
|
||||
when: {
|
||||
show: () => forceBlockMenuOpen(true),
|
||||
},
|
||||
buttons: [
|
||||
{
|
||||
text: "Next",
|
||||
action: () => tour.next(),
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
{
|
||||
id: "search-calculator",
|
||||
title: "Search for a Block",
|
||||
text: `
|
||||
<div class="text-sm leading-[1.375rem] text-zinc-800">
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0">Let's add a Calculator block to start.</p>
|
||||
${banner(ICONS.Keyboard, "Type Calculator in the search bar", "action")}
|
||||
<p class="text-xs font-normal leading-[1.125rem] text-zinc-500 m-0" style="margin-top: 0.5rem;">The search will filter blocks as you type.</p>
|
||||
</div>
|
||||
`,
|
||||
attachTo: {
|
||||
element: TUTORIAL_SELECTORS.BLOCKS_SEARCH_INPUT_BOX,
|
||||
on: "bottom",
|
||||
},
|
||||
beforeShowPromise: () =>
|
||||
waitForElement(TUTORIAL_SELECTORS.BLOCKS_SEARCH_INPUT_BOX),
|
||||
when: {
|
||||
show: () => {
|
||||
forceBlockMenuOpen(true);
|
||||
setTimeout(() => {
|
||||
focusElement(TUTORIAL_SELECTORS.BLOCKS_SEARCH_INPUT_BOX);
|
||||
}, 100);
|
||||
|
||||
const checkForCalculator = setInterval(() => {
|
||||
const calcBlock = document.querySelector(
|
||||
TUTORIAL_SELECTORS.BLOCK_CARD_CALCULATOR_IN_SEARCH,
|
||||
);
|
||||
if (calcBlock) {
|
||||
clearInterval(checkForCalculator);
|
||||
|
||||
const searchInput = document.querySelector(
|
||||
TUTORIAL_SELECTORS.BLOCKS_SEARCH_INPUT,
|
||||
) as HTMLInputElement;
|
||||
if (searchInput) {
|
||||
searchInput.blur();
|
||||
}
|
||||
|
||||
disableOtherBlocks(
|
||||
TUTORIAL_SELECTORS.BLOCK_CARD_CALCULATOR_IN_SEARCH,
|
||||
);
|
||||
pulseElement(TUTORIAL_SELECTORS.BLOCK_CARD_CALCULATOR_IN_SEARCH);
|
||||
calcBlock.scrollIntoView({ behavior: "smooth", block: "center" });
|
||||
setTimeout(() => {
|
||||
tour.next();
|
||||
}, 300);
|
||||
}
|
||||
}, TUTORIAL_CONFIG.ELEMENT_CHECK_INTERVAL);
|
||||
|
||||
(window as any).__tutorialCalcInterval = checkForCalculator;
|
||||
},
|
||||
hide: () => {
|
||||
if ((window as any).__tutorialCalcInterval) {
|
||||
clearInterval((window as any).__tutorialCalcInterval);
|
||||
delete (window as any).__tutorialCalcInterval;
|
||||
}
|
||||
enableAllBlocks();
|
||||
},
|
||||
},
|
||||
buttons: [],
|
||||
},
|
||||
|
||||
{
|
||||
id: "select-calculator",
|
||||
title: "Add the Calculator Block",
|
||||
text: `
|
||||
<div class="text-sm leading-[1.375rem] text-zinc-800">
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0">You should see the <strong>Calculator</strong> block in the results.</p>
|
||||
${banner(ICONS.ClickIcon, "Click on the Calculator block to add it", "action")}
|
||||
|
||||
<div class="bg-zinc-100 ring-1 ring-zinc-200 rounded-2xl p-2 px-4 mt-2 flex items-start gap-2 text-sm font-medium text-zinc-600">
|
||||
<span class="flex-shrink-0">${ICONS.Drag}</span>
|
||||
<span>You can also drag blocks onto the canvas</span>
|
||||
</div>
|
||||
</div>
|
||||
`,
|
||||
attachTo: {
|
||||
element: TUTORIAL_SELECTORS.BLOCK_CARD_CALCULATOR,
|
||||
on: "left",
|
||||
},
|
||||
beforeShowPromise: async () => {
|
||||
forceBlockMenuOpen(true);
|
||||
await waitForElement(TUTORIAL_SELECTORS.BLOCK_CARD_CALCULATOR, 5000);
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
},
|
||||
when: {
|
||||
show: () => {
|
||||
const calcBlock = document.querySelector(
|
||||
TUTORIAL_SELECTORS.BLOCK_CARD_CALCULATOR,
|
||||
);
|
||||
if (calcBlock) {
|
||||
disableOtherBlocks(TUTORIAL_SELECTORS.BLOCK_CARD_CALCULATOR);
|
||||
} else {
|
||||
highlightFirstBlockInSearch();
|
||||
}
|
||||
|
||||
const CALCULATOR_BLOCK_ID = BLOCK_IDS.CALCULATOR;
|
||||
|
||||
const initialNodeCount = useNodeStore.getState().nodes.length;
|
||||
|
||||
const unsubscribe = useNodeStore.subscribe((state) => {
|
||||
if (state.nodes.length > initialNodeCount) {
|
||||
const calculatorNode = state.nodes.find(
|
||||
(node) => node.data?.block_id === CALCULATOR_BLOCK_ID,
|
||||
);
|
||||
|
||||
if (calculatorNode) {
|
||||
unsubscribe();
|
||||
enableAllBlocks();
|
||||
forceBlockMenuOpen(false);
|
||||
tour.next();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(tour.getCurrentStep() as any)._nodeUnsubscribe = unsubscribe;
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
@@ -0,0 +1,51 @@
|
||||
import { StepOptions } from "shepherd.js";
|
||||
|
||||
export const createCompletionSteps = (tour: any): StepOptions[] => [
|
||||
{
|
||||
id: "congratulations",
|
||||
title: "Congratulations! 🎉",
|
||||
text: `
|
||||
<div class="text-sm leading-[1.375rem] text-zinc-800">
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0">You have successfully created and run your first agent flow!</p>
|
||||
|
||||
<div class="mt-3 p-3 bg-green-50 ring-1 ring-green-200 rounded-2xl">
|
||||
<p class="text-sm font-medium text-green-600 m-0">You learned how to:</p>
|
||||
<ul class="text-[0.8125rem] text-green-600 m-0 pl-4 mt-2 space-y-1">
|
||||
<li>• Add blocks from the Block Menu</li>
|
||||
<li>• Understand input and output handles</li>
|
||||
<li>• Configure block values</li>
|
||||
<li>• Connect blocks together</li>
|
||||
<li>• Save and run your agent</li>
|
||||
<li>• View execution status and output</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<p class="text-sm font-medium leading-[1.375rem] text-zinc-800 m-0" style="margin-top: 0.75rem;">Happy building! 🚀</p>
|
||||
</div>
|
||||
`,
|
||||
when: {
|
||||
show: () => {
|
||||
const modal = document.querySelector(
|
||||
".shepherd-modal-overlay-container",
|
||||
);
|
||||
if (modal) {
|
||||
(modal as HTMLElement).style.opacity = "0.3";
|
||||
}
|
||||
},
|
||||
},
|
||||
buttons: [
|
||||
{
|
||||
text: "Restart Tutorial",
|
||||
action: () => {
|
||||
tour.cancel();
|
||||
setTimeout(() => tour.start(), 100);
|
||||
},
|
||||
classes: "shepherd-button-secondary",
|
||||
},
|
||||
{
|
||||
text: "Finish",
|
||||
action: () => tour.complete(),
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
@@ -0,0 +1,197 @@
|
||||
import { StepOptions } from "shepherd.js";
|
||||
import { TUTORIAL_SELECTORS } from "../constants";
|
||||
import {
|
||||
fitViewToScreen,
|
||||
highlightElement,
|
||||
removeAllHighlights,
|
||||
getFirstNode,
|
||||
} from "../helpers";
|
||||
import { ICONS } from "../icons";
|
||||
import { banner } from "../styles";
|
||||
|
||||
const getRequirementsHtml = () => `
|
||||
<div id="requirements-box" class="mt-3 p-3 bg-amber-50 ring-1 ring-amber-200 rounded-2xl">
|
||||
<p id="requirements-title" class="text-sm font-medium text-amber-600 m-0 mb-2">⚠️ Required to continue:</p>
|
||||
<ul id="requirements-list" class="text-[0.8125rem] text-amber-600 m-0 pl-4 space-y-1">
|
||||
<li id="req-a" class="flex items-center gap-2">
|
||||
<span class="req-icon">○</span> Enter a number in field <strong>A</strong> (e.g., 10)
|
||||
</li>
|
||||
<li id="req-b" class="flex items-center gap-2">
|
||||
<span class="req-icon">○</span> Enter a number in field <strong>B</strong> (e.g., 5)
|
||||
</li>
|
||||
<li id="req-op" class="flex items-center gap-2">
|
||||
<span class="req-icon">○</span> Select an <strong>Operation</strong> (Add, Multiply, etc.)
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
`;
|
||||
|
||||
const updateToSuccessState = () => {
|
||||
const reqBox = document.querySelector("#requirements-box");
|
||||
const reqTitle = document.querySelector("#requirements-title");
|
||||
const reqList = document.querySelector("#requirements-list");
|
||||
|
||||
if (reqBox && reqTitle) {
|
||||
reqBox.classList.remove("bg-amber-50", "ring-amber-200");
|
||||
reqBox.classList.add("bg-green-50", "ring-green-200");
|
||||
reqTitle.classList.remove("text-amber-600");
|
||||
reqTitle.classList.add("text-green-600");
|
||||
reqTitle.innerHTML = "🎉 Hurray! All values are completed!";
|
||||
if (reqList) {
|
||||
reqList.classList.add("hidden");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const updateToWarningState = () => {
|
||||
const reqBox = document.querySelector("#requirements-box");
|
||||
const reqTitle = document.querySelector("#requirements-title");
|
||||
const reqList = document.querySelector("#requirements-list");
|
||||
|
||||
if (reqBox && reqTitle) {
|
||||
reqBox.classList.remove("bg-green-50", "ring-green-200");
|
||||
reqBox.classList.add("bg-amber-50", "ring-amber-200");
|
||||
reqTitle.classList.remove("text-green-600");
|
||||
reqTitle.classList.add("text-amber-600");
|
||||
reqTitle.innerHTML = "⚠️ Required to continue:";
|
||||
if (reqList) {
|
||||
reqList.classList.remove("hidden");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const createConfigureCalculatorSteps = (tour: any): StepOptions[] => [
|
||||
{
|
||||
id: "enter-values",
|
||||
title: "Enter Values",
|
||||
text: `
|
||||
<div class="text-sm leading-[1.375rem] text-zinc-800">
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0">Now let's configure the block with actual values.</p>
|
||||
${getRequirementsHtml()}
|
||||
${banner(ICONS.ClickIcon, "Fill in all the required fields above", "action")}
|
||||
</div>
|
||||
`,
|
||||
beforeShowPromise: () => {
|
||||
fitViewToScreen();
|
||||
return Promise.resolve();
|
||||
},
|
||||
attachTo: {
|
||||
element: TUTORIAL_SELECTORS.CALCULATOR_NODE_FORM_CONTAINER,
|
||||
on: "right",
|
||||
},
|
||||
when: {
|
||||
show: () => {
|
||||
const node = getFirstNode();
|
||||
if (node) {
|
||||
highlightElement(`[data-id="custom-node-${node.id}"]`);
|
||||
}
|
||||
|
||||
let wasComplete = false;
|
||||
|
||||
const checkInterval = setInterval(() => {
|
||||
const node = getFirstNode();
|
||||
if (!node) return;
|
||||
|
||||
const hardcodedValues = node.data?.hardcodedValues || {};
|
||||
const hasA =
|
||||
hardcodedValues.a !== undefined &&
|
||||
hardcodedValues.a !== null &&
|
||||
hardcodedValues.a !== "";
|
||||
const hasB =
|
||||
hardcodedValues.b !== undefined &&
|
||||
hardcodedValues.b !== null &&
|
||||
hardcodedValues.b !== "";
|
||||
const hasOp =
|
||||
hardcodedValues.operation !== undefined &&
|
||||
hardcodedValues.operation !== null &&
|
||||
hardcodedValues.operation !== "";
|
||||
|
||||
const allComplete = hasA && hasB && hasOp;
|
||||
|
||||
const reqA = document.querySelector("#req-a .req-icon");
|
||||
const reqB = document.querySelector("#req-b .req-icon");
|
||||
const reqOp = document.querySelector("#req-op .req-icon");
|
||||
|
||||
if (reqA) reqA.textContent = hasA ? "✓" : "○";
|
||||
if (reqB) reqB.textContent = hasB ? "✓" : "○";
|
||||
if (reqOp) reqOp.textContent = hasOp ? "✓" : "○";
|
||||
|
||||
const reqAEl = document.querySelector("#req-a");
|
||||
const reqBEl = document.querySelector("#req-b");
|
||||
const reqOpEl = document.querySelector("#req-op");
|
||||
|
||||
if (reqAEl) {
|
||||
reqAEl.classList.toggle("text-green-600", hasA);
|
||||
reqAEl.classList.toggle("text-amber-600", !hasA);
|
||||
}
|
||||
if (reqBEl) {
|
||||
reqBEl.classList.toggle("text-green-600", hasB);
|
||||
reqBEl.classList.toggle("text-amber-600", !hasB);
|
||||
}
|
||||
if (reqOpEl) {
|
||||
reqOpEl.classList.toggle("text-green-600", hasOp);
|
||||
reqOpEl.classList.toggle("text-amber-600", !hasOp);
|
||||
}
|
||||
|
||||
if (allComplete && !wasComplete) {
|
||||
updateToSuccessState();
|
||||
wasComplete = true;
|
||||
} else if (!allComplete && wasComplete) {
|
||||
updateToWarningState();
|
||||
wasComplete = false;
|
||||
}
|
||||
|
||||
const nextBtn = document.querySelector(
|
||||
".shepherd-button-primary",
|
||||
) as HTMLButtonElement;
|
||||
if (nextBtn) {
|
||||
nextBtn.style.opacity = allComplete ? "1" : "0.5";
|
||||
nextBtn.style.pointerEvents = allComplete ? "auto" : "none";
|
||||
}
|
||||
}, 300);
|
||||
|
||||
(window as any).__tutorialCheckInterval = checkInterval;
|
||||
},
|
||||
hide: () => {
|
||||
removeAllHighlights();
|
||||
if ((window as any).__tutorialCheckInterval) {
|
||||
clearInterval((window as any).__tutorialCheckInterval);
|
||||
delete (window as any).__tutorialCheckInterval;
|
||||
}
|
||||
},
|
||||
},
|
||||
buttons: [
|
||||
{
|
||||
text: "Back",
|
||||
action: () => tour.back(),
|
||||
classes: "shepherd-button-secondary",
|
||||
},
|
||||
{
|
||||
text: "Continue",
|
||||
action: () => {
|
||||
const node = getFirstNode();
|
||||
if (!node) return;
|
||||
|
||||
const hardcodedValues = node.data?.hardcodedValues || {};
|
||||
const hasA =
|
||||
hardcodedValues.a !== undefined &&
|
||||
hardcodedValues.a !== null &&
|
||||
hardcodedValues.a !== "";
|
||||
const hasB =
|
||||
hardcodedValues.b !== undefined &&
|
||||
hardcodedValues.b !== null &&
|
||||
hardcodedValues.b !== "";
|
||||
const hasOp =
|
||||
hardcodedValues.operation !== undefined &&
|
||||
hardcodedValues.operation !== null &&
|
||||
hardcodedValues.operation !== "";
|
||||
|
||||
if (hasA && hasB && hasOp) {
|
||||
tour.next();
|
||||
}
|
||||
},
|
||||
classes: "shepherd-button-primary",
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
@@ -0,0 +1,276 @@
|
||||
import { StepOptions } from "shepherd.js";
|
||||
import {
|
||||
fitViewToScreen,
|
||||
highlightElement,
|
||||
removeAllHighlights,
|
||||
} from "../helpers";
|
||||
import { ICONS } from "../icons";
|
||||
import { banner } from "../styles";
|
||||
import { useEdgeStore } from "../../../../stores/edgeStore";
|
||||
import { TUTORIAL_SELECTORS } from "../constants";
|
||||
|
||||
const getConnectionStatusHtml = (id: string, isConnected: boolean = false) => `
|
||||
<div id="${id}" class="mt-3 p-2 ${isConnected ? "bg-green-50 ring-1 ring-green-200" : "bg-amber-50 ring-1 ring-amber-200"} rounded-2xl text-center text-sm ${isConnected ? "text-green-600" : "text-amber-600"}">
|
||||
${isConnected ? "✅ Connected!" : "Waiting for connection..."}
|
||||
</div>
|
||||
`;
|
||||
|
||||
const updateConnectionStatus = (
|
||||
id: string,
|
||||
isConnected: boolean,
|
||||
message?: string,
|
||||
) => {
|
||||
const statusEl = document.querySelector(`#${id}`);
|
||||
if (statusEl) {
|
||||
statusEl.innerHTML =
|
||||
message || (isConnected ? "✅ Connected!" : "Waiting for connection...");
|
||||
statusEl.classList.remove(
|
||||
"bg-amber-50",
|
||||
"ring-amber-200",
|
||||
"text-amber-600",
|
||||
"bg-green-50",
|
||||
"ring-green-200",
|
||||
"text-green-600",
|
||||
);
|
||||
if (isConnected) {
|
||||
statusEl.classList.add("bg-green-50", "ring-green-200", "text-green-600");
|
||||
} else {
|
||||
statusEl.classList.add("bg-amber-50", "ring-amber-200", "text-amber-600");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const hasAnyEdge = (): boolean => {
|
||||
return useEdgeStore.getState().edges.length > 0;
|
||||
};
|
||||
|
||||
export const createConnectionSteps = (tour: any): StepOptions[] => {
|
||||
let isConnecting = false;
|
||||
|
||||
const handleMouseDown = () => {
|
||||
isConnecting = true;
|
||||
|
||||
const inputSelector =
|
||||
TUTORIAL_SELECTORS.FIRST_CALCULATOR_RESULT_OUTPUT_HANDLER;
|
||||
if (inputSelector) {
|
||||
highlightElement(inputSelector);
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
if (isConnecting) {
|
||||
tour.next();
|
||||
}
|
||||
}, 100);
|
||||
};
|
||||
|
||||
const resetConnectionState = () => {
|
||||
isConnecting = false;
|
||||
};
|
||||
|
||||
return [
|
||||
{
|
||||
id: "connect-blocks-output",
|
||||
title: "Connect the Blocks: Output",
|
||||
text: `
|
||||
<div class="text-sm leading-[1.375rem] text-zinc-800">
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0">Now, let's connect the <strong>Result output</strong> of the first Calculator to the <strong>input (A)</strong> of the second Calculator.</p>
|
||||
|
||||
<div class="mt-3 p-3 bg-blue-50 ring-1 ring-blue-200 rounded-2xl">
|
||||
<p class="text-sm font-medium text-blue-600 m-0 mb-2">Drag from the Result output:</p>
|
||||
<p class="text-[0.8125rem] text-blue-600 m-0">Click and drag from the <strong>Result</strong> output pin (right side) of the <strong>first Calculator block</strong>.</p>
|
||||
</div>
|
||||
${getConnectionStatusHtml("connection-status-output", false)}
|
||||
${banner(ICONS.Drag, "Drag from the Result output pin", "action")}
|
||||
</div>
|
||||
`,
|
||||
attachTo: {
|
||||
element: TUTORIAL_SELECTORS.FIRST_CALCULATOR_RESULT_OUTPUT_HANDLER,
|
||||
on: "left",
|
||||
},
|
||||
|
||||
when: {
|
||||
show: () => {
|
||||
resetConnectionState();
|
||||
|
||||
if (hasAnyEdge()) {
|
||||
updateConnectionStatus(
|
||||
"connection-status-output",
|
||||
true,
|
||||
"✅ Connection already exists!",
|
||||
);
|
||||
setTimeout(() => {
|
||||
tour.next();
|
||||
}, 1000);
|
||||
return;
|
||||
}
|
||||
|
||||
const outputSelector =
|
||||
TUTORIAL_SELECTORS.FIRST_CALCULATOR_RESULT_OUTPUT_HANDLER;
|
||||
if (outputSelector) {
|
||||
const outputHandle = document.querySelector(outputSelector);
|
||||
if (outputHandle) {
|
||||
highlightElement(outputSelector);
|
||||
outputHandle.addEventListener("mousedown", handleMouseDown);
|
||||
}
|
||||
}
|
||||
|
||||
const unsubscribe = useEdgeStore.subscribe(() => {
|
||||
if (hasAnyEdge()) {
|
||||
updateConnectionStatus("connection-status-output", true);
|
||||
setTimeout(() => {
|
||||
unsubscribe();
|
||||
tour.next();
|
||||
}, 500);
|
||||
}
|
||||
});
|
||||
|
||||
(tour.getCurrentStep() as any)._edgeUnsubscribe = unsubscribe;
|
||||
},
|
||||
hide: () => {
|
||||
removeAllHighlights();
|
||||
const step = tour.getCurrentStep() as any;
|
||||
if (step?._edgeUnsubscribe) {
|
||||
step._edgeUnsubscribe();
|
||||
}
|
||||
const outputSelector =
|
||||
TUTORIAL_SELECTORS.FIRST_CALCULATOR_RESULT_OUTPUT_HANDLER;
|
||||
if (outputSelector) {
|
||||
const outputHandle = document.querySelector(outputSelector);
|
||||
if (outputHandle) {
|
||||
outputHandle.removeEventListener("mousedown", handleMouseDown);
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
buttons: [
|
||||
{
|
||||
text: "Back",
|
||||
action: () => tour.back(),
|
||||
classes: "shepherd-button-secondary",
|
||||
},
|
||||
{
|
||||
text: "Skip (already connected)",
|
||||
action: () => tour.show("connection-complete"),
|
||||
classes: "shepherd-button-secondary",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
{
|
||||
id: "connect-blocks-input",
|
||||
title: "Connect the Blocks: Input",
|
||||
text: `
|
||||
<div class="text-sm leading-[1.375rem] text-zinc-800">
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0">Now, connect to the <strong>input (A)</strong> of the second Calculator block.</p>
|
||||
|
||||
<div class="mt-3 p-3 bg-blue-50 ring-1 ring-blue-200 rounded-2xl">
|
||||
<p class="text-sm font-medium text-blue-600 m-0 mb-2">Drop on the A input:</p>
|
||||
<p class="text-[0.8125rem] text-blue-600 m-0">Drag to the <strong>A</strong> input handle (left side) of the <strong>second Calculator block</strong>.</p>
|
||||
</div>
|
||||
${getConnectionStatusHtml("connection-status-input", false)}
|
||||
</div>
|
||||
`,
|
||||
attachTo: {
|
||||
element: TUTORIAL_SELECTORS.SECOND_CALCULATOR_NUMBER_A_INPUT_HANDLER,
|
||||
on: "right",
|
||||
},
|
||||
when: {
|
||||
show: () => {
|
||||
const inputSelector =
|
||||
TUTORIAL_SELECTORS.SECOND_CALCULATOR_NUMBER_A_INPUT_HANDLER;
|
||||
if (inputSelector) {
|
||||
highlightElement(inputSelector);
|
||||
}
|
||||
|
||||
if (hasAnyEdge()) {
|
||||
updateConnectionStatus(
|
||||
"connection-status-input",
|
||||
true,
|
||||
"✅ Connected!",
|
||||
);
|
||||
setTimeout(() => {
|
||||
tour.next();
|
||||
}, 500);
|
||||
return;
|
||||
}
|
||||
|
||||
const unsubscribe = useEdgeStore.subscribe(() => {
|
||||
if (hasAnyEdge()) {
|
||||
updateConnectionStatus("connection-status-input", true);
|
||||
setTimeout(() => {
|
||||
unsubscribe();
|
||||
tour.next();
|
||||
}, 500);
|
||||
}
|
||||
});
|
||||
|
||||
(tour.getCurrentStep() as any)._edgeUnsubscribe = unsubscribe;
|
||||
|
||||
const handleMouseUp = () => {
|
||||
setTimeout(() => {
|
||||
if (!hasAnyEdge()) {
|
||||
isConnecting = false;
|
||||
tour.show("connect-blocks-output");
|
||||
}
|
||||
}, 200);
|
||||
};
|
||||
document.addEventListener("mouseup", handleMouseUp, true);
|
||||
(tour.getCurrentStep() as any)._mouseUpHandler = handleMouseUp;
|
||||
},
|
||||
hide: () => {
|
||||
removeAllHighlights();
|
||||
const step = tour.getCurrentStep() as any;
|
||||
if (step?._edgeUnsubscribe) {
|
||||
step._edgeUnsubscribe();
|
||||
}
|
||||
if (step?._mouseUpHandler) {
|
||||
document.removeEventListener("mouseup", step._mouseUpHandler, true);
|
||||
}
|
||||
},
|
||||
},
|
||||
buttons: [
|
||||
{
|
||||
text: "Back",
|
||||
action: () => tour.show("connect-blocks-output"),
|
||||
classes: "shepherd-button-secondary",
|
||||
},
|
||||
{
|
||||
text: "Skip (already connected)",
|
||||
action: () => tour.next(),
|
||||
classes: "shepherd-button-secondary",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
{
|
||||
id: "connection-complete",
|
||||
title: "Blocks Connected! 🎉",
|
||||
text: `
|
||||
<div class="text-sm leading-[1.375rem] text-zinc-800">
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0">Excellent! Your Calculator blocks are now connected:</p>
|
||||
|
||||
<div class="mt-3 p-3 bg-green-50 ring-1 ring-green-200 rounded-2xl">
|
||||
<div class="flex items-center justify-center gap-2 text-sm font-medium text-green-600">
|
||||
<span>Calculator 1</span>
|
||||
<span>→</span>
|
||||
<span>Calculator 2</span>
|
||||
</div>
|
||||
<p class="text-[0.75rem] text-green-500 m-0 mt-2 text-center italic">The result of Calculator 1 flows into Calculator 2's input A</p>
|
||||
</div>
|
||||
|
||||
<p class="text-sm font-normal leading-[1.375rem] text-zinc-800 m-0" style="margin-top: 0.75rem;">Now let's save and run your agent!</p>
|
||||
</div>
|
||||
`,
|
||||
beforeShowPromise: async () => {
|
||||
fitViewToScreen();
|
||||
return Promise.resolve();
|
||||
},
|
||||
buttons: [
|
||||
{
|
||||
text: "Save My Agent",
|
||||
action: () => tour.next(),
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
};
|
||||
@@ -0,0 +1,22 @@
|
||||
import { StepOptions } from "shepherd.js";
|
||||
import { createWelcomeSteps } from "./welcome";
|
||||
import { createBlockMenuSteps } from "./block-menu";
|
||||
import { createBlockBasicsSteps } from "./block-basics";
|
||||
import { createConfigureCalculatorSteps } from "./configure-calculator";
|
||||
import { createSecondCalculatorSteps } from "./second-calculator";
|
||||
import { createConnectionSteps } from "./connections";
|
||||
import { createSaveSteps } from "./save";
|
||||
import { createRunSteps } from "./run";
|
||||
import { createCompletionSteps } from "./completion";
|
||||
|
||||
export const createTutorialSteps = (tour: any): StepOptions[] => [
|
||||
...createWelcomeSteps(tour),
|
||||
...createBlockMenuSteps(tour),
|
||||
...createBlockBasicsSteps(tour),
|
||||
...createConfigureCalculatorSteps(tour),
|
||||
...createSecondCalculatorSteps(tour),
|
||||
...createConnectionSteps(tour),
|
||||
...createSaveSteps(),
|
||||
...createRunSteps(tour),
|
||||
...createCompletionSteps(tour),
|
||||
];
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user