mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 00:28:31 -05:00
Compare commits
54 Commits
hackathon-
...
hackathon/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9560aa8b41 | ||
|
|
5f0a39bbf0 | ||
|
|
9cfe70e554 | ||
|
|
e69f14353e | ||
|
|
1b96d990c5 | ||
|
|
6db59a2665 | ||
|
|
1fd4ec079f | ||
|
|
fb87d6536f | ||
|
|
7ee03fb0ec | ||
|
|
9a83af2787 | ||
|
|
dc1099e205 | ||
|
|
e68a9eb771 | ||
|
|
858a8a818b | ||
|
|
9e1354bfee | ||
|
|
ba003a5e18 | ||
|
|
7a57531063 | ||
|
|
639a1ab0ed | ||
|
|
9abad07bbc | ||
|
|
eeeeb5fe5f | ||
|
|
a163457bc0 | ||
|
|
a4e38be3e3 | ||
|
|
d71c39d24f | ||
|
|
fa7f17334d | ||
|
|
87728ee085 | ||
|
|
9932b05bc7 | ||
|
|
7835bdd39e | ||
|
|
806e3b63d5 | ||
|
|
0cc9ec5546 | ||
|
|
e5fc9e8573 | ||
|
|
d29ae4105f | ||
|
|
2731fd91c8 | ||
|
|
25bc22cc01 | ||
|
|
a3be6d8170 | ||
|
|
fd4f405008 | ||
|
|
1b352c479f | ||
|
|
7d17e6c470 | ||
|
|
0b576d4d48 | ||
|
|
d4f76f9835 | ||
|
|
290fe5d278 | ||
|
|
1a0dd4770b | ||
|
|
e1c0c9397d | ||
|
|
06ce6fa9a1 | ||
|
|
a8c68b585a | ||
|
|
22298c24fd | ||
|
|
5f45a33786 | ||
|
|
d9d6a66608 | ||
|
|
3d8a967395 | ||
|
|
17cef05b8b | ||
|
|
917802aca8 | ||
|
|
e2b2d5f402 | ||
|
|
d726db6488 | ||
|
|
253f2780c3 | ||
|
|
cc2a366c6a | ||
|
|
ad33659ef8 |
@@ -16,7 +16,6 @@
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
|
||||
2
.github/workflows/claude-dependabot.yml
vendored
2
.github/workflows/claude-dependabot.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
12
.github/workflows/copilot-setup-steps.yml
vendored
12
.github/workflows/copilot-setup-steps.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
@@ -108,16 +108,6 @@ jobs:
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
# Remove large unused tools to free disk space for Docker builds
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -134,7 +134,7 @@ jobs:
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
|
||||
9
autogpt_platform/.claude/settings.local.json
Normal file
9
autogpt_platform/.claude/settings.local.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(ls:*)",
|
||||
"WebFetch(domain:langfuse.com)",
|
||||
"Bash(poetry install:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend load-store-agents
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend stop-backend run-frontend load-store-agents backfill-store-embeddings
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
start-core:
|
||||
@@ -6,14 +6,12 @@ start-core:
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop
|
||||
docker compose stop deps
|
||||
|
||||
reset-db:
|
||||
docker compose stop db
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
@@ -35,9 +33,15 @@ init-env:
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
run-backend:
|
||||
stop-backend:
|
||||
@echo "Stopping backend processes..."
|
||||
@cd backend && poetry run cli stop 2>/dev/null || true
|
||||
@echo "Killing any processes using backend ports..."
|
||||
@lsof -ti:8001,8002,8003,8004,8005,8006,8007 | xargs kill -9 2>/dev/null || true
|
||||
@echo "Backend stopped"
|
||||
|
||||
run-backend: stop-backend
|
||||
cd backend && poetry run app
|
||||
|
||||
run-frontend:
|
||||
@@ -49,6 +53,9 @@ test-data:
|
||||
load-store-agents:
|
||||
cd backend && poetry run load-store-agents
|
||||
|
||||
backfill-store-embeddings:
|
||||
cd backend && poetry run python -m backend.api.features.store.backfill_embeddings
|
||||
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@@ -58,7 +65,9 @@ help:
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
|
||||
@echo " migrate - Run backend database migrations"
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " stop-backend - Stop any running backend processes"
|
||||
@echo " run-backend - Run the backend FastAPI server (stops existing processes first)"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@echo " backfill-store-embeddings - Generate embeddings for store agents that don't have them"
|
||||
@@ -58,6 +58,13 @@ V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Langfuse Prompt Management
|
||||
# Used for managing the CoPilot system prompt externally
|
||||
# Get credentials from https://cloud.langfuse.com or your self-hosted instance
|
||||
LANGFUSE_PUBLIC_KEY=
|
||||
LANGFUSE_SECRET_KEY=
|
||||
LANGFUSE_HOST=https://cloud.langfuse.com
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
@@ -48,8 +48,7 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import prisma.enums
|
||||
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.embeddings as store_embeddings
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.util.json
|
||||
|
||||
@@ -150,3 +151,54 @@ async def admin_download_agent_file(
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/embeddings/stats",
|
||||
summary="Get Embedding Statistics",
|
||||
)
|
||||
async def get_embedding_stats() -> dict[str, typing.Any]:
|
||||
"""
|
||||
Get statistics about embedding coverage for store listings.
|
||||
|
||||
Returns counts of total approved listings, listings with embeddings,
|
||||
listings without embeddings, and coverage percentage.
|
||||
"""
|
||||
try:
|
||||
stats = await store_embeddings.get_embedding_stats()
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.exception("Error getting embedding stats: %s", e)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="An error occurred while retrieving embedding stats",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/embeddings/backfill",
|
||||
summary="Backfill Missing Embeddings",
|
||||
)
|
||||
async def backfill_embeddings(
|
||||
batch_size: int = 10,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""
|
||||
Trigger backfill of embeddings for approved listings that don't have them.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate in one call (default 10)
|
||||
|
||||
Returns:
|
||||
Dict with processed count, success count, failure count, and message
|
||||
"""
|
||||
try:
|
||||
result = await store_embeddings.backfill_missing_embeddings(
|
||||
batch_size=batch_size
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception("Error backfilling embeddings: %s", e)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="An error occurred while backfilling embeddings",
|
||||
)
|
||||
|
||||
@@ -45,6 +45,13 @@ class ChatConfig(BaseSettings):
|
||||
default=3, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||
langfuse_prompt_name: str = Field(
|
||||
default="CoPilot Prompt",
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
|
||||
@@ -2,15 +2,11 @@
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
)
|
||||
from prisma.types import ChatSessionUpdateInput
|
||||
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
@@ -34,13 +30,13 @@ async def create_chat_session(
|
||||
user_id: str | None,
|
||||
) -> PrismaChatSession:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
id=session_id,
|
||||
userId=user_id,
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
)
|
||||
data = {
|
||||
"id": session_id,
|
||||
"userId": user_id,
|
||||
"credentials": SafeJson({}),
|
||||
"successfulAgentRuns": SafeJson({}),
|
||||
"successfulAgentSchedules": SafeJson({}),
|
||||
}
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
@@ -94,16 +90,12 @@ async def add_chat_message(
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
# Build the input dict dynamically - only include optional fields when they
|
||||
# have values, as Prisma TypedDict validation fails when optional fields
|
||||
# are explicitly set to None
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
if name is not None:
|
||||
@@ -112,8 +104,6 @@ async def add_chat_message(
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = refusal
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if tool_calls is not None:
|
||||
data["toolCalls"] = SafeJson(tool_calls)
|
||||
if function_call is not None:
|
||||
@@ -125,9 +115,7 @@ async def add_chat_message(
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return await PrismaChatMessage.prisma().create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
return await PrismaChatMessage.prisma().create(data=data)
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
@@ -141,16 +129,12 @@ async def add_chat_messages_batch(
|
||||
|
||||
created_messages = []
|
||||
for i, msg in enumerate(messages):
|
||||
# Build the input dict dynamically - only include optional JSON fields
|
||||
# when they have values, as Prisma TypedDict validation fails when
|
||||
# optional fields are explicitly set to None
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
@@ -159,16 +143,12 @@ async def add_chat_messages_batch(
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma().create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created = await PrismaChatMessage.prisma().create(data=data)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp
|
||||
|
||||
@@ -78,7 +78,7 @@ async def test_chatsession_db_storage():
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=None)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
|
||||
# Upsert to save to both cache and DB
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from langfuse import Langfuse
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
@@ -12,12 +13,20 @@ from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
from .model import ChatMessage, ChatSession, Usage
|
||||
from .model import create_chat_session as model_create_chat_session
|
||||
from .model import get_chat_session, upsert_chat_session
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from .model import (
|
||||
create_chat_session as model_create_chat_session,
|
||||
)
|
||||
from .response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamEnd,
|
||||
@@ -34,8 +43,53 @@ from .tools import execute_tool, tools
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
settings = Settings()
|
||||
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
# Langfuse client (lazy initialization)
|
||||
_langfuse_client: Langfuse | None = None
|
||||
|
||||
|
||||
def _get_langfuse_client() -> Langfuse:
|
||||
"""Get or create the Langfuse client for prompt management."""
|
||||
global _langfuse_client
|
||||
if _langfuse_client is None:
|
||||
if not settings.secrets.langfuse_public_key or not settings.secrets.langfuse_secret_key:
|
||||
raise ValueError(
|
||||
"Langfuse credentials not configured. "
|
||||
"Set LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables."
|
||||
)
|
||||
_langfuse_client = Langfuse(
|
||||
public_key=settings.secrets.langfuse_public_key,
|
||||
secret_key=settings.secrets.langfuse_secret_key,
|
||||
host=settings.secrets.langfuse_host or "https://cloud.langfuse.com",
|
||||
)
|
||||
return _langfuse_client
|
||||
|
||||
|
||||
def _get_langfuse_prompt() -> str:
|
||||
"""Fetch the latest production prompt from Langfuse.
|
||||
|
||||
Returns:
|
||||
The compiled prompt text from Langfuse.
|
||||
|
||||
Raises:
|
||||
Exception: If Langfuse is unavailable or prompt fetch fails.
|
||||
"""
|
||||
try:
|
||||
langfuse = _get_langfuse_client()
|
||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||
prompt = langfuse.get_prompt(config.langfuse_prompt_name, cache_ttl_seconds=0)
|
||||
compiled = prompt.compile()
|
||||
logger.info(
|
||||
f"Fetched prompt '{config.langfuse_prompt_name}' from Langfuse "
|
||||
f"(version: {prompt.version})"
|
||||
)
|
||||
return compiled
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch prompt from Langfuse: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _is_first_session(user_id: str) -> bool:
|
||||
"""Check if this is the user's first chat session.
|
||||
@@ -71,7 +125,12 @@ async def _build_system_prompt(
|
||||
effective_prompt_type = "onboarding"
|
||||
|
||||
# Start with the base system prompt for the specified type
|
||||
base_prompt = config.get_system_prompt_for_type(effective_prompt_type)
|
||||
if effective_prompt_type == "default":
|
||||
# Fetch from Langfuse for the default prompt
|
||||
base_prompt = _get_langfuse_prompt()
|
||||
else:
|
||||
# Use local file for other prompt types (e.g., onboarding)
|
||||
base_prompt = config.get_system_prompt_for_type(effective_prompt_type)
|
||||
|
||||
# If user is authenticated, try to fetch their business understanding
|
||||
if user_id:
|
||||
|
||||
@@ -7,26 +7,41 @@ from backend.api.features.chat.model import ChatSession
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .search_docs import SearchDocsTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
|
||||
# Initialize tool instances
|
||||
add_understanding_tool = AddUnderstandingTool()
|
||||
create_agent_tool = CreateAgentTool()
|
||||
edit_agent_tool = EditAgentTool()
|
||||
find_agent_tool = FindAgentTool()
|
||||
find_block_tool = FindBlockTool()
|
||||
find_library_agent_tool = FindLibraryAgentTool()
|
||||
run_agent_tool = RunAgentTool()
|
||||
run_block_tool = RunBlockTool()
|
||||
search_docs_tool = SearchDocsTool()
|
||||
agent_output_tool = AgentOutputTool()
|
||||
|
||||
# Export tools as OpenAI format
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
add_understanding_tool.as_openai_tool(),
|
||||
create_agent_tool.as_openai_tool(),
|
||||
edit_agent_tool.as_openai_tool(),
|
||||
find_agent_tool.as_openai_tool(),
|
||||
find_block_tool.as_openai_tool(),
|
||||
find_library_agent_tool.as_openai_tool(),
|
||||
run_agent_tool.as_openai_tool(),
|
||||
run_block_tool.as_openai_tool(),
|
||||
search_docs_tool.as_openai_tool(),
|
||||
agent_output_tool.as_openai_tool(),
|
||||
]
|
||||
|
||||
@@ -41,9 +56,14 @@ async def execute_tool(
|
||||
|
||||
tool_map: dict[str, BaseTool] = {
|
||||
"add_understanding": add_understanding_tool,
|
||||
"create_agent": create_agent_tool,
|
||||
"edit_agent": edit_agent_tool,
|
||||
"find_agent": find_agent_tool,
|
||||
"find_block": find_block_tool,
|
||||
"find_library_agent": find_library_agent_tool,
|
||||
"run_agent": run_agent_tool,
|
||||
"run_block": run_block_tool,
|
||||
"search_platform_docs": search_docs_tool,
|
||||
"agent_output": agent_output_tool,
|
||||
}
|
||||
if tool_name not in tool_map:
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import UTC, datetime
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
@@ -50,13 +49,13 @@ async def setup_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Create a test graph with agent input -> agent output
|
||||
@@ -173,13 +172,13 @@ async def setup_llm_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for LLM tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for LLM tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Create test OpenAI credentials for the user
|
||||
@@ -333,13 +332,13 @@ async def setup_firecrawl_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for Firecrawl tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for Firecrawl tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
)
|
||||
|
||||
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
||||
|
||||
@@ -10,7 +10,11 @@ from backend.data.understanding import (
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||
from .models import (
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Agent generator package - Creates agents from natural language."""
|
||||
|
||||
from .core import (
|
||||
apply_agent_patch,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .fixer import apply_all_fixes
|
||||
from .utils import get_blocks_info
|
||||
from .validator import validate_agent
|
||||
|
||||
__all__ = [
|
||||
# Core functions
|
||||
"decompose_goal",
|
||||
"generate_agent",
|
||||
"generate_agent_patch",
|
||||
"apply_agent_patch",
|
||||
"save_agent_to_library",
|
||||
"get_agent_as_json",
|
||||
# Fixer
|
||||
"apply_all_fixes",
|
||||
# Validator
|
||||
"validate_agent",
|
||||
# Utils
|
||||
"get_blocks_info",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
"""OpenRouter client configuration for agent generation."""
|
||||
|
||||
import os
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
|
||||
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY") or os.getenv("OPENROUTER_API_KEY")
|
||||
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
|
||||
|
||||
# OpenRouter client (OpenAI-compatible API)
|
||||
_client: AsyncOpenAI | None = None
|
||||
|
||||
|
||||
def get_client() -> AsyncOpenAI:
|
||||
"""Get or create the OpenRouter client."""
|
||||
global _client
|
||||
if _client is None:
|
||||
if not OPENROUTER_API_KEY:
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
||||
_client = AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=OPENROUTER_API_KEY,
|
||||
)
|
||||
return _client
|
||||
@@ -0,0 +1,390 @@
|
||||
"""Core agent generation functions."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
|
||||
from .client import AGENT_GENERATOR_MODEL, get_client
|
||||
from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT
|
||||
from .utils import get_block_summaries, parse_json_from_llm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
Or None on error
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries())
|
||||
|
||||
full_description = description
|
||||
if context:
|
||||
full_description = f"{description}\n\nAdditional context:\n{context}"
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": full_description},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for decomposition")
|
||||
return None
|
||||
|
||||
result = parse_json_from_llm(content)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"Failed to parse decomposition response: {content[:200]}")
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error decomposing goal: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or None on error
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries())
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": json.dumps(instructions, indent=2)},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for agent generation")
|
||||
return None
|
||||
|
||||
result = parse_json_from_llm(content)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"Failed to parse agent JSON: {content[:200]}")
|
||||
return None
|
||||
|
||||
# Ensure required fields
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
if "version" not in result:
|
||||
result["version"] = 1
|
||||
if "is_active" not in result:
|
||||
result["is_active"] = True
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating agent: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
"""Convert agent JSON dict to Graph model.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON with nodes and links
|
||||
|
||||
Returns:
|
||||
Graph ready for saving
|
||||
"""
|
||||
nodes = []
|
||||
for n in agent_json.get("nodes", []):
|
||||
node = Node(
|
||||
id=n.get("id", str(uuid.uuid4())),
|
||||
block_id=n["block_id"],
|
||||
input_default=n.get("input_default", {}),
|
||||
metadata=n.get("metadata", {}),
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
links = []
|
||||
for link_data in agent_json.get("links", []):
|
||||
link = Link(
|
||||
id=link_data.get("id", str(uuid.uuid4())),
|
||||
source_id=link_data["source_id"],
|
||||
sink_id=link_data["sink_id"],
|
||||
source_name=link_data["source_name"],
|
||||
sink_name=link_data["sink_name"],
|
||||
is_static=link_data.get("is_static", False),
|
||||
)
|
||||
links.append(link)
|
||||
|
||||
return Graph(
|
||||
id=agent_json.get("id", str(uuid.uuid4())),
|
||||
version=agent_json.get("version", 1),
|
||||
is_active=agent_json.get("is_active", True),
|
||||
name=agent_json.get("name", "Generated Agent"),
|
||||
description=agent_json.get("description", ""),
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
|
||||
def _reassign_node_ids(graph: Graph) -> None:
|
||||
"""Reassign all node and link IDs to new UUIDs.
|
||||
|
||||
This is needed when creating a new version to avoid unique constraint violations.
|
||||
"""
|
||||
# Create mapping from old node IDs to new UUIDs
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
|
||||
# Reassign node IDs
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
# Update link references to use new node IDs
|
||||
for link in graph.links:
|
||||
link.id = str(uuid.uuid4()) # Also give links new IDs
|
||||
if link.source_id in id_map:
|
||||
link.source_id = id_map[link.source_id]
|
||||
if link.sink_id in id_map:
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
) -> tuple[Graph, Any]:
|
||||
"""Save agent to database and user's library.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON dict
|
||||
user_id: User ID
|
||||
is_update: Whether this is an update to an existing agent
|
||||
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
from backend.data.graph import get_graph_all_versions
|
||||
|
||||
graph = json_to_graph(agent_json)
|
||||
|
||||
if is_update:
|
||||
# For updates, keep the same graph ID but increment version
|
||||
# and reassign node/link IDs to avoid conflicts
|
||||
if graph.id:
|
||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||
if existing_versions:
|
||||
latest_version = max(v.version for v in existing_versions)
|
||||
graph.version = latest_version + 1
|
||||
# Reassign node IDs (but keep graph ID the same)
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||
else:
|
||||
# For new agents, always generate a fresh UUID to avoid collisions
|
||||
graph.id = str(uuid.uuid4())
|
||||
graph.version = 1
|
||||
# Reassign all node IDs as well
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Creating new agent with ID {graph.id}")
|
||||
|
||||
# Save to database
|
||||
created_graph = await create_graph(graph, user_id)
|
||||
|
||||
# Add to user's library (or update existing library agent)
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=created_graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
)
|
||||
|
||||
return created_graph, library_agents[0]
|
||||
|
||||
|
||||
async def get_agent_as_json(
|
||||
graph_id: str, user_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch an agent and convert to JSON format for editing.
|
||||
|
||||
Args:
|
||||
graph_id: Graph ID or library agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
# Try to get the graph (version=None gets the active version)
|
||||
graph = await get_graph(graph_id, version=None, user_id=user_id)
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
# Convert to JSON format
|
||||
nodes = []
|
||||
for node in graph.nodes:
|
||||
nodes.append(
|
||||
{
|
||||
"id": node.id,
|
||||
"block_id": node.block_id,
|
||||
"input_default": node.input_default,
|
||||
"metadata": node.metadata,
|
||||
}
|
||||
)
|
||||
|
||||
links = []
|
||||
for node in graph.nodes:
|
||||
for link in node.output_links:
|
||||
links.append(
|
||||
{
|
||||
"id": link.id,
|
||||
"source_id": link.source_id,
|
||||
"sink_id": link.sink_id,
|
||||
"source_name": link.source_name,
|
||||
"sink_name": link.sink_name,
|
||||
"is_static": link.is_static,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"id": graph.id,
|
||||
"name": graph.name,
|
||||
"description": graph.description,
|
||||
"version": graph.version,
|
||||
"is_active": graph.is_active,
|
||||
"nodes": nodes,
|
||||
"links": links,
|
||||
}
|
||||
|
||||
|
||||
async def generate_agent_patch(
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate a patch to update an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
|
||||
Returns:
|
||||
Patch dict or clarifying questions, or None on error
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = PATCH_PROMPT.format(
|
||||
current_agent=json.dumps(current_agent, indent=2),
|
||||
block_summaries=get_block_summaries(),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": update_request},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for patch generation")
|
||||
return None
|
||||
|
||||
return parse_json_from_llm(content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating patch: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def apply_agent_patch(
|
||||
current_agent: dict[str, Any], patch: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Apply a patch to an existing agent.
|
||||
|
||||
Args:
|
||||
current_agent: Current agent JSON
|
||||
patch: Patch dict with operations
|
||||
|
||||
Returns:
|
||||
Updated agent JSON
|
||||
"""
|
||||
agent = copy.deepcopy(current_agent)
|
||||
patches = patch.get("patches", [])
|
||||
|
||||
for p in patches:
|
||||
patch_type = p.get("type")
|
||||
|
||||
if patch_type == "modify":
|
||||
node_id = p.get("node_id")
|
||||
changes = p.get("changes", {})
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
_deep_update(node, changes)
|
||||
logger.debug(f"Modified node {node_id}")
|
||||
break
|
||||
|
||||
elif patch_type == "add":
|
||||
new_nodes = p.get("new_nodes", [])
|
||||
new_links = p.get("new_links", [])
|
||||
|
||||
agent["nodes"] = agent.get("nodes", []) + new_nodes
|
||||
agent["links"] = agent.get("links", []) + new_links
|
||||
logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links")
|
||||
|
||||
elif patch_type == "remove":
|
||||
node_ids_to_remove = set(p.get("node_ids", []))
|
||||
link_ids_to_remove = set(p.get("link_ids", []))
|
||||
|
||||
# Remove nodes
|
||||
agent["nodes"] = [
|
||||
n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove
|
||||
]
|
||||
|
||||
# Remove links (both explicit and those referencing removed nodes)
|
||||
agent["links"] = [
|
||||
link
|
||||
for link in agent.get("links", [])
|
||||
if link["id"] not in link_ids_to_remove
|
||||
and link["source_id"] not in node_ids_to_remove
|
||||
and link["sink_id"] not in node_ids_to_remove
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def _deep_update(target: dict, source: dict) -> None:
|
||||
"""Recursively update a dict with another dict."""
|
||||
for key, value in source.items():
|
||||
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
|
||||
_deep_update(target[key], value)
|
||||
else:
|
||||
target[key] = value
|
||||
@@ -0,0 +1,606 @@
|
||||
"""Agent fixer - Fixes common LLM generation errors."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from .utils import (
|
||||
ADDTODICTIONARY_BLOCK_ID,
|
||||
ADDTOLIST_BLOCK_ID,
|
||||
CODE_EXECUTION_BLOCK_ID,
|
||||
CONDITION_BLOCK_ID,
|
||||
CREATEDICT_BLOCK_ID,
|
||||
CREATELIST_BLOCK_ID,
|
||||
DATA_SAMPLING_BLOCK_ID,
|
||||
DOUBLE_CURLY_BRACES_BLOCK_IDS,
|
||||
GET_CURRENT_DATE_BLOCK_ID,
|
||||
STORE_VALUE_BLOCK_ID,
|
||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
get_blocks_info,
|
||||
is_valid_uuid,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix invalid UUIDs in agent and link IDs."""
|
||||
# Fix agent ID
|
||||
if not is_valid_uuid(agent.get("id", "")):
|
||||
agent["id"] = str(uuid.uuid4())
|
||||
logger.debug(f"Fixed agent ID: {agent['id']}")
|
||||
|
||||
# Fix node IDs
|
||||
id_mapping = {} # Old ID -> New ID
|
||||
for node in agent.get("nodes", []):
|
||||
if not is_valid_uuid(node.get("id", "")):
|
||||
old_id = node.get("id", "")
|
||||
new_id = str(uuid.uuid4())
|
||||
id_mapping[old_id] = new_id
|
||||
node["id"] = new_id
|
||||
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
|
||||
|
||||
# Fix link IDs and update references
|
||||
for link in agent.get("links", []):
|
||||
if not is_valid_uuid(link.get("id", "")):
|
||||
link["id"] = str(uuid.uuid4())
|
||||
logger.debug(f"Fixed link ID: {link['id']}")
|
||||
|
||||
# Update source/sink IDs if they were remapped
|
||||
if link.get("source_id") in id_mapping:
|
||||
link["source_id"] = id_mapping[link["source_id"]]
|
||||
if link.get("sink_id") in id_mapping:
|
||||
link["sink_id"] = id_mapping[link["sink_id"]]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix single curly braces to double in template blocks."""
|
||||
for node in agent.get("nodes", []):
|
||||
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
input_data = node.get("input_default", {})
|
||||
for key in ("prompt", "format"):
|
||||
if key in input_data and isinstance(input_data[key], str):
|
||||
original = input_data[key]
|
||||
# Fix simple variable references: {var} -> {{var}}
|
||||
fixed = re.sub(
|
||||
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
|
||||
r"{{\1}}",
|
||||
original,
|
||||
)
|
||||
if fixed != original:
|
||||
input_data[key] = fixed
|
||||
logger.debug(f"Fixed curly braces in {key}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
# Find all ConditionBlock nodes
|
||||
condition_node_ids = {
|
||||
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
|
||||
}
|
||||
|
||||
if not condition_node_ids:
|
||||
return agent
|
||||
|
||||
new_nodes = []
|
||||
new_links = []
|
||||
processed_conditions = set()
|
||||
|
||||
for link in links:
|
||||
sink_id = link.get("sink_id")
|
||||
sink_name = link.get("sink_name")
|
||||
|
||||
# Check if this link goes to a ConditionBlock's value2
|
||||
if sink_id in condition_node_ids and sink_name == "value2":
|
||||
source_node = next(
|
||||
(n for n in nodes if n["id"] == link.get("source_id")), None
|
||||
)
|
||||
|
||||
# Skip if source is already a StoreValueBlock
|
||||
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
|
||||
continue
|
||||
|
||||
# Skip if we already processed this condition
|
||||
if sink_id in processed_conditions:
|
||||
continue
|
||||
|
||||
processed_conditions.add(sink_id)
|
||||
|
||||
# Create StoreValueBlock
|
||||
store_node_id = str(uuid.uuid4())
|
||||
store_node = {
|
||||
"id": store_node_id,
|
||||
"block_id": STORE_VALUE_BLOCK_ID,
|
||||
"input_default": {"data": None},
|
||||
"metadata": {"position": {"x": 0, "y": -100}},
|
||||
}
|
||||
new_nodes.append(store_node)
|
||||
|
||||
# Create link: original source -> StoreValueBlock
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": link["source_id"],
|
||||
"source_name": link["source_name"],
|
||||
"sink_id": store_node_id,
|
||||
"sink_name": "input",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Update original link: StoreValueBlock -> ConditionBlock
|
||||
link["source_id"] = store_node_id
|
||||
link["source_name"] = "output"
|
||||
|
||||
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
|
||||
|
||||
if new_nodes:
|
||||
agent["nodes"] = nodes + new_nodes
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
|
||||
|
||||
When an AddToList block is found:
|
||||
1. Checks if there's a CreateListBlock before it
|
||||
2. Removes CreateListBlock if linked directly to AddToList
|
||||
3. Adds an empty AddToList block before the original
|
||||
4. Ensures the original has a self-referencing link
|
||||
"""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
new_nodes = []
|
||||
original_addtolist_ids = set()
|
||||
nodes_to_remove = set()
|
||||
links_to_remove = []
|
||||
|
||||
# First pass: identify CreateListBlock nodes to remove
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and sink_node
|
||||
and source_node.get("block_id") == CREATELIST_BLOCK_ID
|
||||
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
|
||||
):
|
||||
nodes_to_remove.add(source_node.get("id"))
|
||||
links_to_remove.append(link)
|
||||
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
|
||||
|
||||
# Second pass: process AddToList blocks
|
||||
filtered_nodes = []
|
||||
for node in nodes:
|
||||
if node.get("id") in nodes_to_remove:
|
||||
continue
|
||||
|
||||
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
|
||||
original_addtolist_ids.add(node.get("id"))
|
||||
node_id = node.get("id")
|
||||
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
|
||||
|
||||
# Check if already has prerequisite
|
||||
has_prereq = any(
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "list"
|
||||
and link.get("source_name") == "updated_list"
|
||||
for link in links
|
||||
)
|
||||
|
||||
if not has_prereq:
|
||||
# Remove links to "list" input (except self-reference)
|
||||
for link in links:
|
||||
if (
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "list"
|
||||
and link.get("source_id") != node_id
|
||||
and link not in links_to_remove
|
||||
):
|
||||
links_to_remove.append(link)
|
||||
|
||||
# Create prerequisite AddToList block
|
||||
prereq_id = str(uuid.uuid4())
|
||||
prereq_node = {
|
||||
"id": prereq_id,
|
||||
"block_id": ADDTOLIST_BLOCK_ID,
|
||||
"input_default": {"list": [], "entry": None, "entries": []},
|
||||
"metadata": {
|
||||
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
|
||||
},
|
||||
}
|
||||
new_nodes.append(prereq_node)
|
||||
|
||||
# Link prerequisite to original
|
||||
links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": prereq_id,
|
||||
"source_name": "updated_list",
|
||||
"sink_id": node_id,
|
||||
"sink_name": "list",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
logger.debug(f"Added prerequisite AddToList block for {node_id}")
|
||||
|
||||
filtered_nodes.append(node)
|
||||
|
||||
# Remove marked links
|
||||
filtered_links = [link for link in links if link not in links_to_remove]
|
||||
|
||||
# Add self-referencing links for original AddToList blocks
|
||||
for node in filtered_nodes + new_nodes:
|
||||
if (
|
||||
node.get("block_id") == ADDTOLIST_BLOCK_ID
|
||||
and node.get("id") in original_addtolist_ids
|
||||
):
|
||||
node_id = node.get("id")
|
||||
has_self_ref = any(
|
||||
link["source_id"] == node_id
|
||||
and link["sink_id"] == node_id
|
||||
and link["source_name"] == "updated_list"
|
||||
and link["sink_name"] == "list"
|
||||
for link in filtered_links
|
||||
)
|
||||
if not has_self_ref:
|
||||
filtered_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": node_id,
|
||||
"source_name": "updated_list",
|
||||
"sink_id": node_id,
|
||||
"sink_name": "list",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
logger.debug(f"Added self-reference for AddToList {node_id}")
|
||||
|
||||
agent["nodes"] = filtered_nodes + new_nodes
|
||||
agent["links"] = filtered_links
|
||||
return agent
|
||||
|
||||
|
||||
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
nodes_to_remove = set()
|
||||
links_to_remove = []
|
||||
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and sink_node
|
||||
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
|
||||
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
|
||||
):
|
||||
nodes_to_remove.add(source_node.get("id"))
|
||||
links_to_remove.append(link)
|
||||
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
|
||||
|
||||
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
|
||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
||||
return agent
|
||||
|
||||
|
||||
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
if (
|
||||
source_node
|
||||
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
|
||||
and link.get("source_name") == "response"
|
||||
):
|
||||
link["source_name"] = "stdout_logs"
|
||||
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
links_to_remove = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
|
||||
node_id = node.get("id")
|
||||
input_default = node.get("input_default", {})
|
||||
|
||||
# Remove links to sample_size
|
||||
for link in links:
|
||||
if (
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "sample_size"
|
||||
):
|
||||
links_to_remove.append(link)
|
||||
|
||||
# Set default
|
||||
input_default["sample_size"] = 1
|
||||
node["input_default"] = input_default
|
||||
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
|
||||
|
||||
if links_to_remove:
|
||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
node_lookup = {n.get("id"): n for n in nodes}
|
||||
|
||||
for link in links:
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
source_node = node_lookup.get(source_id)
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_pos = source_node.get("metadata", {}).get("position", {})
|
||||
sink_pos = sink_node.get("metadata", {}).get("position", {})
|
||||
|
||||
source_x = source_pos.get("x", 0)
|
||||
sink_x = sink_pos.get("x", 0)
|
||||
|
||||
if abs(sink_x - source_x) < 800:
|
||||
new_x = source_x + 800
|
||||
if "metadata" not in sink_node:
|
||||
sink_node["metadata"] = {}
|
||||
if "position" not in sink_node["metadata"]:
|
||||
sink_node["metadata"]["position"] = {}
|
||||
sink_node["metadata"]["position"]["x"] = new_x
|
||||
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
|
||||
for node in agent.get("nodes", []):
|
||||
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
|
||||
input_default = node.get("input_default", {})
|
||||
if "offset" in input_default:
|
||||
offset = input_default["offset"]
|
||||
if isinstance(offset, (int, float)) and offset < 0:
|
||||
input_default["offset"] = abs(offset)
|
||||
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_ai_model_parameter(
|
||||
agent: dict[str, Any],
|
||||
blocks_info: list[dict[str, Any]],
|
||||
default_model: str = "gpt-4o",
|
||||
) -> dict[str, Any]:
|
||||
"""Add default model parameter to AI blocks if missing."""
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_map.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
# Check if block has AI category
|
||||
categories = block.get("categories", [])
|
||||
is_ai_block = any(
|
||||
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
|
||||
)
|
||||
|
||||
if is_ai_block:
|
||||
input_default = node.get("input_default", {})
|
||||
if "model" not in input_default:
|
||||
input_default["model"] = default_model
|
||||
node["input_default"] = input_default
|
||||
logger.debug(
|
||||
f"Added model '{default_model}' to AI block {node.get('id')}"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_link_static_properties(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Fix is_static property based on source block's staticOutput."""
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
if not source_block:
|
||||
continue
|
||||
|
||||
static_output = source_block.get("staticOutput", False)
|
||||
if link.get("is_static") != static_output:
|
||||
link["is_static"] = static_output
|
||||
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_data_type_mismatch(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in nodes}
|
||||
|
||||
def get_property_type(schema: dict, name: str) -> str | None:
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema:
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
def are_types_compatible(src: str, sink: str) -> bool:
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
|
||||
type_mapping = {
|
||||
"string": "string",
|
||||
"text": "string",
|
||||
"integer": "number",
|
||||
"number": "number",
|
||||
"float": "number",
|
||||
"boolean": "boolean",
|
||||
"bool": "boolean",
|
||||
"array": "list",
|
||||
"list": "list",
|
||||
"object": "dictionary",
|
||||
"dict": "dictionary",
|
||||
"dictionary": "dictionary",
|
||||
}
|
||||
|
||||
new_links = []
|
||||
nodes_to_add = []
|
||||
|
||||
for link in links:
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
|
||||
if not source_node or not sink_node:
|
||||
new_links.append(link)
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
sink_block = block_map.get(sink_node.get("block_id"))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
new_links.append(link)
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_property_type(source_outputs, link.get("source_name", ""))
|
||||
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
|
||||
|
||||
if (
|
||||
source_type
|
||||
and sink_type
|
||||
and not are_types_compatible(source_type, sink_type)
|
||||
):
|
||||
# Insert type converter
|
||||
converter_id = str(uuid.uuid4())
|
||||
target_type = type_mapping.get(sink_type, sink_type)
|
||||
|
||||
converter_node = {
|
||||
"id": converter_id,
|
||||
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
"input_default": {"type": target_type},
|
||||
"metadata": {"position": {"x": 0, "y": 100}},
|
||||
}
|
||||
nodes_to_add.append(converter_node)
|
||||
|
||||
# source -> converter
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": link["source_id"],
|
||||
"source_name": link["source_name"],
|
||||
"sink_id": converter_id,
|
||||
"sink_name": "value",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
# converter -> sink
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": converter_id,
|
||||
"source_name": "value",
|
||||
"sink_id": link["sink_id"],
|
||||
"sink_name": link["sink_name"],
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
|
||||
else:
|
||||
new_links.append(link)
|
||||
|
||||
if nodes_to_add:
|
||||
agent["nodes"] = nodes + nodes_to_add
|
||||
agent["links"] = new_links
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def apply_all_fixes(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Apply all fixes to an agent JSON.
|
||||
|
||||
Args:
|
||||
agent: Agent JSON dict
|
||||
blocks_info: Optional list of block info dicts for advanced fixes
|
||||
|
||||
Returns:
|
||||
Fixed agent JSON
|
||||
"""
|
||||
# Basic fixes (no block info needed)
|
||||
agent = fix_agent_ids(agent)
|
||||
agent = fix_double_curly_braces(agent)
|
||||
agent = fix_storevalue_before_condition(agent)
|
||||
agent = fix_addtolist_blocks(agent)
|
||||
agent = fix_addtodictionary_blocks(agent)
|
||||
agent = fix_code_execution_output(agent)
|
||||
agent = fix_data_sampling_sample_size(agent)
|
||||
agent = fix_node_x_coordinates(agent)
|
||||
agent = fix_getcurrentdate_offset(agent)
|
||||
|
||||
# Advanced fixes (require block info)
|
||||
if blocks_info is None:
|
||||
blocks_info = get_blocks_info()
|
||||
|
||||
agent = fix_ai_model_parameter(agent, blocks_info)
|
||||
agent = fix_link_static_properties(agent, blocks_info)
|
||||
agent = fix_data_type_mismatch(agent, blocks_info)
|
||||
|
||||
return agent
|
||||
@@ -0,0 +1,225 @@
|
||||
"""Prompt templates for agent generation."""
|
||||
|
||||
DECOMPOSITION_PROMPT = """
|
||||
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
|
||||
|
||||
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
|
||||
|
||||
---
|
||||
|
||||
FIRST: Analyze the user's goal and determine:
|
||||
1) Design-time configuration (fixed settings that won't change per run)
|
||||
2) Runtime inputs (values the agent's end-user will provide each time it runs)
|
||||
|
||||
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
|
||||
- DO NOT ask for the actual value
|
||||
- Instead, define it as an Agent Input with a clear name, type, and description
|
||||
|
||||
Only ask clarifying questions about design-time config that affects how you build the workflow:
|
||||
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
|
||||
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
|
||||
- Business rules that must be hard-coded
|
||||
|
||||
IMPORTANT CLARIFICATIONS POLICY:
|
||||
- Ask no more than five essential questions
|
||||
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
|
||||
- Do not ask for API keys or credentials; the platform handles those directly
|
||||
- If there is enough information to infer reasonable defaults, prefer to propose defaults
|
||||
|
||||
---
|
||||
|
||||
GUIDELINES:
|
||||
1. List each step as a numbered item
|
||||
2. Describe the action clearly and specify inputs/outputs
|
||||
3. Ensure steps are in logical, sequential order
|
||||
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
|
||||
5. Help the user reach their goal efficiently
|
||||
|
||||
---
|
||||
|
||||
RULES:
|
||||
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
|
||||
2. USE ONLY THE BLOCKS PROVIDED
|
||||
3. ALL required_input fields must be provided
|
||||
4. Data types of linked properties must match
|
||||
5. Write expert-level prompts for AI-related blocks
|
||||
|
||||
---
|
||||
|
||||
CRITICAL BLOCK RESTRICTIONS:
|
||||
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
|
||||
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
|
||||
3. ConditionBlock: value2 is reference, value1 is contrast
|
||||
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
|
||||
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
|
||||
|
||||
---
|
||||
|
||||
OUTPUT FORMAT:
|
||||
|
||||
If more information is needed:
|
||||
```json
|
||||
{{
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{{
|
||||
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
|
||||
"keyword": "email_provider",
|
||||
"example": "Gmail"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
If ready to proceed:
|
||||
```json
|
||||
{{
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{{
|
||||
"step_number": 1,
|
||||
"block_name": "AgentShortTextInputBlock",
|
||||
"description": "Get the URL of the content to analyze.",
|
||||
"inputs": [{{"name": "name", "value": "URL"}}],
|
||||
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
AVAILABLE BLOCKS:
|
||||
{block_summaries}
|
||||
"""
|
||||
|
||||
GENERATION_PROMPT = """
|
||||
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
|
||||
|
||||
---
|
||||
|
||||
NODES:
|
||||
Each node must include:
|
||||
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
|
||||
- `block_id`: The block identifier (must match an Allowed Block)
|
||||
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
|
||||
- `metadata`: Must contain:
|
||||
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
|
||||
- `customized_name`: Clear name describing this block's purpose in the workflow
|
||||
|
||||
---
|
||||
|
||||
LINKS:
|
||||
Each link connects a source node's output to a sink node's input:
|
||||
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
|
||||
- `source_id`: ID of the source node
|
||||
- `source_name`: Output field name from the source block
|
||||
- `sink_id`: ID of the sink node
|
||||
- `sink_name`: Input field name on the sink block
|
||||
- `is_static`: true only if source block has static_output: true
|
||||
|
||||
CRITICAL: All IDs must be valid UUID v4 format!
|
||||
|
||||
---
|
||||
|
||||
AGENT (GRAPH):
|
||||
Wrap nodes and links in:
|
||||
- `id`: UUID of the agent
|
||||
- `name`: Short, generic name (avoid specific company names, URLs)
|
||||
- `description`: Short, generic description
|
||||
- `nodes`: List of all nodes
|
||||
- `links`: List of all links
|
||||
- `version`: 1
|
||||
- `is_active`: true
|
||||
|
||||
---
|
||||
|
||||
TIPS:
|
||||
- All required_input fields must be provided via input_default or a valid link
|
||||
- Ensure consistent source_id and sink_id references
|
||||
- Avoid dangling links
|
||||
- Input/output pins must match block schemas
|
||||
- Do not invent unknown block_ids
|
||||
|
||||
---
|
||||
|
||||
ALLOWED BLOCKS:
|
||||
{block_summaries}
|
||||
|
||||
---
|
||||
|
||||
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
|
||||
"""
|
||||
|
||||
PATCH_PROMPT = """
|
||||
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
|
||||
|
||||
CURRENT AGENT:
|
||||
{current_agent}
|
||||
|
||||
AVAILABLE BLOCKS:
|
||||
{block_summaries}
|
||||
|
||||
---
|
||||
|
||||
PATCH FORMAT:
|
||||
Return a JSON object with the following structure:
|
||||
|
||||
```json
|
||||
{{
|
||||
"type": "patch",
|
||||
"intent": "Brief description of what the patch does",
|
||||
"patches": [
|
||||
{{
|
||||
"type": "modify",
|
||||
"node_id": "uuid-of-node-to-modify",
|
||||
"changes": {{
|
||||
"input_default": {{"field": "new_value"}},
|
||||
"metadata": {{"customized_name": "New Name"}}
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"type": "add",
|
||||
"new_nodes": [
|
||||
{{
|
||||
"id": "new-uuid",
|
||||
"block_id": "block-uuid",
|
||||
"input_default": {{}},
|
||||
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
|
||||
}}
|
||||
],
|
||||
"new_links": [
|
||||
{{
|
||||
"id": "link-uuid",
|
||||
"source_id": "source-node-id",
|
||||
"source_name": "output_field",
|
||||
"sink_id": "sink-node-id",
|
||||
"sink_name": "input_field"
|
||||
}}
|
||||
]
|
||||
}},
|
||||
{{
|
||||
"type": "remove",
|
||||
"node_ids": ["uuid-of-node-to-remove"],
|
||||
"link_ids": ["uuid-of-link-to-remove"]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
If you need more information, return:
|
||||
```json
|
||||
{{
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{{
|
||||
"question": "What specific change do you want?",
|
||||
"keyword": "change_type",
|
||||
"example": "Add error handling"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
Generate the minimal patch needed. Output ONLY valid JSON.
|
||||
"""
|
||||
@@ -0,0 +1,213 @@
|
||||
"""Utilities for agent generation."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# UUID validation regex
|
||||
UUID_REGEX = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
|
||||
)
|
||||
|
||||
# Block IDs for various fixes
|
||||
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
|
||||
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
|
||||
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
|
||||
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
|
||||
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
|
||||
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
|
||||
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
|
||||
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
|
||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
|
||||
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
||||
|
||||
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
|
||||
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
|
||||
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
|
||||
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
|
||||
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
|
||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
|
||||
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
|
||||
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
|
||||
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
|
||||
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
|
||||
"716a67b3-6760-42e7-86dc-18645c6e00fc",
|
||||
"530cf046-2ce0-4854-ae2c-659db17c7a46",
|
||||
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
|
||||
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
|
||||
]
|
||||
|
||||
|
||||
def is_valid_uuid(value: str) -> bool:
|
||||
"""Check if a string is a valid UUID v4."""
|
||||
return isinstance(value, str) and UUID_REGEX.match(value) is not None
|
||||
|
||||
|
||||
def _compact_schema(schema: dict) -> dict[str, str]:
|
||||
"""Extract compact type info from a JSON schema properties dict.
|
||||
|
||||
Returns a dict of {field_name: type_string} for essential info only.
|
||||
"""
|
||||
props = schema.get("properties", {})
|
||||
result = {}
|
||||
|
||||
for name, prop in props.items():
|
||||
# Skip internal/complex fields
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Get type string
|
||||
type_str = prop.get("type", "any")
|
||||
|
||||
# Handle anyOf/oneOf (optional types)
|
||||
if "anyOf" in prop:
|
||||
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
|
||||
type_str = "|".join(types) if types else "any"
|
||||
elif "allOf" in prop:
|
||||
type_str = "object"
|
||||
|
||||
# Add array item type if present
|
||||
if type_str == "array" and "items" in prop:
|
||||
items = prop["items"]
|
||||
if isinstance(items, dict):
|
||||
item_type = items.get("type", "any")
|
||||
type_str = f"array[{item_type}]"
|
||||
|
||||
result[name] = type_str
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_block_summaries(include_schemas: bool = True) -> str:
|
||||
"""Generate compact block summaries for prompts.
|
||||
|
||||
Args:
|
||||
include_schemas: Whether to include input/output type info
|
||||
|
||||
Returns:
|
||||
Formatted string of block summaries (compact format)
|
||||
"""
|
||||
blocks = get_blocks()
|
||||
summaries = []
|
||||
|
||||
for block_id, block_cls in blocks.items():
|
||||
block = block_cls()
|
||||
name = block.name
|
||||
desc = getattr(block, "description", "") or ""
|
||||
|
||||
# Truncate description
|
||||
if len(desc) > 150:
|
||||
desc = desc[:147] + "..."
|
||||
|
||||
if not include_schemas:
|
||||
summaries.append(f"- {name} (id: {block_id}): {desc}")
|
||||
else:
|
||||
# Compact format with type info only
|
||||
inputs = {}
|
||||
outputs = {}
|
||||
required = []
|
||||
|
||||
if hasattr(block, "input_schema"):
|
||||
try:
|
||||
schema = block.input_schema.jsonschema()
|
||||
inputs = _compact_schema(schema)
|
||||
required = schema.get("required", [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(block, "output_schema"):
|
||||
try:
|
||||
schema = block.output_schema.jsonschema()
|
||||
outputs = _compact_schema(schema)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build compact line format
|
||||
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
|
||||
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
|
||||
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
|
||||
req_str = f" req=[{','.join(required)}]" if required else ""
|
||||
|
||||
static = " [static]" if getattr(block, "static_output", False) else ""
|
||||
|
||||
line = f"- {name} (id: {block_id}): {desc}"
|
||||
if in_str:
|
||||
line += f"\n in: {{{in_str}}}{req_str}"
|
||||
if out_str:
|
||||
line += f"\n out: {{{out_str}}}{static}"
|
||||
|
||||
summaries.append(line)
|
||||
|
||||
return "\n".join(summaries)
|
||||
|
||||
|
||||
def get_blocks_info() -> list[dict[str, Any]]:
|
||||
"""Get block information with schemas for validation and fixing."""
|
||||
blocks = get_blocks()
|
||||
blocks_info = []
|
||||
for block_id, block_cls in blocks.items():
|
||||
block = block_cls()
|
||||
blocks_info.append(
|
||||
{
|
||||
"id": block_id,
|
||||
"name": block.name,
|
||||
"description": getattr(block, "description", ""),
|
||||
"categories": getattr(block, "categories", []),
|
||||
"staticOutput": getattr(block, "static_output", False),
|
||||
"inputSchema": (
|
||||
block.input_schema.jsonschema()
|
||||
if hasattr(block, "input_schema")
|
||||
else {}
|
||||
),
|
||||
"outputSchema": (
|
||||
block.output_schema.jsonschema()
|
||||
if hasattr(block, "output_schema")
|
||||
else {}
|
||||
),
|
||||
}
|
||||
)
|
||||
return blocks_info
|
||||
|
||||
|
||||
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
|
||||
"""Extract JSON from LLM response (handles markdown code blocks)."""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# Try fenced code block
|
||||
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1).strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try raw text
|
||||
try:
|
||||
return json.loads(text.strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding {...} span
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start != -1 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding [...] span
|
||||
start = text.find("[")
|
||||
end = text.rfind("]")
|
||||
if start != -1 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,279 @@
|
||||
"""Agent validator - Validates agent structure and connections."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from .utils import get_blocks_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentValidator:
|
||||
"""Validator for AutoGPT agents with detailed error reporting."""
|
||||
|
||||
def __init__(self):
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_error(self, error: str) -> None:
|
||||
"""Add an error message."""
|
||||
self.errors.append(error)
|
||||
|
||||
def validate_block_existence(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate all block IDs exist in the blocks library."""
|
||||
valid = True
|
||||
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
node_id = node.get("id")
|
||||
|
||||
if not block_id:
|
||||
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if block_id not in valid_block_ids:
|
||||
self.add_error(
|
||||
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
|
||||
"""Validate all node IDs referenced in links exist."""
|
||||
valid = True
|
||||
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
link_id = link.get("id", "Unknown")
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
if not source_id:
|
||||
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
|
||||
valid = False
|
||||
elif source_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references non-existent source_id '{source_id}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
if not sink_id:
|
||||
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
|
||||
valid = False
|
||||
elif sink_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_required_inputs(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate required inputs are provided."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_map.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
required_inputs = block.get("inputSchema", {}).get("required", [])
|
||||
input_defaults = node.get("input_default", {})
|
||||
node_id = node.get("id")
|
||||
|
||||
# Get linked inputs
|
||||
linked_inputs = {
|
||||
link["sink_name"]
|
||||
for link in agent.get("links", [])
|
||||
if link.get("sink_id") == node_id
|
||||
}
|
||||
|
||||
for req_input in required_inputs:
|
||||
if (
|
||||
req_input not in input_defaults
|
||||
and req_input not in linked_inputs
|
||||
and req_input != "credentials"
|
||||
):
|
||||
block_name = block.get("name", "Unknown Block")
|
||||
self.add_error(
|
||||
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_data_type_compatibility(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate linked data types are compatible."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
def get_type(schema: dict, name: str) -> str | None:
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema:
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
def are_compatible(src: str, sink: str) -> bool:
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
sink_block = block_map.get(sink_node.get("block_id"))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_type(source_outputs, link.get("source_name", ""))
|
||||
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
|
||||
|
||||
if source_type and sink_type and not are_compatible(source_type, sink_type):
|
||||
self.add_error(
|
||||
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
|
||||
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_nested_sink_links(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate nested sink links (with _#_ notation)."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_name = link.get("sink_name", "")
|
||||
|
||||
if "_#_" in sink_name:
|
||||
parent, child = sink_name.split("_#_", 1)
|
||||
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
if not sink_node:
|
||||
continue
|
||||
|
||||
block = block_map.get(sink_node.get("block_id"))
|
||||
if not block:
|
||||
continue
|
||||
|
||||
input_props = block.get("inputSchema", {}).get("properties", {})
|
||||
parent_schema = input_props.get(parent)
|
||||
|
||||
if not parent_schema:
|
||||
self.add_error(
|
||||
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if not parent_schema.get("additionalProperties"):
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and child in parent_schema.get("properties", {})
|
||||
):
|
||||
self.add_error(
|
||||
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
|
||||
"""Validate prompts don't have spaces in template variables."""
|
||||
valid = True
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
input_default = node.get("input_default", {})
|
||||
prompt = input_default.get("prompt", "")
|
||||
|
||||
if not isinstance(prompt, str):
|
||||
continue
|
||||
|
||||
# Find {{...}} with spaces
|
||||
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
|
||||
for match in matches:
|
||||
content = match.group(1)
|
||||
if " " in content:
|
||||
self.add_error(
|
||||
f"Node '{node.get('id')}' has spaces in template variable: "
|
||||
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Run all validations.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
self.errors = []
|
||||
|
||||
if blocks_info is None:
|
||||
blocks_info = get_blocks_info()
|
||||
|
||||
checks = [
|
||||
self.validate_block_existence(agent, blocks_info),
|
||||
self.validate_link_node_references(agent),
|
||||
self.validate_required_inputs(agent, blocks_info),
|
||||
self.validate_data_type_compatibility(agent, blocks_info),
|
||||
self.validate_nested_sink_links(agent, blocks_info),
|
||||
self.validate_prompt_spaces(agent),
|
||||
]
|
||||
|
||||
all_passed = all(checks)
|
||||
|
||||
if all_passed:
|
||||
logger.info("Agent validation successful")
|
||||
return True, None
|
||||
|
||||
error_message = "Agent validation failed:\n"
|
||||
for i, error in enumerate(self.errors, 1):
|
||||
error_message += f"{i}. {error}\n"
|
||||
|
||||
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
|
||||
return False, error_message
|
||||
|
||||
|
||||
def validate_agent(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Convenience function to validate an agent.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
validator = AgentValidator()
|
||||
return validator.validate(agent, blocks_info)
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,279 @@
|
||||
"""CreateAgentTool - Creates agents from natural language descriptions."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
apply_all_fixes,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
get_blocks_info,
|
||||
save_agent_to_library,
|
||||
validate_agent,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum retries for agent generation with validation feedback
|
||||
MAX_GENERATION_RETRIES = 2
|
||||
|
||||
|
||||
class CreateAgentTool(BaseTool):
|
||||
"""Tool for creating agents from natural language descriptions."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent workflow from a natural language description. "
|
||||
"First generates a preview, then saves to library if save=true."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of what the agent should do. "
|
||||
"Be specific about inputs, outputs, and the workflow steps."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions. "
|
||||
"Include any preferences or constraints mentioned by the user."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the agent to the user's library. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["description"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the create_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Decompose the description into steps (may return clarifying questions)
|
||||
2. Generate agent JSON from the steps
|
||||
3. Apply fixes to correct common LLM errors
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
error="Missing description parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 1: Decompose goal into steps
|
||||
try:
|
||||
decomposition_result = await decompose_goal(description, context)
|
||||
except ValueError as e:
|
||||
# Handle missing API key or configuration errors
|
||||
return ErrorResponse(
|
||||
message=f"Agent generation is not configured: {str(e)}",
|
||||
error="configuration_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to analyze the goal. Please try rephrasing.",
|
||||
error="Decomposition failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if decomposition_result.get("type") == "clarifying_questions":
|
||||
questions = decomposition_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to create this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check for unachievable/vague goals
|
||||
if decomposition_result.get("type") == "unachievable_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get("reason", "")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"This goal cannot be accomplished with the available blocks. "
|
||||
f"{reason} "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
error="unachievable_goal",
|
||||
details={"suggested_goal": suggested, "reason": reason},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "vague_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The goal is too vague to create a specific workflow. "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
error="vague_goal",
|
||||
details={"suggested_goal": suggested},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 2: Generate agent JSON with retry on validation failure
|
||||
blocks_info = get_blocks_info()
|
||||
agent_json = None
|
||||
validation_errors = None
|
||||
|
||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
||||
# Generate agent (include validation errors from previous attempt)
|
||||
if attempt == 0:
|
||||
agent_json = await generate_agent(decomposition_result)
|
||||
else:
|
||||
# Retry with validation error feedback
|
||||
logger.info(
|
||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
||||
)
|
||||
retry_instructions = {
|
||||
**decomposition_result,
|
||||
"previous_errors": validation_errors,
|
||||
"retry_instructions": (
|
||||
"The previous generation had validation errors. "
|
||||
"Please fix these issues in the new generation:\n"
|
||||
f"{validation_errors}"
|
||||
),
|
||||
}
|
||||
agent_json = await generate_agent(retry_instructions)
|
||||
|
||||
if agent_json is None:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. Please try again.",
|
||||
error="Generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Step 3: Apply fixes to correct common errors
|
||||
agent_json = apply_all_fixes(agent_json, blocks_info)
|
||||
|
||||
# Step 4: Validate the agent
|
||||
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
||||
)
|
||||
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
# Return error with validation details
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
|
||||
f"Please try rephrasing your request or simplify the workflow."
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"validation_errors": validation_errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
# Step 4: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||
f"Review it and call create_agent with save=true to save it to your library."
|
||||
),
|
||||
agent_json=agent_json,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,294 @@
|
||||
"""EditAgentTool - Edits existing agents using natural language."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
apply_agent_patch,
|
||||
apply_all_fixes,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_blocks_info,
|
||||
save_agent_to_library,
|
||||
validate_agent,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum retries for patch generation with validation feedback
|
||||
MAX_GENERATION_RETRIES = 2
|
||||
|
||||
|
||||
class EditAgentTool(BaseTool):
|
||||
"""Tool for editing existing agents using natural language."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent from the user's library using natural language. "
|
||||
"Generates a patch to update the agent while preserving unchanged parts."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The ID of the agent to edit. "
|
||||
"Can be a graph ID or library agent ID."
|
||||
),
|
||||
},
|
||||
"changes": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of what changes to make. "
|
||||
"Be specific about what to add, remove, or modify."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the changes. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "changes"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the edit_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Fetch the current agent
|
||||
2. Generate a patch based on the requested changes
|
||||
3. Apply the patch to create an updated agent
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
error="Missing agent_id parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not changes:
|
||||
return ErrorResponse(
|
||||
message="Please describe what changes you want to make.",
|
||||
error="Missing changes parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 1: Fetch current agent
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
|
||||
if current_agent is None:
|
||||
return ErrorResponse(
|
||||
message=f"Could not find agent with ID '{agent_id}' in your library.",
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build the update request with context
|
||||
update_request = changes
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
# Step 2: Generate patch with retry on validation failure
|
||||
blocks_info = get_blocks_info()
|
||||
updated_agent = None
|
||||
validation_errors = None
|
||||
intent = "Applied requested changes"
|
||||
|
||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
||||
# Generate patch (include validation errors from previous attempt)
|
||||
try:
|
||||
if attempt == 0:
|
||||
patch_result = await generate_agent_patch(
|
||||
update_request, current_agent
|
||||
)
|
||||
else:
|
||||
# Retry with validation error feedback
|
||||
logger.info(
|
||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
||||
)
|
||||
retry_request = (
|
||||
f"{update_request}\n\n"
|
||||
f"IMPORTANT: The previous edit had validation errors. "
|
||||
f"Please fix these issues:\n{validation_errors}"
|
||||
)
|
||||
patch_result = await generate_agent_patch(
|
||||
retry_request, current_agent
|
||||
)
|
||||
except ValueError as e:
|
||||
# Handle missing API key or configuration errors
|
||||
return ErrorResponse(
|
||||
message=f"Agent generation is not configured: {str(e)}",
|
||||
error="configuration_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if patch_result is None:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. Please try rephrasing.",
|
||||
error="Patch generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if patch_result.get("type") == "clarifying_questions":
|
||||
questions = patch_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information about the changes. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 3: Apply patch and fixes
|
||||
try:
|
||||
updated_agent = apply_agent_patch(current_agent, patch_result)
|
||||
updated_agent = apply_all_fixes(updated_agent, blocks_info)
|
||||
except Exception as e:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to apply changes: {str(e)}",
|
||||
error="patch_apply_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
validation_errors = str(e)
|
||||
continue
|
||||
|
||||
# Step 4: Validate the updated agent
|
||||
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
|
||||
intent = patch_result.get("intent", "Applied requested changes")
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
||||
)
|
||||
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
# Return error with validation details
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Updated agent has validation errors after "
|
||||
f"{MAX_GENERATION_RETRIES + 1} attempts. "
|
||||
f"Please try rephrasing your request or simplify the changes."
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"validation_errors": validation_errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
|
||||
assert updated_agent is not None
|
||||
|
||||
agent_name = updated_agent.get("name", "Updated Agent")
|
||||
agent_description = updated_agent.get("description", "")
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
# Step 5: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've updated the agent. Changes: {intent}. "
|
||||
f"The agent now has {node_count} blocks. "
|
||||
f"Review it and call edit_agent with save=true to save the changes."
|
||||
),
|
||||
agent_json=updated_agent,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library (creates a new version)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
updated_agent, user_id, is_update=True
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
f"Updated agent '{created_graph.name}' has been saved to your library! "
|
||||
f"Changes: {intent}"
|
||||
),
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the updated agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,253 @@
|
||||
"""Tool for searching available blocks using hybrid search."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BlockInfoSummary,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .search_blocks import get_block_search_index
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindBlockTool(BaseTool):
|
||||
"""Tool for searching available blocks."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for available blocks by name or description. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"Use this to find blocks that can be executed directly."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find blocks by name or description. "
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _matches_query(self, block, query: str) -> tuple[int, bool]:
|
||||
"""
|
||||
Check if a block matches the query and return a priority score.
|
||||
|
||||
Returns (priority, matches) where:
|
||||
- priority 0: exact name match
|
||||
- priority 1: name contains query
|
||||
- priority 2: description contains query
|
||||
- priority 3: category contains query
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
name_lower = block.name.lower()
|
||||
desc_lower = block.description.lower()
|
||||
|
||||
# Exact name match
|
||||
if query_lower == name_lower:
|
||||
return 0, True
|
||||
|
||||
# Name contains query
|
||||
if query_lower in name_lower:
|
||||
return 1, True
|
||||
|
||||
# Description contains query
|
||||
if query_lower in desc_lower:
|
||||
return 2, True
|
||||
|
||||
# Category contains query
|
||||
for category in block.categories:
|
||||
if query_lower in category.name.lower():
|
||||
return 3, True
|
||||
|
||||
return 4, False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for blocks matching the query.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
BlockListResponse: List of matching blocks
|
||||
NoResultsResponse: No blocks found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Try hybrid search first
|
||||
search_results = self._hybrid_search(query)
|
||||
|
||||
if search_results is not None:
|
||||
# Hybrid search succeeded
|
||||
if not search_results:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found matching '{query}'",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Search by category: ai, text, social, search, etc.",
|
||||
"Check block names like 'SendEmail', 'HttpRequest', etc.",
|
||||
],
|
||||
)
|
||||
|
||||
# Get full block info for each result
|
||||
all_blocks = load_all_blocks()
|
||||
blocks = []
|
||||
for result in search_results:
|
||||
block_cls = all_blocks.get(result.block_id)
|
||||
if block_cls:
|
||||
block = block_cls()
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
description=block.description,
|
||||
categories=[cat.name for cat in block.categories],
|
||||
input_schema=block.input_schema.jsonschema(),
|
||||
output_schema=block.output_schema.jsonschema(),
|
||||
)
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block{'s' if len(blocks) != 1 else ''} "
|
||||
f"matching '{query}'. Use run_block to execute a block with "
|
||||
"the required inputs."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fallback to simple search if hybrid search failed
|
||||
return self._simple_search(query, session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching blocks: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search blocks. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _hybrid_search(self, query: str) -> list | None:
|
||||
"""
|
||||
Perform hybrid search using embeddings and BM25.
|
||||
|
||||
Returns:
|
||||
List of BlockSearchResult if successful, None if index not available
|
||||
"""
|
||||
try:
|
||||
index = get_block_search_index()
|
||||
if not index.load():
|
||||
logger.info(
|
||||
"Block search index not available, falling back to simple search"
|
||||
)
|
||||
return None
|
||||
|
||||
results = index.search(query, top_k=10)
|
||||
logger.info(f"Hybrid search found {len(results)} blocks for: {query}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Hybrid search failed, falling back to simple: {e}")
|
||||
return None
|
||||
|
||||
def _simple_search(self, query: str, session_id: str) -> ToolResponseBase:
|
||||
"""Fallback simple search using substring matching."""
|
||||
all_blocks = load_all_blocks()
|
||||
logger.info(f"Simple searching {len(all_blocks)} blocks for: {query}")
|
||||
|
||||
# Find matching blocks with priority scores
|
||||
matches: list[tuple[int, Any]] = []
|
||||
for block_id, block_cls in all_blocks.items():
|
||||
block = block_cls()
|
||||
priority, is_match = self._matches_query(block, query)
|
||||
if is_match:
|
||||
matches.append((priority, block))
|
||||
|
||||
# Sort by priority (lower is better)
|
||||
matches.sort(key=lambda x: x[0])
|
||||
|
||||
# Take top 10 results
|
||||
top_matches = [block for _, block in matches[:10]]
|
||||
|
||||
if not top_matches:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found matching '{query}'",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Search by category: ai, text, social, search, etc.",
|
||||
"Check block names like 'SendEmail', 'HttpRequest', etc.",
|
||||
],
|
||||
)
|
||||
|
||||
# Build response
|
||||
blocks = []
|
||||
for block in top_matches:
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
description=block.description,
|
||||
categories=[cat.name for cat in block.categories],
|
||||
input_schema=block.input_schema.jsonschema(),
|
||||
output_schema=block.output_schema.jsonschema(),
|
||||
)
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block{'s' if len(blocks) != 1 else ''} "
|
||||
f"matching '{query}'. Use run_block to execute a block with "
|
||||
"the required inputs."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,483 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Block Indexer for Hybrid Search
|
||||
|
||||
Creates a hybrid search index from blocks:
|
||||
- OpenAI embeddings (text-embedding-3-small)
|
||||
- BM25 index for lexical search
|
||||
- Name index for title matching boost
|
||||
|
||||
Supports incremental updates by tracking content hashes.
|
||||
|
||||
Usage:
|
||||
python -m backend.server.v2.chat.tools.index_blocks [--force]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check for OpenAI availability
|
||||
try:
|
||||
import openai # noqa: F401
|
||||
|
||||
HAS_OPENAI = True
|
||||
except ImportError:
|
||||
HAS_OPENAI = False
|
||||
print("Warning: openai not installed. Run: pip install openai")
|
||||
|
||||
# Default embedding model (OpenAI)
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_EMBEDDING_DIM = 1536
|
||||
|
||||
# Output path (relative to this file)
|
||||
INDEX_PATH = Path(__file__).parent / "blocks_index.json"
|
||||
|
||||
# Stopwords for tokenization
|
||||
STOPWORDS = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"shall",
|
||||
"can",
|
||||
"need",
|
||||
"dare",
|
||||
"ought",
|
||||
"used",
|
||||
"to",
|
||||
"of",
|
||||
"in",
|
||||
"for",
|
||||
"on",
|
||||
"with",
|
||||
"at",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"then",
|
||||
"once",
|
||||
"and",
|
||||
"but",
|
||||
"or",
|
||||
"nor",
|
||||
"so",
|
||||
"yet",
|
||||
"both",
|
||||
"either",
|
||||
"neither",
|
||||
"not",
|
||||
"only",
|
||||
"own",
|
||||
"same",
|
||||
"than",
|
||||
"too",
|
||||
"very",
|
||||
"just",
|
||||
"also",
|
||||
"now",
|
||||
"here",
|
||||
"there",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"how",
|
||||
"all",
|
||||
"each",
|
||||
"every",
|
||||
"few",
|
||||
"more",
|
||||
"most",
|
||||
"other",
|
||||
"some",
|
||||
"such",
|
||||
"no",
|
||||
"any",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
"block", # Too common in block context
|
||||
}
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25."""
|
||||
text = text.lower()
|
||||
# Remove code blocks if any
|
||||
text = re.sub(r"```[\s\S]*?```", "", text)
|
||||
text = re.sub(r"`[^`]+`", "", text)
|
||||
# Extract words (including camelCase split)
|
||||
# First, split camelCase
|
||||
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
|
||||
# Extract words
|
||||
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
|
||||
# Remove very short words and stopwords
|
||||
return [w for w in words if len(w) > 2 and w not in STOPWORDS]
|
||||
|
||||
|
||||
def build_searchable_text(block: Any) -> str:
|
||||
"""Build searchable text from block attributes."""
|
||||
parts = []
|
||||
|
||||
# Block name (split camelCase for better tokenization)
|
||||
name = block.name
|
||||
# Split camelCase: GetCurrentTimeBlock -> Get Current Time Block
|
||||
name_split = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
|
||||
parts.append(name_split)
|
||||
|
||||
# Description
|
||||
if block.description:
|
||||
parts.append(block.description)
|
||||
|
||||
# Categories
|
||||
for category in block.categories:
|
||||
parts.append(category.name)
|
||||
|
||||
# Input schema field names and descriptions
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
if "properties" in input_schema:
|
||||
for field_name, field_info in input_schema["properties"].items():
|
||||
parts.append(field_name)
|
||||
if "description" in field_info:
|
||||
parts.append(field_info["description"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Output schema field names
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
if "properties" in output_schema:
|
||||
for field_name in output_schema["properties"]:
|
||||
parts.append(field_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def compute_content_hash(text: str) -> str:
|
||||
"""Compute MD5 hash of text for change detection."""
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def load_existing_index(index_path: Path) -> dict[str, Any] | None:
|
||||
"""Load existing index if present."""
|
||||
if not index_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(index_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load existing index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_embeddings(
|
||||
texts: list[str],
|
||||
model_name: str = DEFAULT_EMBEDDING_MODEL,
|
||||
batch_size: int = 100,
|
||||
) -> np.ndarray:
|
||||
"""Create embeddings using OpenAI API."""
|
||||
if not HAS_OPENAI:
|
||||
raise RuntimeError("openai not installed. Run: pip install openai")
|
||||
|
||||
# Import here to satisfy type checker
|
||||
from openai import OpenAI
|
||||
|
||||
# Check for API key
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
client = OpenAI(api_key=api_key)
|
||||
embeddings = []
|
||||
|
||||
print(f"Creating embeddings for {len(texts)} texts using {model_name}...")
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
# Truncate texts to max token limit (8191 tokens for text-embedding-3-small)
|
||||
# Roughly 4 chars per token, so ~32000 chars max
|
||||
batch = [text[:32000] for text in batch]
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=model_name,
|
||||
input=batch,
|
||||
)
|
||||
|
||||
for embedding_data in response.data:
|
||||
embeddings.append(embedding_data.embedding)
|
||||
|
||||
print(f" Processed {min(i + batch_size, len(texts))}/{len(texts)} texts")
|
||||
|
||||
return np.array(embeddings, dtype=np.float32)
|
||||
|
||||
|
||||
def build_bm25_data(
|
||||
blocks_data: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""Build BM25 metadata from block data."""
|
||||
# Tokenize all searchable texts
|
||||
tokenized_docs = []
|
||||
for block in blocks_data:
|
||||
tokens = tokenize(block["searchable_text"])
|
||||
tokenized_docs.append(tokens)
|
||||
|
||||
# Calculate document frequencies
|
||||
doc_freq: dict[str, int] = {}
|
||||
for tokens in tokenized_docs:
|
||||
seen = set()
|
||||
for token in tokens:
|
||||
if token not in seen:
|
||||
doc_freq[token] = doc_freq.get(token, 0) + 1
|
||||
seen.add(token)
|
||||
|
||||
n_docs = len(tokenized_docs)
|
||||
doc_lens = [len(d) for d in tokenized_docs]
|
||||
avgdl = sum(doc_lens) / max(n_docs, 1)
|
||||
|
||||
return {
|
||||
"n_docs": n_docs,
|
||||
"avgdl": avgdl,
|
||||
"df": doc_freq,
|
||||
"doc_lens": doc_lens,
|
||||
}
|
||||
|
||||
|
||||
def build_name_index(
|
||||
blocks_data: list[dict[str, Any]],
|
||||
) -> dict[str, list[list[int | float]]]:
|
||||
"""Build inverted index for name search boost."""
|
||||
index: dict[str, list[list[int | float]]] = defaultdict(list)
|
||||
|
||||
for idx, block in enumerate(blocks_data):
|
||||
# Tokenize block name
|
||||
name_tokens = tokenize(block["name"])
|
||||
seen = set()
|
||||
|
||||
for i, token in enumerate(name_tokens):
|
||||
if token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
|
||||
# Score: first token gets higher weight
|
||||
score = 1.5 if i == 0 else 1.0
|
||||
index[token].append([idx, score])
|
||||
|
||||
return dict(index)
|
||||
|
||||
|
||||
def build_block_index(
|
||||
force_rebuild: bool = False,
|
||||
output_path: Path = INDEX_PATH,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build the block search index.
|
||||
|
||||
Args:
|
||||
force_rebuild: If True, rebuild all embeddings even if unchanged
|
||||
output_path: Path to save the index
|
||||
|
||||
Returns:
|
||||
The generated index dictionary
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
print("Loading all blocks...")
|
||||
all_blocks = load_all_blocks()
|
||||
print(f"Found {len(all_blocks)} blocks")
|
||||
|
||||
# Load existing index for incremental updates
|
||||
existing_index = None if force_rebuild else load_existing_index(output_path)
|
||||
existing_blocks: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if existing_index:
|
||||
print(
|
||||
f"Loaded existing index with {len(existing_index.get('blocks', []))} blocks"
|
||||
)
|
||||
for block in existing_index.get("blocks", []):
|
||||
existing_blocks[block["id"]] = block
|
||||
|
||||
# Process each block
|
||||
blocks_data: list[dict[str, Any]] = []
|
||||
blocks_needing_embedding: list[tuple[int, str]] = [] # (index, searchable_text)
|
||||
|
||||
for block_id, block_cls in all_blocks.items():
|
||||
try:
|
||||
block = block_cls()
|
||||
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
searchable_text = build_searchable_text(block)
|
||||
content_hash = compute_content_hash(searchable_text)
|
||||
|
||||
block_data = {
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"description": block.description,
|
||||
"categories": [cat.name for cat in block.categories],
|
||||
"searchable_text": searchable_text,
|
||||
"content_hash": content_hash,
|
||||
"emb": None, # Will be filled later
|
||||
}
|
||||
|
||||
# Check if we can reuse existing embedding
|
||||
if (
|
||||
block.id in existing_blocks
|
||||
and existing_blocks[block.id].get("content_hash") == content_hash
|
||||
and existing_blocks[block.id].get("emb")
|
||||
):
|
||||
# Reuse existing embedding
|
||||
block_data["emb"] = existing_blocks[block.id]["emb"]
|
||||
else:
|
||||
# Need new embedding
|
||||
blocks_needing_embedding.append((len(blocks_data), searchable_text))
|
||||
|
||||
blocks_data.append(block_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process block {block_id}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Processed {len(blocks_data)} blocks")
|
||||
print(f"Blocks needing new embeddings: {len(blocks_needing_embedding)}")
|
||||
|
||||
# Create embeddings for new/changed blocks
|
||||
if blocks_needing_embedding and HAS_OPENAI:
|
||||
texts_to_embed = [text for _, text in blocks_needing_embedding]
|
||||
try:
|
||||
embeddings = create_embeddings(texts_to_embed)
|
||||
|
||||
# Assign embeddings to blocks
|
||||
for i, (block_idx, _) in enumerate(blocks_needing_embedding):
|
||||
emb = embeddings[i].astype(np.float32)
|
||||
# Encode as base64
|
||||
blocks_data[block_idx]["emb"] = base64.b64encode(emb.tobytes()).decode(
|
||||
"ascii"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to create embeddings: {e}")
|
||||
elif blocks_needing_embedding:
|
||||
print(
|
||||
"Warning: Cannot create embeddings (openai not installed or OPENAI_API_KEY not set)"
|
||||
)
|
||||
|
||||
# Build BM25 data
|
||||
print("Building BM25 index...")
|
||||
bm25_data = build_bm25_data(blocks_data)
|
||||
|
||||
# Build name index
|
||||
print("Building name index...")
|
||||
name_index = build_name_index(blocks_data)
|
||||
|
||||
# Build final index
|
||||
index = {
|
||||
"version": "1.0.0",
|
||||
"embedding_model": DEFAULT_EMBEDDING_MODEL,
|
||||
"embedding_dim": DEFAULT_EMBEDDING_DIM,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"blocks": blocks_data,
|
||||
"bm25": bm25_data,
|
||||
"name_index": name_index,
|
||||
}
|
||||
|
||||
# Save index
|
||||
print(f"Saving index to {output_path}...")
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, separators=(",", ":"))
|
||||
|
||||
size_kb = output_path.stat().st_size / 1024
|
||||
print(f"Index saved ({size_kb:.1f} KB)")
|
||||
|
||||
# Print statistics
|
||||
print("\nIndex Statistics:")
|
||||
print(f" Blocks indexed: {len(blocks_data)}")
|
||||
print(f" BM25 vocabulary size: {len(bm25_data['df'])}")
|
||||
print(f" Name index terms: {len(name_index)}")
|
||||
print(f" Embeddings: {'Yes' if any(b.get('emb') for b in blocks_data) else 'No'}")
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Build hybrid search index for blocks")
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Force rebuild all embeddings even if unchanged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=INDEX_PATH,
|
||||
help=f"Output index file path (default: {INDEX_PATH})",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
build_block_index(
|
||||
force_rebuild=args.force,
|
||||
output_path=args.output,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error building index: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,287 @@
|
||||
"""Tool for executing blocks directly."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunBlockTool(BaseTool):
|
||||
"""Tool for executing a block and returning its outputs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a specific block with the provided input data. "
|
||||
"Use find_block to discover available blocks and their input schemas. "
|
||||
"The block will run and return its outputs once complete."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"block_id": {
|
||||
"type": "string",
|
||||
"description": "The UUID of the block to execute",
|
||||
},
|
||||
"input_data": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Input values for the block. Must match the block's input schema. "
|
||||
"Check the block's input_schema from find_block for required fields."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
if not credentials_fields_info:
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
# Get user's available credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
# field_info.provider is a frozenset of acceptable providers
|
||||
# field_info.supported_types is a frozenset of acceptable types
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in field_info.provider
|
||||
and cred.type in field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if matching_cred:
|
||||
matched_credentials[field_name] = CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with the given input data.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
block_id: Block UUID to execute
|
||||
input_data: Input values for the block
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse: Block execution outputs
|
||||
SetupRequirementsResponse: Missing credentials
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
block_id = kwargs.get("block_id", "").strip()
|
||||
input_data = kwargs.get("input_data", {})
|
||||
session_id = session.session_id
|
||||
|
||||
if not block_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a block_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not isinstance(input_data, dict):
|
||||
return ErrorResponse(
|
||||
message="input_data must be an object",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get the block
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
# Check credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
missing_creds_dict = {c.id: c.model_dump() for c in missing_credentials}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires credentials that are not configured. "
|
||||
"Please set up the required credentials before running this block."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=block_id,
|
||||
agent_name=block.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [c.model_dump() for c in missing_credentials],
|
||||
"inputs": self._get_inputs_list(block),
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
),
|
||||
graph_id=None,
|
||||
graph_version=None,
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch actual credentials and prepare kwargs for block execution
|
||||
exec_kwargs: dict[str, Any] = {"user_id": user_id}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
# Fetch actual credentials and pass as kwargs (for execution)
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
inputs_list = []
|
||||
schema = block.input_schema.jsonschema()
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
# Get credential field names to exclude
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in required_fields,
|
||||
}
|
||||
)
|
||||
|
||||
return inputs_list
|
||||
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
Block Hybrid Search
|
||||
|
||||
Combines multiple ranking signals for block search:
|
||||
- Semantic search (OpenAI embeddings + cosine similarity)
|
||||
- Lexical search (BM25)
|
||||
- Name matching (boost for block name matches)
|
||||
- Category matching (boost for category matches)
|
||||
|
||||
Based on the docs search implementation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI embedding model
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
|
||||
# Path to the JSON index file
|
||||
INDEX_PATH = Path(__file__).parent / "blocks_index.json"
|
||||
|
||||
# Stopwords for tokenization (same as index_blocks.py)
|
||||
STOPWORDS = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"shall",
|
||||
"can",
|
||||
"need",
|
||||
"dare",
|
||||
"ought",
|
||||
"used",
|
||||
"to",
|
||||
"of",
|
||||
"in",
|
||||
"for",
|
||||
"on",
|
||||
"with",
|
||||
"at",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"then",
|
||||
"once",
|
||||
"and",
|
||||
"but",
|
||||
"or",
|
||||
"nor",
|
||||
"so",
|
||||
"yet",
|
||||
"both",
|
||||
"either",
|
||||
"neither",
|
||||
"not",
|
||||
"only",
|
||||
"own",
|
||||
"same",
|
||||
"than",
|
||||
"too",
|
||||
"very",
|
||||
"just",
|
||||
"also",
|
||||
"now",
|
||||
"here",
|
||||
"there",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"how",
|
||||
"all",
|
||||
"each",
|
||||
"every",
|
||||
"few",
|
||||
"more",
|
||||
"most",
|
||||
"other",
|
||||
"some",
|
||||
"such",
|
||||
"no",
|
||||
"any",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
"block",
|
||||
}
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for search."""
|
||||
text = text.lower()
|
||||
# Remove code blocks if any
|
||||
text = re.sub(r"```[\s\S]*?```", "", text)
|
||||
text = re.sub(r"`[^`]+`", "", text)
|
||||
# Split camelCase
|
||||
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
|
||||
# Extract words
|
||||
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
|
||||
# Remove very short words and stopwords
|
||||
return [w for w in words if len(w) > 2 and w not in STOPWORDS]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchWeights:
|
||||
"""Configuration for hybrid search signal weights."""
|
||||
|
||||
semantic: float = 0.40 # Embedding similarity
|
||||
bm25: float = 0.25 # Lexical matching
|
||||
name_match: float = 0.25 # Block name matches
|
||||
category_match: float = 0.10 # Category matches
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockSearchResult:
|
||||
"""A single block search result."""
|
||||
|
||||
block_id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
score: float
|
||||
|
||||
# Individual signal scores (for debugging)
|
||||
semantic_score: float = 0.0
|
||||
bm25_score: float = 0.0
|
||||
name_score: float = 0.0
|
||||
category_score: float = 0.0
|
||||
|
||||
|
||||
class BlockSearchIndex:
|
||||
"""Hybrid search index for blocks combining BM25 + embeddings."""
|
||||
|
||||
def __init__(self, index_path: Path = INDEX_PATH):
|
||||
self.blocks: list[dict[str, Any]] = []
|
||||
self.bm25_data: dict[str, Any] = {}
|
||||
self.name_index: dict[str, list[list[int | float]]] = {}
|
||||
self.embeddings: Optional[np.ndarray] = None
|
||||
self.normalized_embeddings: Optional[np.ndarray] = None
|
||||
self._loaded = False
|
||||
self._index_path = index_path
|
||||
self._embedding_model: Any = None
|
||||
|
||||
def load(self) -> bool:
|
||||
"""Load the index from JSON file."""
|
||||
if self._loaded:
|
||||
return True
|
||||
|
||||
if not self._index_path.exists():
|
||||
logger.warning(f"Block index not found at {self._index_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(self._index_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.blocks = data.get("blocks", [])
|
||||
self.bm25_data = data.get("bm25", {})
|
||||
self.name_index = data.get("name_index", {})
|
||||
|
||||
# Decode embeddings from base64
|
||||
embeddings_list = []
|
||||
for block in self.blocks:
|
||||
if block.get("emb"):
|
||||
emb_bytes = base64.b64decode(block["emb"])
|
||||
emb = np.frombuffer(emb_bytes, dtype=np.float32)
|
||||
embeddings_list.append(emb)
|
||||
else:
|
||||
# No embedding, use zeros
|
||||
dim = data.get("embedding_dim", 384)
|
||||
embeddings_list.append(np.zeros(dim, dtype=np.float32))
|
||||
|
||||
if embeddings_list:
|
||||
self.embeddings = np.stack(embeddings_list)
|
||||
# Precompute normalized embeddings for cosine similarity
|
||||
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
|
||||
self.normalized_embeddings = self.embeddings / (norms + 1e-10)
|
||||
|
||||
self._loaded = True
|
||||
logger.info(f"Loaded block index with {len(self.blocks)} blocks")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load block index: {e}")
|
||||
return False
|
||||
|
||||
def _get_openai_client(self) -> Any:
|
||||
"""Get OpenAI client for query embedding."""
|
||||
if self._embedding_model is None:
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set")
|
||||
return None
|
||||
self._embedding_model = OpenAI(api_key=api_key)
|
||||
except ImportError:
|
||||
logger.warning("openai not installed")
|
||||
return None
|
||||
return self._embedding_model
|
||||
|
||||
def _embed_query(self, query: str) -> Optional[np.ndarray]:
|
||||
"""Embed the search query using OpenAI."""
|
||||
client = self._get_openai_client()
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=query,
|
||||
)
|
||||
embedding = response.data[0].embedding
|
||||
return np.array(embedding, dtype=np.float32)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to embed query: {e}")
|
||||
return None
|
||||
|
||||
def _compute_semantic_scores(self, query_embedding: np.ndarray) -> np.ndarray:
|
||||
"""Compute cosine similarity between query and all blocks."""
|
||||
if self.normalized_embeddings is None:
|
||||
return np.zeros(len(self.blocks))
|
||||
|
||||
# Normalize query embedding
|
||||
query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
|
||||
|
||||
# Cosine similarity via dot product
|
||||
similarities = self.normalized_embeddings @ query_norm
|
||||
|
||||
# Scale to [0, 1] (cosine ranges from -1 to 1)
|
||||
return (similarities + 1) / 2
|
||||
|
||||
def _compute_bm25_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute BM25 scores for all blocks."""
|
||||
scores = np.zeros(len(self.blocks))
|
||||
|
||||
if not self.bm25_data or not query_tokens:
|
||||
return scores
|
||||
|
||||
# BM25 parameters
|
||||
k1 = 1.5
|
||||
b = 0.75
|
||||
n_docs = self.bm25_data.get("n_docs", len(self.blocks))
|
||||
avgdl = self.bm25_data.get("avgdl", 100)
|
||||
df = self.bm25_data.get("df", {})
|
||||
doc_lens = self.bm25_data.get("doc_lens", [100] * len(self.blocks))
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
# Tokenize block's searchable text
|
||||
block_tokens = tokenize(block.get("searchable_text", ""))
|
||||
doc_len = doc_lens[i] if i < len(doc_lens) else len(block_tokens)
|
||||
|
||||
# Calculate BM25 score
|
||||
score = 0.0
|
||||
for token in query_tokens:
|
||||
if token not in df:
|
||||
continue
|
||||
|
||||
# Term frequency in this document
|
||||
tf = block_tokens.count(token)
|
||||
if tf == 0:
|
||||
continue
|
||||
|
||||
# IDF
|
||||
doc_freq = df.get(token, 0)
|
||||
idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
|
||||
|
||||
# BM25 score component
|
||||
numerator = tf * (k1 + 1)
|
||||
denominator = tf + k1 * (1 - b + b * doc_len / avgdl)
|
||||
score += idf * numerator / denominator
|
||||
|
||||
scores[i] = score
|
||||
|
||||
# Normalize to [0, 1]
|
||||
max_score = scores.max()
|
||||
if max_score > 0:
|
||||
scores = scores / max_score
|
||||
|
||||
return scores
|
||||
|
||||
def _compute_name_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute name match scores using the name index."""
|
||||
scores = np.zeros(len(self.blocks))
|
||||
|
||||
if not self.name_index or not query_tokens:
|
||||
return scores
|
||||
|
||||
for token in query_tokens:
|
||||
if token in self.name_index:
|
||||
for block_idx, weight in self.name_index[token]:
|
||||
if block_idx < len(scores):
|
||||
scores[int(block_idx)] += weight
|
||||
|
||||
# Also check for partial matches in block names
|
||||
for i, block in enumerate(self.blocks):
|
||||
name_lower = block.get("name", "").lower()
|
||||
for token in query_tokens:
|
||||
if token in name_lower:
|
||||
scores[i] += 0.5
|
||||
|
||||
# Normalize to [0, 1]
|
||||
max_score = scores.max()
|
||||
if max_score > 0:
|
||||
scores = scores / max_score
|
||||
|
||||
return scores
|
||||
|
||||
def _compute_category_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute category match scores."""
|
||||
scores = np.zeros(len(self.blocks))
|
||||
|
||||
if not query_tokens:
|
||||
return scores
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
categories = block.get("categories", [])
|
||||
category_text = " ".join(categories).lower()
|
||||
|
||||
for token in query_tokens:
|
||||
if token in category_text:
|
||||
scores[i] += 1.0
|
||||
|
||||
# Normalize to [0, 1]
|
||||
max_score = scores.max()
|
||||
if max_score > 0:
|
||||
scores = scores / max_score
|
||||
|
||||
return scores
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
weights: Optional[SearchWeights] = None,
|
||||
) -> list[BlockSearchResult]:
|
||||
"""
|
||||
Perform hybrid search combining multiple signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
top_k: Number of results to return
|
||||
weights: Optional custom weights for signals
|
||||
|
||||
Returns:
|
||||
List of BlockSearchResult sorted by score
|
||||
"""
|
||||
if not self._loaded and not self.load():
|
||||
return []
|
||||
|
||||
if weights is None:
|
||||
weights = SearchWeights()
|
||||
|
||||
# Tokenize query
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
# Fallback: try raw query words
|
||||
query_tokens = query.lower().split()
|
||||
|
||||
# Compute semantic scores
|
||||
semantic_scores = np.zeros(len(self.blocks))
|
||||
if self.normalized_embeddings is not None:
|
||||
query_embedding = self._embed_query(query)
|
||||
if query_embedding is not None:
|
||||
semantic_scores = self._compute_semantic_scores(query_embedding)
|
||||
|
||||
# Compute other scores
|
||||
bm25_scores = self._compute_bm25_scores(query_tokens)
|
||||
name_scores = self._compute_name_scores(query_tokens)
|
||||
category_scores = self._compute_category_scores(query_tokens)
|
||||
|
||||
# Combine scores using weights
|
||||
combined_scores = (
|
||||
weights.semantic * semantic_scores
|
||||
+ weights.bm25 * bm25_scores
|
||||
+ weights.name_match * name_scores
|
||||
+ weights.category_match * category_scores
|
||||
)
|
||||
|
||||
# Get top-k indices
|
||||
top_indices = np.argsort(combined_scores)[::-1][:top_k]
|
||||
|
||||
# Build results
|
||||
results = []
|
||||
for idx in top_indices:
|
||||
if combined_scores[idx] <= 0:
|
||||
continue
|
||||
|
||||
block = self.blocks[idx]
|
||||
results.append(
|
||||
BlockSearchResult(
|
||||
block_id=block["id"],
|
||||
name=block["name"],
|
||||
description=block["description"],
|
||||
categories=block.get("categories", []),
|
||||
score=float(combined_scores[idx]),
|
||||
semantic_score=float(semantic_scores[idx]),
|
||||
bm25_score=float(bm25_scores[idx]),
|
||||
name_score=float(name_scores[idx]),
|
||||
category_score=float(category_scores[idx]),
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global index instance (lazy loaded)
|
||||
_block_search_index: Optional[BlockSearchIndex] = None
|
||||
|
||||
|
||||
def get_block_search_index() -> BlockSearchIndex:
|
||||
"""Get or create the block search index singleton."""
|
||||
global _block_search_index
|
||||
if _block_search_index is None:
|
||||
_block_search_index = BlockSearchIndex(INDEX_PATH)
|
||||
return _block_search_index
|
||||
@@ -0,0 +1,386 @@
|
||||
"""Tool for searching platform documentation."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Documentation base URL
|
||||
DOCS_BASE_URL = "https://docs.agpt.co/platform"
|
||||
|
||||
# Path to the JSON index file (relative to this file)
|
||||
INDEX_PATH = Path(__file__).parent / "docs_index.json"
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25."""
|
||||
text = text.lower()
|
||||
# Remove code blocks
|
||||
text = re.sub(r"```[\s\S]*?```", "", text)
|
||||
text = re.sub(r"`[^`]+`", "", text)
|
||||
# Extract words
|
||||
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
|
||||
# Remove very short words and stopwords
|
||||
stopwords = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"shall",
|
||||
"can",
|
||||
"need",
|
||||
"dare",
|
||||
"ought",
|
||||
"used",
|
||||
"to",
|
||||
"of",
|
||||
"in",
|
||||
"for",
|
||||
"on",
|
||||
"with",
|
||||
"at",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"then",
|
||||
"once",
|
||||
"and",
|
||||
"but",
|
||||
"or",
|
||||
"nor",
|
||||
"so",
|
||||
"yet",
|
||||
"both",
|
||||
"either",
|
||||
"neither",
|
||||
"not",
|
||||
"only",
|
||||
"own",
|
||||
"same",
|
||||
"than",
|
||||
"too",
|
||||
"very",
|
||||
"just",
|
||||
"also",
|
||||
"now",
|
||||
"here",
|
||||
"there",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"how",
|
||||
"all",
|
||||
"each",
|
||||
"every",
|
||||
"both",
|
||||
"few",
|
||||
"more",
|
||||
"most",
|
||||
"other",
|
||||
"some",
|
||||
"such",
|
||||
"no",
|
||||
"any",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
}
|
||||
return [w for w in words if len(w) > 2 and w not in stopwords]
|
||||
|
||||
|
||||
class DocSearchIndex:
|
||||
"""Lightweight documentation search index using BM25."""
|
||||
|
||||
def __init__(self, index_path: Path):
|
||||
self.chunks: list[dict] = []
|
||||
self.bm25_data: dict = {}
|
||||
self._loaded = False
|
||||
self._index_path = index_path
|
||||
|
||||
def load(self) -> bool:
|
||||
"""Load the index from JSON file."""
|
||||
if self._loaded:
|
||||
return True
|
||||
|
||||
if not self._index_path.exists():
|
||||
logger.warning(f"Documentation index not found at {self._index_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(self._index_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.chunks = data.get("chunks", [])
|
||||
self.bm25_data = data.get("bm25", {})
|
||||
self._loaded = True
|
||||
logger.info(f"Loaded documentation index with {len(self.chunks)} chunks")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load documentation index: {e}")
|
||||
return False
|
||||
|
||||
def search(self, query: str, top_k: int = 5) -> list[dict]:
|
||||
"""Search the index using BM25."""
|
||||
if not self._loaded and not self.load():
|
||||
return []
|
||||
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
# BM25 parameters
|
||||
k1 = 1.5
|
||||
b = 0.75
|
||||
n_docs = self.bm25_data.get("n_docs", len(self.chunks))
|
||||
avgdl = self.bm25_data.get("avgdl", 100)
|
||||
df = self.bm25_data.get("df", {})
|
||||
doc_lens = self.bm25_data.get("doc_lens", [100] * len(self.chunks))
|
||||
|
||||
scores = []
|
||||
for i, chunk in enumerate(self.chunks):
|
||||
# Tokenize chunk text
|
||||
chunk_tokens = tokenize(chunk.get("text", ""))
|
||||
doc_len = doc_lens[i] if i < len(doc_lens) else len(chunk_tokens)
|
||||
|
||||
# Calculate BM25 score
|
||||
score = 0.0
|
||||
for token in query_tokens:
|
||||
if token not in df:
|
||||
continue
|
||||
|
||||
# Term frequency in this document
|
||||
tf = chunk_tokens.count(token)
|
||||
if tf == 0:
|
||||
continue
|
||||
|
||||
# IDF
|
||||
doc_freq = df.get(token, 0)
|
||||
idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
|
||||
|
||||
# BM25 score component
|
||||
numerator = tf * (k1 + 1)
|
||||
denominator = tf + k1 * (1 - b + b * doc_len / avgdl)
|
||||
score += idf * numerator / denominator
|
||||
|
||||
# Boost for title/heading matches
|
||||
title = chunk.get("title", "").lower()
|
||||
heading = chunk.get("heading", "").lower()
|
||||
for token in query_tokens:
|
||||
if token in title:
|
||||
score *= 1.5
|
||||
if token in heading:
|
||||
score *= 1.2
|
||||
|
||||
scores.append((i, score))
|
||||
|
||||
# Sort by score and return top_k
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
results = []
|
||||
seen_sections = set()
|
||||
for idx, score in scores:
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
chunk = self.chunks[idx]
|
||||
section_key = (chunk.get("doc", ""), chunk.get("heading", ""))
|
||||
|
||||
# Deduplicate by section
|
||||
if section_key in seen_sections:
|
||||
continue
|
||||
seen_sections.add(section_key)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"title": chunk.get("title", ""),
|
||||
"path": chunk.get("doc", ""),
|
||||
"heading": chunk.get("heading", ""),
|
||||
"text": chunk.get("text", ""), # Full text for LLM comprehension
|
||||
"score": score,
|
||||
}
|
||||
)
|
||||
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global index instance (lazy loaded)
|
||||
_search_index: DocSearchIndex | None = None
|
||||
|
||||
|
||||
def get_search_index() -> DocSearchIndex:
|
||||
"""Get or create the search index singleton."""
|
||||
global _search_index
|
||||
if _search_index is None:
|
||||
_search_index = DocSearchIndex(INDEX_PATH)
|
||||
return _search_index
|
||||
|
||||
|
||||
class SearchDocsTool(BaseTool):
|
||||
"""Tool for searching AutoGPT platform documentation."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "search_platform_docs"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search the AutoGPT platform documentation and support Q&A for information about "
|
||||
"how to use the platform, create agents, configure blocks, "
|
||||
"set up integrations, troubleshoot issues, and more. Use this when users ask "
|
||||
"support questions or want to learn how to do something with AutoGPT."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query describing what the user wants to learn about. "
|
||||
"Use keywords like 'blocks', 'agents', 'credentials', 'API', etc."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search documentation for the query.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
DocSearchResultsResponse: List of matching documentation sections
|
||||
NoResultsResponse: No results found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
index = get_search_index()
|
||||
results = index.search(query, top_k=5)
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'. Try different keywords.",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms like 'blocks', 'agents', 'setup'",
|
||||
"Check the documentation at docs.agpt.co",
|
||||
],
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
doc_results = []
|
||||
for r in results:
|
||||
# Build documentation URL
|
||||
path = r["path"]
|
||||
if path.endswith(".md"):
|
||||
path = path[:-3] # Remove .md extension
|
||||
doc_url = f"{DOCS_BASE_URL}/{path}"
|
||||
|
||||
full_text = r["text"]
|
||||
doc_results.append(
|
||||
DocSearchResult(
|
||||
title=r["title"],
|
||||
path=r["path"],
|
||||
section=r["heading"],
|
||||
snippet=(
|
||||
full_text[:300] + "..."
|
||||
if len(full_text) > 300
|
||||
else full_text
|
||||
),
|
||||
content=full_text, # Full text for LLM to read and understand
|
||||
score=round(r["score"], 3),
|
||||
doc_url=doc_url,
|
||||
)
|
||||
)
|
||||
|
||||
return DocSearchResultsResponse(
|
||||
message=(
|
||||
f"Found {len(doc_results)} relevant documentation sections. "
|
||||
"Use these to help answer the user's question. "
|
||||
"Include links to the documentation when helpful."
|
||||
),
|
||||
results=doc_results,
|
||||
count=len(doc_results),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching documentation: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search documentation. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -48,7 +48,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str # ID of user who owns/created this agent graph
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -164,7 +163,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
owner_user_id=agent.userId,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
|
||||
@@ -42,7 +42,6 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -65,7 +64,6 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -140,7 +138,6 @@ async def test_get_favorite_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
@@ -208,7 +205,6 @@ def test_add_agent_to_library_success(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLI script to backfill embeddings for store agents.
|
||||
|
||||
Usage:
|
||||
poetry run python -m backend.server.v2.store.backfill_embeddings [--batch-size N]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import prisma
|
||||
|
||||
|
||||
async def main(batch_size: int = 100) -> int:
|
||||
"""Run the backfill process."""
|
||||
# Initialize Prisma client
|
||||
client = prisma.Prisma()
|
||||
await client.connect()
|
||||
prisma.register(client)
|
||||
|
||||
try:
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
|
||||
# Get current stats
|
||||
print("Current embedding stats:")
|
||||
stats = await get_embedding_stats()
|
||||
print(f" Total approved: {stats['total_approved']}")
|
||||
print(f" With embeddings: {stats['with_embeddings']}")
|
||||
print(f" Without embeddings: {stats['without_embeddings']}")
|
||||
print(f" Coverage: {stats['coverage_percent']}%")
|
||||
|
||||
if stats["without_embeddings"] == 0:
|
||||
print("\nAll agents already have embeddings. Nothing to do.")
|
||||
return 0
|
||||
|
||||
# Run backfill
|
||||
print(f"\nBackfilling up to {batch_size} embeddings...")
|
||||
result = await backfill_missing_embeddings(batch_size=batch_size)
|
||||
print(f" Processed: {result['processed']}")
|
||||
print(f" Success: {result['success']}")
|
||||
print(f" Failed: {result['failed']}")
|
||||
|
||||
# Get final stats
|
||||
print("\nFinal embedding stats:")
|
||||
stats = await get_embedding_stats()
|
||||
print(f" Total approved: {stats['total_approved']}")
|
||||
print(f" With embeddings: {stats['with_embeddings']}")
|
||||
print(f" Without embeddings: {stats['without_embeddings']}")
|
||||
print(f" Coverage: {stats['coverage_percent']}%")
|
||||
|
||||
return 0 if result["failed"] == 0 else 1
|
||||
|
||||
finally:
|
||||
await client.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Backfill embeddings for store agents")
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of embeddings to generate (default: 100)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sys.exit(asyncio.run(main(batch_size=args.batch_size)))
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
|
||||
@@ -10,7 +9,7 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
from backend.data.db import query_raw_with_schema, transaction
|
||||
from backend.data.db import transaction
|
||||
from backend.data.graph import (
|
||||
GraphMeta,
|
||||
GraphModel,
|
||||
@@ -57,95 +56,21 @@ async def get_store_agents(
|
||||
)
|
||||
|
||||
try:
|
||||
# If search_query is provided, use full-text search
|
||||
# If search_query is provided, use hybrid search (embeddings + tsvector)
|
||||
if search_query:
|
||||
offset = (page - 1) * page_size
|
||||
from backend.api.features.store.hybrid_search import hybrid_search
|
||||
|
||||
# Whitelist allowed order_by columns
|
||||
ALLOWED_ORDER_BY = {
|
||||
"rating": "rating DESC, rank DESC",
|
||||
"runs": "runs DESC, rank DESC",
|
||||
"name": "agent_name ASC, rank ASC",
|
||||
"updated_at": "updated_at DESC, rank DESC",
|
||||
}
|
||||
# Use hybrid search combining semantic and lexical signals
|
||||
agents, total = await hybrid_search(
|
||||
query=search_query,
|
||||
featured=featured,
|
||||
creators=creators,
|
||||
category=category,
|
||||
sorted_by="relevance", # Use hybrid scoring for relevance
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
# Validate and get order clause
|
||||
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
||||
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
||||
else:
|
||||
order_by_clause = "updated_at DESC, rank DESC"
|
||||
|
||||
# Build WHERE conditions and parameters list
|
||||
where_parts: list[str] = []
|
||||
params: list[typing.Any] = [search_query] # $1 - search term
|
||||
param_index = 2 # Start at $2 for next parameter
|
||||
|
||||
# Always filter for available agents
|
||||
where_parts.append("is_available = true")
|
||||
|
||||
if featured:
|
||||
where_parts.append("featured = true")
|
||||
|
||||
if creators and creators:
|
||||
# Use ANY with array parameter
|
||||
where_parts.append(f"creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category and category:
|
||||
where_parts.append(f"${param_index} = ANY(categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
limit_param = f"${param_index}"
|
||||
offset_param = f"${param_index + 1}"
|
||||
|
||||
# Execute full-text search query with parameterized values
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
ts_rank_cd(search, query) AS rank
|
||||
FROM {{schema_prefix}}"StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
ORDER BY {order_by_clause}
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
# Count query for pagination - only uses search term parameter
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
"""
|
||||
|
||||
# Execute both queries with parameters
|
||||
agents = await query_raw_with_schema(sql_query, *params)
|
||||
|
||||
# For count, use params without pagination (last 2 params)
|
||||
count_params = params[:-2]
|
||||
count_result = await query_raw_with_schema(count_query, *count_params)
|
||||
|
||||
total = count_result[0]["count"] if count_result else 0
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert raw results to StoreAgent models
|
||||
@@ -614,7 +539,6 @@ async def get_store_submissions(
|
||||
submission_models = []
|
||||
for sub in submissions:
|
||||
submission_model = store_model.StoreSubmission(
|
||||
listing_id=sub.listing_id,
|
||||
agent_id=sub.agent_id,
|
||||
agent_version=sub.agent_version,
|
||||
name=sub.name,
|
||||
@@ -668,48 +592,35 @@ async def delete_store_submission(
|
||||
submission_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a store submission version as the submitting user.
|
||||
Delete a store listing submission as the submitting user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user
|
||||
submission_id: StoreListingVersion ID to delete
|
||||
submission_id: ID of the submission to be deleted
|
||||
|
||||
Returns:
|
||||
bool: True if successfully deleted
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
logger.debug(f"Deleting store submission {submission_id} for user {user_id}")
|
||||
|
||||
try:
|
||||
# Find the submission version with ownership check
|
||||
version = await prisma.models.StoreListingVersion.prisma().find_first(
|
||||
where={"id": submission_id}, include={"StoreListing": True}
|
||||
# Verify the submission belongs to this user
|
||||
submission = await prisma.models.StoreListing.prisma().find_first(
|
||||
where={"agentGraphId": submission_id, "owningUserId": user_id}
|
||||
)
|
||||
|
||||
if (
|
||||
not version
|
||||
or not version.StoreListing
|
||||
or version.StoreListing.owningUserId != user_id
|
||||
):
|
||||
raise store_exceptions.SubmissionNotFoundError("Submission not found")
|
||||
|
||||
# Prevent deletion of approved submissions
|
||||
if version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
"Cannot delete approved submissions"
|
||||
if not submission:
|
||||
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
|
||||
raise store_exceptions.SubmissionNotFoundError(
|
||||
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
|
||||
)
|
||||
|
||||
# Delete the version
|
||||
await prisma.models.StoreListingVersion.prisma().delete(
|
||||
where={"id": version.id}
|
||||
)
|
||||
# Delete the submission
|
||||
await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})
|
||||
|
||||
# Clean up empty listing if this was the last version
|
||||
remaining = await prisma.models.StoreListingVersion.prisma().count(
|
||||
where={"storeListingId": version.storeListingId}
|
||||
logger.debug(
|
||||
f"Successfully deleted submission {submission_id} for user {user_id}"
|
||||
)
|
||||
if remaining == 0:
|
||||
await prisma.models.StoreListing.prisma().delete(
|
||||
where={"id": version.storeListingId}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -773,15 +684,9 @@ async def create_store_submission(
|
||||
logger.warning(
|
||||
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
|
||||
)
|
||||
# Provide more user-friendly error message when agent_id is empty
|
||||
if not agent_id or agent_id.strip() == "":
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
"No agent selected. Please select an agent before submitting to the store."
|
||||
)
|
||||
else:
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
# Check if listing already exists for this agent
|
||||
existing_listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
@@ -853,7 +758,6 @@ async def create_store_submission(
|
||||
logger.debug(f"Created store listing for agent {agent_id}")
|
||||
# Return submission details
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=listing.id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -965,56 +869,81 @@ async def edit_store_submission(
|
||||
# Currently we are not allowing user to update the agent associated with a submission
|
||||
# If we allow it in future, then we need a check here to verify the agent belongs to this user.
|
||||
|
||||
# Only allow editing of PENDING submissions
|
||||
if current_version.submissionStatus != prisma.enums.SubmissionStatus.PENDING:
|
||||
# Check if we can edit this submission
|
||||
if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED:
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
f"Cannot edit a {current_version.submissionStatus.value.lower()} submission. Only pending submissions can be edited."
|
||||
"Cannot edit a rejected submission"
|
||||
)
|
||||
|
||||
# For APPROVED submissions, we need to create a new version
|
||||
if current_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
||||
# Create a new version for the existing listing
|
||||
return await create_store_version(
|
||||
user_id=user_id,
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
store_listing_id=current_version.storeListingId,
|
||||
name=name,
|
||||
video_url=video_url,
|
||||
agent_output_demo_url=agent_output_demo_url,
|
||||
image_urls=image_urls,
|
||||
description=description,
|
||||
sub_heading=sub_heading,
|
||||
categories=categories,
|
||||
changes_summary=changes_summary,
|
||||
recommended_schedule_cron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
# For PENDING submissions, we can update the existing version
|
||||
# Update the existing version
|
||||
updated_version = await prisma.models.StoreListingVersion.prisma().update(
|
||||
where={"id": store_listing_version_id},
|
||||
data=prisma.types.StoreListingVersionUpdateInput(
|
||||
elif current_version.submissionStatus == prisma.enums.SubmissionStatus.PENDING:
|
||||
# Update the existing version
|
||||
updated_version = await prisma.models.StoreListingVersion.prisma().update(
|
||||
where={"id": store_listing_version_id},
|
||||
data=prisma.types.StoreListingVersionUpdateInput(
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
|
||||
)
|
||||
|
||||
if not updated_version:
|
||||
raise DatabaseError("Failed to update store listing version")
|
||||
return store_model.StoreSubmission(
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
sub_heading=sub_heading,
|
||||
slug=current_version.StoreListing.slug,
|
||||
description=description,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
),
|
||||
)
|
||||
image_urls=image_urls,
|
||||
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
||||
status=updated_version.submissionStatus,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
store_listing_version_id=updated_version.id,
|
||||
changes_summary=changes_summary,
|
||||
video_url=video_url,
|
||||
categories=categories,
|
||||
version=updated_version.version,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
|
||||
)
|
||||
|
||||
if not updated_version:
|
||||
raise DatabaseError("Failed to update store listing version")
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=current_version.StoreListing.id,
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
name=name,
|
||||
sub_heading=sub_heading,
|
||||
slug=current_version.StoreListing.slug,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
||||
status=updated_version.submissionStatus,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
store_listing_version_id=updated_version.id,
|
||||
changes_summary=changes_summary,
|
||||
video_url=video_url,
|
||||
categories=categories,
|
||||
version=updated_version.version,
|
||||
)
|
||||
else:
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
f"Cannot edit submission with status: {current_version.submissionStatus}"
|
||||
)
|
||||
|
||||
except (
|
||||
store_exceptions.SubmissionNotFoundError,
|
||||
@@ -1093,78 +1022,38 @@ async def create_store_version(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
# Check if there's already a PENDING submission for this agent (any version)
|
||||
existing_pending_submission = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_first(
|
||||
where=prisma.types.StoreListingVersionWhereInput(
|
||||
storeListingId=store_listing_id,
|
||||
agentGraphId=agent_id,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isDeleted=False,
|
||||
)
|
||||
# Get the latest version number
|
||||
latest_version = listing.Versions[0] if listing.Versions else None
|
||||
|
||||
next_version = (latest_version.version + 1) if latest_version else 1
|
||||
|
||||
# Create a new version for the existing listing
|
||||
new_version = await prisma.models.StoreListingVersion.prisma().create(
|
||||
data=prisma.types.StoreListingVersionCreateInput(
|
||||
version=next_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
storeListingId=store_listing_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Handle existing pending submission and create new one atomically
|
||||
async with transaction() as tx:
|
||||
# Get the latest version number first
|
||||
latest_listing = await prisma.models.StoreListing.prisma(tx).find_first(
|
||||
where=prisma.types.StoreListingWhereInput(
|
||||
id=store_listing_id, owningUserId=user_id
|
||||
),
|
||||
include={"Versions": {"order_by": {"version": "desc"}, "take": 1}},
|
||||
)
|
||||
|
||||
if not latest_listing:
|
||||
raise store_exceptions.ListingNotFoundError(
|
||||
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
|
||||
)
|
||||
|
||||
latest_version = (
|
||||
latest_listing.Versions[0] if latest_listing.Versions else None
|
||||
)
|
||||
next_version = (latest_version.version + 1) if latest_version else 1
|
||||
|
||||
# If there's an existing pending submission, delete it atomically before creating new one
|
||||
if existing_pending_submission:
|
||||
logger.info(
|
||||
f"Found existing PENDING submission for agent {agent_id} (was v{existing_pending_submission.agentGraphVersion}, now v{agent_version}), replacing existing submission instead of creating duplicate"
|
||||
)
|
||||
await prisma.models.StoreListingVersion.prisma(tx).delete(
|
||||
where={"id": existing_pending_submission.id}
|
||||
)
|
||||
logger.debug(
|
||||
f"Deleted existing pending submission {existing_pending_submission.id}"
|
||||
)
|
||||
|
||||
# Create a new version for the existing listing
|
||||
new_version = await prisma.models.StoreListingVersion.prisma(tx).create(
|
||||
data=prisma.types.StoreListingVersionCreateInput(
|
||||
version=next_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
storeListingId=store_listing_id,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Created new version for listing {store_listing_id} of agent {agent_id}"
|
||||
)
|
||||
# Return submission details
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=listing.id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -1600,6 +1489,24 @@ async def review_store_submission(
|
||||
},
|
||||
)
|
||||
|
||||
# Generate embedding for approved listing (non-blocking)
|
||||
try:
|
||||
from backend.api.features.store.embeddings import ensure_embedding
|
||||
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=store_listing_version.name,
|
||||
description=store_listing_version.description,
|
||||
sub_heading=store_listing_version.subHeading,
|
||||
categories=store_listing_version.categories or [],
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't fail approval if embedding generation fails
|
||||
logger.warning(
|
||||
f"Failed to generate embedding for approved listing "
|
||||
f"{store_listing_version_id}: {e}"
|
||||
)
|
||||
|
||||
# If rejecting an approved agent, update the StoreListing accordingly
|
||||
if is_rejecting_approved:
|
||||
# Check if there are other approved versions
|
||||
@@ -1744,12 +1651,15 @@ async def review_store_submission(
|
||||
|
||||
# Convert to Pydantic model for consistency
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=(submission.StoreListing.id if submission.StoreListing else ""),
|
||||
agent_id=submission.agentGraphId,
|
||||
agent_version=submission.agentGraphVersion,
|
||||
name=submission.name,
|
||||
sub_heading=submission.subHeading,
|
||||
slug=(submission.StoreListing.slug if submission.StoreListing else ""),
|
||||
slug=(
|
||||
submission.StoreListing.slug
|
||||
if hasattr(submission, "storeListing") and submission.StoreListing
|
||||
else ""
|
||||
),
|
||||
description=submission.description,
|
||||
instructions=submission.instructions,
|
||||
image_urls=submission.imageUrls or [],
|
||||
@@ -1851,7 +1761,9 @@ async def get_admin_listings_with_versions(
|
||||
where = prisma.types.StoreListingWhereInput(**where_dict)
|
||||
include = prisma.types.StoreListingInclude(
|
||||
Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing(
|
||||
order_by={"version": "desc"}
|
||||
order_by=prisma.types._StoreListingVersion_version_OrderByInput(
|
||||
version="desc"
|
||||
)
|
||||
),
|
||||
OwningUser=True,
|
||||
)
|
||||
@@ -1876,7 +1788,6 @@ async def get_admin_listings_with_versions(
|
||||
# If we have versions, turn them into StoreSubmission models
|
||||
for version in listing.Versions or []:
|
||||
version_model = store_model.StoreSubmission(
|
||||
listing_id=listing.id,
|
||||
agent_id=version.agentGraphId,
|
||||
agent_version=version.agentGraphVersion,
|
||||
name=version.name,
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Store Listing Embeddings Service
|
||||
|
||||
Handles generation and storage of OpenAI embeddings for store listings
|
||||
to enable semantic/hybrid search.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
EMBEDDING_DIM = 1536
|
||||
|
||||
|
||||
def build_searchable_text(
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Build searchable text from listing version fields.
|
||||
|
||||
Combines relevant fields into a single string for embedding.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Name is important - include it
|
||||
if name:
|
||||
parts.append(name)
|
||||
|
||||
# Sub-heading provides context
|
||||
if sub_heading:
|
||||
parts.append(sub_heading)
|
||||
|
||||
# Description is the main content
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
# Categories help with semantic matching
|
||||
if categories:
|
||||
parts.append(" ".join(categories))
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def compute_content_hash(text: str) -> str:
|
||||
"""Compute MD5 hash of text for change detection."""
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
|
||||
async def generate_embedding(text: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for text using OpenAI API.
|
||||
|
||||
Returns None if embedding generation fails.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set, cannot generate embedding")
|
||||
return None
|
||||
|
||||
client = OpenAI(api_key=api_key)
|
||||
|
||||
# Truncate text to avoid token limits (~32k chars for safety)
|
||||
truncated_text = text[:32000]
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=truncated_text,
|
||||
)
|
||||
|
||||
embedding = response.data[0].embedding
|
||||
logger.debug(f"Generated embedding with {len(embedding)} dimensions")
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embedding: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def store_embedding(
|
||||
version_id: str,
|
||||
embedding: list[float],
|
||||
searchable_text: str,
|
||||
content_hash: str,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the database.
|
||||
|
||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||
"""
|
||||
try:
|
||||
client = tx if tx else prisma.get_client()
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
|
||||
# Upsert the embedding
|
||||
# Set search_path to include public for vector type visibility
|
||||
await client.execute_raw(
|
||||
"""
|
||||
SET LOCAL search_path TO platform, public;
|
||||
INSERT INTO platform."StoreListingEmbedding" (
|
||||
"id", "storeListingVersionId", "embedding",
|
||||
"searchableText", "contentHash", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES (
|
||||
gen_random_uuid(), $1, $2::vector,
|
||||
$3, $4, NOW(), NOW()
|
||||
)
|
||||
ON CONFLICT ("storeListingVersionId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $2::vector,
|
||||
"searchableText" = $3,
|
||||
"contentHash" = $4,
|
||||
"updatedAt" = NOW()
|
||||
""",
|
||||
version_id,
|
||||
embedding_str,
|
||||
searchable_text,
|
||||
content_hash,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for version {version_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for a listing version.
|
||||
|
||||
Returns dict with embedding, searchableText, contentHash or None if not found.
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
result = await client.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
"id",
|
||||
"storeListingVersionId",
|
||||
"embedding"::text as "embedding",
|
||||
"searchableText",
|
||||
"contentHash",
|
||||
"createdAt",
|
||||
"updatedAt"
|
||||
FROM platform."StoreListingEmbedding"
|
||||
WHERE "storeListingVersionId" = $1
|
||||
""",
|
||||
version_id,
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding for version {version_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def ensure_embedding(
|
||||
version_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for the listing version.
|
||||
|
||||
Creates embedding if missing or if content has changed.
|
||||
Skips if content hash matches existing embedding.
|
||||
|
||||
Args:
|
||||
version_id: The StoreListingVersion ID
|
||||
name: Agent name
|
||||
description: Agent description
|
||||
sub_heading: Agent sub-heading
|
||||
categories: Agent categories
|
||||
force: Force regeneration even if hash matches
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Build searchable text and compute hash
|
||||
searchable_text = build_searchable_text(
|
||||
name, description, sub_heading, categories
|
||||
)
|
||||
content_hash = compute_content_hash(searchable_text)
|
||||
|
||||
# Check if embedding already exists with same hash
|
||||
if not force:
|
||||
existing = await get_embedding(version_id)
|
||||
if existing and existing.get("contentHash") == content_hash:
|
||||
logger.debug(
|
||||
f"Embedding for version {version_id} is up to date (hash match)"
|
||||
)
|
||||
return True
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
||||
return False
|
||||
|
||||
# Store the embedding
|
||||
return await store_embedding(
|
||||
version_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
content_hash=content_hash,
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_embedding(version_id: str) -> bool:
|
||||
"""
|
||||
Delete embedding for a listing version.
|
||||
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
await client.execute_raw(
|
||||
"""
|
||||
DELETE FROM platform."StoreListingEmbedding"
|
||||
WHERE "storeListingVersionId" = $1
|
||||
""",
|
||||
version_id,
|
||||
)
|
||||
|
||||
logger.info(f"Deleted embedding for version {version_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding_stats() -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
|
||||
Returns counts of:
|
||||
- Total approved listing versions
|
||||
- Versions with embeddings
|
||||
- Versions without embeddings
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
# Count approved versions
|
||||
approved_result = await client.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM platform."StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await client.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM platform."StoreListingVersion" slv
|
||||
JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total_approved": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
"coverage_percent": (
|
||||
round(with_embeddings / total_approved * 100, 1)
|
||||
if total_approved > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding stats: {e}")
|
||||
return {
|
||||
"total_approved": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for approved listings that don't have them.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate in one call
|
||||
|
||||
Returns:
|
||||
Dict with success/failure counts
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
# Find approved versions without embeddings
|
||||
missing = await client.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM platform."StoreListingVersion" slv
|
||||
LEFT JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND sle.id IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
|
||||
if not missing:
|
||||
return {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"message": "No missing embeddings",
|
||||
}
|
||||
|
||||
success = 0
|
||||
failed = 0
|
||||
|
||||
for row in missing:
|
||||
result = await ensure_embedding(
|
||||
version_id=row["id"],
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
)
|
||||
if result:
|
||||
success += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
return {
|
||||
"processed": len(missing),
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"message": f"Backfilled {success} embeddings, {failed} failed",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill embeddings: {e}")
|
||||
return {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def embed_query(query: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for a search query.
|
||||
|
||||
Same as generate_embedding but with clearer intent.
|
||||
"""
|
||||
return await generate_embedding(query)
|
||||
|
||||
|
||||
def embedding_to_vector_string(embedding: list[float]) -> str:
|
||||
"""Convert embedding list to PostgreSQL vector string format."""
|
||||
return "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
Hybrid Search for Store Agents
|
||||
|
||||
Combines semantic (embedding) search with lexical (tsvector) search
|
||||
for improved relevance in marketplace agent discovery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
import prisma
|
||||
|
||||
from backend.api.features.store.embeddings import (
|
||||
embed_query,
|
||||
embedding_to_vector_string,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchWeights:
|
||||
"""Weights for combining search signals."""
|
||||
|
||||
semantic: float = 0.35 # Embedding cosine similarity
|
||||
lexical: float = 0.35 # tsvector ts_rank_cd score
|
||||
category: float = 0.20 # Category match boost
|
||||
recency: float = 0.10 # Newer agents ranked higher
|
||||
|
||||
|
||||
DEFAULT_WEIGHTS = HybridSearchWeights()
|
||||
|
||||
# Minimum relevance score threshold - agents below this are filtered out
|
||||
# With weights (0.35 semantic + 0.35 lexical + 0.20 category + 0.10 recency):
|
||||
# - 0.20 means at least ~50% semantic match OR strong lexical match required
|
||||
# - Ensures only genuinely relevant results are returned
|
||||
# - Recency alone (0.10 max) won't pass the threshold
|
||||
DEFAULT_MIN_SCORE = 0.20
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchResult:
|
||||
"""A single search result with score breakdown."""
|
||||
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_image: str
|
||||
creator_username: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
runs: int
|
||||
rating: float
|
||||
categories: list[str]
|
||||
featured: bool
|
||||
is_available: bool
|
||||
updated_at: datetime
|
||||
|
||||
# Score breakdown (for debugging/tuning)
|
||||
combined_score: float
|
||||
semantic_score: float = 0.0
|
||||
lexical_score: float = 0.0
|
||||
category_score: float = 0.0
|
||||
recency_score: float = 0.0
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
query: str,
|
||||
featured: bool = False,
|
||||
creators: list[str] | None = None,
|
||||
category: str | None = None,
|
||||
sorted_by: (
|
||||
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
|
||||
) = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
weights: HybridSearchWeights | None = None,
|
||||
min_score: float | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Perform hybrid search combining semantic and lexical signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
featured: Filter for featured agents only
|
||||
creators: Filter by creator usernames
|
||||
category: Filter by category
|
||||
sorted_by: Sort order (relevance uses hybrid scoring)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Results per page
|
||||
weights: Custom weights for search signals
|
||||
min_score: Minimum relevance score threshold (0-1). Results below
|
||||
this score are filtered out. Defaults to DEFAULT_MIN_SCORE.
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, total count). Returns empty list if no
|
||||
results meet the minimum relevance threshold.
|
||||
"""
|
||||
if weights is None:
|
||||
weights = DEFAULT_WEIGHTS
|
||||
if min_score is None:
|
||||
min_score = DEFAULT_MIN_SCORE
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
client = prisma.get_client()
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Build WHERE clause conditions
|
||||
where_parts: list[str] = ["sa.is_available = true"]
|
||||
params: list[Any] = []
|
||||
param_index = 1
|
||||
|
||||
# Add search query for lexical matching
|
||||
params.append(query)
|
||||
query_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
if featured:
|
||||
where_parts.append("sa.featured = true")
|
||||
|
||||
if creators:
|
||||
where_parts.append(f"sa.creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category:
|
||||
where_parts.append(f"${param_index} = ANY(sa.categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
where_clause = " AND ".join(where_parts)
|
||||
|
||||
# Determine if we can use hybrid search (have query embedding)
|
||||
use_hybrid = query_embedding is not None
|
||||
|
||||
if use_hybrid:
|
||||
# Add embedding parameter
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
params.append(embedding_str)
|
||||
embedding_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Build hybrid search query with weighted scoring
|
||||
# The semantic score is (1 - cosine_distance), normalized to [0,1]
|
||||
# The lexical score is ts_rank_cd, normalized by max value
|
||||
# Set search_path to include public for vector type visibility
|
||||
sql_query = f"""
|
||||
SET LOCAL search_path TO platform, public;
|
||||
WITH search_scores AS (
|
||||
SELECT
|
||||
sa.*,
|
||||
-- Semantic score: cosine similarity (1 - distance)
|
||||
COALESCE(1 - (sle.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score: ts_rank_cd normalized
|
||||
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match: 1 if query term appears in categories, else 0
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(sa.categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
|
||||
) THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
-- Recency score: exponential decay over 90 days
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM platform."StoreAgent" sa
|
||||
LEFT JOIN platform."StoreListing" sl ON sa.slug = sl.slug
|
||||
LEFT JOIN platform."StoreListingVersion" slv ON sl."activeVersionId" = slv.id
|
||||
LEFT JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
|
||||
WHERE {where_clause}
|
||||
AND (
|
||||
sa.search @@ plainto_tsquery('english', {query_param})
|
||||
OR sle.embedding IS NOT NULL
|
||||
)
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
*,
|
||||
-- Normalize lexical score by max in result set
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM search_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
(
|
||||
{weights.semantic} * semantic_score +
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
)
|
||||
SELECT * FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||
"""
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
|
||||
# Count query - must also filter by min_score
|
||||
count_query = f"""
|
||||
SET LOCAL search_path TO platform, public;
|
||||
WITH search_scores AS (
|
||||
SELECT
|
||||
sa.slug,
|
||||
COALESCE(1 - (sle.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(sa.categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
|
||||
) THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM platform."StoreAgent" sa
|
||||
LEFT JOIN platform."StoreListing" sl ON sa.slug = sl.slug
|
||||
LEFT JOIN platform."StoreListingVersion" slv ON sl."activeVersionId" = slv.id
|
||||
LEFT JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
|
||||
WHERE {where_clause}
|
||||
AND (
|
||||
sa.search @@ plainto_tsquery('english', {query_param})
|
||||
OR sle.embedding IS NOT NULL
|
||||
)
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
slug,
|
||||
semantic_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM search_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
(
|
||||
{weights.semantic} * semantic_score +
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
)
|
||||
SELECT COUNT(*) as count FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
"""
|
||||
|
||||
else:
|
||||
# Fallback to lexical-only search (existing behavior)
|
||||
# Note: For lexical-only, we still require tsvector match but don't
|
||||
# apply min_score since ts_rank_cd isn't normalized to [0,1]
|
||||
logger.warning("Falling back to lexical-only search (no query embedding)")
|
||||
|
||||
sql_query = f"""
|
||||
WITH lexical_scores AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
0.0 as semantic_score,
|
||||
ts_rank_cd(search, plainto_tsquery('english', {query_param})) as lexical_raw,
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
|
||||
) THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM platform."StoreAgent" sa
|
||||
WHERE {where_clause}
|
||||
AND search @@ plainto_tsquery('english', {query_param})
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
*,
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM lexical_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
(
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
)
|
||||
SELECT * FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||
"""
|
||||
|
||||
params.extend([page_size, offset])
|
||||
|
||||
count_query = f"""
|
||||
WITH lexical_scores AS (
|
||||
SELECT
|
||||
slug,
|
||||
ts_rank_cd(search, plainto_tsquery('english', {query_param})) as lexical_raw,
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
|
||||
) THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM platform."StoreAgent" sa
|
||||
WHERE {where_clause}
|
||||
AND search @@ plainto_tsquery('english', {query_param})
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
slug,
|
||||
category_score,
|
||||
recency_score,
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM lexical_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
(
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
)
|
||||
SELECT COUNT(*) as count FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
"""
|
||||
|
||||
try:
|
||||
# Execute search query
|
||||
# Dynamic SQL is safe here - all user inputs are parameterized ($1, $2, etc.)
|
||||
results = await client.query_raw(sql_query, *params) # type: ignore[arg-type]
|
||||
|
||||
# Execute count query (without pagination params)
|
||||
count_params = params[:-2] # Remove LIMIT and OFFSET params
|
||||
count_result = await client.query_raw(count_query, *count_params) # type: ignore[arg-type]
|
||||
total = count_result[0]["count"] if count_result else 0
|
||||
|
||||
logger.info(
|
||||
f"Hybrid search for '{query}': {len(results)} results, {total} total "
|
||||
f"(hybrid={use_hybrid})"
|
||||
)
|
||||
|
||||
return results, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hybrid search failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def hybrid_search_simple(
|
||||
query: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Simplified hybrid search for common use cases.
|
||||
|
||||
Uses default weights and no filters.
|
||||
"""
|
||||
return await hybrid_search(
|
||||
query=query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -110,7 +110,6 @@ class Profile(pydantic.BaseModel):
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
listing_id: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
name: str
|
||||
@@ -165,12 +164,8 @@ class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
||||
|
||||
|
||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
agent_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Agent ID cannot be empty"
|
||||
)
|
||||
agent_version: int = pydantic.Field(
|
||||
..., gt=0, description="Agent version must be greater than 0"
|
||||
)
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
slug: str
|
||||
name: str
|
||||
sub_heading: str
|
||||
|
||||
@@ -138,7 +138,6 @@ def test_creator_details():
|
||||
|
||||
def test_store_submission():
|
||||
submission = store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
@@ -160,7 +159,6 @@ def test_store_submissions_response():
|
||||
response = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
|
||||
@@ -521,7 +521,6 @@ def test_get_submissions_success(
|
||||
mocked_value = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="test-listing-id",
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
image_urls=["test.jpg"],
|
||||
|
||||
@@ -6,9 +6,6 @@ import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
from prisma.types import Serializable
|
||||
|
||||
from backend.sdk import (
|
||||
BaseWebhooksManager,
|
||||
@@ -87,9 +84,7 @@ class AirtableWebhookManager(BaseWebhooksManager):
|
||||
# update webhook config
|
||||
await update_webhook(
|
||||
webhook.id,
|
||||
config=cast(
|
||||
dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor}
|
||||
),
|
||||
config={"base_id": base_id, "cursor": response.cursor},
|
||||
)
|
||||
|
||||
event_type = "notification"
|
||||
|
||||
@@ -975,28 +975,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
nodes_to_skip: set[str] | None = None,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
tool_functions = await self._create_tool_node_signatures(node_id)
|
||||
original_tool_count = len(tool_functions)
|
||||
|
||||
# Filter out tools for nodes that should be skipped (e.g., missing optional credentials)
|
||||
if nodes_to_skip:
|
||||
tool_functions = [
|
||||
tf
|
||||
for tf in tool_functions
|
||||
if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip
|
||||
]
|
||||
|
||||
# Only raise error if we had tools but they were all filtered out
|
||||
if original_tool_count > 0 and not tool_functions:
|
||||
raise ValueError(
|
||||
"No available tools to execute - all downstream nodes are unavailable "
|
||||
"(possibly due to missing optional credentials)"
|
||||
)
|
||||
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
conversation_history = input_data.conversation_history or []
|
||||
|
||||
@@ -383,7 +383,6 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
self,
|
||||
execution_context: ExecutionContext,
|
||||
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
@@ -391,7 +390,6 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip or set(),
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
@@ -1147,8 +1145,6 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None
|
||||
nodes_to_skip: set[str] = Field(default_factory=set)
|
||||
"""Node IDs that should be skipped due to optional credentials not being configured."""
|
||||
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
|
||||
|
||||
|
||||
|
||||
@@ -94,15 +94,6 @@ class Node(BaseDbModel):
|
||||
input_links: list[Link] = []
|
||||
output_links: list[Link] = []
|
||||
|
||||
@property
|
||||
def credentials_optional(self) -> bool:
|
||||
"""
|
||||
Whether credentials are optional for this node.
|
||||
When True and credentials are not configured, the node will be skipped
|
||||
during execution rather than causing a validation error.
|
||||
"""
|
||||
return self.metadata.get("credentials_optional", False)
|
||||
|
||||
@property
|
||||
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
|
||||
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
|
||||
@@ -335,35 +326,7 @@ class Graph(BaseGraph):
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
schema = self._credentials_input_schema.jsonschema()
|
||||
|
||||
# Determine which credential fields are required based on credentials_optional metadata
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
required_fields = []
|
||||
|
||||
# Build a map of node_id -> node for quick lookup
|
||||
all_nodes = {node.id: node for node in self.nodes}
|
||||
for sub_graph in self.sub_graphs:
|
||||
for node in sub_graph.nodes:
|
||||
all_nodes[node.id] = node
|
||||
|
||||
for field_key, (
|
||||
_field_info,
|
||||
node_field_pairs,
|
||||
) in graph_credentials_inputs.items():
|
||||
# A field is required if ANY node using it has credentials_optional=False
|
||||
is_required = False
|
||||
for node_id, _field_name in node_field_pairs:
|
||||
node = all_nodes.get(node_id)
|
||||
if node and not node.credentials_optional:
|
||||
is_required = True
|
||||
break
|
||||
|
||||
if is_required:
|
||||
required_fields.append(field_key)
|
||||
|
||||
schema["required"] = required_fields
|
||||
return schema
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
|
||||
@@ -396,58 +396,3 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
)
|
||||
assert got_graph is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_node_credentials_optional_default():
|
||||
"""Test that credentials_optional defaults to False when not set in metadata."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={},
|
||||
)
|
||||
assert node.credentials_optional is False
|
||||
|
||||
|
||||
def test_node_credentials_optional_true():
|
||||
"""Test that credentials_optional returns True when explicitly set."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={"credentials_optional": True},
|
||||
)
|
||||
assert node.credentials_optional is True
|
||||
|
||||
|
||||
def test_node_credentials_optional_false():
|
||||
"""Test that credentials_optional returns False when explicitly set to False."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={"credentials_optional": False},
|
||||
)
|
||||
assert node.credentials_optional is False
|
||||
|
||||
|
||||
def test_node_credentials_optional_with_other_metadata():
|
||||
"""Test that credentials_optional works correctly with other metadata present."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={
|
||||
"position": {"x": 100, "y": 200},
|
||||
"customized_name": "My Custom Node",
|
||||
"credentials_optional": True,
|
||||
},
|
||||
)
|
||||
assert node.credentials_optional is True
|
||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||
assert node.metadata["customized_name"] == "My Custom Node"
|
||||
|
||||
@@ -178,7 +178,6 @@ async def execute_node(
|
||||
execution_processor: "ExecutionProcessor",
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -246,7 +245,6 @@ async def execute_node(
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
"execution_processor": execution_processor,
|
||||
"nodes_to_skip": nodes_to_skip or set(),
|
||||
}
|
||||
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
@@ -544,7 +542,6 @@ class ExecutionProcessor:
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
@@ -567,7 +564,6 @@ class ExecutionProcessor:
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
)
|
||||
if isinstance(status, BaseException):
|
||||
raise status
|
||||
@@ -613,7 +609,6 @@ class ExecutionProcessor:
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
@@ -650,7 +645,6 @@ class ExecutionProcessor:
|
||||
execution_processor=self,
|
||||
execution_stats=stats,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
):
|
||||
await persist_output(output_name, output_data)
|
||||
|
||||
@@ -962,21 +956,6 @@ class ExecutionProcessor:
|
||||
|
||||
queued_node_exec = execution_queue.get()
|
||||
|
||||
# Check if this node should be skipped due to optional credentials
|
||||
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
|
||||
log_metadata.info(
|
||||
f"Skipping node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id} - optional credentials not configured"
|
||||
)
|
||||
# Mark the node as completed without executing
|
||||
# No outputs will be produced, so downstream nodes won't trigger
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
exec_id=queued_node_exec.node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
)
|
||||
continue
|
||||
|
||||
log_metadata.debug(
|
||||
f"Dispatching node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id}",
|
||||
@@ -1037,7 +1016,6 @@ class ExecutionProcessor:
|
||||
execution_stats,
|
||||
execution_stats_lock,
|
||||
),
|
||||
nodes_to_skip=graph_exec.nodes_to_skip,
|
||||
),
|
||||
self.node_execution_loop,
|
||||
)
|
||||
|
||||
@@ -239,19 +239,14 @@ async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[dict[str, dict[str, str]], set[str]]:
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Checks all credentials for all nodes of the graph and returns structured errors
|
||||
and a set of nodes that should be skipped due to optional missing credentials.
|
||||
Checks all credentials for all nodes of the graph and returns structured errors.
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
|
||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
||||
]
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
|
||||
"""
|
||||
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
||||
nodes_to_skip: set[str] = set()
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
@@ -261,46 +256,27 @@ async def _validate_node_input_credentials(
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
# Track if any credential field is missing for this node
|
||||
has_missing_credentials = False
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
try:
|
||||
# Check nodes_input_masks first, then input_default
|
||||
field_value = None
|
||||
if (
|
||||
nodes_input_masks
|
||||
and (node_input_mask := nodes_input_masks.get(node.id))
|
||||
and field_name in node_input_mask
|
||||
):
|
||||
field_value = node_input_mask[field_name]
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node_input_mask[field_name]
|
||||
)
|
||||
elif field_name in node.input_default:
|
||||
# For optional credentials, don't use input_default - treat as missing
|
||||
# This prevents stale credential IDs from failing validation
|
||||
if node.credentials_optional:
|
||||
field_value = None
|
||||
else:
|
||||
field_value = node.input_default[field_name]
|
||||
|
||||
# Check if credentials are missing (None, empty, or not present)
|
||||
if field_value is None or (
|
||||
isinstance(field_value, dict) and not field_value.get("id")
|
||||
):
|
||||
has_missing_credentials = True
|
||||
# If node has credentials_optional flag, mark for skipping instead of error
|
||||
if node.credentials_optional:
|
||||
continue # Don't add error, will be marked for skip after loop
|
||||
else:
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
|
||||
credentials_meta = credentials_meta_type.model_validate(field_value)
|
||||
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
else:
|
||||
# Missing credentials
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
except ValidationError as e:
|
||||
# Validation error means credentials were provided but invalid
|
||||
# This should always be an error, even if optional
|
||||
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
|
||||
continue
|
||||
|
||||
@@ -311,7 +287,6 @@ async def _validate_node_input_credentials(
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any errors fetching credentials
|
||||
# If credentials were explicitly configured but unavailable, it's an error
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = f"Credentials not available: {e}"
|
||||
@@ -338,19 +313,7 @@ async def _validate_node_input_credentials(
|
||||
] = "Invalid credentials: type/provider mismatch"
|
||||
continue
|
||||
|
||||
# If node has optional credentials and any are missing, mark for skipping
|
||||
# But only if there are no other errors for this node
|
||||
if (
|
||||
has_missing_credentials
|
||||
and node.credentials_optional
|
||||
and node.id not in credential_errors
|
||||
):
|
||||
nodes_to_skip.add(node.id)
|
||||
logger.info(
|
||||
f"Node #{node.id} will be skipped: optional credentials not configured"
|
||||
)
|
||||
|
||||
return credential_errors, nodes_to_skip
|
||||
return credential_errors
|
||||
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
@@ -392,25 +355,21 @@ async def validate_graph_with_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
|
||||
) -> Mapping[str, Mapping[str, str]]:
|
||||
"""
|
||||
Validate graph including credentials and return structured errors per node,
|
||||
along with a set of nodes that should be skipped due to optional missing credentials.
|
||||
Validate graph including credentials and return structured errors per node.
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
|
||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
||||
]
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node
|
||||
"""
|
||||
# Get input validation errors
|
||||
node_input_errors = GraphModel.validate_graph_get_errors(
|
||||
graph, for_run=True, nodes_input_masks=nodes_input_masks
|
||||
)
|
||||
|
||||
# Get credential input/availability/validation errors and nodes to skip
|
||||
node_credential_input_errors, nodes_to_skip = (
|
||||
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
|
||||
# Get credential input/availability/validation errors
|
||||
node_credential_input_errors = await _validate_node_input_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
|
||||
# Merge credential errors with structural errors
|
||||
@@ -419,7 +378,7 @@ async def validate_graph_with_credentials(
|
||||
node_input_errors[node_id] = {}
|
||||
node_input_errors[node_id].update(field_errors)
|
||||
|
||||
return node_input_errors, nodes_to_skip
|
||||
return node_input_errors
|
||||
|
||||
|
||||
async def _construct_starting_node_execution_input(
|
||||
@@ -427,7 +386,7 @@ async def _construct_starting_node_execution_input(
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
This function checks the graph for starting nodes, validates the input data
|
||||
@@ -441,14 +400,11 @@ async def _construct_starting_node_execution_input(
|
||||
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
|
||||
and the corresponding input data for that node.
|
||||
set[str]: Node IDs that should be skipped (optional credentials not configured)
|
||||
]
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
# Use new validation function that includes credentials
|
||||
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
validation_errors = await validate_graph_with_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
n_error_nodes = len(validation_errors)
|
||||
@@ -489,7 +445,7 @@ async def _construct_starting_node_execution_input(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
return nodes_input, nodes_to_skip
|
||||
return nodes_input
|
||||
|
||||
|
||||
async def validate_and_construct_node_execution_input(
|
||||
@@ -500,7 +456,7 @@ async def validate_and_construct_node_execution_input(
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
is_sub_graph: bool = False,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
This centralizes the logic used by both scheduler validation and actual execution.
|
||||
@@ -517,7 +473,6 @@ async def validate_and_construct_node_execution_input(
|
||||
GraphModel: Full graph object for the given `graph_id`.
|
||||
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
||||
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
||||
set[str]: Node IDs that should be skipped (optional credentials not configured).
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the graph is not found.
|
||||
@@ -559,16 +514,14 @@ async def validate_and_construct_node_execution_input(
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
|
||||
starting_nodes_input, nodes_to_skip = (
|
||||
await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
starting_nodes_input = await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
|
||||
return graph, starting_nodes_input, nodes_input_masks
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
@@ -826,9 +779,6 @@ async def add_graph_execution(
|
||||
|
||||
# Use existing execution's compiled input masks
|
||||
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
|
||||
# For resumed executions, nodes_to_skip was already determined at creation time
|
||||
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
|
||||
nodes_to_skip: set[str] = set()
|
||||
|
||||
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
|
||||
else:
|
||||
@@ -837,7 +787,7 @@ async def add_graph_execution(
|
||||
)
|
||||
|
||||
# Create new execution
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks = (
|
||||
await validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
@@ -886,7 +836,6 @@ async def add_graph_execution(
|
||||
try:
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
||||
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||
|
||||
@@ -367,13 +367,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
)
|
||||
|
||||
# Setup mock returns
|
||||
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
|
||||
nodes_to_skip: set[str] = set()
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
nodes_to_skip,
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
@@ -459,212 +456,3 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
# Both executions should succeed (though they create different objects)
|
||||
assert result1 == mock_graph_exec
|
||||
assert result2 == mock_graph_exec_2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that _validate_node_input_credentials returns nodes_to_skip set
|
||||
for nodes with credentials_optional=True and missing credentials.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
# Create a mock node with credentials_optional=True
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-optional-creds"
|
||||
mock_node.credentials_optional = True
|
||||
mock_node.input_default = {} # No credentials configured
|
||||
|
||||
# Create a mock block with credentials field
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should be in nodes_to_skip, not in errors
|
||||
assert mock_node.id in nodes_to_skip
|
||||
assert mock_node.id not in errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_required_missing_creds_error(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that _validate_node_input_credentials returns errors
|
||||
for nodes with credentials_optional=False and missing credentials.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
# Create a mock node with credentials_optional=False (required)
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-required-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {} # No credentials configured
|
||||
|
||||
# Create a mock block with credentials field
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should be in errors, not in nodes_to_skip
|
||||
assert mock_node.id in errors
|
||||
assert "credentials" in errors[mock_node.id]
|
||||
assert "required" in errors[mock_node.id]["credentials"].lower()
|
||||
assert mock_node.id not in nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_with_credentials_returns_nodes_to_skip(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that validate_graph_with_credentials returns nodes_to_skip set
|
||||
from _validate_node_input_credentials.
|
||||
"""
|
||||
from backend.executor.utils import validate_graph_with_credentials
|
||||
|
||||
# Mock _validate_node_input_credentials to return specific values
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils._validate_node_input_credentials"
|
||||
)
|
||||
expected_errors = {"node1": {"field": "error"}}
|
||||
expected_nodes_to_skip = {"node2", "node3"}
|
||||
mock_validate.return_value = (expected_errors, expected_nodes_to_skip)
|
||||
|
||||
# Mock GraphModel with validate_graph_get_errors method
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.validate_graph_get_errors.return_value = {}
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Verify nodes_to_skip is passed through
|
||||
assert nodes_to_skip == expected_nodes_to_skip
|
||||
assert "node1" in errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
"""
|
||||
Test that add_graph_execution properly passes nodes_to_skip
|
||||
to the graph execution entry.
|
||||
"""
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
# Mock data
|
||||
graph_id = "test-graph-id"
|
||||
user_id = "test-user-id"
|
||||
inputs = {"test_input": "test_value"}
|
||||
graph_version = 1
|
||||
|
||||
# Mock the graph object
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.version = graph_version
|
||||
|
||||
# Starting nodes and masks
|
||||
starting_nodes_input = [("node1", {"input1": "value1"})]
|
||||
compiled_nodes_input_masks = {}
|
||||
nodes_to_skip = {"skipped-node-1", "skipped-node-2"}
|
||||
|
||||
# Mock the graph execution object
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = []
|
||||
|
||||
# Track what's passed to to_graph_execution_entry
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_to_entry(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return mocker.MagicMock()
|
||||
|
||||
mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry
|
||||
|
||||
# Setup mocks
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils.validate_and_construct_node_execution_input"
|
||||
)
|
||||
mock_edb = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_udb = mocker.patch("backend.executor.utils.user_db")
|
||||
mock_gdb = mocker.patch("backend.executor.utils.graph_db")
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
|
||||
# Setup returns - include nodes_to_skip in the tuple
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
nodes_to_skip, # This should be passed through
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
|
||||
|
||||
mock_user = mocker.MagicMock()
|
||||
mock_user.timezone = "UTC"
|
||||
mock_settings = mocker.MagicMock()
|
||||
mock_settings.human_in_the_loop_safe_mode = True
|
||||
|
||||
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
|
||||
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
|
||||
mock_get_queue.return_value = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock())
|
||||
|
||||
# Call the function
|
||||
await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
||||
assert "nodes_to_skip" in captured_kwargs
|
||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||
|
||||
@@ -658,6 +658,14 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
|
||||
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
|
||||
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
|
||||
|
||||
# Langfuse prompt management
|
||||
langfuse_public_key: str = Field(default="", description="Langfuse public key")
|
||||
langfuse_secret_key: str = Field(default="", description="Langfuse secret key")
|
||||
langfuse_host: str = Field(
|
||||
default="https://cloud.langfuse.com", description="Langfuse host URL"
|
||||
)
|
||||
|
||||
# Add more secret fields as needed
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
|
||||
@@ -1,227 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate a lightweight stub for prisma/types.py that collapses all exported
|
||||
symbols to Any. This prevents Pyright from spending time/budget on Prisma's
|
||||
query DSL types while keeping runtime behavior unchanged.
|
||||
|
||||
Usage:
|
||||
poetry run gen-prisma-stub
|
||||
|
||||
This script automatically finds the prisma package location and generates
|
||||
the types.pyi stub file in the same directory as types.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
|
||||
|
||||
def _iter_assigned_names(target: ast.expr) -> Iterable[str]:
|
||||
"""Extract names from assignment targets (handles tuple unpacking)."""
|
||||
if isinstance(target, ast.Name):
|
||||
yield target.id
|
||||
elif isinstance(target, (ast.Tuple, ast.List)):
|
||||
for elt in target.elts:
|
||||
yield from _iter_assigned_names(elt)
|
||||
|
||||
|
||||
def _is_private(name: str) -> bool:
|
||||
"""Check if a name is private (starts with _ but not __)."""
|
||||
return name.startswith("_") and not name.startswith("__")
|
||||
|
||||
|
||||
def _is_safe_type_alias(node: ast.Assign) -> bool:
|
||||
"""Check if an assignment is a safe type alias that shouldn't be stubbed.
|
||||
|
||||
Safe types are:
|
||||
- Literal types (don't cause type budget issues)
|
||||
- Simple type references (SortMode, SortOrder, etc.)
|
||||
- TypeVar definitions
|
||||
"""
|
||||
if not node.value:
|
||||
return False
|
||||
|
||||
# Check if it's a Subscript (like Literal[...], Union[...], TypeVar[...])
|
||||
if isinstance(node.value, ast.Subscript):
|
||||
# Get the base type name
|
||||
if isinstance(node.value.value, ast.Name):
|
||||
base_name = node.value.value.id
|
||||
# Literal types are safe
|
||||
if base_name == "Literal":
|
||||
return True
|
||||
# TypeVar is safe
|
||||
if base_name == "TypeVar":
|
||||
return True
|
||||
elif isinstance(node.value.value, ast.Attribute):
|
||||
# Handle typing_extensions.Literal etc.
|
||||
if node.value.value.attr == "Literal":
|
||||
return True
|
||||
|
||||
# Check if it's a simple Name reference (like SortMode = _types.SortMode)
|
||||
if isinstance(node.value, ast.Attribute):
|
||||
return True
|
||||
|
||||
# Check if it's a Call (like TypeVar(...))
|
||||
if isinstance(node.value, ast.Call):
|
||||
if isinstance(node.value.func, ast.Name):
|
||||
if node.value.func.id == "TypeVar":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def collect_top_level_symbols(
|
||||
tree: ast.Module, source_lines: list[str]
|
||||
) -> tuple[Set[str], Set[str], list[str], Set[str]]:
|
||||
"""Collect all top-level symbols from an AST module.
|
||||
|
||||
Returns:
|
||||
Tuple of (class_names, function_names, safe_variable_sources, unsafe_variable_names)
|
||||
safe_variable_sources contains the actual source code lines for safe variables
|
||||
"""
|
||||
classes: Set[str] = set()
|
||||
functions: Set[str] = set()
|
||||
safe_variable_sources: list[str] = []
|
||||
unsafe_variables: Set[str] = set()
|
||||
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.ClassDef):
|
||||
if not _is_private(node.name):
|
||||
classes.add(node.name)
|
||||
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
if not _is_private(node.name):
|
||||
functions.add(node.name)
|
||||
elif isinstance(node, ast.Assign):
|
||||
is_safe = _is_safe_type_alias(node)
|
||||
names = []
|
||||
for t in node.targets:
|
||||
for n in _iter_assigned_names(t):
|
||||
if not _is_private(n):
|
||||
names.append(n)
|
||||
if names:
|
||||
if is_safe:
|
||||
# Extract the source code for this assignment
|
||||
start_line = node.lineno - 1 # 0-indexed
|
||||
end_line = node.end_lineno if node.end_lineno else node.lineno
|
||||
source = "\n".join(source_lines[start_line:end_line])
|
||||
safe_variable_sources.append(source)
|
||||
else:
|
||||
unsafe_variables.update(names)
|
||||
elif isinstance(node, ast.AnnAssign) and node.target:
|
||||
# Annotated assignments are always stubbed
|
||||
for n in _iter_assigned_names(node.target):
|
||||
if not _is_private(n):
|
||||
unsafe_variables.add(n)
|
||||
|
||||
return classes, functions, safe_variable_sources, unsafe_variables
|
||||
|
||||
|
||||
def find_prisma_types_path() -> Path:
|
||||
"""Find the prisma types.py file in the installed package."""
|
||||
spec = importlib.util.find_spec("prisma")
|
||||
if spec is None or spec.origin is None:
|
||||
raise RuntimeError("Could not find prisma package. Is it installed?")
|
||||
|
||||
prisma_dir = Path(spec.origin).parent
|
||||
types_path = prisma_dir / "types.py"
|
||||
|
||||
if not types_path.exists():
|
||||
raise RuntimeError(f"prisma/types.py not found at {types_path}")
|
||||
|
||||
return types_path
|
||||
|
||||
|
||||
def generate_stub(src_path: Path, stub_path: Path) -> int:
|
||||
"""Generate the .pyi stub file from the source types.py."""
|
||||
code = src_path.read_text(encoding="utf-8", errors="ignore")
|
||||
source_lines = code.splitlines()
|
||||
tree = ast.parse(code, filename=str(src_path))
|
||||
classes, functions, safe_variable_sources, unsafe_variables = (
|
||||
collect_top_level_symbols(tree, source_lines)
|
||||
)
|
||||
|
||||
header = """\
|
||||
# -*- coding: utf-8 -*-
|
||||
# Auto-generated stub file - DO NOT EDIT
|
||||
# Generated by gen_prisma_types_stub.py
|
||||
#
|
||||
# This stub intentionally collapses complex Prisma query DSL types to Any.
|
||||
# Prisma's generated types can explode Pyright's type inference budgets
|
||||
# on large schemas. We collapse them to Any so the rest of the codebase
|
||||
# can remain strongly typed while keeping runtime behavior unchanged.
|
||||
#
|
||||
# Safe types (Literal, TypeVar, simple references) are preserved from the
|
||||
# original types.py to maintain proper type checking where possible.
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
from typing_extensions import Literal
|
||||
|
||||
# Re-export commonly used typing constructs that may be imported from this module
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic, Union, Optional, List, Dict
|
||||
|
||||
# Base type alias for stubbed Prisma types - allows any dict structure
|
||||
_PrismaDict = dict[str, Any]
|
||||
|
||||
"""
|
||||
|
||||
lines = [header]
|
||||
|
||||
# Include safe variable definitions (Literal types, TypeVars, etc.)
|
||||
lines.append("# Safe type definitions preserved from original types.py")
|
||||
for source in safe_variable_sources:
|
||||
lines.append(source)
|
||||
lines.append("")
|
||||
|
||||
# Stub all classes and unsafe variables uniformly as dict[str, Any] aliases
|
||||
# This allows:
|
||||
# 1. Use in type annotations: x: SomeType
|
||||
# 2. Constructor calls: SomeType(...)
|
||||
# 3. Dict literal assignments: x: SomeType = {...}
|
||||
lines.append(
|
||||
"# Stubbed types (collapsed to dict[str, Any] to prevent type budget exhaustion)"
|
||||
)
|
||||
all_stubbed = sorted(classes | unsafe_variables)
|
||||
for name in all_stubbed:
|
||||
lines.append(f"{name} = _PrismaDict")
|
||||
|
||||
lines.append("")
|
||||
|
||||
# Stub functions
|
||||
for name in sorted(functions):
|
||||
lines.append(f"def {name}(*args: Any, **kwargs: Any) -> Any: ...")
|
||||
|
||||
lines.append("")
|
||||
|
||||
stub_path.write_text("\n".join(lines), encoding="utf-8")
|
||||
return (
|
||||
len(classes)
|
||||
+ len(functions)
|
||||
+ len(safe_variable_sources)
|
||||
+ len(unsafe_variables)
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point."""
|
||||
try:
|
||||
types_path = find_prisma_types_path()
|
||||
stub_path = types_path.with_suffix(".pyi")
|
||||
|
||||
print(f"Found prisma types.py at: {types_path}")
|
||||
print(f"Generating stub at: {stub_path}")
|
||||
|
||||
num_symbols = generate_stub(types_path, stub_path)
|
||||
print(f"Generated {stub_path.name} with {num_symbols} Any-typed symbols")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -25,9 +25,6 @@ def run(*command: str) -> None:
|
||||
|
||||
|
||||
def lint():
|
||||
# Generate Prisma types stub before running pyright to prevent type budget exhaustion
|
||||
run("gen-prisma-stub")
|
||||
|
||||
lint_step_args: list[list[str]] = [
|
||||
["ruff", "check", *TARGET_DIRS, "--exit-zero"],
|
||||
["ruff", "format", "--diff", "--check", LIBS_DIR],
|
||||
@@ -52,6 +49,4 @@ def format():
|
||||
run("ruff", "format", LIBS_DIR)
|
||||
run("isort", "--profile", "black", BACKEND_DIR)
|
||||
run("black", BACKEND_DIR)
|
||||
# Generate Prisma types stub before running pyright to prevent type budget exhaustion
|
||||
run("gen-prisma-stub")
|
||||
run("pyright", *TARGET_DIRS)
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
-- DropIndex
|
||||
DROP INDEX "StoreListingVersion_storeListingId_version_key";
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "UserBusinessUnderstanding" (
|
||||
"id" TEXT NOT NULL,
|
||||
@@ -0,0 +1,41 @@
|
||||
-- Migration: Add pgvector extension and StoreListingEmbedding table
|
||||
-- This enables hybrid search combining semantic (embedding) and lexical (tsvector) search
|
||||
|
||||
-- Enable pgvector extension for vector similarity search
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
-- Create table to store embeddings for store listing versions
|
||||
CREATE TABLE "StoreListingEmbedding" (
|
||||
"id" TEXT NOT NULL DEFAULT gen_random_uuid(),
|
||||
"storeListingVersionId" TEXT NOT NULL,
|
||||
"embedding" vector(1536), -- OpenAI text-embedding-3-small produces 1536 dimensions
|
||||
"searchableText" TEXT, -- The text that was embedded (for debugging/recomputation)
|
||||
"contentHash" TEXT, -- MD5 hash of searchable text for change detection
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "StoreListingEmbedding_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- Unique constraint: one embedding per listing version
|
||||
CREATE UNIQUE INDEX "StoreListingEmbedding_storeListingVersionId_key"
|
||||
ON "StoreListingEmbedding"("storeListingVersionId");
|
||||
|
||||
-- HNSW index for fast approximate nearest neighbor search
|
||||
-- Using cosine distance (vector_cosine_ops) which is standard for text embeddings
|
||||
CREATE INDEX "StoreListingEmbedding_embedding_idx"
|
||||
ON "StoreListingEmbedding"
|
||||
USING hnsw ("embedding" vector_cosine_ops);
|
||||
|
||||
-- Index on content hash for fast lookup during change detection
|
||||
CREATE INDEX "StoreListingEmbedding_contentHash_idx"
|
||||
ON "StoreListingEmbedding"("contentHash");
|
||||
|
||||
-- Foreign key to StoreListingVersion with CASCADE delete
|
||||
-- When a listing version is deleted, its embedding is automatically removed
|
||||
ALTER TABLE "StoreListingEmbedding"
|
||||
ADD CONSTRAINT "StoreListingEmbedding_storeListingVersionId_fkey"
|
||||
FOREIGN KEY ("storeListingVersionId")
|
||||
REFERENCES "StoreListingVersion"("id")
|
||||
ON DELETE CASCADE
|
||||
ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,5 @@
|
||||
-- DropIndex
|
||||
DROP INDEX "StoreListingEmbedding_embedding_idx";
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "StoreListingEmbedding" ALTER COLUMN "id" DROP DEFAULT;
|
||||
@@ -33,6 +33,7 @@ html2text = "^2024.2.26"
|
||||
jinja2 = "^3.1.6"
|
||||
jsonref = "^1.1.0"
|
||||
jsonschema = "^4.25.0"
|
||||
langfuse = "^2.0.0"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
mem0ai = "^0.1.115"
|
||||
moviepy = "^2.1.2"
|
||||
@@ -117,7 +118,6 @@ lint = "linter:lint"
|
||||
test = "run_tests:test"
|
||||
load-store-agents = "test.load_store_agents:run"
|
||||
export-api-schema = "backend.cli.generate_openapi_json:main"
|
||||
gen-prisma-stub = "gen_prisma_types_stub:main"
|
||||
oauth-tool = "backend.cli.oauth_tool:cli"
|
||||
|
||||
[tool.isort]
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
datasource db {
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
extensions = [pgvector(map: "vector", schema: "public")]
|
||||
}
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = -1
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views", "fullTextSearch"]
|
||||
previewFeatures = ["views", "fullTextSearch", "postgresqlExtensions"]
|
||||
partial_type_generator = "backend/data/partial_types.py"
|
||||
}
|
||||
|
||||
@@ -990,6 +991,10 @@ model StoreListingVersion {
|
||||
// Reviews for this specific version
|
||||
Reviews StoreListingReview[]
|
||||
|
||||
// Embedding for semantic search (one-to-one)
|
||||
Embedding StoreListingEmbedding?
|
||||
|
||||
@@unique([storeListingId, version])
|
||||
@@index([storeListingId, submissionStatus, isAvailable])
|
||||
@@index([submissionStatus])
|
||||
@@index([reviewerId])
|
||||
@@ -1014,6 +1019,24 @@ model StoreListingReview {
|
||||
@@index([reviewByUserId])
|
||||
}
|
||||
|
||||
// Stores vector embeddings for semantic search of store listings
|
||||
// Uses pgvector extension for efficient similarity search
|
||||
model StoreListingEmbedding {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
storeListingVersionId String @unique
|
||||
StoreListingVersion StoreListingVersion @relation(fields: [storeListingVersionId], references: [id], onDelete: Cascade)
|
||||
|
||||
// pgvector embedding - stored as Unsupported type since Prisma doesn't natively support vector
|
||||
embedding Unsupported("vector(1536)")?
|
||||
searchableText String? // The text that was embedded (for debugging/recomputation)
|
||||
contentHash String? // MD5 hash for change detection
|
||||
|
||||
@@index([contentHash])
|
||||
}
|
||||
|
||||
enum SubmissionStatus {
|
||||
DRAFT // Being prepared, not yet submitted
|
||||
PENDING // Submitted, awaiting review
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
"created_at": "2025-09-04T13:37:00",
|
||||
"credentials_input_schema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "TestGraphCredentialsInputSchema",
|
||||
"type": "object"
|
||||
},
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
{
|
||||
"credentials_input_schema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "TestGraphCredentialsInputSchema",
|
||||
"type": "object"
|
||||
},
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
"id": "test-agent-1",
|
||||
"graph_id": "test-agent-1",
|
||||
"graph_version": 1,
|
||||
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"image_url": null,
|
||||
"creator_name": "Test Creator",
|
||||
"creator_image_url": "",
|
||||
@@ -42,7 +41,6 @@
|
||||
"id": "test-agent-2",
|
||||
"graph_id": "test-agent-2",
|
||||
"graph_version": 1,
|
||||
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"image_url": null,
|
||||
"creator_name": "Test Creator",
|
||||
"creator_image_url": "",
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{
|
||||
"submissions": [
|
||||
{
|
||||
"listing_id": "test-listing-id",
|
||||
"agent_id": "test-agent-id",
|
||||
"agent_version": 1,
|
||||
"name": "Test Agent",
|
||||
|
||||
@@ -37,7 +37,7 @@ services:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: migrate
|
||||
command: ["sh", "-c", "poetry run prisma generate && poetry run gen-prisma-stub && poetry run prisma migrate deploy"]
|
||||
command: ["sh", "-c", "poetry run prisma generate && poetry run prisma migrate deploy"]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
|
||||
@@ -54,7 +54,7 @@
|
||||
"@radix-ui/react-tooltip": "1.2.8",
|
||||
"@rjsf/core": "6.1.2",
|
||||
"@rjsf/utils": "6.1.2",
|
||||
"@rjsf/validator-ajv8": "6.1.2",
|
||||
"@rjsf/validator-ajv8": "5.24.13",
|
||||
"@sentry/nextjs": "10.27.0",
|
||||
"@supabase/ssr": "0.7.0",
|
||||
"@supabase/supabase-js": "2.78.0",
|
||||
@@ -92,6 +92,7 @@
|
||||
"react-currency-input-field": "4.0.3",
|
||||
"react-day-picker": "9.11.1",
|
||||
"react-dom": "18.3.1",
|
||||
"react-drag-drop-files": "2.4.0",
|
||||
"react-hook-form": "7.66.0",
|
||||
"react-icons": "5.5.0",
|
||||
"react-markdown": "9.0.3",
|
||||
|
||||
3704
autogpt_platform/frontend/pnpm-lock.yaml
generated
3704
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
Before Width: | Height: | Size: 2.6 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 16 KiB |
@@ -1,6 +1,6 @@
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs";
|
||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||
import { CredentialsInput } from "@/components/contextual/CredentialsInputs/CredentialsInputs";
|
||||
import { useState } from "react";
|
||||
import { getSchemaDefaultCredentials } from "../../helpers";
|
||||
import { areAllCredentialsSet, getCredentialFields } from "./helpers";
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentInputs/RunAgentInputs";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/__legacy__/ui/card";
|
||||
import { RunAgentInputs } from "@/components/contextual/RunAgentInputs/RunAgentInputs";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { CircleNotchIcon } from "@phosphor-icons/react/dist/ssr";
|
||||
import { Play } from "lucide-react";
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
"use client";
|
||||
|
||||
import { ChatDrawer } from "@/components/contextual/Chat/ChatDrawer";
|
||||
import { usePathname } from "next/navigation";
|
||||
import { Children, ReactNode } from "react";
|
||||
|
||||
interface PlatformLayoutContentProps {
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
export function PlatformLayoutContent({
|
||||
children,
|
||||
}: PlatformLayoutContentProps) {
|
||||
const pathname = usePathname();
|
||||
const isAuthPage =
|
||||
pathname?.includes("/login") || pathname?.includes("/signup");
|
||||
|
||||
// Extract Navbar, AdminImpersonationBanner, and page content from children
|
||||
const childrenArray = Children.toArray(children);
|
||||
const navbar = childrenArray[0];
|
||||
const adminBanner = childrenArray[1];
|
||||
const pageContent = childrenArray.slice(2);
|
||||
|
||||
// For login/signup pages, use a simpler layout that doesn't interfere with centering
|
||||
if (isAuthPage) {
|
||||
return (
|
||||
<main className="flex min-h-screen w-full flex-col">
|
||||
{navbar}
|
||||
{adminBanner}
|
||||
<section className="flex-1">{pageContent}</section>
|
||||
{/* ChatDrawer must always be rendered to maintain consistent hook count */}
|
||||
<ChatDrawer />
|
||||
</main>
|
||||
);
|
||||
}
|
||||
|
||||
// For logged-in pages, use the drawer layout
|
||||
return (
|
||||
<main className="flex h-screen w-full flex-col overflow-hidden">
|
||||
{navbar}
|
||||
{adminBanner}
|
||||
<section className="flex min-h-0 flex-1 overflow-auto">
|
||||
{pageContent}
|
||||
</section>
|
||||
<ChatDrawer />
|
||||
</main>
|
||||
);
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import { AuthCard } from "@/components/auth/AuthCard";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs";
|
||||
import { CredentialsInput } from "@/components/contextual/CredentialsInputs/CredentialsInputs";
|
||||
import type {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
import { BlockUIType } from "@/app/(platform)/build/components/types";
|
||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputActions,
|
||||
OutputItem,
|
||||
} from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import { Label } from "@/components/__legacy__/ui/label";
|
||||
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
||||
import {
|
||||
@@ -23,6 +18,11 @@ import {
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputActions,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import { BookOpenIcon } from "@phosphor-icons/react";
|
||||
import { useMemo } from "react";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
|
||||
@@ -66,7 +66,6 @@ export const RunInputDialog = ({
|
||||
formContext={{
|
||||
showHandles: false,
|
||||
size: "large",
|
||||
showOptionalToggle: false,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -66,7 +66,7 @@ export const useRunInputDialog = ({
|
||||
if (isCredentialFieldSchema(fieldSchema)) {
|
||||
dynamicUiSchema[fieldName] = {
|
||||
...dynamicUiSchema[fieldName],
|
||||
"ui:field": "custom/credential_field",
|
||||
"ui:field": "credentials",
|
||||
};
|
||||
}
|
||||
});
|
||||
@@ -76,18 +76,12 @@ export const useRunInputDialog = ({
|
||||
}, [credentialsSchema]);
|
||||
|
||||
const handleManualRun = async () => {
|
||||
// Filter out incomplete credentials (those without a valid id)
|
||||
// RJSF auto-populates const values (provider, type) but not id field
|
||||
const validCredentials = Object.fromEntries(
|
||||
Object.entries(credentialValues).filter(([_, cred]) => cred && cred.id),
|
||||
);
|
||||
|
||||
await executeGraph({
|
||||
graphId: flowID ?? "",
|
||||
graphVersion: flowVersion || null,
|
||||
data: {
|
||||
inputs: inputValues,
|
||||
credentials_inputs: validCredentials,
|
||||
credentials_inputs: credentialValues,
|
||||
source: "builder",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -97,9 +97,6 @@ export const Flow = () => {
|
||||
onConnect={onConnect}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onNodeDragStop={onNodeDragStop}
|
||||
onNodeContextMenu={(event) => {
|
||||
event.preventDefault();
|
||||
}}
|
||||
maxZoom={2}
|
||||
minZoom={0.1}
|
||||
onDragOver={onDragOver}
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { BlockCost } from "@/app/api/__generated__/models/blockCost";
|
||||
import { BlockInfoCategoriesItem } from "@/app/api/__generated__/models/blockInfoCategoriesItem";
|
||||
import { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
|
||||
import { NodeModelMetadata } from "@/app/api/__generated__/models/nodeModelMetadata";
|
||||
import { preprocessInputSchema } from "@/components/renderers/InputRenderer/utils/input-schema-pre-processor";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import { NodeProps, Node as XYNode } from "@xyflow/react";
|
||||
import React from "react";
|
||||
import { Node as XYNode, NodeProps } from "@xyflow/react";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import { BlockUIType } from "../../../types";
|
||||
import { FormCreator } from "../FormCreator";
|
||||
import { OutputHandler } from "../OutputHandler";
|
||||
import { AyrshareConnectButton } from "./components/AyrshareConnectButton";
|
||||
import { NodeAdvancedToggle } from "./components/NodeAdvancedToggle";
|
||||
import { NodeContainer } from "./components/NodeContainer";
|
||||
import { NodeExecutionBadge } from "./components/NodeExecutionBadge";
|
||||
import { NodeHeader } from "./components/NodeHeader";
|
||||
import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput";
|
||||
import { NodeRightClickMenu } from "./components/NodeRightClickMenu";
|
||||
import { StickyNoteBlock } from "./components/StickyNoteBlock";
|
||||
import { BlockInfoCategoriesItem } from "@/app/api/__generated__/models/blockInfoCategoriesItem";
|
||||
import { BlockCost } from "@/app/api/__generated__/models/blockCost";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
|
||||
import { NodeContainer } from "./components/NodeContainer";
|
||||
import { NodeHeader } from "./components/NodeHeader";
|
||||
import { FormCreator } from "../FormCreator";
|
||||
import { preprocessInputSchema } from "@/components/renderers/InputRenderer/utils/input-schema-pre-processor";
|
||||
import { OutputHandler } from "../OutputHandler";
|
||||
import { NodeAdvancedToggle } from "./components/NodeAdvancedToggle";
|
||||
import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput";
|
||||
import { NodeExecutionBadge } from "./components/NodeExecutionBadge";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
|
||||
import { AyrshareConnectButton } from "./components/AyrshareConnectButton";
|
||||
import { NodeModelMetadata } from "@/app/api/__generated__/models/nodeModelMetadata";
|
||||
|
||||
export type CustomNodeData = {
|
||||
hardcodedValues: {
|
||||
@@ -89,7 +88,7 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
|
||||
// Currently all blockTypes design are similar - that's why i am using the same component for all of them
|
||||
// If in future - if we need some drastic change in some blockTypes design - we can create separate components for them
|
||||
const node = (
|
||||
return (
|
||||
<NodeContainer selected={selected} nodeId={nodeId} hasErrors={hasErrors}>
|
||||
<div className="rounded-xlarge bg-white">
|
||||
<NodeHeader data={data} nodeId={nodeId} />
|
||||
@@ -118,15 +117,6 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
<NodeExecutionBadge nodeId={nodeId} />
|
||||
</NodeContainer>
|
||||
);
|
||||
|
||||
return (
|
||||
<NodeRightClickMenu
|
||||
nodeId={nodeId}
|
||||
subGraphID={data.hardcodedValues?.graph_id}
|
||||
>
|
||||
{node}
|
||||
</NodeRightClickMenu>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -1,31 +1,26 @@
|
||||
import { useCopyPasteStore } from "@/app/(platform)/build/stores/copyPasteStore";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/molecules/DropdownMenu/DropdownMenu";
|
||||
import {
|
||||
SecondaryDropdownMenuContent,
|
||||
SecondaryDropdownMenuItem,
|
||||
SecondaryDropdownMenuSeparator,
|
||||
} from "@/components/molecules/SecondaryMenu/SecondaryMenu";
|
||||
import {
|
||||
ArrowSquareOutIcon,
|
||||
CopyIcon,
|
||||
DotsThreeOutlineVerticalIcon,
|
||||
TrashIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { DotsThreeOutlineVerticalIcon } from "@phosphor-icons/react";
|
||||
import { Copy, Trash2, ExternalLink } from "lucide-react";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { useCopyPasteStore } from "@/app/(platform)/build/stores/copyPasteStore";
|
||||
import { useReactFlow } from "@xyflow/react";
|
||||
|
||||
type Props = {
|
||||
export const NodeContextMenu = ({
|
||||
nodeId,
|
||||
subGraphID,
|
||||
}: {
|
||||
nodeId: string;
|
||||
subGraphID?: string;
|
||||
};
|
||||
|
||||
export const NodeContextMenu = ({ nodeId, subGraphID }: Props) => {
|
||||
}) => {
|
||||
const { deleteElements } = useReactFlow();
|
||||
|
||||
function handleCopy() {
|
||||
const handleCopy = () => {
|
||||
useNodeStore.setState((state) => ({
|
||||
nodes: state.nodes.map((node) => ({
|
||||
...node,
|
||||
@@ -35,47 +30,47 @@ export const NodeContextMenu = ({ nodeId, subGraphID }: Props) => {
|
||||
|
||||
useCopyPasteStore.getState().copySelectedNodes();
|
||||
useCopyPasteStore.getState().pasteNodes();
|
||||
}
|
||||
};
|
||||
|
||||
function handleDelete() {
|
||||
const handleDelete = () => {
|
||||
deleteElements({ nodes: [{ id: nodeId }] });
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger className="py-2">
|
||||
<DotsThreeOutlineVerticalIcon size={16} weight="fill" />
|
||||
</DropdownMenuTrigger>
|
||||
<SecondaryDropdownMenuContent side="right" align="start">
|
||||
<SecondaryDropdownMenuItem onClick={handleCopy}>
|
||||
<CopyIcon size={20} className="mr-2 dark:text-gray-100" />
|
||||
<span className="dark:text-gray-100">Copy</span>
|
||||
</SecondaryDropdownMenuItem>
|
||||
<SecondaryDropdownMenuSeparator />
|
||||
<DropdownMenuContent
|
||||
side="right"
|
||||
align="start"
|
||||
className="rounded-xlarge"
|
||||
>
|
||||
<DropdownMenuItem onClick={handleCopy} className="hover:rounded-xlarge">
|
||||
<Copy className="mr-2 h-4 w-4" />
|
||||
Copy Node
|
||||
</DropdownMenuItem>
|
||||
|
||||
{subGraphID && (
|
||||
<>
|
||||
<SecondaryDropdownMenuItem
|
||||
onClick={() => window.open(`/build?flowID=${subGraphID}`)}
|
||||
>
|
||||
<ArrowSquareOutIcon
|
||||
size={20}
|
||||
className="mr-2 dark:text-gray-100"
|
||||
/>
|
||||
<span className="dark:text-gray-100">Open agent</span>
|
||||
</SecondaryDropdownMenuItem>
|
||||
<SecondaryDropdownMenuSeparator />
|
||||
</>
|
||||
<DropdownMenuItem
|
||||
onClick={() => window.open(`/build?flowID=${subGraphID}`)}
|
||||
className="hover:rounded-xlarge"
|
||||
>
|
||||
<ExternalLink className="mr-2 h-4 w-4" />
|
||||
Open Agent
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
|
||||
<SecondaryDropdownMenuItem variant="destructive" onClick={handleDelete}>
|
||||
<TrashIcon
|
||||
size={20}
|
||||
className="mr-2 text-red-500 dark:text-red-400"
|
||||
/>
|
||||
<span className="dark:text-red-400">Delete</span>
|
||||
</SecondaryDropdownMenuItem>
|
||||
</SecondaryDropdownMenuContent>
|
||||
<Separator className="my-2" />
|
||||
|
||||
<DropdownMenuItem
|
||||
onClick={handleDelete}
|
||||
className="text-red-600 hover:rounded-xlarge"
|
||||
>
|
||||
<Trash2 className="mr-2 h-4 w-4" />
|
||||
Delete
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { beautifyString, cn } from "@/lib/utils";
|
||||
import { NodeCost } from "./NodeCost";
|
||||
import { NodeBadges } from "./NodeBadges";
|
||||
import { NodeContextMenu } from "./NodeContextMenu";
|
||||
import { CustomNodeData } from "../CustomNode";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { useState } from "react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { beautifyString, cn } from "@/lib/utils";
|
||||
import { useState } from "react";
|
||||
import { CustomNodeData } from "../CustomNode";
|
||||
import { NodeBadges } from "./NodeBadges";
|
||||
import { NodeContextMenu } from "./NodeContextMenu";
|
||||
import { NodeCost } from "./NodeCost";
|
||||
|
||||
type Props = {
|
||||
export const NodeHeader = ({
|
||||
data,
|
||||
nodeId,
|
||||
}: {
|
||||
data: CustomNodeData;
|
||||
nodeId: string;
|
||||
};
|
||||
|
||||
export const NodeHeader = ({ data, nodeId }: Props) => {
|
||||
}) => {
|
||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||
const title = (data.metadata?.customized_name as string) || data.title;
|
||||
const [isEditingTitle, setIsEditingTitle] = useState(false);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import type { OutputMetadata } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import { globalRegistry } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import { globalRegistry } from "@/components/contextual/OutputRenderers";
|
||||
|
||||
export const TextRenderer: React.FC<{
|
||||
value: any;
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import {
|
||||
OutputActions,
|
||||
OutputItem,
|
||||
} from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
@@ -11,6 +7,10 @@ import {
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import {
|
||||
OutputActions,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import {
|
||||
@@ -151,7 +151,7 @@ export const NodeDataViewer: FC<NodeDataViewerProps> = ({
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end pt-4">
|
||||
{outputItems.length > 1 && (
|
||||
{outputItems.length > 0 && (
|
||||
<OutputActions
|
||||
items={outputItems.map((item) => ({
|
||||
value: item.value,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { OutputMetadata } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import { globalRegistry } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import { downloadOutputs } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers/utils/download";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import { globalRegistry } from "@/components/contextual/OutputRenderers";
|
||||
import { downloadOutputs } from "@/components/contextual/OutputRenderers/utils/download";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import React, { useMemo, useState } from "react";
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
import { useCopyPasteStore } from "@/app/(platform)/build/stores/copyPasteStore";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import {
|
||||
SecondaryMenuContent,
|
||||
SecondaryMenuItem,
|
||||
SecondaryMenuSeparator,
|
||||
} from "@/components/molecules/SecondaryMenu/SecondaryMenu";
|
||||
import { ArrowSquareOutIcon, CopyIcon, TrashIcon } from "@phosphor-icons/react";
|
||||
import * as ContextMenu from "@radix-ui/react-context-menu";
|
||||
import { useReactFlow } from "@xyflow/react";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { CustomNode } from "../CustomNode";
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
subGraphID?: string;
|
||||
children: React.ReactNode;
|
||||
};
|
||||
|
||||
const DOUBLE_CLICK_TIMEOUT = 300;
|
||||
|
||||
export function NodeRightClickMenu({ nodeId, subGraphID, children }: Props) {
|
||||
const { deleteElements } = useReactFlow<CustomNode>();
|
||||
const lastRightClickTime = useRef<number>(0);
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
function copyNode() {
|
||||
useNodeStore.setState((state) => ({
|
||||
nodes: state.nodes.map((node) => ({
|
||||
...node,
|
||||
selected: node.id === nodeId,
|
||||
})),
|
||||
}));
|
||||
|
||||
useCopyPasteStore.getState().copySelectedNodes();
|
||||
useCopyPasteStore.getState().pasteNodes();
|
||||
}
|
||||
|
||||
function deleteNode() {
|
||||
deleteElements({ nodes: [{ id: nodeId }] });
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
const container = containerRef.current;
|
||||
if (!container) return;
|
||||
|
||||
function handleContextMenu(e: MouseEvent) {
|
||||
const now = Date.now();
|
||||
const timeSinceLastClick = now - lastRightClickTime.current;
|
||||
|
||||
if (timeSinceLastClick < DOUBLE_CLICK_TIMEOUT) {
|
||||
e.stopImmediatePropagation();
|
||||
lastRightClickTime.current = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
lastRightClickTime.current = now;
|
||||
}
|
||||
|
||||
container.addEventListener("contextmenu", handleContextMenu, true);
|
||||
|
||||
return () => {
|
||||
container.removeEventListener("contextmenu", handleContextMenu, true);
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<ContextMenu.Root>
|
||||
<ContextMenu.Trigger asChild>
|
||||
<div ref={containerRef}>{children}</div>
|
||||
</ContextMenu.Trigger>
|
||||
<SecondaryMenuContent>
|
||||
<SecondaryMenuItem onSelect={copyNode}>
|
||||
<CopyIcon size={20} className="mr-2 dark:text-gray-100" />
|
||||
<span className="dark:text-gray-100">Copy</span>
|
||||
</SecondaryMenuItem>
|
||||
<SecondaryMenuSeparator />
|
||||
|
||||
{subGraphID && (
|
||||
<>
|
||||
<SecondaryMenuItem
|
||||
onClick={() => window.open(`/build?flowID=${subGraphID}`)}
|
||||
>
|
||||
<ArrowSquareOutIcon
|
||||
size={20}
|
||||
className="mr-2 dark:text-gray-100"
|
||||
/>
|
||||
<span className="dark:text-gray-100">Open agent</span>
|
||||
</SecondaryMenuItem>
|
||||
<SecondaryMenuSeparator />
|
||||
</>
|
||||
)}
|
||||
|
||||
<SecondaryMenuItem variant="destructive" onSelect={deleteNode}>
|
||||
<TrashIcon
|
||||
size={20}
|
||||
className="mr-2 text-red-500 dark:text-red-400"
|
||||
/>
|
||||
<span className="dark:text-red-400">Delete</span>
|
||||
</SecondaryMenuItem>
|
||||
</SecondaryMenuContent>
|
||||
</ContextMenu.Root>
|
||||
);
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import Link from "next/link";
|
||||
import { useGetV2GetLibraryAgentByGraphId } from "@/app/api/__generated__/endpoints/library/library";
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { useQueryStates, parseAsString } from "nuqs";
|
||||
import { isValidUUID } from "@/app/(platform)/chat/helpers";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { isValidUUID } from "@/components/contextual/Chat/helpers";
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
import Link from "next/link";
|
||||
import { parseAsString, useQueryStates } from "nuqs";
|
||||
|
||||
export const WebhookDisclaimer = ({ nodeId }: { nodeId: string }) => {
|
||||
const [{ flowID }] = useQueryStates({
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
export const uiSchema = {
|
||||
credentials: {
|
||||
"ui:field": "custom/credential_field",
|
||||
"ui:field": "credentials",
|
||||
provider: { "ui:widget": "hidden" },
|
||||
type: { "ui:widget": "hidden" },
|
||||
id: { "ui:autofocus": true },
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
import { useBlockMenuStore } from "@/app/(platform)/build/stores/blockMenuStore";
|
||||
import { FilterChip } from "../FilterChip";
|
||||
import { categories } from "./constants";
|
||||
import { FilterSheet } from "../FilterSheet/FilterSheet";
|
||||
import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem";
|
||||
|
||||
export const BlockMenuFilters = () => {
|
||||
const {
|
||||
filters,
|
||||
addFilter,
|
||||
removeFilter,
|
||||
categoryCounts,
|
||||
creators,
|
||||
addCreator,
|
||||
removeCreator,
|
||||
} = useBlockMenuStore();
|
||||
|
||||
const handleFilterClick = (filter: GetV2BuilderSearchFilterAnyOfItem) => {
|
||||
if (filters.includes(filter)) {
|
||||
removeFilter(filter);
|
||||
} else {
|
||||
addFilter(filter);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCreatorClick = (creator: string) => {
|
||||
if (creators.includes(creator)) {
|
||||
removeCreator(creator);
|
||||
} else {
|
||||
addCreator(creator);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<FilterSheet categories={categories} />
|
||||
{creators.length > 0 &&
|
||||
creators.map((creator) => (
|
||||
<FilterChip
|
||||
key={creator}
|
||||
name={"Created by " + creator.slice(0, 10) + "..."}
|
||||
selected={creators.includes(creator)}
|
||||
onClick={() => handleCreatorClick(creator)}
|
||||
/>
|
||||
))}
|
||||
{categories.map((category) => (
|
||||
<FilterChip
|
||||
key={category.key}
|
||||
name={category.name}
|
||||
selected={filters.includes(category.key)}
|
||||
onClick={() => handleFilterClick(category.key)}
|
||||
number={categoryCounts[category.key] ?? 0}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,15 +0,0 @@
|
||||
import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem";
|
||||
import { CategoryKey } from "./types";
|
||||
|
||||
export const categories: Array<{ key: CategoryKey; name: string }> = [
|
||||
{ key: GetV2BuilderSearchFilterAnyOfItem.blocks, name: "Blocks" },
|
||||
{
|
||||
key: GetV2BuilderSearchFilterAnyOfItem.integrations,
|
||||
name: "Integrations",
|
||||
},
|
||||
{
|
||||
key: GetV2BuilderSearchFilterAnyOfItem.marketplace_agents,
|
||||
name: "Marketplace agents",
|
||||
},
|
||||
{ key: GetV2BuilderSearchFilterAnyOfItem.my_agents, name: "My agents" },
|
||||
];
|
||||
@@ -1,26 +0,0 @@
|
||||
import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem";
|
||||
|
||||
export type DefaultStateType =
|
||||
| "suggestion"
|
||||
| "all_blocks"
|
||||
| "input_blocks"
|
||||
| "action_blocks"
|
||||
| "output_blocks"
|
||||
| "integrations"
|
||||
| "marketplace_agents"
|
||||
| "my_agents";
|
||||
|
||||
export type CategoryKey = GetV2BuilderSearchFilterAnyOfItem;
|
||||
|
||||
export interface Filters {
|
||||
categories: {
|
||||
blocks: boolean;
|
||||
integrations: boolean;
|
||||
marketplace_agents: boolean;
|
||||
my_agents: boolean;
|
||||
providers: boolean;
|
||||
};
|
||||
createdBy: string[];
|
||||
}
|
||||
|
||||
export type CategoryCounts = Record<CategoryKey, number>;
|
||||
@@ -1,14 +1,111 @@
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { useBlockMenuSearch } from "./useBlockMenuSearch";
|
||||
import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll";
|
||||
import { LoadingSpinner } from "@/components/__legacy__/ui/loading";
|
||||
import { SearchResponseItemsItem } from "@/app/api/__generated__/models/searchResponseItemsItem";
|
||||
import { MarketplaceAgentBlock } from "../MarketplaceAgentBlock";
|
||||
import { Block } from "../Block";
|
||||
import { UGCAgentBlock } from "../UGCAgentBlock";
|
||||
import { getSearchItemType } from "./helper";
|
||||
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
||||
import { blockMenuContainerStyle } from "../style";
|
||||
import { BlockMenuFilters } from "../BlockMenuFilters/BlockMenuFilters";
|
||||
import { BlockMenuSearchContent } from "../BlockMenuSearchContent/BlockMenuSearchContent";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { NoSearchResult } from "../NoSearchResult";
|
||||
|
||||
export const BlockMenuSearch = () => {
|
||||
const {
|
||||
searchResults,
|
||||
isFetchingNextPage,
|
||||
fetchNextPage,
|
||||
hasNextPage,
|
||||
searchLoading,
|
||||
handleAddLibraryAgent,
|
||||
handleAddMarketplaceAgent,
|
||||
addingLibraryAgentId,
|
||||
addingMarketplaceAgentSlug,
|
||||
} = useBlockMenuSearch();
|
||||
const { searchQuery } = useBlockMenuStore();
|
||||
|
||||
if (searchLoading) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
blockMenuContainerStyle,
|
||||
"flex items-center justify-center",
|
||||
)}
|
||||
>
|
||||
<LoadingSpinner className="size-13" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (searchResults.length === 0) {
|
||||
return <NoSearchResult />;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={blockMenuContainerStyle}>
|
||||
<BlockMenuFilters />
|
||||
<Text variant="body-medium">Search results</Text>
|
||||
<BlockMenuSearchContent />
|
||||
<InfiniteScroll
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
fetchNextPage={fetchNextPage}
|
||||
hasNextPage={hasNextPage}
|
||||
loader={<LoadingSpinner className="size-13" />}
|
||||
className="space-y-2.5"
|
||||
>
|
||||
{searchResults.map((item: SearchResponseItemsItem, index: number) => {
|
||||
const { type, data } = getSearchItemType(item);
|
||||
// backend give support to these 3 types only [right now] - we need to give support to integration and ai agent types in follow up PRs
|
||||
switch (type) {
|
||||
case "store_agent":
|
||||
return (
|
||||
<MarketplaceAgentBlock
|
||||
key={index}
|
||||
slug={data.slug}
|
||||
highlightedText={searchQuery}
|
||||
title={data.agent_name}
|
||||
image_url={data.agent_image}
|
||||
creator_name={data.creator}
|
||||
number_of_runs={data.runs}
|
||||
loading={addingMarketplaceAgentSlug === data.slug}
|
||||
onClick={() =>
|
||||
handleAddMarketplaceAgent({
|
||||
creator_name: data.creator,
|
||||
slug: data.slug,
|
||||
})
|
||||
}
|
||||
/>
|
||||
);
|
||||
case "block":
|
||||
return (
|
||||
<Block
|
||||
key={index}
|
||||
title={data.name}
|
||||
highlightedText={searchQuery}
|
||||
description={data.description}
|
||||
blockData={data}
|
||||
/>
|
||||
);
|
||||
|
||||
case "library_agent":
|
||||
return (
|
||||
<UGCAgentBlock
|
||||
key={index}
|
||||
title={data.name}
|
||||
highlightedText={searchQuery}
|
||||
image_url={data.image_url}
|
||||
version={data.graph_version}
|
||||
edited_time={data.updated_at}
|
||||
isLoading={addingLibraryAgentId === data.id}
|
||||
onClick={() => handleAddLibraryAgent(data)}
|
||||
/>
|
||||
);
|
||||
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
})}
|
||||
</InfiniteScroll>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -23,19 +23,9 @@ import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem";
|
||||
|
||||
export const useBlockMenuSearchContent = () => {
|
||||
const {
|
||||
searchQuery,
|
||||
searchId,
|
||||
setSearchId,
|
||||
filters,
|
||||
setCreatorsList,
|
||||
creators,
|
||||
setCategoryCounts,
|
||||
} = useBlockMenuStore();
|
||||
|
||||
export const useBlockMenuSearch = () => {
|
||||
const { searchQuery, searchId, setSearchId } = useBlockMenuStore();
|
||||
const { toast } = useToast();
|
||||
const { addAgentToBuilder, addLibraryAgentToBuilder } =
|
||||
useAddAgentToBuilder();
|
||||
@@ -67,8 +57,6 @@ export const useBlockMenuSearchContent = () => {
|
||||
page_size: 8,
|
||||
search_query: searchQuery,
|
||||
search_id: searchId,
|
||||
filter: filters.length > 0 ? filters : undefined,
|
||||
by_creator: creators.length > 0 ? creators : undefined,
|
||||
},
|
||||
{
|
||||
query: { getNextPageParam: getPaginationNextPageNumber },
|
||||
@@ -110,26 +98,6 @@ export const useBlockMenuSearchContent = () => {
|
||||
}
|
||||
}, [searchQueryData, searchId, setSearchId]);
|
||||
|
||||
// from all the results, we need to get all the unique creators
|
||||
useEffect(() => {
|
||||
if (!searchQueryData?.pages?.length) {
|
||||
return;
|
||||
}
|
||||
const latestData = okData(searchQueryData.pages.at(-1));
|
||||
setCategoryCounts(
|
||||
(latestData?.total_items as Record<
|
||||
GetV2BuilderSearchFilterAnyOfItem,
|
||||
number
|
||||
>) || {
|
||||
blocks: 0,
|
||||
integrations: 0,
|
||||
marketplace_agents: 0,
|
||||
my_agents: 0,
|
||||
},
|
||||
);
|
||||
setCreatorsList(latestData?.items || []);
|
||||
}, [searchQueryData]);
|
||||
|
||||
useEffect(() => {
|
||||
if (searchId && !searchQuery) {
|
||||
resetSearchSession();
|
||||
@@ -1,108 +0,0 @@
|
||||
import { SearchResponseItemsItem } from "@/app/api/__generated__/models/searchResponseItemsItem";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll";
|
||||
import { getSearchItemType } from "./helper";
|
||||
import { MarketplaceAgentBlock } from "../MarketplaceAgentBlock";
|
||||
import { Block } from "../Block";
|
||||
import { UGCAgentBlock } from "../UGCAgentBlock";
|
||||
import { useBlockMenuSearchContent } from "./useBlockMenuSearchContent";
|
||||
import { useBlockMenuStore } from "@/app/(platform)/build/stores/blockMenuStore";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { blockMenuContainerStyle } from "../style";
|
||||
import { NoSearchResult } from "../NoSearchResult";
|
||||
|
||||
export const BlockMenuSearchContent = () => {
|
||||
const {
|
||||
searchResults,
|
||||
isFetchingNextPage,
|
||||
fetchNextPage,
|
||||
hasNextPage,
|
||||
searchLoading,
|
||||
handleAddLibraryAgent,
|
||||
handleAddMarketplaceAgent,
|
||||
addingLibraryAgentId,
|
||||
addingMarketplaceAgentSlug,
|
||||
} = useBlockMenuSearchContent();
|
||||
|
||||
const { searchQuery } = useBlockMenuStore();
|
||||
|
||||
if (searchLoading) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
blockMenuContainerStyle,
|
||||
"flex items-center justify-center",
|
||||
)}
|
||||
>
|
||||
<LoadingSpinner className="size-13" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (searchResults.length === 0) {
|
||||
return <NoSearchResult />;
|
||||
}
|
||||
|
||||
return (
|
||||
<InfiniteScroll
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
fetchNextPage={fetchNextPage}
|
||||
hasNextPage={hasNextPage}
|
||||
loader={<LoadingSpinner className="size-13" />}
|
||||
className="space-y-2.5"
|
||||
>
|
||||
{searchResults.map((item: SearchResponseItemsItem, index: number) => {
|
||||
const { type, data } = getSearchItemType(item);
|
||||
// backend give support to these 3 types only [right now] - we need to give support to integration and ai agent types in follow up PRs
|
||||
switch (type) {
|
||||
case "store_agent":
|
||||
return (
|
||||
<MarketplaceAgentBlock
|
||||
key={index}
|
||||
slug={data.slug}
|
||||
highlightedText={searchQuery}
|
||||
title={data.agent_name}
|
||||
image_url={data.agent_image}
|
||||
creator_name={data.creator}
|
||||
number_of_runs={data.runs}
|
||||
loading={addingMarketplaceAgentSlug === data.slug}
|
||||
onClick={() =>
|
||||
handleAddMarketplaceAgent({
|
||||
creator_name: data.creator,
|
||||
slug: data.slug,
|
||||
})
|
||||
}
|
||||
/>
|
||||
);
|
||||
case "block":
|
||||
return (
|
||||
<Block
|
||||
key={index}
|
||||
title={data.name}
|
||||
highlightedText={searchQuery}
|
||||
description={data.description}
|
||||
blockData={data}
|
||||
/>
|
||||
);
|
||||
|
||||
case "library_agent":
|
||||
return (
|
||||
<UGCAgentBlock
|
||||
key={index}
|
||||
title={data.name}
|
||||
highlightedText={searchQuery}
|
||||
image_url={data.image_url}
|
||||
version={data.graph_version}
|
||||
edited_time={data.updated_at}
|
||||
isLoading={addingLibraryAgentId === data.id}
|
||||
onClick={() => handleAddLibraryAgent(data)}
|
||||
/>
|
||||
);
|
||||
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
})}
|
||||
</InfiniteScroll>
|
||||
);
|
||||
};
|
||||
@@ -1,9 +1,7 @@
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { XIcon } from "@phosphor-icons/react";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
|
||||
import React, { ButtonHTMLAttributes, useState } from "react";
|
||||
import { X } from "lucide-react";
|
||||
import React, { ButtonHTMLAttributes } from "react";
|
||||
|
||||
interface Props extends ButtonHTMLAttributes<HTMLButtonElement> {
|
||||
selected?: boolean;
|
||||
@@ -18,51 +16,39 @@ export const FilterChip: React.FC<Props> = ({
|
||||
className,
|
||||
...rest
|
||||
}) => {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
return (
|
||||
<AnimatePresence mode="wait">
|
||||
<Button
|
||||
onMouseEnter={() => setIsHovered(true)}
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
<Button
|
||||
className={cn(
|
||||
"group w-fit space-x-1 rounded-[1.5rem] border border-zinc-300 bg-transparent px-[0.625rem] py-[0.375rem] shadow-none transition-transform duration-300 ease-in-out",
|
||||
"hover:border-violet-500 hover:bg-transparent focus:ring-0 disabled:cursor-not-allowed",
|
||||
selected && "border-0 bg-violet-700 hover:border",
|
||||
className,
|
||||
)}
|
||||
{...rest}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
"group w-fit space-x-1 rounded-[1.5rem] border border-zinc-300 bg-transparent px-[0.625rem] py-[0.375rem] shadow-none",
|
||||
"hover:border-violet-500 hover:bg-transparent focus:ring-0 disabled:cursor-not-allowed",
|
||||
selected && "border-0 bg-violet-700 hover:border",
|
||||
className,
|
||||
"font-sans text-sm font-medium leading-[1.375rem] text-zinc-600 group-hover:text-zinc-600 group-disabled:text-zinc-400",
|
||||
selected && "text-zinc-50",
|
||||
)}
|
||||
{...rest}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
"font-sans text-sm font-medium leading-[1.375rem] text-zinc-600 group-hover:text-zinc-600 group-disabled:text-zinc-400",
|
||||
selected && "text-zinc-50",
|
||||
{name}
|
||||
</span>
|
||||
{selected && (
|
||||
<>
|
||||
<span className="flex h-4 w-4 items-center justify-center rounded-full bg-zinc-50 transition-all duration-300 ease-in-out group-hover:hidden">
|
||||
<X
|
||||
className="h-3 w-3 rounded-full text-violet-700"
|
||||
strokeWidth={2}
|
||||
/>
|
||||
</span>
|
||||
{number !== undefined && (
|
||||
<span className="hidden h-[1.375rem] items-center rounded-[1.25rem] bg-violet-700 p-[0.375rem] text-zinc-50 transition-all duration-300 ease-in-out animate-in fade-in zoom-in group-hover:flex">
|
||||
{number > 100 ? "100+" : number}
|
||||
</span>
|
||||
)}
|
||||
>
|
||||
{name}
|
||||
</span>
|
||||
{selected && !isHovered && (
|
||||
<motion.span
|
||||
initial={{ opacity: 0.5, scale: 0.5, filter: "blur(20px)" }}
|
||||
animate={{ opacity: 1, scale: 1, filter: "blur(0px)" }}
|
||||
exit={{ opacity: 0.5, scale: 0.5, filter: "blur(20px)" }}
|
||||
transition={{ duration: 0.3, type: "spring", bounce: 0.2 }}
|
||||
className="flex h-4 w-4 items-center justify-center rounded-full bg-zinc-50"
|
||||
>
|
||||
<XIcon size={12} weight="bold" className="text-violet-700" />
|
||||
</motion.span>
|
||||
)}
|
||||
{number !== undefined && isHovered && (
|
||||
<motion.span
|
||||
initial={{ opacity: 0.5, scale: 0.5, filter: "blur(10px)" }}
|
||||
animate={{ opacity: 1, scale: 1, filter: "blur(0px)" }}
|
||||
exit={{ opacity: 0.5, scale: 0.5, filter: "blur(10px)" }}
|
||||
transition={{ duration: 0.3, type: "spring", bounce: 0.2 }}
|
||||
className="flex h-[1.375rem] items-center rounded-[1.25rem] bg-violet-700 p-[0.375rem] text-zinc-50"
|
||||
>
|
||||
{number > 100 ? "100+" : number}
|
||||
</motion.span>
|
||||
)}
|
||||
</Button>
|
||||
</AnimatePresence>
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
import { FilterChip } from "../FilterChip";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CategoryKey } from "../BlockMenuFilters/types";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { XIcon } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
import { Checkbox } from "@/components/__legacy__/ui/checkbox";
|
||||
import { useFilterSheet } from "./useFilterSheet";
|
||||
import { INITIAL_CREATORS_TO_SHOW } from "./constant";
|
||||
|
||||
export function FilterSheet({
|
||||
categories,
|
||||
}: {
|
||||
categories: Array<{ key: CategoryKey; name: string }>;
|
||||
}) {
|
||||
const {
|
||||
isOpen,
|
||||
localCategories,
|
||||
localCreators,
|
||||
displayedCreatorsCount,
|
||||
handleLocalCategoryChange,
|
||||
handleToggleShowMoreCreators,
|
||||
handleLocalCreatorChange,
|
||||
handleClearFilters,
|
||||
handleCloseButton,
|
||||
handleApplyFilters,
|
||||
hasLocalActiveFilters,
|
||||
visibleCreators,
|
||||
creators,
|
||||
handleOpenFilters,
|
||||
hasActiveFilters,
|
||||
} = useFilterSheet();
|
||||
|
||||
return (
|
||||
<div className="m-0 inline w-fit p-0">
|
||||
<FilterChip
|
||||
name={hasActiveFilters() ? "Edit filters" : "All filters"}
|
||||
onClick={handleOpenFilters}
|
||||
/>
|
||||
|
||||
<AnimatePresence>
|
||||
{isOpen && (
|
||||
<motion.div
|
||||
className={cn(
|
||||
"absolute bottom-2 left-2 top-2 z-20 w-3/4 max-w-[22.5rem] space-y-4 overflow-hidden rounded-[0.75rem] bg-white pb-4 shadow-[0_4px_12px_2px_rgba(0,0,0,0.1)]",
|
||||
)}
|
||||
initial={{ x: "-100%", filter: "blur(10px)" }}
|
||||
animate={{ x: 0, filter: "blur(0px)" }}
|
||||
exit={{ x: "-110%", filter: "blur(10px)" }}
|
||||
transition={{ duration: 0.4, type: "spring", bounce: 0.2 }}
|
||||
>
|
||||
{/* Top section */}
|
||||
<div className="flex items-center justify-between px-5 pt-4">
|
||||
<Text variant="body">Filters</Text>
|
||||
<Button
|
||||
className="p-0"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={handleCloseButton}
|
||||
>
|
||||
<XIcon size={20} />
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<Separator className="h-[1px] w-full text-zinc-300" />
|
||||
|
||||
{/* Category section */}
|
||||
<div className="space-y-4 px-5">
|
||||
<Text variant="large">Categories</Text>
|
||||
<div className="space-y-2">
|
||||
{categories.map((category) => (
|
||||
<div
|
||||
key={category.key}
|
||||
className="flex items-center space-x-2"
|
||||
>
|
||||
<Checkbox
|
||||
id={category.key}
|
||||
checked={localCategories.includes(category.key)}
|
||||
onCheckedChange={() =>
|
||||
handleLocalCategoryChange(category.key)
|
||||
}
|
||||
className="border border-[#D4D4D4] shadow-none data-[state=checked]:border-none data-[state=checked]:bg-violet-700 data-[state=checked]:text-white"
|
||||
/>
|
||||
<label
|
||||
htmlFor={category.key}
|
||||
className="font-sans text-sm leading-[1.375rem] text-zinc-600"
|
||||
>
|
||||
{category.name}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Created by section */}
|
||||
<div className="space-y-4 px-5">
|
||||
<p className="font-sans text-base font-medium text-zinc-800">
|
||||
Created by
|
||||
</p>
|
||||
<div className="space-y-2">
|
||||
{visibleCreators.map((creator, i) => (
|
||||
<div key={i} className="flex items-center space-x-2">
|
||||
<Checkbox
|
||||
id={`creator-${creator}`}
|
||||
checked={localCreators.includes(creator)}
|
||||
onCheckedChange={() => handleLocalCreatorChange(creator)}
|
||||
className="border border-[#D4D4D4] shadow-none data-[state=checked]:border-none data-[state=checked]:bg-violet-700 data-[state=checked]:text-white"
|
||||
/>
|
||||
<label
|
||||
htmlFor={`creator-${creator}`}
|
||||
className="font-sans text-sm leading-[1.375rem] text-zinc-600"
|
||||
>
|
||||
{creator}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
{creators.length > INITIAL_CREATORS_TO_SHOW && (
|
||||
<Button
|
||||
variant={"link"}
|
||||
className="m-0 p-0 font-sans text-sm font-medium leading-[1.375rem] text-zinc-800 underline hover:text-zinc-600"
|
||||
onClick={handleToggleShowMoreCreators}
|
||||
>
|
||||
{displayedCreatorsCount < creators.length ? "More" : "Less"}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Footer section */}
|
||||
<div className="fixed bottom-0 flex w-full justify-between gap-3 border-t border-zinc-200 bg-white px-5 py-3">
|
||||
<Button
|
||||
size="small"
|
||||
variant={"outline"}
|
||||
onClick={handleClearFilters}
|
||||
className="rounded-[8px] px-2 py-1.5"
|
||||
>
|
||||
Clear
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
size="small"
|
||||
onClick={handleApplyFilters}
|
||||
disabled={!hasLocalActiveFilters()}
|
||||
className="rounded-[8px] px-2 py-1.5"
|
||||
>
|
||||
Apply filters
|
||||
</Button>
|
||||
</div>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
export const INITIAL_CREATORS_TO_SHOW = 5;
|
||||
@@ -1,100 +0,0 @@
|
||||
import { useBlockMenuStore } from "@/app/(platform)/build/stores/blockMenuStore";
|
||||
import { useState } from "react";
|
||||
import { INITIAL_CREATORS_TO_SHOW } from "./constant";
|
||||
import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem";
|
||||
|
||||
export const useFilterSheet = () => {
|
||||
const { filters, creators_list, creators, setFilters, setCreators } =
|
||||
useBlockMenuStore();
|
||||
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [localCategories, setLocalCategories] =
|
||||
useState<GetV2BuilderSearchFilterAnyOfItem[]>(filters);
|
||||
const [localCreators, setLocalCreators] = useState<string[]>(creators);
|
||||
const [displayedCreatorsCount, setDisplayedCreatorsCount] = useState(
|
||||
INITIAL_CREATORS_TO_SHOW,
|
||||
);
|
||||
|
||||
const handleLocalCategoryChange = (
|
||||
category: GetV2BuilderSearchFilterAnyOfItem,
|
||||
) => {
|
||||
setLocalCategories((prev) => {
|
||||
if (prev.includes(category)) {
|
||||
return prev.filter((c) => c !== category);
|
||||
}
|
||||
return [...prev, category];
|
||||
});
|
||||
};
|
||||
|
||||
const hasActiveFilters = () => {
|
||||
return filters.length > 0 || creators.length > 0;
|
||||
};
|
||||
|
||||
const handleToggleShowMoreCreators = () => {
|
||||
if (displayedCreatorsCount < creators.length) {
|
||||
setDisplayedCreatorsCount(creators.length);
|
||||
} else {
|
||||
setDisplayedCreatorsCount(INITIAL_CREATORS_TO_SHOW);
|
||||
}
|
||||
};
|
||||
|
||||
const handleLocalCreatorChange = (creator: string) => {
|
||||
setLocalCreators((prev) => {
|
||||
if (prev.includes(creator)) {
|
||||
return prev.filter((c) => c !== creator);
|
||||
}
|
||||
return [...prev, creator];
|
||||
});
|
||||
};
|
||||
|
||||
const handleClearFilters = () => {
|
||||
setLocalCategories([]);
|
||||
setLocalCreators([]);
|
||||
setDisplayedCreatorsCount(INITIAL_CREATORS_TO_SHOW);
|
||||
};
|
||||
|
||||
const handleCloseButton = () => {
|
||||
setIsOpen(false);
|
||||
setLocalCategories(filters);
|
||||
setLocalCreators(creators);
|
||||
setDisplayedCreatorsCount(INITIAL_CREATORS_TO_SHOW);
|
||||
};
|
||||
|
||||
const handleApplyFilters = () => {
|
||||
setFilters(localCategories);
|
||||
setCreators(localCreators);
|
||||
setIsOpen(false);
|
||||
};
|
||||
|
||||
const handleOpenFilters = () => {
|
||||
setIsOpen(true);
|
||||
setLocalCategories(filters);
|
||||
setLocalCreators(creators);
|
||||
};
|
||||
|
||||
const hasLocalActiveFilters = () => {
|
||||
return localCategories.length > 0 || localCreators.length > 0;
|
||||
};
|
||||
|
||||
const visibleCreators = creators_list.slice(0, displayedCreatorsCount);
|
||||
|
||||
return {
|
||||
creators,
|
||||
isOpen,
|
||||
setIsOpen,
|
||||
localCategories,
|
||||
localCreators,
|
||||
displayedCreatorsCount,
|
||||
setDisplayedCreatorsCount,
|
||||
handleLocalCategoryChange,
|
||||
handleToggleShowMoreCreators,
|
||||
handleLocalCreatorChange,
|
||||
handleClearFilters,
|
||||
handleCloseButton,
|
||||
handleOpenFilters,
|
||||
handleApplyFilters,
|
||||
hasLocalActiveFilters,
|
||||
visibleCreators,
|
||||
hasActiveFilters,
|
||||
};
|
||||
};
|
||||
@@ -1,9 +1,9 @@
|
||||
import type { OutputMetadata } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputActions,
|
||||
OutputItem,
|
||||
} from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/OutputRenderers";
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
|
||||
@@ -3,7 +3,6 @@ import {
|
||||
CustomNodeData,
|
||||
} from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
|
||||
import { NodeTableInput } from "@/app/(platform)/build/components/legacy-builder/NodeTableInput";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { Calendar } from "@/components/__legacy__/ui/calendar";
|
||||
import { LocalValuedInput } from "@/components/__legacy__/ui/input";
|
||||
@@ -28,6 +27,7 @@ import {
|
||||
SelectValue,
|
||||
} from "@/components/__legacy__/ui/select";
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
import { CredentialsInput } from "@/components/contextual/CredentialsInputs/CredentialsInputs";
|
||||
import { GoogleDrivePickerInput } from "@/components/contextual/GoogleDrivePicker/GoogleDrivePickerInput";
|
||||
import {
|
||||
BlockIOArraySubSchema,
|
||||
|
||||
@@ -1,30 +1,12 @@
|
||||
import { create } from "zustand";
|
||||
import { DefaultStateType } from "../components/NewControlPanel/NewBlockMenu/types";
|
||||
import { SearchResponseItemsItem } from "@/app/api/__generated__/models/searchResponseItemsItem";
|
||||
import { getSearchItemType } from "../components/NewControlPanel/NewBlockMenu/BlockMenuSearchContent/helper";
|
||||
import { StoreAgent } from "@/app/api/__generated__/models/storeAgent";
|
||||
import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem";
|
||||
|
||||
type BlockMenuStore = {
|
||||
searchQuery: string;
|
||||
searchId: string | undefined;
|
||||
defaultState: DefaultStateType;
|
||||
integration: string | undefined;
|
||||
filters: GetV2BuilderSearchFilterAnyOfItem[];
|
||||
creators: string[];
|
||||
creators_list: string[];
|
||||
categoryCounts: Record<GetV2BuilderSearchFilterAnyOfItem, number>;
|
||||
|
||||
setCategoryCounts: (
|
||||
counts: Record<GetV2BuilderSearchFilterAnyOfItem, number>,
|
||||
) => void;
|
||||
setCreatorsList: (searchData: SearchResponseItemsItem[]) => void;
|
||||
addCreator: (creator: string) => void;
|
||||
setCreators: (creators: string[]) => void;
|
||||
removeCreator: (creator: string) => void;
|
||||
addFilter: (filter: GetV2BuilderSearchFilterAnyOfItem) => void;
|
||||
setFilters: (filters: GetV2BuilderSearchFilterAnyOfItem[]) => void;
|
||||
removeFilter: (filter: GetV2BuilderSearchFilterAnyOfItem) => void;
|
||||
setSearchQuery: (query: string) => void;
|
||||
setSearchId: (id: string | undefined) => void;
|
||||
setDefaultState: (state: DefaultStateType) => void;
|
||||
@@ -37,44 +19,11 @@ export const useBlockMenuStore = create<BlockMenuStore>((set) => ({
|
||||
searchId: undefined,
|
||||
defaultState: DefaultStateType.SUGGESTION,
|
||||
integration: undefined,
|
||||
filters: [],
|
||||
creators: [], // creator filters that are applied to the search results
|
||||
creators_list: [], // all creators that are available to filter by
|
||||
categoryCounts: {
|
||||
blocks: 0,
|
||||
integrations: 0,
|
||||
marketplace_agents: 0,
|
||||
my_agents: 0,
|
||||
},
|
||||
|
||||
setCategoryCounts: (counts) => set({ categoryCounts: counts }),
|
||||
setCreatorsList: (searchData) => {
|
||||
const marketplaceAgents = searchData.filter((item) => {
|
||||
return getSearchItemType(item).type === "store_agent";
|
||||
}) as StoreAgent[];
|
||||
|
||||
const newCreators = marketplaceAgents.map((agent) => agent.creator);
|
||||
|
||||
set((state) => ({
|
||||
creators_list: Array.from(
|
||||
new Set([...state.creators_list, ...newCreators]),
|
||||
),
|
||||
}));
|
||||
},
|
||||
setCreators: (creators) => set({ creators }),
|
||||
setFilters: (filters) => set({ filters }),
|
||||
setSearchQuery: (query) => set({ searchQuery: query }),
|
||||
setSearchId: (id) => set({ searchId: id }),
|
||||
setDefaultState: (state) => set({ defaultState: state }),
|
||||
setIntegration: (integration) => set({ integration }),
|
||||
addFilter: (filter) =>
|
||||
set((state) => ({ filters: [...state.filters, filter] })),
|
||||
removeFilter: (filter) =>
|
||||
set((state) => ({ filters: state.filters.filter((f) => f !== filter) })),
|
||||
addCreator: (creator) =>
|
||||
set((state) => ({ creators: [...state.creators, creator] })),
|
||||
removeCreator: (creator) =>
|
||||
set((state) => ({ creators: state.creators.filter((c) => c !== creator) })),
|
||||
reset: () =>
|
||||
set({
|
||||
searchQuery: "",
|
||||
|
||||
@@ -68,9 +68,6 @@ type NodeStore = {
|
||||
clearAllNodeErrors: () => void; // Add this
|
||||
|
||||
syncHardcodedValuesWithHandleIds: (nodeId: string) => void;
|
||||
|
||||
// Credentials optional helpers
|
||||
setCredentialsOptional: (nodeId: string, optional: boolean) => void;
|
||||
};
|
||||
|
||||
export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
@@ -229,9 +226,6 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
...(node.data.metadata?.customized_name !== undefined && {
|
||||
customized_name: node.data.metadata.customized_name,
|
||||
}),
|
||||
...(node.data.metadata?.credentials_optional !== undefined && {
|
||||
credentials_optional: node.data.metadata.credentials_optional,
|
||||
}),
|
||||
},
|
||||
};
|
||||
},
|
||||
@@ -348,30 +342,4 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
}));
|
||||
}
|
||||
},
|
||||
|
||||
setCredentialsOptional: (nodeId: string, optional: boolean) => {
|
||||
set((state) => ({
|
||||
nodes: state.nodes.map((n) =>
|
||||
n.id === nodeId
|
||||
? {
|
||||
...n,
|
||||
data: {
|
||||
...n.data,
|
||||
metadata: {
|
||||
...n.data.metadata,
|
||||
credentials_optional: optional,
|
||||
},
|
||||
},
|
||||
}
|
||||
: n,
|
||||
),
|
||||
}));
|
||||
|
||||
const newState = {
|
||||
nodes: get().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
|
||||
useHistoryStore.getState().pushState(newState);
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
import { useState, useCallback, useRef, useMemo } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { useChatStream } from "@/app/(platform)/chat/useChatStream";
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import type { ChatMessageData } from "@/app/(platform)/chat/components/ChatMessage/useChatMessage";
|
||||
import {
|
||||
parseToolResponse,
|
||||
isValidMessage,
|
||||
isToolCallArray,
|
||||
createUserMessage,
|
||||
filterAuthMessages,
|
||||
} from "./helpers";
|
||||
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
||||
|
||||
interface UseChatContainerArgs {
|
||||
sessionId: string | null;
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
onRefreshSession: () => Promise<void>;
|
||||
}
|
||||
|
||||
export function useChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
}: UseChatContainerArgs) {
|
||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||
const [hasTextChunks, setHasTextChunks] = useState(false);
|
||||
const streamingChunksRef = useRef<string[]>([]);
|
||||
const { error, sendMessage: sendStreamMessage } = useChatStream();
|
||||
const isStreaming = hasTextChunks;
|
||||
|
||||
const allMessages = useMemo(() => {
|
||||
const processedInitialMessages = initialMessages
|
||||
.filter((msg: Record<string, unknown>) => {
|
||||
if (!isValidMessage(msg)) {
|
||||
console.warn("Invalid message structure from backend:", msg);
|
||||
return false;
|
||||
}
|
||||
const content = String(msg.content || "").trim();
|
||||
const toolCalls = msg.tool_calls;
|
||||
return (
|
||||
content.length > 0 ||
|
||||
(toolCalls && Array.isArray(toolCalls) && toolCalls.length > 0)
|
||||
);
|
||||
})
|
||||
.map((msg: Record<string, unknown>) => {
|
||||
const content = String(msg.content || "");
|
||||
const role = String(msg.role || "assistant").toLowerCase();
|
||||
const toolCalls = msg.tool_calls;
|
||||
if (
|
||||
role === "assistant" &&
|
||||
toolCalls &&
|
||||
isToolCallArray(toolCalls) &&
|
||||
toolCalls.length > 0
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
if (role === "tool") {
|
||||
const timestamp = msg.timestamp
|
||||
? new Date(msg.timestamp as string)
|
||||
: undefined;
|
||||
const toolResponse = parseToolResponse(
|
||||
content,
|
||||
(msg.tool_call_id as string) || "",
|
||||
"unknown",
|
||||
timestamp,
|
||||
);
|
||||
if (!toolResponse) {
|
||||
return null;
|
||||
}
|
||||
return toolResponse;
|
||||
}
|
||||
return {
|
||||
type: "message",
|
||||
role: role as "user" | "assistant" | "system",
|
||||
content,
|
||||
timestamp: msg.timestamp
|
||||
? new Date(msg.timestamp as string)
|
||||
: undefined,
|
||||
};
|
||||
})
|
||||
.filter((msg): msg is ChatMessageData => msg !== null);
|
||||
|
||||
return [...processedInitialMessages, ...messages];
|
||||
}, [initialMessages, messages]);
|
||||
|
||||
const sendMessage = useCallback(
|
||||
async function sendMessage(content: string, isUserMessage: boolean = true) {
|
||||
if (!sessionId) {
|
||||
console.error("Cannot send message: no session ID");
|
||||
return;
|
||||
}
|
||||
if (isUserMessage) {
|
||||
const userMessage = createUserMessage(content);
|
||||
setMessages((prev) => [...filterAuthMessages(prev), userMessage]);
|
||||
} else {
|
||||
setMessages((prev) => filterAuthMessages(prev));
|
||||
}
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
setMessages,
|
||||
sessionId,
|
||||
});
|
||||
try {
|
||||
await sendStreamMessage(sessionId, content, dispatcher, isUserMessage);
|
||||
} catch (err) {
|
||||
console.error("Failed to send message:", err);
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to send message";
|
||||
toast.error("Failed to send message", {
|
||||
description: errorMessage,
|
||||
});
|
||||
}
|
||||
},
|
||||
[sessionId, sendStreamMessage],
|
||||
);
|
||||
|
||||
return {
|
||||
messages: allMessages,
|
||||
streamingChunks,
|
||||
isStreaming,
|
||||
error,
|
||||
sendMessage,
|
||||
};
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user