mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-15 01:58:23 -05:00
Compare commits
1 Commits
dev
...
copilot-ba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5d3903b6fb |
@@ -1,37 +0,0 @@
|
|||||||
{
|
|
||||||
"worktreeCopyPatterns": [
|
|
||||||
".env*",
|
|
||||||
".vscode/**",
|
|
||||||
".auth/**",
|
|
||||||
".claude/**",
|
|
||||||
"autogpt_platform/.env*",
|
|
||||||
"autogpt_platform/backend/.env*",
|
|
||||||
"autogpt_platform/frontend/.env*",
|
|
||||||
"autogpt_platform/frontend/.auth/**",
|
|
||||||
"autogpt_platform/db/docker/.env*"
|
|
||||||
],
|
|
||||||
"worktreeCopyIgnores": [
|
|
||||||
"**/node_modules/**",
|
|
||||||
"**/dist/**",
|
|
||||||
"**/.git/**",
|
|
||||||
"**/Thumbs.db",
|
|
||||||
"**/.DS_Store",
|
|
||||||
"**/.next/**",
|
|
||||||
"**/__pycache__/**",
|
|
||||||
"**/.ruff_cache/**",
|
|
||||||
"**/.pytest_cache/**",
|
|
||||||
"**/*.pyc",
|
|
||||||
"**/playwright-report/**",
|
|
||||||
"**/logs/**",
|
|
||||||
"**/site/**"
|
|
||||||
],
|
|
||||||
"worktreePathTemplate": "$BASE_PATH.worktree",
|
|
||||||
"postCreateCmd": [
|
|
||||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
|
||||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
|
||||||
"cd autogpt_platform/frontend && pnpm install",
|
|
||||||
"cd docs && pip install -r requirements.txt"
|
|
||||||
],
|
|
||||||
"terminalCommand": "code .",
|
|
||||||
"deleteBranchWithWorktree": false
|
|
||||||
}
|
|
||||||
@@ -16,7 +16,6 @@
|
|||||||
!autogpt_platform/backend/poetry.lock
|
!autogpt_platform/backend/poetry.lock
|
||||||
!autogpt_platform/backend/README.md
|
!autogpt_platform/backend/README.md
|
||||||
!autogpt_platform/backend/.env
|
!autogpt_platform/backend/.env
|
||||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
|
||||||
|
|
||||||
# Platform - Market
|
# Platform - Market
|
||||||
!autogpt_platform/market/market/
|
!autogpt_platform/market/market/
|
||||||
|
|||||||
2
.github/workflows/claude-dependabot.yml
vendored
2
.github/workflows/claude-dependabot.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Generate Prisma Client
|
- name: Generate Prisma Client
|
||||||
working-directory: autogpt_platform/backend
|
working-directory: autogpt_platform/backend
|
||||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
run: poetry run prisma generate
|
||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
|
|||||||
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Generate Prisma Client
|
- name: Generate Prisma Client
|
||||||
working-directory: autogpt_platform/backend
|
working-directory: autogpt_platform/backend
|
||||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
run: poetry run prisma generate
|
||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
|
|||||||
12
.github/workflows/copilot-setup-steps.yml
vendored
12
.github/workflows/copilot-setup-steps.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Generate Prisma Client
|
- name: Generate Prisma Client
|
||||||
working-directory: autogpt_platform/backend
|
working-directory: autogpt_platform/backend
|
||||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
run: poetry run prisma generate
|
||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
@@ -108,16 +108,6 @@ jobs:
|
|||||||
# run: pnpm playwright install --with-deps chromium
|
# run: pnpm playwright install --with-deps chromium
|
||||||
|
|
||||||
# Docker setup for development environment
|
# Docker setup for development environment
|
||||||
- name: Free up disk space
|
|
||||||
run: |
|
|
||||||
# Remove large unused tools to free disk space for Docker builds
|
|
||||||
sudo rm -rf /usr/share/dotnet
|
|
||||||
sudo rm -rf /usr/local/lib/android
|
|
||||||
sudo rm -rf /opt/ghc
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
|
||||||
sudo docker system prune -af
|
|
||||||
df -h
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
|||||||
4
.github/workflows/platform-backend-ci.yml
vendored
4
.github/workflows/platform-backend-ci.yml
vendored
@@ -134,7 +134,7 @@ jobs:
|
|||||||
run: poetry install
|
run: poetry install
|
||||||
|
|
||||||
- name: Generate Prisma Client
|
- name: Generate Prisma Client
|
||||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
run: poetry run prisma generate
|
||||||
|
|
||||||
- id: supabase
|
- id: supabase
|
||||||
name: Start Supabase
|
name: Start Supabase
|
||||||
@@ -176,7 +176,7 @@ jobs:
|
|||||||
}
|
}
|
||||||
|
|
||||||
- name: Run Database Migrations
|
- name: Run Database Migrations
|
||||||
run: poetry run prisma migrate deploy
|
run: poetry run prisma migrate dev --name updates
|
||||||
env:
|
env:
|
||||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||||
|
|||||||
9
.github/workflows/platform-frontend-ci.yml
vendored
9
.github/workflows/platform-frontend-ci.yml
vendored
@@ -11,7 +11,6 @@ on:
|
|||||||
- ".github/workflows/platform-frontend-ci.yml"
|
- ".github/workflows/platform-frontend-ci.yml"
|
||||||
- "autogpt_platform/frontend/**"
|
- "autogpt_platform/frontend/**"
|
||||||
merge_group:
|
merge_group:
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }}
|
group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }}
|
||||||
@@ -152,14 +151,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
cp ../.env.default ../.env
|
cp ../.env.default ../.env
|
||||||
|
|
||||||
- name: Copy backend .env and set OpenAI API key
|
|
||||||
run: |
|
|
||||||
cp ../backend/.env.default ../backend/.env
|
|
||||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
|
||||||
env:
|
|
||||||
# Used by E2E test data script to generate embeddings for approved store agents
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
|||||||
@@ -6,13 +6,11 @@ start-core:
|
|||||||
|
|
||||||
# Stop core services
|
# Stop core services
|
||||||
stop-core:
|
stop-core:
|
||||||
docker compose stop deps
|
docker compose stop
|
||||||
|
|
||||||
reset-db:
|
reset-db:
|
||||||
|
docker compose stop db
|
||||||
rm -rf db/docker/volumes/db/data
|
rm -rf db/docker/volumes/db/data
|
||||||
cd backend && poetry run prisma migrate deploy
|
|
||||||
cd backend && poetry run prisma generate
|
|
||||||
cd backend && poetry run gen-prisma-stub
|
|
||||||
|
|
||||||
# View logs for core services
|
# View logs for core services
|
||||||
logs-core:
|
logs-core:
|
||||||
@@ -34,7 +32,6 @@ init-env:
|
|||||||
migrate:
|
migrate:
|
||||||
cd backend && poetry run prisma migrate deploy
|
cd backend && poetry run prisma migrate deploy
|
||||||
cd backend && poetry run prisma generate
|
cd backend && poetry run prisma generate
|
||||||
cd backend && poetry run gen-prisma-stub
|
|
||||||
|
|
||||||
run-backend:
|
run-backend:
|
||||||
cd backend && poetry run app
|
cd backend && poetry run app
|
||||||
@@ -60,4 +57,4 @@ help:
|
|||||||
@echo " run-backend - Run the backend FastAPI server"
|
@echo " run-backend - Run the backend FastAPI server"
|
||||||
@echo " run-frontend - Run the frontend Next.js development server"
|
@echo " run-frontend - Run the frontend Next.js development server"
|
||||||
@echo " test-data - Run the test data creator"
|
@echo " test-data - Run the test data creator"
|
||||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||||
|
|||||||
1
autogpt_platform/backend/.gitignore
vendored
1
autogpt_platform/backend/.gitignore
vendored
@@ -18,4 +18,3 @@ load-tests/results/
|
|||||||
load-tests/*.json
|
load-tests/*.json
|
||||||
load-tests/*.log
|
load-tests/*.log
|
||||||
load-tests/node_modules/*
|
load-tests/node_modules/*
|
||||||
migrations/*/rollback*.sql
|
|
||||||
|
|||||||
@@ -48,8 +48,7 @@ RUN poetry install --no-ansi --no-root
|
|||||||
# Generate Prisma client
|
# Generate Prisma client
|
||||||
COPY autogpt_platform/backend/schema.prisma ./
|
COPY autogpt_platform/backend/schema.prisma ./
|
||||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
RUN poetry run prisma generate
|
||||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
|
||||||
|
|
||||||
FROM debian:13-slim AS server_dependencies
|
FROM debian:13-slim AS server_dependencies
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,11 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
# OpenAI API Configuration
|
# OpenAI API Configuration
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="qwen/qwen3-235b-a22b-2507", description="Default model to use"
|
default="anthropic/claude-opus-4.5", description="Default model to use"
|
||||||
|
)
|
||||||
|
title_model: str = Field(
|
||||||
|
default="openai/gpt-4o-mini",
|
||||||
|
description="Model to use for generating session titles (should be fast/cheap)",
|
||||||
)
|
)
|
||||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||||
base_url: str | None = Field(
|
base_url: str | None = Field(
|
||||||
@@ -72,8 +76,31 @@ class ChatConfig(BaseSettings):
|
|||||||
v = "https://openrouter.ai/api/v1"
|
v = "https://openrouter.ai/api/v1"
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
# Prompt paths for different contexts
|
||||||
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
|
"default": "prompts/chat_system.md",
|
||||||
|
"onboarding": "prompts/onboarding_system.md",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_system_prompt_for_type(
|
||||||
|
self, prompt_type: str = "default", **template_vars
|
||||||
|
) -> str:
|
||||||
|
"""Load and render a system prompt by type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_type: The type of prompt to load ("default" or "onboarding")
|
||||||
|
**template_vars: Variables to substitute in the template
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered system prompt string
|
||||||
|
"""
|
||||||
|
prompt_path_str = self.PROMPT_PATHS.get(
|
||||||
|
prompt_type, self.PROMPT_PATHS["default"]
|
||||||
|
)
|
||||||
|
return self._load_prompt_from_path(prompt_path_str, **template_vars)
|
||||||
|
|
||||||
def get_system_prompt(self, **template_vars) -> str:
|
def get_system_prompt(self, **template_vars) -> str:
|
||||||
"""Load and render the system prompt from file.
|
"""Load and render the default system prompt from file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**template_vars: Variables to substitute in the template
|
**template_vars: Variables to substitute in the template
|
||||||
@@ -82,9 +109,21 @@ class ChatConfig(BaseSettings):
|
|||||||
Rendered system prompt string
|
Rendered system prompt string
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
return self._load_prompt_from_path(self.system_prompt_path, **template_vars)
|
||||||
|
|
||||||
|
def _load_prompt_from_path(self, prompt_path_str: str, **template_vars) -> str:
|
||||||
|
"""Load and render a system prompt from a given path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_path_str: Path to the prompt file relative to chat module
|
||||||
|
**template_vars: Variables to substitute in the template
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered system prompt string
|
||||||
|
"""
|
||||||
# Get the path relative to this module
|
# Get the path relative to this module
|
||||||
module_dir = Path(__file__).parent
|
module_dir = Path(__file__).parent
|
||||||
prompt_path = module_dir / self.system_prompt_path
|
prompt_path = module_dir / prompt_path_str
|
||||||
|
|
||||||
# Check for .j2 extension first (Jinja2 template)
|
# Check for .j2 extension first (Jinja2 template)
|
||||||
j2_path = Path(str(prompt_path) + ".j2")
|
j2_path = Path(str(prompt_path) + ".j2")
|
||||||
|
|||||||
215
autogpt_platform/backend/backend/api/features/chat/db.py
Normal file
215
autogpt_platform/backend/backend/api/features/chat/db.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
"""Database operations for chat sessions."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from prisma.models import ChatMessage as PrismaChatMessage
|
||||||
|
from prisma.models import ChatSession as PrismaChatSession
|
||||||
|
from prisma.types import (
|
||||||
|
ChatMessageCreateInput,
|
||||||
|
ChatSessionCreateInput,
|
||||||
|
ChatSessionUpdateInput,
|
||||||
|
)
|
||||||
|
|
||||||
|
from backend.util.json import SafeJson
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||||
|
"""Get a chat session by ID from the database."""
|
||||||
|
session = await PrismaChatSession.prisma().find_unique(
|
||||||
|
where={"id": session_id},
|
||||||
|
include={"Messages": True},
|
||||||
|
)
|
||||||
|
if session and session.Messages:
|
||||||
|
# Sort messages by sequence in Python since Prisma doesn't support order_by in include
|
||||||
|
session.Messages.sort(key=lambda m: m.sequence)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def create_chat_session(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> PrismaChatSession:
|
||||||
|
"""Create a new chat session in the database."""
|
||||||
|
data = ChatSessionCreateInput(
|
||||||
|
id=session_id,
|
||||||
|
userId=user_id,
|
||||||
|
credentials=SafeJson({}),
|
||||||
|
successfulAgentRuns=SafeJson({}),
|
||||||
|
successfulAgentSchedules=SafeJson({}),
|
||||||
|
)
|
||||||
|
return await PrismaChatSession.prisma().create(
|
||||||
|
data=data,
|
||||||
|
include={"Messages": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_chat_session(
|
||||||
|
session_id: str,
|
||||||
|
credentials: dict[str, Any] | None = None,
|
||||||
|
successful_agent_runs: dict[str, Any] | None = None,
|
||||||
|
successful_agent_schedules: dict[str, Any] | None = None,
|
||||||
|
total_prompt_tokens: int | None = None,
|
||||||
|
total_completion_tokens: int | None = None,
|
||||||
|
title: str | None = None,
|
||||||
|
) -> PrismaChatSession | None:
|
||||||
|
"""Update a chat session's metadata."""
|
||||||
|
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||||
|
|
||||||
|
if credentials is not None:
|
||||||
|
data["credentials"] = SafeJson(credentials)
|
||||||
|
if successful_agent_runs is not None:
|
||||||
|
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
|
||||||
|
if successful_agent_schedules is not None:
|
||||||
|
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
|
||||||
|
if total_prompt_tokens is not None:
|
||||||
|
data["totalPromptTokens"] = total_prompt_tokens
|
||||||
|
if total_completion_tokens is not None:
|
||||||
|
data["totalCompletionTokens"] = total_completion_tokens
|
||||||
|
if title is not None:
|
||||||
|
data["title"] = title
|
||||||
|
|
||||||
|
session = await PrismaChatSession.prisma().update(
|
||||||
|
where={"id": session_id},
|
||||||
|
data=data,
|
||||||
|
include={"Messages": True},
|
||||||
|
)
|
||||||
|
if session and session.Messages:
|
||||||
|
session.Messages.sort(key=lambda m: m.sequence)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def add_chat_message(
|
||||||
|
session_id: str,
|
||||||
|
role: str,
|
||||||
|
sequence: int,
|
||||||
|
content: str | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
tool_call_id: str | None = None,
|
||||||
|
refusal: str | None = None,
|
||||||
|
tool_calls: list[dict[str, Any]] | None = None,
|
||||||
|
function_call: dict[str, Any] | None = None,
|
||||||
|
) -> PrismaChatMessage:
|
||||||
|
"""Add a message to a chat session."""
|
||||||
|
# Build the input dict dynamically - only include optional fields when they
|
||||||
|
# have values, as Prisma TypedDict validation fails when optional fields
|
||||||
|
# are explicitly set to None
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"Session": {"connect": {"id": session_id}},
|
||||||
|
"role": role,
|
||||||
|
"sequence": sequence,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional string fields
|
||||||
|
if content is not None:
|
||||||
|
data["content"] = content
|
||||||
|
if name is not None:
|
||||||
|
data["name"] = name
|
||||||
|
if tool_call_id is not None:
|
||||||
|
data["toolCallId"] = tool_call_id
|
||||||
|
if refusal is not None:
|
||||||
|
data["refusal"] = refusal
|
||||||
|
|
||||||
|
# Add optional JSON fields only when they have values
|
||||||
|
if tool_calls is not None:
|
||||||
|
data["toolCalls"] = SafeJson(tool_calls)
|
||||||
|
if function_call is not None:
|
||||||
|
data["functionCall"] = SafeJson(function_call)
|
||||||
|
|
||||||
|
# Update session's updatedAt timestamp
|
||||||
|
await PrismaChatSession.prisma().update(
|
||||||
|
where={"id": session_id},
|
||||||
|
data={"updatedAt": datetime.now(UTC)},
|
||||||
|
)
|
||||||
|
|
||||||
|
return await PrismaChatMessage.prisma().create(
|
||||||
|
data=cast(ChatMessageCreateInput, data)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def add_chat_messages_batch(
|
||||||
|
session_id: str,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
start_sequence: int,
|
||||||
|
) -> list[PrismaChatMessage]:
|
||||||
|
"""Add multiple messages to a chat session in a batch."""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
created_messages = []
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
# Build the input dict dynamically - only include optional JSON fields
|
||||||
|
# when they have values, as Prisma TypedDict validation fails when
|
||||||
|
# optional fields are explicitly set to None
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"Session": {"connect": {"id": session_id}},
|
||||||
|
"role": msg["role"],
|
||||||
|
"sequence": start_sequence + i,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional string fields
|
||||||
|
if msg.get("content") is not None:
|
||||||
|
data["content"] = msg["content"]
|
||||||
|
if msg.get("name") is not None:
|
||||||
|
data["name"] = msg["name"]
|
||||||
|
if msg.get("tool_call_id") is not None:
|
||||||
|
data["toolCallId"] = msg["tool_call_id"]
|
||||||
|
if msg.get("refusal") is not None:
|
||||||
|
data["refusal"] = msg["refusal"]
|
||||||
|
|
||||||
|
# Add optional JSON fields only when they have values
|
||||||
|
if msg.get("tool_calls") is not None:
|
||||||
|
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||||
|
if msg.get("function_call") is not None:
|
||||||
|
data["functionCall"] = SafeJson(msg["function_call"])
|
||||||
|
|
||||||
|
created = await PrismaChatMessage.prisma().create(
|
||||||
|
data=cast(ChatMessageCreateInput, data)
|
||||||
|
)
|
||||||
|
created_messages.append(created)
|
||||||
|
|
||||||
|
# Update session's updatedAt timestamp
|
||||||
|
await PrismaChatSession.prisma().update(
|
||||||
|
where={"id": session_id},
|
||||||
|
data={"updatedAt": datetime.now(UTC)},
|
||||||
|
)
|
||||||
|
|
||||||
|
return created_messages
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_chat_sessions(
|
||||||
|
user_id: str,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[PrismaChatSession]:
|
||||||
|
"""Get chat sessions for a user, ordered by most recent."""
|
||||||
|
return await PrismaChatSession.prisma().find_many(
|
||||||
|
where={"userId": user_id},
|
||||||
|
order={"updatedAt": "desc"},
|
||||||
|
take=limit,
|
||||||
|
skip=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_session_count(user_id: str) -> int:
|
||||||
|
"""Get the total number of chat sessions for a user."""
|
||||||
|
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_chat_session(session_id: str) -> bool:
|
||||||
|
"""Delete a chat session and all its messages."""
|
||||||
|
try:
|
||||||
|
await PrismaChatSession.prisma().delete(where={"id": session_id})
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chat_session_message_count(session_id: str) -> int:
|
||||||
|
"""Get the number of messages in a chat session."""
|
||||||
|
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||||
|
return count
|
||||||
@@ -16,11 +16,15 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
|||||||
ChatCompletionMessageToolCallParam,
|
ChatCompletionMessageToolCallParam,
|
||||||
Function,
|
Function,
|
||||||
)
|
)
|
||||||
|
from prisma.models import ChatMessage as PrismaChatMessage
|
||||||
|
from prisma.models import ChatSession as PrismaChatSession
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
|
from backend.util import json
|
||||||
from backend.util.exceptions import RedisError
|
from backend.util.exceptions import RedisError
|
||||||
|
|
||||||
|
from . import db as chat_db
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -46,6 +50,7 @@ class Usage(BaseModel):
|
|||||||
class ChatSession(BaseModel):
|
class ChatSession(BaseModel):
|
||||||
session_id: str
|
session_id: str
|
||||||
user_id: str | None
|
user_id: str | None
|
||||||
|
title: str | None = None
|
||||||
messages: list[ChatMessage]
|
messages: list[ChatMessage]
|
||||||
usage: list[Usage]
|
usage: list[Usage]
|
||||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||||
@@ -59,6 +64,7 @@ class ChatSession(BaseModel):
|
|||||||
return ChatSession(
|
return ChatSession(
|
||||||
session_id=str(uuid.uuid4()),
|
session_id=str(uuid.uuid4()),
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
title=None,
|
||||||
messages=[],
|
messages=[],
|
||||||
usage=[],
|
usage=[],
|
||||||
credentials={},
|
credentials={},
|
||||||
@@ -66,6 +72,85 @@ class ChatSession(BaseModel):
|
|||||||
updated_at=datetime.now(UTC),
|
updated_at=datetime.now(UTC),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_prisma(
|
||||||
|
prisma_session: PrismaChatSession,
|
||||||
|
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||||
|
) -> "ChatSession":
|
||||||
|
"""Convert Prisma models to Pydantic ChatSession."""
|
||||||
|
messages = []
|
||||||
|
if prisma_messages:
|
||||||
|
for msg in prisma_messages:
|
||||||
|
tool_calls = None
|
||||||
|
if msg.toolCalls:
|
||||||
|
tool_calls = (
|
||||||
|
json.loads(msg.toolCalls)
|
||||||
|
if isinstance(msg.toolCalls, str)
|
||||||
|
else msg.toolCalls
|
||||||
|
)
|
||||||
|
|
||||||
|
function_call = None
|
||||||
|
if msg.functionCall:
|
||||||
|
function_call = (
|
||||||
|
json.loads(msg.functionCall)
|
||||||
|
if isinstance(msg.functionCall, str)
|
||||||
|
else msg.functionCall
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role=msg.role,
|
||||||
|
content=msg.content,
|
||||||
|
name=msg.name,
|
||||||
|
tool_call_id=msg.toolCallId,
|
||||||
|
refusal=msg.refusal,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
function_call=function_call,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse JSON fields from Prisma
|
||||||
|
credentials = (
|
||||||
|
json.loads(prisma_session.credentials)
|
||||||
|
if isinstance(prisma_session.credentials, str)
|
||||||
|
else prisma_session.credentials or {}
|
||||||
|
)
|
||||||
|
successful_agent_runs = (
|
||||||
|
json.loads(prisma_session.successfulAgentRuns)
|
||||||
|
if isinstance(prisma_session.successfulAgentRuns, str)
|
||||||
|
else prisma_session.successfulAgentRuns or {}
|
||||||
|
)
|
||||||
|
successful_agent_schedules = (
|
||||||
|
json.loads(prisma_session.successfulAgentSchedules)
|
||||||
|
if isinstance(prisma_session.successfulAgentSchedules, str)
|
||||||
|
else prisma_session.successfulAgentSchedules or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate usage from token counts
|
||||||
|
usage = []
|
||||||
|
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
|
||||||
|
usage.append(
|
||||||
|
Usage(
|
||||||
|
prompt_tokens=prisma_session.totalPromptTokens or 0,
|
||||||
|
completion_tokens=prisma_session.totalCompletionTokens or 0,
|
||||||
|
total_tokens=(prisma_session.totalPromptTokens or 0)
|
||||||
|
+ (prisma_session.totalCompletionTokens or 0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatSession(
|
||||||
|
session_id=prisma_session.id,
|
||||||
|
user_id=prisma_session.userId,
|
||||||
|
title=prisma_session.title,
|
||||||
|
messages=messages,
|
||||||
|
usage=usage,
|
||||||
|
credentials=credentials,
|
||||||
|
started_at=prisma_session.createdAt,
|
||||||
|
updated_at=prisma_session.updatedAt,
|
||||||
|
successful_agent_runs=successful_agent_runs,
|
||||||
|
successful_agent_schedules=successful_agent_schedules,
|
||||||
|
)
|
||||||
|
|
||||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||||
messages = []
|
messages = []
|
||||||
for message in self.messages:
|
for message in self.messages:
|
||||||
@@ -155,50 +240,234 @@ class ChatSession(BaseModel):
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
async def get_chat_session(
|
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||||
session_id: str,
|
"""Get a chat session from Redis cache."""
|
||||||
user_id: str | None,
|
|
||||||
) -> ChatSession | None:
|
|
||||||
"""Get a chat session by ID."""
|
|
||||||
redis_key = f"chat:session:{session_id}"
|
redis_key = f"chat:session:{session_id}"
|
||||||
async_redis = await get_redis_async()
|
async_redis = await get_redis_async()
|
||||||
|
|
||||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||||
|
|
||||||
if raw_session is None:
|
if raw_session is None:
|
||||||
logger.warning(f"Session {session_id} not found in Redis")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session = ChatSession.model_validate_json(raw_session)
|
session = ChatSession.model_validate_json(raw_session)
|
||||||
|
logger.info(
|
||||||
|
f"Loading session {session_id} from cache: "
|
||||||
|
f"message_count={len(session.messages)}, "
|
||||||
|
f"roles={[m.role for m in session.messages]}"
|
||||||
|
)
|
||||||
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||||
|
|
||||||
|
|
||||||
|
async def _cache_session(session: ChatSession) -> None:
|
||||||
|
"""Cache a chat session in Redis."""
|
||||||
|
redis_key = f"chat:session:{session.session_id}"
|
||||||
|
async_redis = await get_redis_async()
|
||||||
|
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||||
|
"""Get a chat session from the database."""
|
||||||
|
prisma_session = await chat_db.get_chat_session(session_id)
|
||||||
|
if not prisma_session:
|
||||||
|
return None
|
||||||
|
|
||||||
|
messages = prisma_session.Messages
|
||||||
|
logger.info(
|
||||||
|
f"Loading session {session_id} from DB: "
|
||||||
|
f"has_messages={messages is not None}, "
|
||||||
|
f"message_count={len(messages) if messages else 0}, "
|
||||||
|
f"roles={[m.role for m in messages] if messages else []}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatSession.from_prisma(prisma_session, messages)
|
||||||
|
|
||||||
|
|
||||||
|
async def _save_session_to_db(
|
||||||
|
session: ChatSession, existing_message_count: int
|
||||||
|
) -> None:
|
||||||
|
"""Save or update a chat session in the database."""
|
||||||
|
# Check if session exists in DB
|
||||||
|
existing = await chat_db.get_chat_session(session.session_id)
|
||||||
|
|
||||||
|
if not existing:
|
||||||
|
# Create new session
|
||||||
|
await chat_db.create_chat_session(
|
||||||
|
session_id=session.session_id,
|
||||||
|
user_id=session.user_id,
|
||||||
|
)
|
||||||
|
existing_message_count = 0
|
||||||
|
|
||||||
|
# Calculate total tokens from usage
|
||||||
|
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||||
|
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||||
|
|
||||||
|
# Update session metadata
|
||||||
|
await chat_db.update_chat_session(
|
||||||
|
session_id=session.session_id,
|
||||||
|
credentials=session.credentials,
|
||||||
|
successful_agent_runs=session.successful_agent_runs,
|
||||||
|
successful_agent_schedules=session.successful_agent_schedules,
|
||||||
|
total_prompt_tokens=total_prompt,
|
||||||
|
total_completion_tokens=total_completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add new messages (only those after existing count)
|
||||||
|
new_messages = session.messages[existing_message_count:]
|
||||||
|
if new_messages:
|
||||||
|
messages_data = []
|
||||||
|
for msg in new_messages:
|
||||||
|
messages_data.append(
|
||||||
|
{
|
||||||
|
"role": msg.role,
|
||||||
|
"content": msg.content,
|
||||||
|
"name": msg.name,
|
||||||
|
"tool_call_id": msg.tool_call_id,
|
||||||
|
"refusal": msg.refusal,
|
||||||
|
"tool_calls": msg.tool_calls,
|
||||||
|
"function_call": msg.function_call,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||||
|
f"roles={[m['role'] for m in messages_data]}, "
|
||||||
|
f"start_sequence={existing_message_count}"
|
||||||
|
)
|
||||||
|
await chat_db.add_chat_messages_batch(
|
||||||
|
session_id=session.session_id,
|
||||||
|
messages=messages_data,
|
||||||
|
start_sequence=existing_message_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chat_session(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> ChatSession | None:
|
||||||
|
"""Get a chat session by ID.
|
||||||
|
|
||||||
|
Checks Redis cache first, falls back to database if not found.
|
||||||
|
Caches database results back to Redis.
|
||||||
|
"""
|
||||||
|
# Try cache first
|
||||||
|
try:
|
||||||
|
session = await _get_session_from_cache(session_id)
|
||||||
|
if session:
|
||||||
|
# Verify user ownership
|
||||||
|
if session.user_id is not None and session.user_id != user_id:
|
||||||
|
logger.warning(
|
||||||
|
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return session
|
||||||
|
except RedisError:
|
||||||
|
logger.warning(f"Cache error for session {session_id}, trying database")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||||
|
|
||||||
|
# Fall back to database
|
||||||
|
logger.info(f"Session {session_id} not in cache, checking database")
|
||||||
|
session = await _get_session_from_db(session_id)
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
logger.warning(f"Session {session_id} not found in cache or database")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Verify user ownership
|
||||||
if session.user_id is not None and session.user_id != user_id:
|
if session.user_id is not None and session.user_id != user_id:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Cache the session from DB
|
||||||
|
try:
|
||||||
|
await _cache_session(session)
|
||||||
|
logger.info(f"Cached session {session_id} from database")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def upsert_chat_session(
|
async def upsert_chat_session(
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
) -> ChatSession:
|
) -> ChatSession:
|
||||||
"""Update a chat session with the given messages."""
|
"""Update a chat session in both cache and database."""
|
||||||
|
# Get existing message count from DB for incremental saves
|
||||||
redis_key = f"chat:session:{session.session_id}"
|
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||||
|
session.session_id
|
||||||
async_redis = await get_redis_async()
|
|
||||||
resp = await async_redis.setex(
|
|
||||||
redis_key, config.session_ttl, session.model_dump_json()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not resp:
|
# Save to database
|
||||||
|
try:
|
||||||
|
await _save_session_to_db(session, existing_message_count)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save session {session.session_id} to database: {e}")
|
||||||
|
# Continue to cache even if DB fails
|
||||||
|
|
||||||
|
# Save to cache
|
||||||
|
try:
|
||||||
|
await _cache_session(session)
|
||||||
|
except Exception as e:
|
||||||
raise RedisError(
|
raise RedisError(
|
||||||
f"Failed to persist chat session {session.session_id} to Redis: {resp}"
|
f"Failed to persist chat session {session.session_id} to Redis: {e}"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def create_chat_session(user_id: str | None) -> ChatSession:
|
||||||
|
"""Create a new chat session and persist it."""
|
||||||
|
session = ChatSession.new(user_id)
|
||||||
|
|
||||||
|
# Create in database first
|
||||||
|
try:
|
||||||
|
await chat_db.create_chat_session(
|
||||||
|
session_id=session.session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create session in database: {e}")
|
||||||
|
# Continue even if DB fails - cache will still work
|
||||||
|
|
||||||
|
# Cache the session
|
||||||
|
try:
|
||||||
|
await _cache_session(session)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to cache new session: {e}")
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_sessions(
|
||||||
|
user_id: str,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[ChatSession]:
|
||||||
|
"""Get all chat sessions for a user from the database."""
|
||||||
|
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||||
|
|
||||||
|
sessions = []
|
||||||
|
for prisma_session in prisma_sessions:
|
||||||
|
# Convert without messages for listing (lighter weight)
|
||||||
|
sessions.append(ChatSession.from_prisma(prisma_session, None))
|
||||||
|
|
||||||
|
return sessions
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_chat_session(session_id: str) -> bool:
|
||||||
|
"""Delete a chat session from both cache and database."""
|
||||||
|
# Delete from cache
|
||||||
|
try:
|
||||||
|
redis_key = f"chat:session:{session_id}"
|
||||||
|
async_redis = await get_redis_async()
|
||||||
|
await async_redis.delete(redis_key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
|
||||||
|
|
||||||
|
# Delete from database
|
||||||
|
return await chat_db.delete_chat_session(session_id)
|
||||||
|
|||||||
@@ -68,3 +68,50 @@ async def test_chatsession_redis_storage_user_id_mismatch():
|
|||||||
s2 = await get_chat_session(s.session_id, None)
|
s2 = await get_chat_session(s.session_id, None)
|
||||||
|
|
||||||
assert s2 is None
|
assert s2 is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_chatsession_db_storage():
|
||||||
|
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
|
|
||||||
|
# Create session with messages including assistant message
|
||||||
|
s = ChatSession.new(user_id=None)
|
||||||
|
s.messages = messages # Contains user, assistant, and tool messages
|
||||||
|
assert s.session_id is not None, "Session id is not set"
|
||||||
|
# Upsert to save to both cache and DB
|
||||||
|
s = await upsert_chat_session(s)
|
||||||
|
|
||||||
|
# Clear the Redis cache to force DB load
|
||||||
|
redis_key = f"chat:session:{s.session_id}"
|
||||||
|
async_redis = await get_redis_async()
|
||||||
|
await async_redis.delete(redis_key)
|
||||||
|
|
||||||
|
# Load from DB (cache was cleared)
|
||||||
|
s2 = await get_chat_session(
|
||||||
|
session_id=s.session_id,
|
||||||
|
user_id=s.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert s2 is not None, "Session not found after loading from DB"
|
||||||
|
assert len(s2.messages) == len(
|
||||||
|
s.messages
|
||||||
|
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
|
||||||
|
|
||||||
|
# Verify all roles are present
|
||||||
|
roles = [m.role for m in s2.messages]
|
||||||
|
assert "user" in roles, f"User message missing. Roles found: {roles}"
|
||||||
|
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
|
||||||
|
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
|
||||||
|
|
||||||
|
# Verify message content
|
||||||
|
for orig, loaded in zip(s.messages, s2.messages):
|
||||||
|
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
|
||||||
|
assert (
|
||||||
|
orig.content == loaded.content
|
||||||
|
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
|
||||||
|
if orig.tool_calls:
|
||||||
|
assert (
|
||||||
|
loaded.tool_calls is not None
|
||||||
|
), f"Tool calls missing for {orig.role} message"
|
||||||
|
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||||
|
|||||||
@@ -1,12 +1,80 @@
|
|||||||
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find and set up AutoGPT agents to solve their business problems.
|
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find, create, and set up AutoGPT agents to solve their business problems.
|
||||||
|
|
||||||
Here are the functions available to you:
|
Here are the functions available to you:
|
||||||
|
|
||||||
<functions>
|
<functions>
|
||||||
1. **find_agent** - Search for agents that solve the user's problem
|
**Understanding & Discovery:**
|
||||||
2. **run_agent** - Run or schedule an agent (automatically handles setup)
|
1. **add_understanding** - Save information about the user's business context (use this as you learn about them)
|
||||||
|
2. **find_agent** - Search the marketplace for pre-built agents that solve the user's problem
|
||||||
|
3. **find_library_agent** - Search the user's personal library of saved agents
|
||||||
|
4. **find_block** - Search for individual blocks (building components for agents)
|
||||||
|
5. **search_platform_docs** - Search AutoGPT documentation for help
|
||||||
|
|
||||||
|
**Agent Creation & Editing:**
|
||||||
|
6. **create_agent** - Create a new custom agent from scratch based on user requirements
|
||||||
|
7. **edit_agent** - Modify an existing agent (add/remove blocks, change configuration)
|
||||||
|
|
||||||
|
**Execution & Output:**
|
||||||
|
8. **run_agent** - Run or schedule an agent (automatically handles setup)
|
||||||
|
9. **run_block** - Run a single block directly without creating an agent
|
||||||
|
10. **agent_output** - Get the output/results from a running or completed agent execution
|
||||||
</functions>
|
</functions>
|
||||||
|
|
||||||
|
## ALWAYS GET THE USER'S NAME
|
||||||
|
|
||||||
|
**This is critical:** If you don't know the user's name, ask for it in your first response. Use a friendly, natural approach:
|
||||||
|
- "Hi! I'm Otto. What's your name?"
|
||||||
|
- "Hey there! Before we dive in, what should I call you?"
|
||||||
|
|
||||||
|
Once you have their name, immediately save it with `add_understanding(user_name="...")` and use it throughout the conversation.
|
||||||
|
|
||||||
|
## BUILDING USER UNDERSTANDING
|
||||||
|
|
||||||
|
**If no User Business Context is provided below**, gather information naturally during conversation - don't interrogate them.
|
||||||
|
|
||||||
|
**Key information to gather (in priority order):**
|
||||||
|
1. Their name (ALWAYS first if unknown)
|
||||||
|
2. Their job title and role
|
||||||
|
3. Their business/company and industry
|
||||||
|
4. Pain points and what they want to automate
|
||||||
|
5. Tools they currently use
|
||||||
|
|
||||||
|
**How to gather this information:**
|
||||||
|
- Ask naturally as part of helping them (e.g., "What's your role?" or "What industry are you in?")
|
||||||
|
- When they share information, immediately save it using `add_understanding`
|
||||||
|
- Don't ask all questions at once - spread them across the conversation
|
||||||
|
- Prioritize understanding their immediate problem first
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```
|
||||||
|
User: "I need help automating my social media"
|
||||||
|
Otto: I can help with that! I'm Otto - what's your name?
|
||||||
|
User: "I'm Sarah"
|
||||||
|
Otto: [calls add_understanding with user_name="Sarah"]
|
||||||
|
Nice to meet you, Sarah! What's your role - are you a social media manager or business owner?
|
||||||
|
User: "I'm the marketing director at a fintech startup"
|
||||||
|
Otto: [calls add_understanding with job_title="Marketing Director", industry="fintech", business_size="startup"]
|
||||||
|
Great! Let me find social media automation agents for you.
|
||||||
|
[calls find_agent with query="social media automation marketing"]
|
||||||
|
```
|
||||||
|
|
||||||
|
## WHEN TO USE WHICH TOOL
|
||||||
|
|
||||||
|
**Finding existing agents:**
|
||||||
|
- `find_agent` - Search the marketplace for pre-built agents others have created
|
||||||
|
- `find_library_agent` - Search agents the user has already saved to their library
|
||||||
|
|
||||||
|
**Creating/editing agents:**
|
||||||
|
- `create_agent` - When user wants a custom agent that doesn't exist, or has specific requirements
|
||||||
|
- `edit_agent` - When user wants to modify an existing agent (change inputs, add blocks, etc.)
|
||||||
|
|
||||||
|
**Running agents:**
|
||||||
|
- `run_agent` - To execute an agent (handles credentials and inputs automatically)
|
||||||
|
- `agent_output` - To check the results of a running or completed agent execution
|
||||||
|
|
||||||
|
**Direct execution:**
|
||||||
|
- `run_block` - Run a single block directly without needing a full agent
|
||||||
|
|
||||||
## HOW run_agent WORKS
|
## HOW run_agent WORKS
|
||||||
|
|
||||||
The `run_agent` tool automatically handles the entire setup flow:
|
The `run_agent` tool automatically handles the entire setup flow:
|
||||||
@@ -21,49 +89,61 @@ Parameters:
|
|||||||
- `use_defaults`: Set to `true` to run with default values (only after user confirms)
|
- `use_defaults`: Set to `true` to run with default values (only after user confirms)
|
||||||
- `schedule_name` + `cron`: For scheduled execution
|
- `schedule_name` + `cron`: For scheduled execution
|
||||||
|
|
||||||
|
## HOW create_agent WORKS
|
||||||
|
|
||||||
|
Use `create_agent` when the user wants to build a custom automation:
|
||||||
|
- Describe what the agent should do
|
||||||
|
- The tool will create the agent structure with appropriate blocks
|
||||||
|
- Returns the agent ID for further editing or running
|
||||||
|
|
||||||
|
## HOW agent_output WORKS
|
||||||
|
|
||||||
|
Use `agent_output` to get results from agent executions:
|
||||||
|
- Pass the execution_id from a run_agent response
|
||||||
|
- Returns the current status and any outputs produced
|
||||||
|
- Useful for checking if an agent has completed and what it produced
|
||||||
|
|
||||||
## WORKFLOW
|
## WORKFLOW
|
||||||
|
|
||||||
1. **find_agent** - Search for agents that solve the user's problem
|
1. **Get their name** - If unknown, ask for it first
|
||||||
2. **run_agent** (first call, no inputs) - Get available inputs for the agent
|
2. **Understand context** - Ask 1-2 questions about their problem while helping
|
||||||
3. **Ask user** what values they want to use OR if they want to use defaults
|
3. **Find or create** - Use find_agent for existing solutions, create_agent for custom needs
|
||||||
4. **run_agent** (second call) - Either with `inputs={...}` or `use_defaults=true`
|
4. **Set up and run** - Use run_agent to execute, agent_output to get results
|
||||||
|
|
||||||
## YOUR APPROACH
|
## YOUR APPROACH
|
||||||
|
|
||||||
**Step 1: Understand the Problem**
|
**Step 1: Greet and Identify**
|
||||||
|
- If you don't know their name, ask for it
|
||||||
|
- Be friendly and conversational
|
||||||
|
|
||||||
|
**Step 2: Understand the Problem**
|
||||||
- Ask maximum 1-2 targeted questions
|
- Ask maximum 1-2 targeted questions
|
||||||
- Focus on: What business problem are they solving?
|
- Focus on: What business problem are they solving?
|
||||||
- Move quickly to searching for solutions
|
- If they want to create/edit an agent, understand what it should do
|
||||||
|
|
||||||
**Step 2: Find Agents**
|
**Step 3: Find or Create**
|
||||||
- Use `find_agent` immediately with relevant keywords
|
- For existing solutions: Use `find_agent` with relevant keywords
|
||||||
- Suggest the best option from search results
|
- For custom needs: Use `create_agent` with their requirements
|
||||||
- Explain briefly how it solves their problem
|
- For modifications: Use `edit_agent` on an existing agent
|
||||||
|
|
||||||
**Step 3: Get Agent Inputs**
|
**Step 4: Execute**
|
||||||
- Call `run_agent(username_agent_slug="creator/agent-name")` without inputs
|
- Call `run_agent` without inputs first to see what's available
|
||||||
- This returns the available inputs (required and optional)
|
- Ask user what values they want or if defaults are okay
|
||||||
- Present these to the user and ask what values they want
|
- Call `run_agent` again with inputs or `use_defaults=true`
|
||||||
|
- Use `agent_output` to check results when needed
|
||||||
|
|
||||||
**Step 4: Run with User's Choice**
|
## USING add_understanding
|
||||||
- If user provides values: `run_agent(username_agent_slug="...", inputs={...})`
|
|
||||||
- If user says "use defaults": `run_agent(username_agent_slug="...", use_defaults=true)`
|
|
||||||
- On success, share the agent link with the user
|
|
||||||
|
|
||||||
**For Scheduled Execution:**
|
Call `add_understanding` whenever you learn something about the user:
|
||||||
- Add `schedule_name` and `cron` parameters
|
|
||||||
- Example: `run_agent(username_agent_slug="...", inputs={...}, schedule_name="Daily Report", cron="0 9 * * *")`
|
|
||||||
|
|
||||||
## FUNCTION CALL FORMAT
|
**User info:** `user_name`, `job_title`
|
||||||
|
**Business:** `business_name`, `industry`, `business_size` (1-10, 11-50, 51-200, 201-1000, 1000+), `user_role` (decision maker, implementer, end user)
|
||||||
|
**Processes:** `key_workflows` (array), `daily_activities` (array)
|
||||||
|
**Pain points:** `pain_points` (array), `bottlenecks` (array), `manual_tasks` (array), `automation_goals` (array)
|
||||||
|
**Tools:** `current_software` (array), `existing_automation` (array)
|
||||||
|
**Other:** `additional_notes`
|
||||||
|
|
||||||
To call a function, use this exact format:
|
Example: `add_understanding(user_name="Sarah", job_title="Marketing Director", industry="fintech")`
|
||||||
`<function_call>function_name(parameter="value")</function_call>`
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
- `<function_call>find_agent(query="social media automation")</function_call>`
|
|
||||||
- `<function_call>run_agent(username_agent_slug="creator/agent-name")</function_call>` (get inputs)
|
|
||||||
- `<function_call>run_agent(username_agent_slug="creator/agent-name", inputs={"topic": "AI news"})</function_call>`
|
|
||||||
- `<function_call>run_agent(username_agent_slug="creator/agent-name", use_defaults=true)</function_call>`
|
|
||||||
|
|
||||||
## KEY RULES
|
## KEY RULES
|
||||||
|
|
||||||
@@ -73,8 +153,12 @@ Examples:
|
|||||||
- Don't run agents without first showing available inputs to the user
|
- Don't run agents without first showing available inputs to the user
|
||||||
- Don't use `use_defaults=true` without user explicitly confirming
|
- Don't use `use_defaults=true` without user explicitly confirming
|
||||||
- Don't write responses longer than 3 sentences
|
- Don't write responses longer than 3 sentences
|
||||||
|
- Don't interrogate users with many questions - gather info naturally
|
||||||
|
|
||||||
**What You DO:**
|
**What You DO:**
|
||||||
|
- ALWAYS ask for user's name if you don't have it
|
||||||
|
- Save user information with `add_understanding` as you learn it
|
||||||
|
- Use their name when addressing them
|
||||||
- Always call run_agent first without inputs to see what's available
|
- Always call run_agent first without inputs to see what's available
|
||||||
- Ask user what values they want OR if they want to use defaults
|
- Ask user what values they want OR if they want to use defaults
|
||||||
- Keep all responses to maximum 3 sentences
|
- Keep all responses to maximum 3 sentences
|
||||||
@@ -87,18 +171,22 @@ Examples:
|
|||||||
## RESPONSE STRUCTURE
|
## RESPONSE STRUCTURE
|
||||||
|
|
||||||
Before responding, wrap your analysis in <thinking> tags to systematically plan your approach:
|
Before responding, wrap your analysis in <thinking> tags to systematically plan your approach:
|
||||||
|
- Check if you know the user's name - if not, ask for it
|
||||||
|
- Check if you have user context - if not, plan to gather some naturally
|
||||||
- Extract the key business problem or request from the user's message
|
- Extract the key business problem or request from the user's message
|
||||||
- Determine what function call (if any) you need to make next
|
- Determine what function call (if any) you need to make next
|
||||||
- Plan your response to stay under the 3-sentence maximum
|
- Plan your response to stay under the 3-sentence maximum
|
||||||
|
|
||||||
Example interaction:
|
Example interaction:
|
||||||
```
|
```
|
||||||
User: "Run the AI news agent for me"
|
User: "Hi, I want to build an agent that monitors my competitors"
|
||||||
Otto: <function_call>run_agent(username_agent_slug="autogpt/ai-news")</function_call>
|
Otto: <thinking>I don't know this user's name. I should ask for it while acknowledging their request.</thinking>
|
||||||
[Tool returns: Agent accepts inputs - Required: topic. Optional: num_articles (default: 5)]
|
Hi! I'm Otto and I'd love to help you build a competitor monitoring agent. What's your name?
|
||||||
Otto: The AI News agent needs a topic. What topic would you like news about, or should I use the defaults?
|
User: "I'm Mike"
|
||||||
User: "Use defaults"
|
Otto: [calls add_understanding with user_name="Mike"]
|
||||||
Otto: <function_call>run_agent(username_agent_slug="autogpt/ai-news", use_defaults=true)</function_call>
|
<thinking>Now I know Mike wants competitor monitoring. I should search for existing agents first.</thinking>
|
||||||
|
Great to meet you, Mike! Let me search for competitor monitoring agents.
|
||||||
|
[calls find_agent with query="competitor monitoring analysis"]
|
||||||
```
|
```
|
||||||
|
|
||||||
KEEP ANSWERS TO 3 SENTENCES
|
KEEP ANSWERS TO 3 SENTENCES
|
||||||
|
|||||||
@@ -0,0 +1,155 @@
|
|||||||
|
You are Otto, an AI Co-Pilot helping new users get started with AutoGPT, an AI Business Automation platform. Your mission is to welcome them, learn about their needs, and help them run their first successful agent.
|
||||||
|
|
||||||
|
Here are the functions available to you:
|
||||||
|
|
||||||
|
<functions>
|
||||||
|
**Understanding & Discovery:**
|
||||||
|
1. **add_understanding** - Save information about the user's business context (use this as you learn about them)
|
||||||
|
2. **find_agent** - Search the marketplace for pre-built agents that solve the user's problem
|
||||||
|
3. **find_library_agent** - Search the user's personal library of saved agents
|
||||||
|
4. **find_block** - Search for individual blocks (building components for agents)
|
||||||
|
5. **search_platform_docs** - Search AutoGPT documentation for help
|
||||||
|
|
||||||
|
**Agent Creation & Editing:**
|
||||||
|
6. **create_agent** - Create a new custom agent from scratch based on user requirements
|
||||||
|
7. **edit_agent** - Modify an existing agent (add/remove blocks, change configuration)
|
||||||
|
|
||||||
|
**Execution & Output:**
|
||||||
|
8. **run_agent** - Run or schedule an agent (automatically handles setup)
|
||||||
|
9. **run_block** - Run a single block directly without creating an agent
|
||||||
|
10. **agent_output** - Get the output/results from a running or completed agent execution
|
||||||
|
</functions>
|
||||||
|
|
||||||
|
## YOUR ONBOARDING MISSION
|
||||||
|
|
||||||
|
You are guiding a new user through their first experience with AutoGPT. Your goal is to:
|
||||||
|
1. Welcome them warmly and get their name
|
||||||
|
2. Learn about them and their business
|
||||||
|
3. Find or create an agent that solves a real problem for them
|
||||||
|
4. Get that agent running successfully
|
||||||
|
5. Celebrate their success and point them to next steps
|
||||||
|
|
||||||
|
## PHASE 1: WELCOME & INTRODUCTION
|
||||||
|
|
||||||
|
**Start every conversation by:**
|
||||||
|
- Giving a warm, friendly greeting
|
||||||
|
- Introducing yourself as Otto, their AI assistant
|
||||||
|
- Asking for their name immediately
|
||||||
|
|
||||||
|
**Example opening:**
|
||||||
|
```
|
||||||
|
Hi! I'm Otto, your AI assistant. Welcome to AutoGPT! I'm here to help you set up your first automation. What's your name?
|
||||||
|
```
|
||||||
|
|
||||||
|
Once you have their name, save it immediately with `add_understanding(user_name="...")` and use it throughout.
|
||||||
|
|
||||||
|
## PHASE 2: DISCOVERY
|
||||||
|
|
||||||
|
**After getting their name, learn about them:**
|
||||||
|
- What's their role/job title?
|
||||||
|
- What industry/business are they in?
|
||||||
|
- What's one thing they'd love to automate?
|
||||||
|
|
||||||
|
**Keep it conversational - don't interrogate. Example:**
|
||||||
|
```
|
||||||
|
Nice to meet you, Sarah! What do you do for work, and what's one task you wish you could automate?
|
||||||
|
```
|
||||||
|
|
||||||
|
Save everything you learn with `add_understanding`.
|
||||||
|
|
||||||
|
## PHASE 3: FIND OR CREATE AN AGENT
|
||||||
|
|
||||||
|
**Once you understand their need:**
|
||||||
|
- Search for existing agents with `find_agent`
|
||||||
|
- Present the best match and explain how it helps them
|
||||||
|
- If nothing fits, offer to create a custom agent with `create_agent`
|
||||||
|
|
||||||
|
**Be enthusiastic about the solution:**
|
||||||
|
```
|
||||||
|
I found a great agent for you! The "Social Media Scheduler" can automatically post to your accounts on a schedule. Want to try it?
|
||||||
|
```
|
||||||
|
|
||||||
|
## PHASE 4: SETUP & RUN
|
||||||
|
|
||||||
|
**Guide them through running the agent:**
|
||||||
|
1. Call `run_agent` without inputs first to see what's needed
|
||||||
|
2. Explain each input in simple terms
|
||||||
|
3. Ask what values they want to use
|
||||||
|
4. Run the agent with their inputs or defaults
|
||||||
|
|
||||||
|
**Don't mention credentials** - the UI handles that automatically.
|
||||||
|
|
||||||
|
## PHASE 5: CELEBRATE & HANDOFF
|
||||||
|
|
||||||
|
**After successful execution:**
|
||||||
|
- Congratulate them on their first automation!
|
||||||
|
- Tell them where to find this agent (their Library)
|
||||||
|
- Mention they can explore more agents in the Marketplace
|
||||||
|
- Offer to help with anything else
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```
|
||||||
|
You did it! Your first agent is running. You can find it anytime in your Library. Ready to explore more automations?
|
||||||
|
```
|
||||||
|
|
||||||
|
## KEY RULES
|
||||||
|
|
||||||
|
**What You DON'T Do:**
|
||||||
|
- Don't help with login (frontend handles this)
|
||||||
|
- Don't mention credentials (UI handles automatically)
|
||||||
|
- Don't run agents without showing inputs first
|
||||||
|
- Don't use `use_defaults=true` without explicit confirmation
|
||||||
|
- Don't write responses longer than 3 sentences
|
||||||
|
- Don't overwhelm with too many questions at once
|
||||||
|
|
||||||
|
**What You DO:**
|
||||||
|
- ALWAYS get the user's name first
|
||||||
|
- Be warm, encouraging, and celebratory
|
||||||
|
- Save info with `add_understanding` as you learn it
|
||||||
|
- Use their name when addressing them
|
||||||
|
- Keep responses to maximum 3 sentences
|
||||||
|
- Make them feel successful at each step
|
||||||
|
|
||||||
|
## USING add_understanding
|
||||||
|
|
||||||
|
Save information as you learn it:
|
||||||
|
|
||||||
|
**User info:** `user_name`, `job_title`
|
||||||
|
**Business:** `business_name`, `industry`, `business_size`, `user_role`
|
||||||
|
**Pain points:** `pain_points`, `manual_tasks`, `automation_goals`
|
||||||
|
**Tools:** `current_software`
|
||||||
|
|
||||||
|
Example: `add_understanding(user_name="Sarah", job_title="Marketing Manager", automation_goals=["social media scheduling"])`
|
||||||
|
|
||||||
|
## HOW run_agent WORKS
|
||||||
|
|
||||||
|
1. **First call** (no inputs) → Shows available inputs
|
||||||
|
2. **Credentials** → UI handles automatically (don't mention)
|
||||||
|
3. **Execution** → Run with `inputs={...}` or `use_defaults=true`
|
||||||
|
|
||||||
|
## RESPONSE STRUCTURE
|
||||||
|
|
||||||
|
Before responding, plan your approach in <thinking> tags:
|
||||||
|
- What phase am I in? (Welcome/Discovery/Find/Setup/Celebrate)
|
||||||
|
- Do I know their name? If not, ask for it
|
||||||
|
- What's the next step to move them forward?
|
||||||
|
- Keep response under 3 sentences
|
||||||
|
|
||||||
|
**Example flow:**
|
||||||
|
```
|
||||||
|
User: "Hi"
|
||||||
|
Otto: <thinking>Phase 1 - I need to welcome them and get their name.</thinking>
|
||||||
|
Hi! I'm Otto, welcome to AutoGPT! I'm here to help you set up your first automation - what's your name?
|
||||||
|
|
||||||
|
User: "I'm Alex"
|
||||||
|
Otto: [calls add_understanding with user_name="Alex"]
|
||||||
|
<thinking>Got their name. Phase 2 - learn about them.</thinking>
|
||||||
|
Great to meet you, Alex! What do you do for work, and what's one task you'd love to automate?
|
||||||
|
|
||||||
|
User: "I run an e-commerce store and spend hours on customer support emails"
|
||||||
|
Otto: [calls add_understanding with industry="e-commerce", pain_points=["customer support emails"]]
|
||||||
|
<thinking>Phase 3 - search for agents.</thinking>
|
||||||
|
[calls find_agent with query="customer support email automation"]
|
||||||
|
```
|
||||||
|
|
||||||
|
KEEP ANSWERS TO 3 SENTENCES - Be warm, helpful, and focused on their success!
|
||||||
@@ -26,6 +26,14 @@ router = APIRouter(
|
|||||||
# ========== Request/Response Models ==========
|
# ========== Request/Response Models ==========
|
||||||
|
|
||||||
|
|
||||||
|
class StreamChatRequest(BaseModel):
|
||||||
|
"""Request model for streaming chat with optional context."""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
is_user_message: bool = True
|
||||||
|
context: dict[str, str] | None = None # {url: str, content: str}
|
||||||
|
|
||||||
|
|
||||||
class CreateSessionResponse(BaseModel):
|
class CreateSessionResponse(BaseModel):
|
||||||
"""Response model containing information on a newly created chat session."""
|
"""Response model containing information on a newly created chat session."""
|
||||||
|
|
||||||
@@ -44,9 +52,64 @@ class SessionDetailResponse(BaseModel):
|
|||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class SessionSummaryResponse(BaseModel):
|
||||||
|
"""Response model for a session summary (without messages)."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
created_at: str
|
||||||
|
updated_at: str
|
||||||
|
title: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListSessionsResponse(BaseModel):
|
||||||
|
"""Response model for listing chat sessions."""
|
||||||
|
|
||||||
|
sessions: list[SessionSummaryResponse]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
# ========== Routes ==========
|
# ========== Routes ==========
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/sessions",
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
)
|
||||||
|
async def list_sessions(
|
||||||
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
|
limit: int = Query(default=50, ge=1, le=100),
|
||||||
|
offset: int = Query(default=0, ge=0),
|
||||||
|
) -> ListSessionsResponse:
|
||||||
|
"""
|
||||||
|
List chat sessions for the authenticated user.
|
||||||
|
|
||||||
|
Returns a paginated list of chat sessions belonging to the current user,
|
||||||
|
ordered by most recently updated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The authenticated user's ID.
|
||||||
|
limit: Maximum number of sessions to return (1-100).
|
||||||
|
offset: Number of sessions to skip for pagination.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ListSessionsResponse: List of session summaries and total count.
|
||||||
|
"""
|
||||||
|
sessions = await chat_service.get_user_sessions(user_id, limit, offset)
|
||||||
|
|
||||||
|
return ListSessionsResponse(
|
||||||
|
sessions=[
|
||||||
|
SessionSummaryResponse(
|
||||||
|
id=session.session_id,
|
||||||
|
created_at=session.started_at.isoformat(),
|
||||||
|
updated_at=session.updated_at.isoformat(),
|
||||||
|
title=None, # TODO: Add title support
|
||||||
|
)
|
||||||
|
for session in sessions
|
||||||
|
],
|
||||||
|
total=len(sessions),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/sessions",
|
"/sessions",
|
||||||
)
|
)
|
||||||
@@ -102,26 +165,89 @@ async def get_session(
|
|||||||
session = await chat_service.get_session(session_id, user_id)
|
session = await chat_service.get_session(session_id, user_id)
|
||||||
if not session:
|
if not session:
|
||||||
raise NotFoundError(f"Session {session_id} not found")
|
raise NotFoundError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
messages = [message.model_dump() for message in session.messages]
|
||||||
|
logger.info(
|
||||||
|
f"Returning session {session_id}: "
|
||||||
|
f"message_count={len(messages)}, "
|
||||||
|
f"roles={[m.get('role') for m in messages]}"
|
||||||
|
)
|
||||||
|
|
||||||
return SessionDetailResponse(
|
return SessionDetailResponse(
|
||||||
id=session.session_id,
|
id=session.session_id,
|
||||||
created_at=session.started_at.isoformat(),
|
created_at=session.started_at.isoformat(),
|
||||||
updated_at=session.updated_at.isoformat(),
|
updated_at=session.updated_at.isoformat(),
|
||||||
user_id=session.user_id or None,
|
user_id=session.user_id or None,
|
||||||
messages=[message.model_dump() for message in session.messages],
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/sessions/{session_id}/stream",
|
||||||
|
)
|
||||||
|
async def stream_chat_post(
|
||||||
|
session_id: str,
|
||||||
|
request: StreamChatRequest,
|
||||||
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stream chat responses for a session (POST with context support).
|
||||||
|
|
||||||
|
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||||
|
- Text fragments as they are generated
|
||||||
|
- Tool call UI elements (if invoked)
|
||||||
|
- Tool execution results
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
|
request: Request body containing message, is_user_message, and optional context.
|
||||||
|
user_id: Optional authenticated user ID.
|
||||||
|
Returns:
|
||||||
|
StreamingResponse: SSE-formatted response chunks.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Validate session exists before starting the stream
|
||||||
|
# This prevents errors after the response has already started
|
||||||
|
session = await chat_service.get_session(session_id, user_id)
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
raise NotFoundError(f"Session {session_id} not found. ")
|
||||||
|
if session.user_id is None and user_id is not None:
|
||||||
|
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||||
|
|
||||||
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
|
session_id,
|
||||||
|
request.message,
|
||||||
|
is_user_message=request.is_user_message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
|
context=request.context,
|
||||||
|
):
|
||||||
|
yield chunk.to_sse()
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}/stream",
|
"/sessions/{session_id}/stream",
|
||||||
)
|
)
|
||||||
async def stream_chat(
|
async def stream_chat_get(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
is_user_message: bool = Query(default=True),
|
is_user_message: bool = Query(default=True),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Stream chat responses for a session.
|
Stream chat responses for a session (GET - legacy endpoint).
|
||||||
|
|
||||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||||
- Text fragments as they are generated
|
- Text fragments as they are generated
|
||||||
@@ -193,6 +319,133 @@ async def session_assign_user(
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Onboarding Routes ==========
|
||||||
|
# These routes use a specialized onboarding system prompt
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/onboarding/sessions",
|
||||||
|
)
|
||||||
|
async def create_onboarding_session(
|
||||||
|
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||||
|
) -> CreateSessionResponse:
|
||||||
|
"""
|
||||||
|
Create a new onboarding chat session.
|
||||||
|
|
||||||
|
Initiates a new chat session specifically for user onboarding,
|
||||||
|
using a specialized prompt that guides users through their first
|
||||||
|
experience with AutoGPT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The optional authenticated user ID parsed from the JWT.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CreateSessionResponse: Details of the created onboarding session.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Creating onboarding session with user_id: "
|
||||||
|
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
session = await chat_service.create_chat_session(user_id)
|
||||||
|
|
||||||
|
return CreateSessionResponse(
|
||||||
|
id=session.session_id,
|
||||||
|
created_at=session.started_at.isoformat(),
|
||||||
|
user_id=session.user_id or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/onboarding/sessions/{session_id}",
|
||||||
|
)
|
||||||
|
async def get_onboarding_session(
|
||||||
|
session_id: str,
|
||||||
|
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||||
|
) -> SessionDetailResponse:
|
||||||
|
"""
|
||||||
|
Retrieve the details of an onboarding chat session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The unique identifier for the onboarding session.
|
||||||
|
user_id: The optional authenticated user ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SessionDetailResponse: Details for the requested session.
|
||||||
|
"""
|
||||||
|
session = await chat_service.get_session(session_id, user_id)
|
||||||
|
if not session:
|
||||||
|
raise NotFoundError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
messages = [message.model_dump() for message in session.messages]
|
||||||
|
logger.info(
|
||||||
|
f"Returning onboarding session {session_id}: "
|
||||||
|
f"message_count={len(messages)}, "
|
||||||
|
f"roles={[m.get('role') for m in messages]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return SessionDetailResponse(
|
||||||
|
id=session.session_id,
|
||||||
|
created_at=session.started_at.isoformat(),
|
||||||
|
updated_at=session.updated_at.isoformat(),
|
||||||
|
user_id=session.user_id or None,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/onboarding/sessions/{session_id}/stream",
|
||||||
|
)
|
||||||
|
async def stream_onboarding_chat(
|
||||||
|
session_id: str,
|
||||||
|
request: StreamChatRequest,
|
||||||
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stream onboarding chat responses for a session.
|
||||||
|
|
||||||
|
Uses the specialized onboarding system prompt to guide new users
|
||||||
|
through their first experience with AutoGPT. Streams AI responses
|
||||||
|
in real time over Server-Sent Events (SSE).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The onboarding session identifier.
|
||||||
|
request: Request body containing message and optional context.
|
||||||
|
user_id: Optional authenticated user ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse: SSE-formatted response chunks.
|
||||||
|
"""
|
||||||
|
session = await chat_service.get_session(session_id, user_id)
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
raise NotFoundError(f"Session {session_id} not found.")
|
||||||
|
if session.user_id is None and user_id is not None:
|
||||||
|
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||||
|
|
||||||
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
|
session_id,
|
||||||
|
request.message,
|
||||||
|
is_user_message=request.is_user_message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
context=request.context,
|
||||||
|
prompt_type="onboarding", # Use onboarding system prompt
|
||||||
|
):
|
||||||
|
yield chunk.to_sse()
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ========== Health Check ==========
|
# ========== Health Check ==========
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,16 +7,17 @@ import orjson
|
|||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||||
|
|
||||||
|
from backend.data.understanding import (
|
||||||
|
format_understanding_for_prompt,
|
||||||
|
get_business_understanding,
|
||||||
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
|
from . import db as chat_db
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import ChatMessage, ChatSession, Usage
|
||||||
ChatMessage,
|
from .model import create_chat_session as model_create_chat_session
|
||||||
ChatSession,
|
from .model import get_chat_session, upsert_chat_session
|
||||||
Usage,
|
|
||||||
get_chat_session,
|
|
||||||
upsert_chat_session,
|
|
||||||
)
|
|
||||||
from .response_model import (
|
from .response_model import (
|
||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamEnd,
|
StreamEnd,
|
||||||
@@ -36,15 +37,109 @@ config = ChatConfig()
|
|||||||
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
|
||||||
|
|
||||||
|
async def _is_first_session(user_id: str) -> bool:
|
||||||
|
"""Check if this is the user's first chat session.
|
||||||
|
|
||||||
|
Returns True if the user has 1 or fewer sessions (meaning this is their first).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session_count = await chat_db.get_user_session_count(user_id)
|
||||||
|
return session_count <= 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to check session count for user {user_id}: {e}")
|
||||||
|
return False # Default to non-onboarding if we can't check
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_system_prompt(
|
||||||
|
user_id: str | None, prompt_type: str = "default"
|
||||||
|
) -> str:
|
||||||
|
"""Build the full system prompt including business understanding if available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID for fetching business understanding
|
||||||
|
prompt_type: The type of prompt to load ("default" or "onboarding")
|
||||||
|
If "default" and this is the user's first session, will use "onboarding" instead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The full system prompt with business understanding context if available
|
||||||
|
"""
|
||||||
|
# Auto-detect: if using default prompt and this is user's first session, use onboarding
|
||||||
|
effective_prompt_type = prompt_type
|
||||||
|
if prompt_type == "default" and user_id:
|
||||||
|
if await _is_first_session(user_id):
|
||||||
|
logger.info("First session detected for user, using onboarding prompt")
|
||||||
|
effective_prompt_type = "onboarding"
|
||||||
|
|
||||||
|
# Start with the base system prompt for the specified type
|
||||||
|
base_prompt = config.get_system_prompt_for_type(effective_prompt_type)
|
||||||
|
|
||||||
|
# If user is authenticated, try to fetch their business understanding
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
understanding = await get_business_understanding(user_id)
|
||||||
|
if understanding:
|
||||||
|
context = format_understanding_for_prompt(understanding)
|
||||||
|
if context:
|
||||||
|
return (
|
||||||
|
f"{base_prompt}\n\n---\n\n"
|
||||||
|
f"{context}\n\n"
|
||||||
|
"Use this context to provide more personalized recommendations "
|
||||||
|
"and to better understand the user's business needs when "
|
||||||
|
"suggesting agents and automations."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||||
|
|
||||||
|
return base_prompt
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_session_title(message: str) -> str | None:
|
||||||
|
"""Generate a concise title for a chat session based on the first message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The first user message in the session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A short title (3-6 words) or None if generation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=config.title_model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"Generate a very short title (3-6 words) for a chat conversation "
|
||||||
|
"based on the user's first message. The title should capture the "
|
||||||
|
"main topic or intent. Return ONLY the title, no quotes or punctuation."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": message[:500]}, # Limit input length
|
||||||
|
],
|
||||||
|
max_tokens=20,
|
||||||
|
temperature=0.7,
|
||||||
|
)
|
||||||
|
title = response.choices[0].message.content
|
||||||
|
if title:
|
||||||
|
# Clean up the title
|
||||||
|
title = title.strip().strip("\"'")
|
||||||
|
# Limit length
|
||||||
|
if len(title) > 50:
|
||||||
|
title = title[:47] + "..."
|
||||||
|
return title
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to generate session title: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def create_chat_session(
|
async def create_chat_session(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
) -> ChatSession:
|
) -> ChatSession:
|
||||||
"""
|
"""
|
||||||
Create a new chat session and persist it to the database.
|
Create a new chat session and persist it to the database.
|
||||||
"""
|
"""
|
||||||
session = ChatSession.new(user_id)
|
return await model_create_chat_session(user_id)
|
||||||
# Persist the session immediately so it can be used for streaming
|
|
||||||
return await upsert_chat_session(session)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_session(
|
async def get_session(
|
||||||
@@ -57,6 +152,19 @@ async def get_session(
|
|||||||
return await get_chat_session(session_id, user_id)
|
return await get_chat_session(session_id, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_sessions(
|
||||||
|
user_id: str,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[ChatSession]:
|
||||||
|
"""
|
||||||
|
Get all chat sessions for a user.
|
||||||
|
"""
|
||||||
|
from .model import get_user_sessions as model_get_user_sessions
|
||||||
|
|
||||||
|
return await model_get_user_sessions(user_id, limit, offset)
|
||||||
|
|
||||||
|
|
||||||
async def assign_user_to_session(
|
async def assign_user_to_session(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@@ -78,6 +186,8 @@ async def stream_chat_completion(
|
|||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
retry_count: int = 0,
|
retry_count: int = 0,
|
||||||
session: ChatSession | None = None,
|
session: ChatSession | None = None,
|
||||||
|
context: dict[str, str] | None = None, # {url: str, content: str}
|
||||||
|
prompt_type: str = "default",
|
||||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
"""Main entry point for streaming chat completions with database handling.
|
"""Main entry point for streaming chat completions with database handling.
|
||||||
|
|
||||||
@@ -89,6 +199,7 @@ async def stream_chat_completion(
|
|||||||
user_message: User's input message
|
user_message: User's input message
|
||||||
user_id: User ID for authentication (None for anonymous)
|
user_id: User ID for authentication (None for anonymous)
|
||||||
session: Optional pre-loaded session object (for recursive calls to avoid Redis refetch)
|
session: Optional pre-loaded session object (for recursive calls to avoid Redis refetch)
|
||||||
|
prompt_type: The type of prompt to use ("default" or "onboarding")
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
StreamBaseResponse objects formatted as SSE
|
StreamBaseResponse objects formatted as SSE
|
||||||
@@ -121,9 +232,18 @@ async def stream_chat_completion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
|
# Build message content with context if provided
|
||||||
|
message_content = message
|
||||||
|
if context and context.get("url") and context.get("content"):
|
||||||
|
context_text = f"Page URL: {context['url']}\n\nPage Content:\n{context['content']}\n\n---\n\nUser Message: {message}"
|
||||||
|
message_content = context_text
|
||||||
|
logger.info(
|
||||||
|
f"Including page context: URL={context['url']}, content_length={len(context['content'])}"
|
||||||
|
)
|
||||||
|
|
||||||
session.messages.append(
|
session.messages.append(
|
||||||
ChatMessage(
|
ChatMessage(
|
||||||
role="user" if is_user_message else "assistant", content=message
|
role="user" if is_user_message else "assistant", content=message_content
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -141,6 +261,32 @@ async def stream_chat_completion(
|
|||||||
session = await upsert_chat_session(session)
|
session = await upsert_chat_session(session)
|
||||||
assert session, "Session not found"
|
assert session, "Session not found"
|
||||||
|
|
||||||
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
|
# Check: is_user_message, no title yet, and this is the first user message
|
||||||
|
if is_user_message and message and not session.title:
|
||||||
|
user_messages = [m for m in session.messages if m.role == "user"]
|
||||||
|
if len(user_messages) == 1:
|
||||||
|
# First user message - generate title in background
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def _update_title():
|
||||||
|
try:
|
||||||
|
title = await _generate_session_title(message)
|
||||||
|
if title:
|
||||||
|
session.title = title
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
logger.info(
|
||||||
|
f"Generated title for session {session_id}: {title}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to update session title: {e}")
|
||||||
|
|
||||||
|
# Fire and forget - don't block the chat response
|
||||||
|
asyncio.create_task(_update_title())
|
||||||
|
|
||||||
|
# Build system prompt with business understanding
|
||||||
|
system_prompt = await _build_system_prompt(user_id, prompt_type)
|
||||||
|
|
||||||
assistant_response = ChatMessage(
|
assistant_response = ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
@@ -159,6 +305,7 @@ async def stream_chat_completion(
|
|||||||
async for chunk in _stream_chat_chunks(
|
async for chunk in _stream_chat_chunks(
|
||||||
session=session,
|
session=session,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
system_prompt=system_prompt,
|
||||||
):
|
):
|
||||||
|
|
||||||
if isinstance(chunk, StreamTextChunk):
|
if isinstance(chunk, StreamTextChunk):
|
||||||
@@ -279,6 +426,7 @@ async def stream_chat_completion(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
retry_count=retry_count + 1,
|
retry_count=retry_count + 1,
|
||||||
session=session,
|
session=session,
|
||||||
|
prompt_type=prompt_type,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
return # Exit after retry to avoid double-saving in finally block
|
return # Exit after retry to avoid double-saving in finally block
|
||||||
@@ -324,6 +472,7 @@ async def stream_chat_completion(
|
|||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass session object to avoid Redis refetch
|
session=session, # Pass session object to avoid Redis refetch
|
||||||
|
prompt_type=prompt_type,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@@ -331,6 +480,7 @@ async def stream_chat_completion(
|
|||||||
async def _stream_chat_chunks(
|
async def _stream_chat_chunks(
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
tools: list[ChatCompletionToolParam],
|
tools: list[ChatCompletionToolParam],
|
||||||
|
system_prompt: str | None = None,
|
||||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
"""
|
"""
|
||||||
Pure streaming function for OpenAI chat completions with tool calling.
|
Pure streaming function for OpenAI chat completions with tool calling.
|
||||||
@@ -338,9 +488,9 @@ async def _stream_chat_chunks(
|
|||||||
This function is database-agnostic and focuses only on streaming logic.
|
This function is database-agnostic and focuses only on streaming logic.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: Conversation context as ChatCompletionMessageParam list
|
session: Chat session with conversation history
|
||||||
session_id: Session ID
|
tools: Available tools for the model
|
||||||
user_id: User ID for tool execution
|
system_prompt: System prompt to prepend to messages
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
SSE formatted JSON response objects
|
SSE formatted JSON response objects
|
||||||
@@ -350,6 +500,17 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
logger.info("Starting pure chat stream")
|
logger.info("Starting pure chat stream")
|
||||||
|
|
||||||
|
# Build messages with system prompt prepended
|
||||||
|
messages = session.to_openai_messages()
|
||||||
|
if system_prompt:
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||||
|
|
||||||
|
system_message = ChatCompletionSystemMessageParam(
|
||||||
|
role="system",
|
||||||
|
content=system_prompt,
|
||||||
|
)
|
||||||
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Loop to handle tool calls and continue conversation
|
# Loop to handle tool calls and continue conversation
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -358,7 +519,7 @@ async def _stream_chat_chunks(
|
|||||||
# Create the stream with proper types
|
# Create the stream with proper types
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=session.to_openai_messages(),
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
stream=True,
|
stream=True,
|
||||||
@@ -502,8 +663,12 @@ async def _yield_tool_call(
|
|||||||
"""
|
"""
|
||||||
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
|
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
|
||||||
|
|
||||||
# Parse tool call arguments - exceptions will propagate to caller
|
# Parse tool call arguments - handle empty arguments gracefully
|
||||||
arguments = orjson.loads(tool_calls[yield_idx]["function"]["arguments"])
|
raw_arguments = tool_calls[yield_idx]["function"]["arguments"]
|
||||||
|
if raw_arguments:
|
||||||
|
arguments = orjson.loads(raw_arguments)
|
||||||
|
else:
|
||||||
|
arguments = {}
|
||||||
|
|
||||||
yield StreamToolCall(
|
yield StreamToolCall(
|
||||||
tool_id=tool_calls[yield_idx]["id"],
|
tool_id=tool_calls[yield_idx]["id"],
|
||||||
|
|||||||
@@ -4,21 +4,30 @@ from openai.types.chat import ChatCompletionToolParam
|
|||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
|
from .add_understanding import AddUnderstandingTool
|
||||||
|
from .agent_output import AgentOutputTool
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .find_agent import FindAgentTool
|
from .find_agent import FindAgentTool
|
||||||
|
from .find_library_agent import FindLibraryAgentTool
|
||||||
from .run_agent import RunAgentTool
|
from .run_agent import RunAgentTool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||||
|
|
||||||
# Initialize tool instances
|
# Initialize tool instances
|
||||||
|
add_understanding_tool = AddUnderstandingTool()
|
||||||
find_agent_tool = FindAgentTool()
|
find_agent_tool = FindAgentTool()
|
||||||
|
find_library_agent_tool = FindLibraryAgentTool()
|
||||||
run_agent_tool = RunAgentTool()
|
run_agent_tool = RunAgentTool()
|
||||||
|
agent_output_tool = AgentOutputTool()
|
||||||
|
|
||||||
# Export tools as OpenAI format
|
# Export tools as OpenAI format
|
||||||
tools: list[ChatCompletionToolParam] = [
|
tools: list[ChatCompletionToolParam] = [
|
||||||
|
add_understanding_tool.as_openai_tool(),
|
||||||
find_agent_tool.as_openai_tool(),
|
find_agent_tool.as_openai_tool(),
|
||||||
|
find_library_agent_tool.as_openai_tool(),
|
||||||
run_agent_tool.as_openai_tool(),
|
run_agent_tool.as_openai_tool(),
|
||||||
|
agent_output_tool.as_openai_tool(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -31,8 +40,11 @@ async def execute_tool(
|
|||||||
) -> "StreamToolExecutionResult":
|
) -> "StreamToolExecutionResult":
|
||||||
|
|
||||||
tool_map: dict[str, BaseTool] = {
|
tool_map: dict[str, BaseTool] = {
|
||||||
|
"add_understanding": add_understanding_tool,
|
||||||
"find_agent": find_agent_tool,
|
"find_agent": find_agent_tool,
|
||||||
|
"find_library_agent": find_library_agent_tool,
|
||||||
"run_agent": run_agent_tool,
|
"run_agent": run_agent_tool,
|
||||||
|
"agent_output": agent_output_tool,
|
||||||
}
|
}
|
||||||
if tool_name not in tool_map:
|
if tool_name not in tool_map:
|
||||||
raise ValueError(f"Tool {tool_name} not found")
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from datetime import UTC, datetime
|
|||||||
from os import getenv
|
from os import getenv
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from prisma.types import ProfileCreateInput
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
@@ -49,13 +50,13 @@ async def setup_test_data():
|
|||||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||||
username = user.email.split("@")[0]
|
username = user.email.split("@")[0]
|
||||||
await prisma.profile.create(
|
await prisma.profile.create(
|
||||||
data={
|
data=ProfileCreateInput(
|
||||||
"userId": user.id,
|
userId=user.id,
|
||||||
"username": username,
|
username=username,
|
||||||
"name": f"Test User {username}",
|
name=f"Test User {username}",
|
||||||
"description": "Test user profile",
|
description="Test user profile",
|
||||||
"links": [], # Required field - empty array for test profiles
|
links=[], # Required field - empty array for test profiles
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Create a test graph with agent input -> agent output
|
# 2. Create a test graph with agent input -> agent output
|
||||||
@@ -172,13 +173,13 @@ async def setup_llm_test_data():
|
|||||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||||
username = user.email.split("@")[0]
|
username = user.email.split("@")[0]
|
||||||
await prisma.profile.create(
|
await prisma.profile.create(
|
||||||
data={
|
data=ProfileCreateInput(
|
||||||
"userId": user.id,
|
userId=user.id,
|
||||||
"username": username,
|
username=username,
|
||||||
"name": f"Test User {username}",
|
name=f"Test User {username}",
|
||||||
"description": "Test user profile for LLM tests",
|
description="Test user profile for LLM tests",
|
||||||
"links": [], # Required field - empty array for test profiles
|
links=[], # Required field - empty array for test profiles
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Create test OpenAI credentials for the user
|
# 2. Create test OpenAI credentials for the user
|
||||||
@@ -332,13 +333,13 @@ async def setup_firecrawl_test_data():
|
|||||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||||
username = user.email.split("@")[0]
|
username = user.email.split("@")[0]
|
||||||
await prisma.profile.create(
|
await prisma.profile.create(
|
||||||
data={
|
data=ProfileCreateInput(
|
||||||
"userId": user.id,
|
userId=user.id,
|
||||||
"username": username,
|
username=username,
|
||||||
"name": f"Test User {username}",
|
name=f"Test User {username}",
|
||||||
"description": "Test user profile for Firecrawl tests",
|
description="Test user profile for Firecrawl tests",
|
||||||
"links": [], # Required field - empty array for test profiles
|
links=[], # Required field - empty array for test profiles
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
||||||
|
|||||||
@@ -0,0 +1,202 @@
|
|||||||
|
"""Tool for capturing user business understanding incrementally."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.data.understanding import (
|
||||||
|
BusinessUnderstandingInput,
|
||||||
|
upsert_business_understanding,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AddUnderstandingTool(BaseTool):
|
||||||
|
"""Tool for capturing user's business understanding incrementally."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "add_understanding"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """Capture and store information about the user's business context,
|
||||||
|
workflows, pain points, and automation goals. Call this tool whenever the user
|
||||||
|
shares information about their business. Each call incrementally adds to the
|
||||||
|
existing understanding - you don't need to provide all fields at once.
|
||||||
|
|
||||||
|
Use this to build a comprehensive profile that helps recommend better agents
|
||||||
|
and automations for the user's specific needs."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"user_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The user's name",
|
||||||
|
},
|
||||||
|
"job_title": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The user's job title (e.g., 'Marketing Manager', 'CEO', 'Software Engineer')",
|
||||||
|
},
|
||||||
|
"business_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Name of the user's business or organization",
|
||||||
|
},
|
||||||
|
"industry": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Industry or sector (e.g., 'e-commerce', 'healthcare', 'finance')",
|
||||||
|
},
|
||||||
|
"business_size": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Company size: '1-10', '11-50', '51-200', '201-1000', or '1000+'",
|
||||||
|
},
|
||||||
|
"user_role": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "User's role in organization context (e.g., 'decision maker', 'implementer', 'end user')",
|
||||||
|
},
|
||||||
|
"key_workflows": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Key business workflows (e.g., 'lead qualification', 'content publishing')",
|
||||||
|
},
|
||||||
|
"daily_activities": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Regular daily activities the user performs",
|
||||||
|
},
|
||||||
|
"pain_points": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Current pain points or challenges",
|
||||||
|
},
|
||||||
|
"bottlenecks": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Process bottlenecks slowing things down",
|
||||||
|
},
|
||||||
|
"manual_tasks": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Manual or repetitive tasks that could be automated",
|
||||||
|
},
|
||||||
|
"automation_goals": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Desired automation outcomes or goals",
|
||||||
|
},
|
||||||
|
"current_software": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Software and tools currently in use",
|
||||||
|
},
|
||||||
|
"existing_automation": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Any existing automations or integrations",
|
||||||
|
},
|
||||||
|
"additional_notes": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Any other relevant context or notes",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
"""Requires authentication to store user-specific data."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""
|
||||||
|
Capture and store business understanding incrementally.
|
||||||
|
|
||||||
|
Each call merges new data with existing understanding:
|
||||||
|
- String fields are overwritten if provided
|
||||||
|
- List fields are appended (with deduplication)
|
||||||
|
"""
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required to save business understanding.",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if any data was provided
|
||||||
|
if not any(v is not None for v in kwargs.values()):
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide at least one field to update.",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build input model
|
||||||
|
input_data = BusinessUnderstandingInput(
|
||||||
|
user_name=kwargs.get("user_name"),
|
||||||
|
job_title=kwargs.get("job_title"),
|
||||||
|
business_name=kwargs.get("business_name"),
|
||||||
|
industry=kwargs.get("industry"),
|
||||||
|
business_size=kwargs.get("business_size"),
|
||||||
|
user_role=kwargs.get("user_role"),
|
||||||
|
key_workflows=kwargs.get("key_workflows"),
|
||||||
|
daily_activities=kwargs.get("daily_activities"),
|
||||||
|
pain_points=kwargs.get("pain_points"),
|
||||||
|
bottlenecks=kwargs.get("bottlenecks"),
|
||||||
|
manual_tasks=kwargs.get("manual_tasks"),
|
||||||
|
automation_goals=kwargs.get("automation_goals"),
|
||||||
|
current_software=kwargs.get("current_software"),
|
||||||
|
existing_automation=kwargs.get("existing_automation"),
|
||||||
|
additional_notes=kwargs.get("additional_notes"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track which fields were updated
|
||||||
|
updated_fields = [k for k, v in kwargs.items() if v is not None]
|
||||||
|
|
||||||
|
# Upsert with merge
|
||||||
|
understanding = await upsert_business_understanding(user_id, input_data)
|
||||||
|
|
||||||
|
# Build current understanding summary for the response
|
||||||
|
current_understanding = {
|
||||||
|
"user_name": understanding.user_name,
|
||||||
|
"job_title": understanding.job_title,
|
||||||
|
"business_name": understanding.business_name,
|
||||||
|
"industry": understanding.industry,
|
||||||
|
"business_size": understanding.business_size,
|
||||||
|
"user_role": understanding.user_role,
|
||||||
|
"key_workflows": understanding.key_workflows,
|
||||||
|
"daily_activities": understanding.daily_activities,
|
||||||
|
"pain_points": understanding.pain_points,
|
||||||
|
"bottlenecks": understanding.bottlenecks,
|
||||||
|
"manual_tasks": understanding.manual_tasks,
|
||||||
|
"automation_goals": understanding.automation_goals,
|
||||||
|
"current_software": understanding.current_software,
|
||||||
|
"existing_automation": understanding.existing_automation,
|
||||||
|
"additional_notes": understanding.additional_notes,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter out empty values for cleaner response
|
||||||
|
current_understanding = {
|
||||||
|
k: v
|
||||||
|
for k, v in current_understanding.items()
|
||||||
|
if v is not None and v != [] and v != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return UnderstandingUpdatedResponse(
|
||||||
|
message=f"Updated understanding with: {', '.join(updated_fields)}. "
|
||||||
|
"I now have a better picture of your business context.",
|
||||||
|
session_id=session_id,
|
||||||
|
updated_fields=updated_fields,
|
||||||
|
current_understanding=current_understanding,
|
||||||
|
)
|
||||||
@@ -0,0 +1,455 @@
|
|||||||
|
"""Tool for retrieving agent execution outputs from user's library."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.library import db as library_db
|
||||||
|
from backend.api.features.library.model import LibraryAgent
|
||||||
|
from backend.data import execution as execution_db
|
||||||
|
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import (
|
||||||
|
AgentOutputResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
ExecutionOutputInfo,
|
||||||
|
NoResultsResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
from .utils import fetch_graph_from_store_slug
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentOutputInput(BaseModel):
|
||||||
|
"""Input parameters for the agent_output tool."""
|
||||||
|
|
||||||
|
agent_name: str = ""
|
||||||
|
library_agent_id: str = ""
|
||||||
|
store_slug: str = ""
|
||||||
|
execution_id: str = ""
|
||||||
|
run_time: str = "latest"
|
||||||
|
|
||||||
|
@field_validator(
|
||||||
|
"agent_name",
|
||||||
|
"library_agent_id",
|
||||||
|
"store_slug",
|
||||||
|
"execution_id",
|
||||||
|
"run_time",
|
||||||
|
mode="before",
|
||||||
|
)
|
||||||
|
@classmethod
|
||||||
|
def strip_strings(cls, v: Any) -> Any:
|
||||||
|
"""Strip whitespace from string fields."""
|
||||||
|
return v.strip() if isinstance(v, str) else v
|
||||||
|
|
||||||
|
|
||||||
|
def parse_time_expression(
|
||||||
|
time_expr: str | None,
|
||||||
|
) -> tuple[datetime | None, datetime | None]:
|
||||||
|
"""
|
||||||
|
Parse time expression into datetime range (start, end).
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- "latest" or None -> returns (None, None) to get most recent
|
||||||
|
- "yesterday" -> 24h window for yesterday
|
||||||
|
- "today" -> Today from midnight
|
||||||
|
- "last week" / "last 7 days" -> 7 day window
|
||||||
|
- "last month" / "last 30 days" -> 30 day window
|
||||||
|
- ISO date "YYYY-MM-DD" -> 24h window for that date
|
||||||
|
"""
|
||||||
|
if not time_expr or time_expr.lower() == "latest":
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
expr = time_expr.lower().strip()
|
||||||
|
|
||||||
|
# Relative expressions
|
||||||
|
if expr == "yesterday":
|
||||||
|
end = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
start = end - timedelta(days=1)
|
||||||
|
return start, end
|
||||||
|
|
||||||
|
if expr in ("last week", "last 7 days"):
|
||||||
|
return now - timedelta(days=7), now
|
||||||
|
|
||||||
|
if expr in ("last month", "last 30 days"):
|
||||||
|
return now - timedelta(days=30), now
|
||||||
|
|
||||||
|
if expr == "today":
|
||||||
|
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
return start, now
|
||||||
|
|
||||||
|
# Try ISO date format (YYYY-MM-DD)
|
||||||
|
date_match = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", expr)
|
||||||
|
if date_match:
|
||||||
|
year, month, day = map(int, date_match.groups())
|
||||||
|
start = datetime(year, month, day, 0, 0, 0, tzinfo=timezone.utc)
|
||||||
|
end = start + timedelta(days=1)
|
||||||
|
return start, end
|
||||||
|
|
||||||
|
# Try ISO datetime
|
||||||
|
try:
|
||||||
|
parsed = datetime.fromisoformat(expr.replace("Z", "+00:00"))
|
||||||
|
if parsed.tzinfo is None:
|
||||||
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||||
|
# Return +/- 1 hour window around the specified time
|
||||||
|
return parsed - timedelta(hours=1), parsed + timedelta(hours=1)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: treat as "latest"
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentOutputTool(BaseTool):
|
||||||
|
"""Tool for retrieving execution outputs from user's library agents."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "agent_output"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """Retrieve execution outputs from agents in the user's library.
|
||||||
|
|
||||||
|
Identify the agent using one of:
|
||||||
|
- agent_name: Fuzzy search in user's library
|
||||||
|
- library_agent_id: Exact library agent ID
|
||||||
|
- store_slug: Marketplace format 'username/agent-name'
|
||||||
|
|
||||||
|
Select which run to retrieve using:
|
||||||
|
- execution_id: Specific execution ID
|
||||||
|
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"agent_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Agent name to search for in user's library (fuzzy match)",
|
||||||
|
},
|
||||||
|
"library_agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Exact library agent ID",
|
||||||
|
},
|
||||||
|
"store_slug": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Marketplace identifier: 'username/agent-slug'",
|
||||||
|
},
|
||||||
|
"execution_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Specific execution ID to retrieve",
|
||||||
|
},
|
||||||
|
"run_time": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _resolve_agent(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
agent_name: str | None,
|
||||||
|
library_agent_id: str | None,
|
||||||
|
store_slug: str | None,
|
||||||
|
) -> tuple[LibraryAgent | None, str | None]:
|
||||||
|
"""
|
||||||
|
Resolve agent from provided identifiers.
|
||||||
|
Returns (library_agent, error_message).
|
||||||
|
"""
|
||||||
|
# Priority 1: Exact library agent ID
|
||||||
|
if library_agent_id:
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||||
|
return agent, None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||||
|
return None, f"Library agent '{library_agent_id}' not found"
|
||||||
|
|
||||||
|
# Priority 2: Store slug (username/agent-name)
|
||||||
|
if store_slug and "/" in store_slug:
|
||||||
|
username, agent_slug = store_slug.split("/", 1)
|
||||||
|
graph, _ = await fetch_graph_from_store_slug(username, agent_slug)
|
||||||
|
if not graph:
|
||||||
|
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||||
|
|
||||||
|
# Find in user's library by graph_id
|
||||||
|
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||||
|
if not agent:
|
||||||
|
return (
|
||||||
|
None,
|
||||||
|
f"Agent '{store_slug}' is not in your library. "
|
||||||
|
"Add it first to see outputs.",
|
||||||
|
)
|
||||||
|
return agent, None
|
||||||
|
|
||||||
|
# Priority 3: Fuzzy name search in library
|
||||||
|
if agent_name:
|
||||||
|
try:
|
||||||
|
response = await library_db.list_library_agents(
|
||||||
|
user_id=user_id,
|
||||||
|
search_term=agent_name,
|
||||||
|
page_size=5,
|
||||||
|
)
|
||||||
|
if not response.agents:
|
||||||
|
return (
|
||||||
|
None,
|
||||||
|
f"No agents matching '{agent_name}' found in your library",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return best match (first result from search)
|
||||||
|
return response.agents[0], None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error searching library agents: {e}")
|
||||||
|
return None, f"Error searching for agent: {e}"
|
||||||
|
|
||||||
|
return (
|
||||||
|
None,
|
||||||
|
"Please specify an agent name, library_agent_id, or store_slug",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_execution(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
execution_id: str | None,
|
||||||
|
time_start: datetime | None,
|
||||||
|
time_end: datetime | None,
|
||||||
|
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
||||||
|
"""
|
||||||
|
Fetch execution(s) based on filters.
|
||||||
|
Returns (single_execution, available_executions_meta, error_message).
|
||||||
|
"""
|
||||||
|
# If specific execution_id provided, fetch it directly
|
||||||
|
if execution_id:
|
||||||
|
execution = await execution_db.get_graph_execution(
|
||||||
|
user_id=user_id,
|
||||||
|
execution_id=execution_id,
|
||||||
|
include_node_executions=False,
|
||||||
|
)
|
||||||
|
if not execution:
|
||||||
|
return None, [], f"Execution '{execution_id}' not found"
|
||||||
|
return execution, [], None
|
||||||
|
|
||||||
|
# Get completed executions with time filters
|
||||||
|
executions = await execution_db.get_graph_executions(
|
||||||
|
graph_id=graph_id,
|
||||||
|
user_id=user_id,
|
||||||
|
statuses=[ExecutionStatus.COMPLETED],
|
||||||
|
created_time_gte=time_start,
|
||||||
|
created_time_lte=time_end,
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not executions:
|
||||||
|
return None, [], None # No error, just no executions
|
||||||
|
|
||||||
|
# If only one execution, fetch full details
|
||||||
|
if len(executions) == 1:
|
||||||
|
full_execution = await execution_db.get_graph_execution(
|
||||||
|
user_id=user_id,
|
||||||
|
execution_id=executions[0].id,
|
||||||
|
include_node_executions=False,
|
||||||
|
)
|
||||||
|
return full_execution, [], None
|
||||||
|
|
||||||
|
# Multiple executions - return latest with full details, plus list of available
|
||||||
|
full_execution = await execution_db.get_graph_execution(
|
||||||
|
user_id=user_id,
|
||||||
|
execution_id=executions[0].id,
|
||||||
|
include_node_executions=False,
|
||||||
|
)
|
||||||
|
return full_execution, executions, None
|
||||||
|
|
||||||
|
def _build_response(
|
||||||
|
self,
|
||||||
|
agent: LibraryAgent,
|
||||||
|
execution: GraphExecution | None,
|
||||||
|
available_executions: list[GraphExecutionMeta],
|
||||||
|
session_id: str | None,
|
||||||
|
) -> AgentOutputResponse:
|
||||||
|
"""Build the response based on execution data."""
|
||||||
|
library_agent_link = f"/library/agents/{agent.id}"
|
||||||
|
|
||||||
|
if not execution:
|
||||||
|
return AgentOutputResponse(
|
||||||
|
message=f"No completed executions found for agent '{agent.name}'",
|
||||||
|
session_id=session_id,
|
||||||
|
agent_name=agent.name,
|
||||||
|
agent_id=agent.graph_id,
|
||||||
|
library_agent_id=agent.id,
|
||||||
|
library_agent_link=library_agent_link,
|
||||||
|
total_executions=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_info = ExecutionOutputInfo(
|
||||||
|
execution_id=execution.id,
|
||||||
|
status=execution.status.value,
|
||||||
|
started_at=execution.started_at,
|
||||||
|
ended_at=execution.ended_at,
|
||||||
|
outputs=dict(execution.outputs),
|
||||||
|
inputs_summary=execution.inputs if execution.inputs else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
available_list = None
|
||||||
|
if len(available_executions) > 1:
|
||||||
|
available_list = [
|
||||||
|
{
|
||||||
|
"id": e.id,
|
||||||
|
"status": e.status.value,
|
||||||
|
"started_at": e.started_at.isoformat() if e.started_at else None,
|
||||||
|
}
|
||||||
|
for e in available_executions[:5]
|
||||||
|
]
|
||||||
|
|
||||||
|
message = f"Found execution outputs for agent '{agent.name}'"
|
||||||
|
if len(available_executions) > 1:
|
||||||
|
message += (
|
||||||
|
f". Showing latest of {len(available_executions)} matching executions."
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentOutputResponse(
|
||||||
|
message=message,
|
||||||
|
session_id=session_id,
|
||||||
|
agent_name=agent.name,
|
||||||
|
agent_id=agent.graph_id,
|
||||||
|
library_agent_id=agent.id,
|
||||||
|
library_agent_link=library_agent_link,
|
||||||
|
execution=execution_info,
|
||||||
|
available_executions=available_list,
|
||||||
|
total_executions=len(available_executions) if available_executions else 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Execute the agent_output tool."""
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
# Parse and validate input
|
||||||
|
try:
|
||||||
|
input_data = AgentOutputInput(**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Invalid input: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Invalid input parameters",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure user_id is present (should be guaranteed by requires_auth)
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="User authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if at least one identifier is provided
|
||||||
|
if not any(
|
||||||
|
[
|
||||||
|
input_data.agent_name,
|
||||||
|
input_data.library_agent_id,
|
||||||
|
input_data.store_slug,
|
||||||
|
input_data.execution_id,
|
||||||
|
]
|
||||||
|
):
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Please specify at least one of: agent_name, "
|
||||||
|
"library_agent_id, store_slug, or execution_id"
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If only execution_id provided, we need to find the agent differently
|
||||||
|
if (
|
||||||
|
input_data.execution_id
|
||||||
|
and not input_data.agent_name
|
||||||
|
and not input_data.library_agent_id
|
||||||
|
and not input_data.store_slug
|
||||||
|
):
|
||||||
|
# Fetch execution directly to get graph_id
|
||||||
|
execution = await execution_db.get_graph_execution(
|
||||||
|
user_id=user_id,
|
||||||
|
execution_id=input_data.execution_id,
|
||||||
|
include_node_executions=False,
|
||||||
|
)
|
||||||
|
if not execution:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Execution '{input_data.execution_id}' not found",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find library agent by graph_id
|
||||||
|
agent = await library_db.get_library_agent_by_graph_id(
|
||||||
|
user_id, execution.graph_id
|
||||||
|
)
|
||||||
|
if not agent:
|
||||||
|
return NoResultsResponse(
|
||||||
|
message=(
|
||||||
|
f"Execution found but agent not in your library. "
|
||||||
|
f"Graph ID: {execution.graph_id}"
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
suggestions=["Add the agent to your library to see more details"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._build_response(agent, execution, [], session_id)
|
||||||
|
|
||||||
|
# Resolve agent from identifiers
|
||||||
|
agent, error = await self._resolve_agent(
|
||||||
|
user_id=user_id,
|
||||||
|
agent_name=input_data.agent_name or None,
|
||||||
|
library_agent_id=input_data.library_agent_id or None,
|
||||||
|
store_slug=input_data.store_slug or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if error or not agent:
|
||||||
|
return NoResultsResponse(
|
||||||
|
message=error or "Agent not found",
|
||||||
|
session_id=session_id,
|
||||||
|
suggestions=[
|
||||||
|
"Check the agent name or ID",
|
||||||
|
"Make sure the agent is in your library",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse time expression
|
||||||
|
time_start, time_end = parse_time_expression(input_data.run_time)
|
||||||
|
|
||||||
|
# Fetch execution(s)
|
||||||
|
execution, available_executions, exec_error = await self._get_execution(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
execution_id=input_data.execution_id or None,
|
||||||
|
time_start=time_start,
|
||||||
|
time_end=time_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
if exec_error:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=exec_error,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._build_response(agent, execution, available_executions, session_id)
|
||||||
@@ -0,0 +1,157 @@
|
|||||||
|
"""Tool for searching agents in the user's library."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.library import db as library_db
|
||||||
|
from backend.util.exceptions import DatabaseError
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import (
|
||||||
|
AgentCarouselResponse,
|
||||||
|
AgentInfo,
|
||||||
|
ErrorResponse,
|
||||||
|
NoResultsResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FindLibraryAgentTool(BaseTool):
|
||||||
|
"""Tool for searching agents in the user's library."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "find_library_agent"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Search for agents in the user's library. Use this to find agents "
|
||||||
|
"the user has already added to their library, including agents they "
|
||||||
|
"created or added from the marketplace."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Search query to find agents by name or description. "
|
||||||
|
"Use keywords for best results."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Search for agents in the user's library.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID (required)
|
||||||
|
session: Chat session
|
||||||
|
query: Search query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentCarouselResponse: List of agents found in the library
|
||||||
|
NoResultsResponse: No agents found
|
||||||
|
ErrorResponse: Error message
|
||||||
|
"""
|
||||||
|
query = kwargs.get("query", "").strip()
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide a search query",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="User authentication required to search library",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
agents = []
|
||||||
|
try:
|
||||||
|
logger.info(f"Searching user library for: {query}")
|
||||||
|
library_results = await library_db.list_library_agents(
|
||||||
|
user_id=user_id,
|
||||||
|
search_term=query,
|
||||||
|
page_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Find library agents tool found {len(library_results.agents)} agents"
|
||||||
|
)
|
||||||
|
|
||||||
|
for agent in library_results.agents:
|
||||||
|
agents.append(
|
||||||
|
AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
except DatabaseError as e:
|
||||||
|
logger.error(f"Error searching library agents: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to search library. Please try again.",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not agents:
|
||||||
|
return NoResultsResponse(
|
||||||
|
message=(
|
||||||
|
f"No agents found matching '{query}' in your library. "
|
||||||
|
"Try different keywords or use find_agent to search the marketplace."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
suggestions=[
|
||||||
|
"Try more general terms",
|
||||||
|
"Use find_agent to search the marketplace",
|
||||||
|
"Check your library at /library",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
title = (
|
||||||
|
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
|
||||||
|
f"in your library for '{query}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentCarouselResponse(
|
||||||
|
message=(
|
||||||
|
"Found agents in the user's library. You can provide a link to "
|
||||||
|
"view an agent at: /library/agents/{agent_id}. "
|
||||||
|
"Use agent_output to get execution results, or run_agent to execute."
|
||||||
|
),
|
||||||
|
title=title,
|
||||||
|
agents=agents,
|
||||||
|
count=len(agents),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Pydantic models for tool responses."""
|
"""Pydantic models for tool responses."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -19,6 +20,15 @@ class ResponseType(str, Enum):
|
|||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
NO_RESULTS = "no_results"
|
NO_RESULTS = "no_results"
|
||||||
SUCCESS = "success"
|
SUCCESS = "success"
|
||||||
|
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||||
|
AGENT_OUTPUT = "agent_output"
|
||||||
|
BLOCK_LIST = "block_list"
|
||||||
|
BLOCK_OUTPUT = "block_output"
|
||||||
|
UNDERSTANDING_UPDATED = "understanding_updated"
|
||||||
|
# Agent generation responses
|
||||||
|
AGENT_PREVIEW = "agent_preview"
|
||||||
|
AGENT_SAVED = "agent_saved"
|
||||||
|
CLARIFICATION_NEEDED = "clarification_needed"
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
# Base response model
|
||||||
@@ -173,3 +183,128 @@ class ErrorResponse(ToolResponseBase):
|
|||||||
type: ResponseType = ResponseType.ERROR
|
type: ResponseType = ResponseType.ERROR
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
details: dict[str, Any] | None = None
|
details: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Documentation search models
|
||||||
|
class DocSearchResult(BaseModel):
|
||||||
|
"""A single documentation search result."""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
path: str
|
||||||
|
section: str
|
||||||
|
snippet: str # Short excerpt for UI display
|
||||||
|
content: str # Full text content for LLM to read and understand
|
||||||
|
score: float
|
||||||
|
doc_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DocSearchResultsResponse(ToolResponseBase):
|
||||||
|
"""Response for search_docs tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.DOC_SEARCH_RESULTS
|
||||||
|
results: list[DocSearchResult]
|
||||||
|
count: int
|
||||||
|
query: str
|
||||||
|
|
||||||
|
|
||||||
|
# Agent output models
|
||||||
|
class ExecutionOutputInfo(BaseModel):
|
||||||
|
"""Summary of a single execution's outputs."""
|
||||||
|
|
||||||
|
execution_id: str
|
||||||
|
status: str
|
||||||
|
started_at: datetime | None = None
|
||||||
|
ended_at: datetime | None = None
|
||||||
|
outputs: dict[str, list[Any]]
|
||||||
|
inputs_summary: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentOutputResponse(ToolResponseBase):
|
||||||
|
"""Response for agent_output tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.AGENT_OUTPUT
|
||||||
|
agent_name: str
|
||||||
|
agent_id: str
|
||||||
|
library_agent_id: str | None = None
|
||||||
|
library_agent_link: str | None = None
|
||||||
|
execution: ExecutionOutputInfo | None = None
|
||||||
|
available_executions: list[dict[str, Any]] | None = None
|
||||||
|
total_executions: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Block models
|
||||||
|
class BlockInfoSummary(BaseModel):
|
||||||
|
"""Summary of a block for search results."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
categories: list[str]
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
output_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class BlockListResponse(ToolResponseBase):
|
||||||
|
"""Response for find_block tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.BLOCK_LIST
|
||||||
|
blocks: list[BlockInfoSummary]
|
||||||
|
count: int
|
||||||
|
query: str
|
||||||
|
|
||||||
|
|
||||||
|
class BlockOutputResponse(ToolResponseBase):
|
||||||
|
"""Response for run_block tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.BLOCK_OUTPUT
|
||||||
|
block_id: str
|
||||||
|
block_name: str
|
||||||
|
outputs: dict[str, list[Any]]
|
||||||
|
success: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
# Business understanding models
|
||||||
|
class UnderstandingUpdatedResponse(ToolResponseBase):
|
||||||
|
"""Response for add_understanding tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
|
||||||
|
updated_fields: list[str] = Field(default_factory=list)
|
||||||
|
current_understanding: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
# Agent generation models
|
||||||
|
class ClarifyingQuestion(BaseModel):
|
||||||
|
"""A question that needs user clarification."""
|
||||||
|
|
||||||
|
question: str
|
||||||
|
keyword: str
|
||||||
|
example: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentPreviewResponse(ToolResponseBase):
|
||||||
|
"""Response for previewing a generated agent before saving."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.AGENT_PREVIEW
|
||||||
|
agent_json: dict[str, Any]
|
||||||
|
agent_name: str
|
||||||
|
description: str
|
||||||
|
node_count: int
|
||||||
|
link_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSavedResponse(ToolResponseBase):
|
||||||
|
"""Response when an agent is saved to the library."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.AGENT_SAVED
|
||||||
|
agent_id: str
|
||||||
|
agent_name: str
|
||||||
|
library_agent_id: str
|
||||||
|
library_agent_link: str
|
||||||
|
agent_page_link: str # Link to the agent builder/editor page
|
||||||
|
|
||||||
|
|
||||||
|
class ClarificationNeededResponse(ToolResponseBase):
|
||||||
|
"""Response when the LLM needs more information from the user."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.CLARIFICATION_NEEDED
|
||||||
|
questions: list[ClarifyingQuestion] = Field(default_factory=list)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
|
|
||||||
from backend.api.features.chat.config import ChatConfig
|
from backend.api.features.chat.config import ChatConfig
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.data.user import get_user_by_id
|
from backend.data.user import get_user_by_id
|
||||||
@@ -57,6 +58,7 @@ class RunAgentInput(BaseModel):
|
|||||||
"""Input parameters for the run_agent tool."""
|
"""Input parameters for the run_agent tool."""
|
||||||
|
|
||||||
username_agent_slug: str = ""
|
username_agent_slug: str = ""
|
||||||
|
library_agent_id: str = ""
|
||||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||||
use_defaults: bool = False
|
use_defaults: bool = False
|
||||||
schedule_name: str = ""
|
schedule_name: str = ""
|
||||||
@@ -64,7 +66,12 @@ class RunAgentInput(BaseModel):
|
|||||||
timezone: str = "UTC"
|
timezone: str = "UTC"
|
||||||
|
|
||||||
@field_validator(
|
@field_validator(
|
||||||
"username_agent_slug", "schedule_name", "cron", "timezone", mode="before"
|
"username_agent_slug",
|
||||||
|
"library_agent_id",
|
||||||
|
"schedule_name",
|
||||||
|
"cron",
|
||||||
|
"timezone",
|
||||||
|
mode="before",
|
||||||
)
|
)
|
||||||
@classmethod
|
@classmethod
|
||||||
def strip_strings(cls, v: Any) -> Any:
|
def strip_strings(cls, v: Any) -> Any:
|
||||||
@@ -90,7 +97,7 @@ class RunAgentTool(BaseTool):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return """Run or schedule an agent from the marketplace.
|
return """Run or schedule an agent from the marketplace or user's library.
|
||||||
|
|
||||||
The tool automatically handles the setup flow:
|
The tool automatically handles the setup flow:
|
||||||
- Returns missing inputs if required fields are not provided
|
- Returns missing inputs if required fields are not provided
|
||||||
@@ -98,6 +105,10 @@ class RunAgentTool(BaseTool):
|
|||||||
- Executes immediately if all requirements are met
|
- Executes immediately if all requirements are met
|
||||||
- Schedules execution if cron expression is provided
|
- Schedules execution if cron expression is provided
|
||||||
|
|
||||||
|
Identify the agent using either:
|
||||||
|
- username_agent_slug: Marketplace format 'username/agent-name'
|
||||||
|
- library_agent_id: ID of an agent in the user's library
|
||||||
|
|
||||||
For scheduled execution, provide: schedule_name, cron, and optionally timezone."""
|
For scheduled execution, provide: schedule_name, cron, and optionally timezone."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -109,6 +120,10 @@ class RunAgentTool(BaseTool):
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Agent identifier in format 'username/agent-name'",
|
"description": "Agent identifier in format 'username/agent-name'",
|
||||||
},
|
},
|
||||||
|
"library_agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Library agent ID from user's library",
|
||||||
|
},
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"description": "Input values for the agent",
|
"description": "Input values for the agent",
|
||||||
@@ -131,7 +146,7 @@ class RunAgentTool(BaseTool):
|
|||||||
"description": "IANA timezone for schedule (default: UTC)",
|
"description": "IANA timezone for schedule (default: UTC)",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["username_agent_slug"],
|
"required": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -149,10 +164,16 @@ class RunAgentTool(BaseTool):
|
|||||||
params = RunAgentInput(**kwargs)
|
params = RunAgentInput(**kwargs)
|
||||||
session_id = session.session_id
|
session_id = session.session_id
|
||||||
|
|
||||||
# Validate agent slug format
|
# Validate at least one identifier is provided
|
||||||
if not params.username_agent_slug or "/" not in params.username_agent_slug:
|
has_slug = params.username_agent_slug and "/" in params.username_agent_slug
|
||||||
|
has_library_id = bool(params.library_agent_id)
|
||||||
|
|
||||||
|
if not has_slug and not has_library_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide an agent slug in format 'username/agent-name'",
|
message=(
|
||||||
|
"Please provide either a username_agent_slug "
|
||||||
|
"(format 'username/agent-name') or a library_agent_id"
|
||||||
|
),
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -167,13 +188,41 @@ class RunAgentTool(BaseTool):
|
|||||||
is_schedule = bool(params.schedule_name or params.cron)
|
is_schedule = bool(params.schedule_name or params.cron)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Step 1: Fetch agent details (always happens first)
|
# Step 1: Fetch agent details
|
||||||
username, agent_name = params.username_agent_slug.split("/", 1)
|
graph: GraphModel | None = None
|
||||||
graph, store_agent = await fetch_graph_from_store_slug(username, agent_name)
|
library_agent = None
|
||||||
|
|
||||||
|
# Priority: library_agent_id if provided
|
||||||
|
if has_library_id:
|
||||||
|
library_agent = await library_db.get_library_agent(
|
||||||
|
params.library_agent_id, user_id
|
||||||
|
)
|
||||||
|
if not library_agent:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Library agent '{params.library_agent_id}' not found",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
# Get the graph from the library agent
|
||||||
|
from backend.data.graph import get_graph
|
||||||
|
|
||||||
|
graph = await get_graph(
|
||||||
|
library_agent.graph_id,
|
||||||
|
library_agent.graph_version,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fetch from marketplace slug
|
||||||
|
username, agent_name = params.username_agent_slug.split("/", 1)
|
||||||
|
graph, _ = await fetch_graph_from_store_slug(username, agent_name)
|
||||||
|
|
||||||
if not graph:
|
if not graph:
|
||||||
|
identifier = (
|
||||||
|
params.library_agent_id
|
||||||
|
if has_library_id
|
||||||
|
else params.username_agent_slug
|
||||||
|
)
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"Agent '{params.username_agent_slug}' not found in marketplace",
|
message=f"Agent '{identifier}' not found",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
import pytest
|
import pytest
|
||||||
@@ -18,17 +17,6 @@ setup_test_data = setup_test_data
|
|||||||
setup_firecrawl_test_data = setup_firecrawl_test_data
|
setup_firecrawl_test_data = setup_firecrawl_test_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def mock_embedding_functions():
|
|
||||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.db.ensure_embedding",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=True,
|
|
||||||
):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(scope="session")
|
||||||
async def test_run_agent(setup_test_data):
|
async def test_run_agent(setup_test_data):
|
||||||
"""Test that the run_agent tool successfully executes an approved agent"""
|
"""Test that the run_agent tool successfully executes an approved agent"""
|
||||||
|
|||||||
@@ -35,7 +35,11 @@ from backend.data.model import (
|
|||||||
OAuth2Credentials,
|
OAuth2Credentials,
|
||||||
UserIntegrations,
|
UserIntegrations,
|
||||||
)
|
)
|
||||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
from backend.data.onboarding import (
|
||||||
|
OnboardingStep,
|
||||||
|
complete_onboarding_step,
|
||||||
|
increment_runs,
|
||||||
|
)
|
||||||
from backend.data.user import get_user_integrations
|
from backend.data.user import get_user_integrations
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||||
@@ -374,6 +378,7 @@ async def webhook_ingress_generic(
|
|||||||
return
|
return
|
||||||
|
|
||||||
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
|
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
|
||||||
|
await increment_runs(user_id)
|
||||||
|
|
||||||
# Execute all triggers concurrently for better performance
|
# Execute all triggers concurrently for better performance
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|||||||
@@ -489,7 +489,7 @@ async def update_agent_version_in_library(
|
|||||||
agent_graph_version: int,
|
agent_graph_version: int,
|
||||||
) -> library_model.LibraryAgent:
|
) -> library_model.LibraryAgent:
|
||||||
"""
|
"""
|
||||||
Updates the agent version in the library for any agent owned by the user.
|
Updates the agent version in the library if useGraphIsActiveVersion is True.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: Owner of the LibraryAgent.
|
user_id: Owner of the LibraryAgent.
|
||||||
@@ -498,31 +498,20 @@ async def update_agent_version_in_library(
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
DatabaseError: If there's an error with the update.
|
DatabaseError: If there's an error with the update.
|
||||||
NotFoundError: If no library agent is found for this user and agent.
|
|
||||||
"""
|
"""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Updating agent version in library for user #{user_id}, "
|
f"Updating agent version in library for user #{user_id}, "
|
||||||
f"agent #{agent_graph_id} v{agent_graph_version}"
|
f"agent #{agent_graph_id} v{agent_graph_version}"
|
||||||
)
|
)
|
||||||
async with transaction() as tx:
|
try:
|
||||||
library_agent = await prisma.models.LibraryAgent.prisma(tx).find_first_or_raise(
|
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
|
||||||
where={
|
where={
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
"agentGraphId": agent_graph_id,
|
"agentGraphId": agent_graph_id,
|
||||||
|
"useGraphIsActiveVersion": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
lib = await prisma.models.LibraryAgent.prisma().update(
|
||||||
# Delete any conflicting LibraryAgent for the target version
|
|
||||||
await prisma.models.LibraryAgent.prisma(tx).delete_many(
|
|
||||||
where={
|
|
||||||
"userId": user_id,
|
|
||||||
"agentGraphId": agent_graph_id,
|
|
||||||
"agentGraphVersion": agent_graph_version,
|
|
||||||
"id": {"not": library_agent.id},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
lib = await prisma.models.LibraryAgent.prisma(tx).update(
|
|
||||||
where={"id": library_agent.id},
|
where={"id": library_agent.id},
|
||||||
data={
|
data={
|
||||||
"AgentGraph": {
|
"AgentGraph": {
|
||||||
@@ -536,13 +525,13 @@ async def update_agent_version_in_library(
|
|||||||
},
|
},
|
||||||
include={"AgentGraph": True},
|
include={"AgentGraph": True},
|
||||||
)
|
)
|
||||||
|
if lib is None:
|
||||||
|
raise NotFoundError(f"Library agent {library_agent.id} not found")
|
||||||
|
|
||||||
if lib is None:
|
return library_model.LibraryAgent.from_db(lib)
|
||||||
raise NotFoundError(
|
except prisma.errors.PrismaError as e:
|
||||||
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
|
logger.error(f"Database error updating agent version in library: {e}")
|
||||||
)
|
raise DatabaseError("Failed to update agent version in library") from e
|
||||||
|
|
||||||
return library_model.LibraryAgent.from_db(lib)
|
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
@@ -836,7 +825,6 @@ async def add_store_agent_to_library(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"isCreatedByUser": False,
|
"isCreatedByUser": False,
|
||||||
"useGraphIsActiveVersion": False,
|
|
||||||
"settings": SafeJson(
|
"settings": SafeJson(
|
||||||
_initialize_graph_settings(graph_model).model_dump()
|
_initialize_graph_settings(graph_model).model_dump()
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
owner_user_id: str # ID of user who owns/created this agent graph
|
|
||||||
|
|
||||||
image_url: str | None
|
image_url: str | None
|
||||||
|
|
||||||
@@ -164,7 +163,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
id=agent.id,
|
id=agent.id,
|
||||||
graph_id=agent.agentGraphId,
|
graph_id=agent.agentGraphId,
|
||||||
graph_version=agent.agentGraphVersion,
|
graph_version=agent.agentGraphVersion,
|
||||||
owner_user_id=agent.userId,
|
|
||||||
image_url=agent.imageUrl,
|
image_url=agent.imageUrl,
|
||||||
creator_name=creator_name,
|
creator_name=creator_name,
|
||||||
creator_image_url=creator_image_url,
|
creator_image_url=creator_image_url,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from backend.data.execution import GraphExecutionMeta
|
|||||||
from backend.data.graph import get_graph
|
from backend.data.graph import get_graph
|
||||||
from backend.data.integrations import get_webhook
|
from backend.data.integrations import get_webhook
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
|
from backend.data.onboarding import increment_runs
|
||||||
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks import get_webhook_manager
|
from backend.integrations.webhooks import get_webhook_manager
|
||||||
@@ -402,6 +403,8 @@ async def execute_preset(
|
|||||||
merged_node_input = preset.inputs | inputs
|
merged_node_input = preset.inputs | inputs
|
||||||
merged_credential_inputs = preset.credentials | credential_inputs
|
merged_credential_inputs = preset.credentials | credential_inputs
|
||||||
|
|
||||||
|
await increment_runs(user_id)
|
||||||
|
|
||||||
return await add_graph_execution(
|
return await add_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
graph_id=preset.graph_id,
|
graph_id=preset.graph_id,
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ async def test_get_library_agents_success(
|
|||||||
id="test-agent-1",
|
id="test-agent-1",
|
||||||
graph_id="test-agent-1",
|
graph_id="test-agent-1",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
owner_user_id=test_user_id,
|
|
||||||
name="Test Agent 1",
|
name="Test Agent 1",
|
||||||
description="Test Description 1",
|
description="Test Description 1",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
@@ -65,7 +64,6 @@ async def test_get_library_agents_success(
|
|||||||
id="test-agent-2",
|
id="test-agent-2",
|
||||||
graph_id="test-agent-2",
|
graph_id="test-agent-2",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
owner_user_id=test_user_id,
|
|
||||||
name="Test Agent 2",
|
name="Test Agent 2",
|
||||||
description="Test Description 2",
|
description="Test Description 2",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
@@ -140,7 +138,6 @@ async def test_get_favorite_library_agents_success(
|
|||||||
id="test-agent-1",
|
id="test-agent-1",
|
||||||
graph_id="test-agent-1",
|
graph_id="test-agent-1",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
owner_user_id=test_user_id,
|
|
||||||
name="Favorite Agent 1",
|
name="Favorite Agent 1",
|
||||||
description="Test Favorite Description 1",
|
description="Test Favorite Description 1",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
@@ -208,7 +205,6 @@ def test_add_agent_to_library_success(
|
|||||||
id="test-library-agent-id",
|
id="test-library-agent-id",
|
||||||
graph_id="test-agent-1",
|
graph_id="test-agent-1",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
owner_user_id=test_user_id,
|
|
||||||
name="Test Agent 1",
|
name="Test Agent 1",
|
||||||
description="Test Description 1",
|
description="Test Description 1",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import typing
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Literal
|
from typing import Literal
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
@@ -9,7 +10,7 @@ import prisma.errors
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import prisma.types
|
import prisma.types
|
||||||
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import query_raw_with_schema, transaction
|
||||||
from backend.data.graph import (
|
from backend.data.graph import (
|
||||||
GraphMeta,
|
GraphMeta,
|
||||||
GraphModel,
|
GraphModel,
|
||||||
@@ -29,8 +30,6 @@ from backend.util.settings import Settings
|
|||||||
|
|
||||||
from . import exceptions as store_exceptions
|
from . import exceptions as store_exceptions
|
||||||
from . import model as store_model
|
from . import model as store_model
|
||||||
from .embeddings import ensure_embedding
|
|
||||||
from .hybrid_search import hybrid_search
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
@@ -51,77 +50,128 @@ async def get_store_agents(
|
|||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
) -> store_model.StoreAgentsResponse:
|
) -> store_model.StoreAgentsResponse:
|
||||||
"""
|
"""
|
||||||
Get PUBLIC store agents from the StoreAgent view.
|
Get PUBLIC store agents from the StoreAgent view
|
||||||
|
|
||||||
Search behavior:
|
|
||||||
- With search_query: Uses hybrid search (semantic + lexical)
|
|
||||||
- Fallback: If embeddings unavailable, gracefully degrades to lexical-only
|
|
||||||
- Rationale: User-facing endpoint prioritizes availability over accuracy
|
|
||||||
|
|
||||||
Note: Admin operations (approval) use fail-fast to prevent inconsistent state.
|
|
||||||
"""
|
"""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||||
)
|
)
|
||||||
|
|
||||||
search_used_hybrid = False
|
|
||||||
store_agents: list[store_model.StoreAgent] = []
|
|
||||||
agents: list[dict[str, Any]] = []
|
|
||||||
total = 0
|
|
||||||
total_pages = 0
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# If search_query is provided, use hybrid search (embeddings + tsvector)
|
# If search_query is provided, use full-text search
|
||||||
if search_query:
|
if search_query:
|
||||||
# Try hybrid search combining semantic and lexical signals
|
offset = (page - 1) * page_size
|
||||||
# Falls back to lexical-only if OpenAI unavailable (user-facing, high SLA)
|
|
||||||
try:
|
|
||||||
agents, total = await hybrid_search(
|
|
||||||
query=search_query,
|
|
||||||
featured=featured,
|
|
||||||
creators=creators,
|
|
||||||
category=category,
|
|
||||||
sorted_by="relevance", # Use hybrid scoring for relevance
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
)
|
|
||||||
search_used_hybrid = True
|
|
||||||
except Exception as e:
|
|
||||||
# Log error but fall back to lexical search for better UX
|
|
||||||
logger.error(
|
|
||||||
f"Hybrid search failed (likely OpenAI unavailable), "
|
|
||||||
f"falling back to lexical search: {e}"
|
|
||||||
)
|
|
||||||
# search_used_hybrid remains False, will use fallback path below
|
|
||||||
|
|
||||||
# Convert hybrid search results (dict format) if hybrid succeeded
|
# Whitelist allowed order_by columns
|
||||||
if search_used_hybrid:
|
ALLOWED_ORDER_BY = {
|
||||||
total_pages = (total + page_size - 1) // page_size
|
"rating": "rating DESC, rank DESC",
|
||||||
store_agents: list[store_model.StoreAgent] = []
|
"runs": "runs DESC, rank DESC",
|
||||||
for agent in agents:
|
"name": "agent_name ASC, rank ASC",
|
||||||
try:
|
"updated_at": "updated_at DESC, rank DESC",
|
||||||
store_agent = store_model.StoreAgent(
|
}
|
||||||
slug=agent["slug"],
|
|
||||||
agent_name=agent["agent_name"],
|
|
||||||
agent_image=(
|
|
||||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
|
||||||
),
|
|
||||||
creator=agent["creator_username"] or "Needs Profile",
|
|
||||||
creator_avatar=agent["creator_avatar"] or "",
|
|
||||||
sub_heading=agent["sub_heading"],
|
|
||||||
description=agent["description"],
|
|
||||||
runs=agent["runs"],
|
|
||||||
rating=agent["rating"],
|
|
||||||
)
|
|
||||||
store_agents.append(store_agent)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error parsing Store agent from hybrid search results: {e}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not search_used_hybrid:
|
# Validate and get order clause
|
||||||
# Fallback path - use basic search or no search
|
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
||||||
|
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
||||||
|
else:
|
||||||
|
order_by_clause = "updated_at DESC, rank DESC"
|
||||||
|
|
||||||
|
# Build WHERE conditions and parameters list
|
||||||
|
where_parts: list[str] = []
|
||||||
|
params: list[typing.Any] = [search_query] # $1 - search term
|
||||||
|
param_index = 2 # Start at $2 for next parameter
|
||||||
|
|
||||||
|
# Always filter for available agents
|
||||||
|
where_parts.append("is_available = true")
|
||||||
|
|
||||||
|
if featured:
|
||||||
|
where_parts.append("featured = true")
|
||||||
|
|
||||||
|
if creators and creators:
|
||||||
|
# Use ANY with array parameter
|
||||||
|
where_parts.append(f"creator_username = ANY(${param_index})")
|
||||||
|
params.append(creators)
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
if category and category:
|
||||||
|
where_parts.append(f"${param_index} = ANY(categories)")
|
||||||
|
params.append(category)
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
||||||
|
|
||||||
|
# Add pagination params
|
||||||
|
params.extend([page_size, offset])
|
||||||
|
limit_param = f"${param_index}"
|
||||||
|
offset_param = f"${param_index + 1}"
|
||||||
|
|
||||||
|
# Execute full-text search query with parameterized values
|
||||||
|
sql_query = f"""
|
||||||
|
SELECT
|
||||||
|
slug,
|
||||||
|
agent_name,
|
||||||
|
agent_image,
|
||||||
|
creator_username,
|
||||||
|
creator_avatar,
|
||||||
|
sub_heading,
|
||||||
|
description,
|
||||||
|
runs,
|
||||||
|
rating,
|
||||||
|
categories,
|
||||||
|
featured,
|
||||||
|
is_available,
|
||||||
|
updated_at,
|
||||||
|
ts_rank_cd(search, query) AS rank
|
||||||
|
FROM {{schema_prefix}}"StoreAgent",
|
||||||
|
plainto_tsquery('english', $1) AS query
|
||||||
|
WHERE {sql_where_clause}
|
||||||
|
AND search @@ query
|
||||||
|
ORDER BY {order_by_clause}
|
||||||
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Count query for pagination - only uses search term parameter
|
||||||
|
count_query = f"""
|
||||||
|
SELECT COUNT(*) as count
|
||||||
|
FROM {{schema_prefix}}"StoreAgent",
|
||||||
|
plainto_tsquery('english', $1) AS query
|
||||||
|
WHERE {sql_where_clause}
|
||||||
|
AND search @@ query
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Execute both queries with parameters
|
||||||
|
agents = await query_raw_with_schema(sql_query, *params)
|
||||||
|
|
||||||
|
# For count, use params without pagination (last 2 params)
|
||||||
|
count_params = params[:-2]
|
||||||
|
count_result = await query_raw_with_schema(count_query, *count_params)
|
||||||
|
|
||||||
|
total = count_result[0]["count"] if count_result else 0
|
||||||
|
total_pages = (total + page_size - 1) // page_size
|
||||||
|
|
||||||
|
# Convert raw results to StoreAgent models
|
||||||
|
store_agents: list[store_model.StoreAgent] = []
|
||||||
|
for agent in agents:
|
||||||
|
try:
|
||||||
|
store_agent = store_model.StoreAgent(
|
||||||
|
slug=agent["slug"],
|
||||||
|
agent_name=agent["agent_name"],
|
||||||
|
agent_image=(
|
||||||
|
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||||
|
),
|
||||||
|
creator=agent["creator_username"] or "Needs Profile",
|
||||||
|
creator_avatar=agent["creator_avatar"] or "",
|
||||||
|
sub_heading=agent["sub_heading"],
|
||||||
|
description=agent["description"],
|
||||||
|
runs=agent["runs"],
|
||||||
|
rating=agent["rating"],
|
||||||
|
)
|
||||||
|
store_agents.append(store_agent)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing Store agent from search results: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Non-search query path (original logic)
|
||||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||||
if featured:
|
if featured:
|
||||||
where_clause["featured"] = featured
|
where_clause["featured"] = featured
|
||||||
@@ -130,14 +180,6 @@ async def get_store_agents(
|
|||||||
if category:
|
if category:
|
||||||
where_clause["categories"] = {"has": category}
|
where_clause["categories"] = {"has": category}
|
||||||
|
|
||||||
# Add basic text search if search_query provided but hybrid failed
|
|
||||||
if search_query:
|
|
||||||
where_clause["OR"] = [
|
|
||||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
|
||||||
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
|
|
||||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
|
||||||
]
|
|
||||||
|
|
||||||
order_by = []
|
order_by = []
|
||||||
if sorted_by == "rating":
|
if sorted_by == "rating":
|
||||||
order_by.append({"rating": "desc"})
|
order_by.append({"rating": "desc"})
|
||||||
@@ -146,7 +188,7 @@ async def get_store_agents(
|
|||||||
elif sorted_by == "name":
|
elif sorted_by == "name":
|
||||||
order_by.append({"agent_name": "asc"})
|
order_by.append({"agent_name": "asc"})
|
||||||
|
|
||||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||||
where=where_clause,
|
where=where_clause,
|
||||||
order=order_by,
|
order=order_by,
|
||||||
skip=(page - 1) * page_size,
|
skip=(page - 1) * page_size,
|
||||||
@@ -157,7 +199,7 @@ async def get_store_agents(
|
|||||||
total_pages = (total + page_size - 1) // page_size
|
total_pages = (total + page_size - 1) // page_size
|
||||||
|
|
||||||
store_agents: list[store_model.StoreAgent] = []
|
store_agents: list[store_model.StoreAgent] = []
|
||||||
for agent in db_agents:
|
for agent in agents:
|
||||||
try:
|
try:
|
||||||
# Create the StoreAgent object safely
|
# Create the StoreAgent object safely
|
||||||
store_agent = store_model.StoreAgent(
|
store_agent = store_model.StoreAgent(
|
||||||
@@ -572,7 +614,6 @@ async def get_store_submissions(
|
|||||||
submission_models = []
|
submission_models = []
|
||||||
for sub in submissions:
|
for sub in submissions:
|
||||||
submission_model = store_model.StoreSubmission(
|
submission_model = store_model.StoreSubmission(
|
||||||
listing_id=sub.listing_id,
|
|
||||||
agent_id=sub.agent_id,
|
agent_id=sub.agent_id,
|
||||||
agent_version=sub.agent_version,
|
agent_version=sub.agent_version,
|
||||||
name=sub.name,
|
name=sub.name,
|
||||||
@@ -626,48 +667,35 @@ async def delete_store_submission(
|
|||||||
submission_id: str,
|
submission_id: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete a store submission version as the submitting user.
|
Delete a store listing submission as the submitting user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: ID of the authenticated user
|
user_id: ID of the authenticated user
|
||||||
submission_id: StoreListingVersion ID to delete
|
submission_id: ID of the submission to be deleted
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if successfully deleted
|
bool: True if the submission was successfully deleted, False otherwise
|
||||||
"""
|
"""
|
||||||
|
logger.debug(f"Deleting store submission {submission_id} for user {user_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Find the submission version with ownership check
|
# Verify the submission belongs to this user
|
||||||
version = await prisma.models.StoreListingVersion.prisma().find_first(
|
submission = await prisma.models.StoreListing.prisma().find_first(
|
||||||
where={"id": submission_id}, include={"StoreListing": True}
|
where={"agentGraphId": submission_id, "owningUserId": user_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if not submission:
|
||||||
not version
|
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
|
||||||
or not version.StoreListing
|
raise store_exceptions.SubmissionNotFoundError(
|
||||||
or version.StoreListing.owningUserId != user_id
|
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
|
||||||
):
|
|
||||||
raise store_exceptions.SubmissionNotFoundError("Submission not found")
|
|
||||||
|
|
||||||
# Prevent deletion of approved submissions
|
|
||||||
if version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
|
||||||
raise store_exceptions.InvalidOperationError(
|
|
||||||
"Cannot delete approved submissions"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete the version
|
# Delete the submission
|
||||||
await prisma.models.StoreListingVersion.prisma().delete(
|
await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})
|
||||||
where={"id": version.id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up empty listing if this was the last version
|
logger.debug(
|
||||||
remaining = await prisma.models.StoreListingVersion.prisma().count(
|
f"Successfully deleted submission {submission_id} for user {user_id}"
|
||||||
where={"storeListingId": version.storeListingId}
|
|
||||||
)
|
)
|
||||||
if remaining == 0:
|
|
||||||
await prisma.models.StoreListing.prisma().delete(
|
|
||||||
where={"id": version.storeListingId}
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -731,15 +759,9 @@ async def create_store_submission(
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
|
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
|
||||||
)
|
)
|
||||||
# Provide more user-friendly error message when agent_id is empty
|
raise store_exceptions.AgentNotFoundError(
|
||||||
if not agent_id or agent_id.strip() == "":
|
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||||
raise store_exceptions.AgentNotFoundError(
|
)
|
||||||
"No agent selected. Please select an agent before submitting to the store."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise store_exceptions.AgentNotFoundError(
|
|
||||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if listing already exists for this agent
|
# Check if listing already exists for this agent
|
||||||
existing_listing = await prisma.models.StoreListing.prisma().find_first(
|
existing_listing = await prisma.models.StoreListing.prisma().find_first(
|
||||||
@@ -811,7 +833,6 @@ async def create_store_submission(
|
|||||||
logger.debug(f"Created store listing for agent {agent_id}")
|
logger.debug(f"Created store listing for agent {agent_id}")
|
||||||
# Return submission details
|
# Return submission details
|
||||||
return store_model.StoreSubmission(
|
return store_model.StoreSubmission(
|
||||||
listing_id=listing.id,
|
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
agent_version=agent_version,
|
agent_version=agent_version,
|
||||||
name=name,
|
name=name,
|
||||||
@@ -923,56 +944,81 @@ async def edit_store_submission(
|
|||||||
# Currently we are not allowing user to update the agent associated with a submission
|
# Currently we are not allowing user to update the agent associated with a submission
|
||||||
# If we allow it in future, then we need a check here to verify the agent belongs to this user.
|
# If we allow it in future, then we need a check here to verify the agent belongs to this user.
|
||||||
|
|
||||||
# Only allow editing of PENDING submissions
|
# Check if we can edit this submission
|
||||||
if current_version.submissionStatus != prisma.enums.SubmissionStatus.PENDING:
|
if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED:
|
||||||
raise store_exceptions.InvalidOperationError(
|
raise store_exceptions.InvalidOperationError(
|
||||||
f"Cannot edit a {current_version.submissionStatus.value.lower()} submission. Only pending submissions can be edited."
|
"Cannot edit a rejected submission"
|
||||||
|
)
|
||||||
|
|
||||||
|
# For APPROVED submissions, we need to create a new version
|
||||||
|
if current_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
||||||
|
# Create a new version for the existing listing
|
||||||
|
return await create_store_version(
|
||||||
|
user_id=user_id,
|
||||||
|
agent_id=current_version.agentGraphId,
|
||||||
|
agent_version=current_version.agentGraphVersion,
|
||||||
|
store_listing_id=current_version.storeListingId,
|
||||||
|
name=name,
|
||||||
|
video_url=video_url,
|
||||||
|
agent_output_demo_url=agent_output_demo_url,
|
||||||
|
image_urls=image_urls,
|
||||||
|
description=description,
|
||||||
|
sub_heading=sub_heading,
|
||||||
|
categories=categories,
|
||||||
|
changes_summary=changes_summary,
|
||||||
|
recommended_schedule_cron=recommended_schedule_cron,
|
||||||
|
instructions=instructions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For PENDING submissions, we can update the existing version
|
# For PENDING submissions, we can update the existing version
|
||||||
# Update the existing version
|
elif current_version.submissionStatus == prisma.enums.SubmissionStatus.PENDING:
|
||||||
updated_version = await prisma.models.StoreListingVersion.prisma().update(
|
# Update the existing version
|
||||||
where={"id": store_listing_version_id},
|
updated_version = await prisma.models.StoreListingVersion.prisma().update(
|
||||||
data=prisma.types.StoreListingVersionUpdateInput(
|
where={"id": store_listing_version_id},
|
||||||
|
data=prisma.types.StoreListingVersionUpdateInput(
|
||||||
|
name=name,
|
||||||
|
videoUrl=video_url,
|
||||||
|
agentOutputDemoUrl=agent_output_demo_url,
|
||||||
|
imageUrls=image_urls,
|
||||||
|
description=description,
|
||||||
|
categories=categories,
|
||||||
|
subHeading=sub_heading,
|
||||||
|
changesSummary=changes_summary,
|
||||||
|
recommendedScheduleCron=recommended_schedule_cron,
|
||||||
|
instructions=instructions,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not updated_version:
|
||||||
|
raise DatabaseError("Failed to update store listing version")
|
||||||
|
return store_model.StoreSubmission(
|
||||||
|
agent_id=current_version.agentGraphId,
|
||||||
|
agent_version=current_version.agentGraphVersion,
|
||||||
name=name,
|
name=name,
|
||||||
videoUrl=video_url,
|
sub_heading=sub_heading,
|
||||||
agentOutputDemoUrl=agent_output_demo_url,
|
slug=current_version.StoreListing.slug,
|
||||||
imageUrls=image_urls,
|
|
||||||
description=description,
|
description=description,
|
||||||
categories=categories,
|
|
||||||
subHeading=sub_heading,
|
|
||||||
changesSummary=changes_summary,
|
|
||||||
recommendedScheduleCron=recommended_schedule_cron,
|
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
),
|
image_urls=image_urls,
|
||||||
)
|
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
||||||
|
status=updated_version.submissionStatus,
|
||||||
|
runs=0,
|
||||||
|
rating=0.0,
|
||||||
|
store_listing_version_id=updated_version.id,
|
||||||
|
changes_summary=changes_summary,
|
||||||
|
video_url=video_url,
|
||||||
|
categories=categories,
|
||||||
|
version=updated_version.version,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
else:
|
||||||
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
|
raise store_exceptions.InvalidOperationError(
|
||||||
)
|
f"Cannot edit submission with status: {current_version.submissionStatus}"
|
||||||
|
)
|
||||||
if not updated_version:
|
|
||||||
raise DatabaseError("Failed to update store listing version")
|
|
||||||
return store_model.StoreSubmission(
|
|
||||||
listing_id=current_version.StoreListing.id,
|
|
||||||
agent_id=current_version.agentGraphId,
|
|
||||||
agent_version=current_version.agentGraphVersion,
|
|
||||||
name=name,
|
|
||||||
sub_heading=sub_heading,
|
|
||||||
slug=current_version.StoreListing.slug,
|
|
||||||
description=description,
|
|
||||||
instructions=instructions,
|
|
||||||
image_urls=image_urls,
|
|
||||||
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
|
||||||
status=updated_version.submissionStatus,
|
|
||||||
runs=0,
|
|
||||||
rating=0.0,
|
|
||||||
store_listing_version_id=updated_version.id,
|
|
||||||
changes_summary=changes_summary,
|
|
||||||
video_url=video_url,
|
|
||||||
categories=categories,
|
|
||||||
version=updated_version.version,
|
|
||||||
)
|
|
||||||
|
|
||||||
except (
|
except (
|
||||||
store_exceptions.SubmissionNotFoundError,
|
store_exceptions.SubmissionNotFoundError,
|
||||||
@@ -1051,78 +1097,38 @@ async def create_store_version(
|
|||||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if there's already a PENDING submission for this agent (any version)
|
# Get the latest version number
|
||||||
existing_pending_submission = (
|
latest_version = listing.Versions[0] if listing.Versions else None
|
||||||
await prisma.models.StoreListingVersion.prisma().find_first(
|
|
||||||
where=prisma.types.StoreListingVersionWhereInput(
|
next_version = (latest_version.version + 1) if latest_version else 1
|
||||||
storeListingId=store_listing_id,
|
|
||||||
agentGraphId=agent_id,
|
# Create a new version for the existing listing
|
||||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
new_version = await prisma.models.StoreListingVersion.prisma().create(
|
||||||
isDeleted=False,
|
data=prisma.types.StoreListingVersionCreateInput(
|
||||||
)
|
version=next_version,
|
||||||
|
agentGraphId=agent_id,
|
||||||
|
agentGraphVersion=agent_version,
|
||||||
|
name=name,
|
||||||
|
videoUrl=video_url,
|
||||||
|
agentOutputDemoUrl=agent_output_demo_url,
|
||||||
|
imageUrls=image_urls,
|
||||||
|
description=description,
|
||||||
|
instructions=instructions,
|
||||||
|
categories=categories,
|
||||||
|
subHeading=sub_heading,
|
||||||
|
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||||
|
submittedAt=datetime.now(),
|
||||||
|
changesSummary=changes_summary,
|
||||||
|
recommendedScheduleCron=recommended_schedule_cron,
|
||||||
|
storeListingId=store_listing_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle existing pending submission and create new one atomically
|
|
||||||
async with transaction() as tx:
|
|
||||||
# Get the latest version number first
|
|
||||||
latest_listing = await prisma.models.StoreListing.prisma(tx).find_first(
|
|
||||||
where=prisma.types.StoreListingWhereInput(
|
|
||||||
id=store_listing_id, owningUserId=user_id
|
|
||||||
),
|
|
||||||
include={"Versions": {"order_by": {"version": "desc"}, "take": 1}},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not latest_listing:
|
|
||||||
raise store_exceptions.ListingNotFoundError(
|
|
||||||
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
latest_version = (
|
|
||||||
latest_listing.Versions[0] if latest_listing.Versions else None
|
|
||||||
)
|
|
||||||
next_version = (latest_version.version + 1) if latest_version else 1
|
|
||||||
|
|
||||||
# If there's an existing pending submission, delete it atomically before creating new one
|
|
||||||
if existing_pending_submission:
|
|
||||||
logger.info(
|
|
||||||
f"Found existing PENDING submission for agent {agent_id} (was v{existing_pending_submission.agentGraphVersion}, now v{agent_version}), replacing existing submission instead of creating duplicate"
|
|
||||||
)
|
|
||||||
await prisma.models.StoreListingVersion.prisma(tx).delete(
|
|
||||||
where={"id": existing_pending_submission.id}
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Deleted existing pending submission {existing_pending_submission.id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a new version for the existing listing
|
|
||||||
new_version = await prisma.models.StoreListingVersion.prisma(tx).create(
|
|
||||||
data=prisma.types.StoreListingVersionCreateInput(
|
|
||||||
version=next_version,
|
|
||||||
agentGraphId=agent_id,
|
|
||||||
agentGraphVersion=agent_version,
|
|
||||||
name=name,
|
|
||||||
videoUrl=video_url,
|
|
||||||
agentOutputDemoUrl=agent_output_demo_url,
|
|
||||||
imageUrls=image_urls,
|
|
||||||
description=description,
|
|
||||||
instructions=instructions,
|
|
||||||
categories=categories,
|
|
||||||
subHeading=sub_heading,
|
|
||||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
|
||||||
submittedAt=datetime.now(),
|
|
||||||
changesSummary=changes_summary,
|
|
||||||
recommendedScheduleCron=recommended_schedule_cron,
|
|
||||||
storeListingId=store_listing_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Created new version for listing {store_listing_id} of agent {agent_id}"
|
f"Created new version for listing {store_listing_id} of agent {agent_id}"
|
||||||
)
|
)
|
||||||
# Return submission details
|
# Return submission details
|
||||||
return store_model.StoreSubmission(
|
return store_model.StoreSubmission(
|
||||||
listing_id=listing.id,
|
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
agent_version=agent_version,
|
agent_version=agent_version,
|
||||||
name=name,
|
name=name,
|
||||||
@@ -1535,7 +1541,7 @@ async def review_store_submission(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update the AgentGraph with store listing data
|
# Update the AgentGraph with store listing data
|
||||||
await prisma.models.AgentGraph.prisma(tx).update(
|
await prisma.models.AgentGraph.prisma().update(
|
||||||
where={
|
where={
|
||||||
"graphVersionId": {
|
"graphVersionId": {
|
||||||
"id": store_listing_version.agentGraphId,
|
"id": store_listing_version.agentGraphId,
|
||||||
@@ -1550,23 +1556,6 @@ async def review_store_submission(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate embedding for approved listing (blocking - admin operation)
|
|
||||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
|
||||||
embedding_success = await ensure_embedding(
|
|
||||||
version_id=store_listing_version_id,
|
|
||||||
name=store_listing_version.name,
|
|
||||||
description=store_listing_version.description,
|
|
||||||
sub_heading=store_listing_version.subHeading,
|
|
||||||
categories=store_listing_version.categories or [],
|
|
||||||
tx=tx,
|
|
||||||
)
|
|
||||||
if not embedding_success:
|
|
||||||
raise ValueError(
|
|
||||||
f"Failed to generate embedding for listing {store_listing_version_id}. "
|
|
||||||
"This is likely due to OpenAI API being unavailable. "
|
|
||||||
"Please try again later or contact support if the issue persists."
|
|
||||||
)
|
|
||||||
|
|
||||||
await prisma.models.StoreListing.prisma(tx).update(
|
await prisma.models.StoreListing.prisma(tx).update(
|
||||||
where={"id": store_listing_version.StoreListing.id},
|
where={"id": store_listing_version.StoreListing.id},
|
||||||
data={
|
data={
|
||||||
@@ -1719,12 +1708,15 @@ async def review_store_submission(
|
|||||||
|
|
||||||
# Convert to Pydantic model for consistency
|
# Convert to Pydantic model for consistency
|
||||||
return store_model.StoreSubmission(
|
return store_model.StoreSubmission(
|
||||||
listing_id=(submission.StoreListing.id if submission.StoreListing else ""),
|
|
||||||
agent_id=submission.agentGraphId,
|
agent_id=submission.agentGraphId,
|
||||||
agent_version=submission.agentGraphVersion,
|
agent_version=submission.agentGraphVersion,
|
||||||
name=submission.name,
|
name=submission.name,
|
||||||
sub_heading=submission.subHeading,
|
sub_heading=submission.subHeading,
|
||||||
slug=(submission.StoreListing.slug if submission.StoreListing else ""),
|
slug=(
|
||||||
|
submission.StoreListing.slug
|
||||||
|
if hasattr(submission, "storeListing") and submission.StoreListing
|
||||||
|
else ""
|
||||||
|
),
|
||||||
description=submission.description,
|
description=submission.description,
|
||||||
instructions=submission.instructions,
|
instructions=submission.instructions,
|
||||||
image_urls=submission.imageUrls or [],
|
image_urls=submission.imageUrls or [],
|
||||||
@@ -1826,7 +1818,9 @@ async def get_admin_listings_with_versions(
|
|||||||
where = prisma.types.StoreListingWhereInput(**where_dict)
|
where = prisma.types.StoreListingWhereInput(**where_dict)
|
||||||
include = prisma.types.StoreListingInclude(
|
include = prisma.types.StoreListingInclude(
|
||||||
Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing(
|
Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing(
|
||||||
order_by={"version": "desc"}
|
order_by=prisma.types._StoreListingVersion_version_OrderByInput(
|
||||||
|
version="desc"
|
||||||
|
)
|
||||||
),
|
),
|
||||||
OwningUser=True,
|
OwningUser=True,
|
||||||
)
|
)
|
||||||
@@ -1851,7 +1845,6 @@ async def get_admin_listings_with_versions(
|
|||||||
# If we have versions, turn them into StoreSubmission models
|
# If we have versions, turn them into StoreSubmission models
|
||||||
for version in listing.Versions or []:
|
for version in listing.Versions or []:
|
||||||
version_model = store_model.StoreSubmission(
|
version_model = store_model.StoreSubmission(
|
||||||
listing_id=listing.id,
|
|
||||||
agent_id=version.agentGraphId,
|
agent_id=version.agentGraphId,
|
||||||
agent_version=version.agentGraphVersion,
|
agent_version=version.agentGraphVersion,
|
||||||
name=version.name,
|
name=version.name,
|
||||||
|
|||||||
@@ -1,568 +0,0 @@
|
|||||||
"""
|
|
||||||
Unified Content Embeddings Service
|
|
||||||
|
|
||||||
Handles generation and storage of OpenAI embeddings for all content types
|
|
||||||
(store listings, blocks, documentation, library agents) to enable semantic/hybrid search.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import prisma
|
|
||||||
from prisma.enums import ContentType
|
|
||||||
from tiktoken import encoding_for_model
|
|
||||||
|
|
||||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
|
||||||
from backend.util.clients import get_openai_client
|
|
||||||
from backend.util.json import dumps
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# OpenAI embedding model configuration
|
|
||||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
|
||||||
# OpenAI embedding token limit (8,191 with 1 token buffer for safety)
|
|
||||||
EMBEDDING_MAX_TOKENS = 8191
|
|
||||||
|
|
||||||
|
|
||||||
def build_searchable_text(
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
sub_heading: str,
|
|
||||||
categories: list[str],
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Build searchable text from listing version fields.
|
|
||||||
|
|
||||||
Combines relevant fields into a single string for embedding.
|
|
||||||
"""
|
|
||||||
parts = []
|
|
||||||
|
|
||||||
# Name is important - include it
|
|
||||||
if name:
|
|
||||||
parts.append(name)
|
|
||||||
|
|
||||||
# Sub-heading provides context
|
|
||||||
if sub_heading:
|
|
||||||
parts.append(sub_heading)
|
|
||||||
|
|
||||||
# Description is the main content
|
|
||||||
if description:
|
|
||||||
parts.append(description)
|
|
||||||
|
|
||||||
# Categories help with semantic matching
|
|
||||||
if categories:
|
|
||||||
parts.append(" ".join(categories))
|
|
||||||
|
|
||||||
return " ".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_embedding(text: str) -> list[float] | None:
|
|
||||||
"""
|
|
||||||
Generate embedding for text using OpenAI API.
|
|
||||||
|
|
||||||
Returns None if embedding generation fails.
|
|
||||||
Fail-fast: no retries to maintain consistency with approval flow.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
client = get_openai_client()
|
|
||||||
if not client:
|
|
||||||
logger.error("openai_internal_api_key not set, cannot generate embedding")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Truncate text to token limit using tiktoken
|
|
||||||
# Character-based truncation is insufficient because token ratios vary by content type
|
|
||||||
enc = encoding_for_model(EMBEDDING_MODEL)
|
|
||||||
tokens = enc.encode(text)
|
|
||||||
if len(tokens) > EMBEDDING_MAX_TOKENS:
|
|
||||||
tokens = tokens[:EMBEDDING_MAX_TOKENS]
|
|
||||||
truncated_text = enc.decode(tokens)
|
|
||||||
logger.info(
|
|
||||||
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
truncated_text = text
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
response = await client.embeddings.create(
|
|
||||||
model=EMBEDDING_MODEL,
|
|
||||||
input=truncated_text,
|
|
||||||
)
|
|
||||||
latency_ms = (time.time() - start_time) * 1000
|
|
||||||
|
|
||||||
embedding = response.data[0].embedding
|
|
||||||
logger.info(
|
|
||||||
f"Generated embedding: {len(embedding)} dims, "
|
|
||||||
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
|
|
||||||
)
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to generate embedding: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def store_embedding(
|
|
||||||
version_id: str,
|
|
||||||
embedding: list[float],
|
|
||||||
tx: prisma.Prisma | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Store embedding in the database.
|
|
||||||
|
|
||||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
|
||||||
DEPRECATED: Use ensure_embedding() instead (includes searchable_text).
|
|
||||||
"""
|
|
||||||
return await store_content_embedding(
|
|
||||||
content_type=ContentType.STORE_AGENT,
|
|
||||||
content_id=version_id,
|
|
||||||
embedding=embedding,
|
|
||||||
searchable_text="", # Empty for backward compat; ensure_embedding() populates this
|
|
||||||
metadata=None,
|
|
||||||
user_id=None, # Store agents are public
|
|
||||||
tx=tx,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def store_content_embedding(
|
|
||||||
content_type: ContentType,
|
|
||||||
content_id: str,
|
|
||||||
embedding: list[float],
|
|
||||||
searchable_text: str,
|
|
||||||
metadata: dict | None = None,
|
|
||||||
user_id: str | None = None,
|
|
||||||
tx: prisma.Prisma | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Store embedding in the unified content embeddings table.
|
|
||||||
|
|
||||||
New function for unified content embedding storage.
|
|
||||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
client = tx if tx else prisma.get_client()
|
|
||||||
|
|
||||||
# Convert embedding to PostgreSQL vector format
|
|
||||||
embedding_str = embedding_to_vector_string(embedding)
|
|
||||||
metadata_json = dumps(metadata or {})
|
|
||||||
|
|
||||||
# Upsert the embedding
|
|
||||||
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
|
||||||
await execute_raw_with_schema(
|
|
||||||
"""
|
|
||||||
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
|
||||||
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
|
||||||
)
|
|
||||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
|
||||||
ON CONFLICT ("contentType", "contentId", "userId")
|
|
||||||
DO UPDATE SET
|
|
||||||
"embedding" = $4::vector,
|
|
||||||
"searchableText" = $5,
|
|
||||||
"metadata" = $6::jsonb,
|
|
||||||
"updatedAt" = NOW()
|
|
||||||
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
|
|
||||||
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
|
|
||||||
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
|
|
||||||
""",
|
|
||||||
content_type,
|
|
||||||
content_id,
|
|
||||||
user_id,
|
|
||||||
embedding_str,
|
|
||||||
searchable_text,
|
|
||||||
metadata_json,
|
|
||||||
client=client,
|
|
||||||
set_public_search_path=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
|
||||||
"""
|
|
||||||
Retrieve embedding record for a listing version.
|
|
||||||
|
|
||||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
|
||||||
Returns dict with storeListingVersionId, embedding, timestamps or None if not found.
|
|
||||||
"""
|
|
||||||
result = await get_content_embedding(
|
|
||||||
ContentType.STORE_AGENT, version_id, user_id=None
|
|
||||||
)
|
|
||||||
if result:
|
|
||||||
# Transform to old format for backward compatibility
|
|
||||||
return {
|
|
||||||
"storeListingVersionId": result["contentId"],
|
|
||||||
"embedding": result["embedding"],
|
|
||||||
"createdAt": result["createdAt"],
|
|
||||||
"updatedAt": result["updatedAt"],
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_content_embedding(
|
|
||||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""
|
|
||||||
Retrieve embedding record for any content type.
|
|
||||||
|
|
||||||
New function for unified content embedding retrieval.
|
|
||||||
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await query_raw_with_schema(
|
|
||||||
"""
|
|
||||||
SELECT
|
|
||||||
"contentType",
|
|
||||||
"contentId",
|
|
||||||
"userId",
|
|
||||||
"embedding"::text as "embedding",
|
|
||||||
"searchableText",
|
|
||||||
"metadata",
|
|
||||||
"createdAt",
|
|
||||||
"updatedAt"
|
|
||||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
|
||||||
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
|
||||||
""",
|
|
||||||
content_type,
|
|
||||||
content_id,
|
|
||||||
user_id,
|
|
||||||
set_public_search_path=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result and len(result) > 0:
|
|
||||||
return result[0]
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def ensure_embedding(
|
|
||||||
version_id: str,
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
sub_heading: str,
|
|
||||||
categories: list[str],
|
|
||||||
force: bool = False,
|
|
||||||
tx: prisma.Prisma | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Ensure an embedding exists for the listing version.
|
|
||||||
|
|
||||||
Creates embedding if missing. Use force=True to regenerate.
|
|
||||||
Backward-compatible wrapper for store listings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
version_id: The StoreListingVersion ID
|
|
||||||
name: Agent name
|
|
||||||
description: Agent description
|
|
||||||
sub_heading: Agent sub-heading
|
|
||||||
categories: Agent categories
|
|
||||||
force: Force regeneration even if embedding exists
|
|
||||||
tx: Optional transaction client
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if embedding exists/was created, False on failure
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Check if embedding already exists
|
|
||||||
if not force:
|
|
||||||
existing = await get_embedding(version_id)
|
|
||||||
if existing and existing.get("embedding"):
|
|
||||||
logger.debug(f"Embedding for version {version_id} already exists")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Build searchable text for embedding
|
|
||||||
searchable_text = build_searchable_text(
|
|
||||||
name, description, sub_heading, categories
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate new embedding
|
|
||||||
embedding = await generate_embedding(searchable_text)
|
|
||||||
if embedding is None:
|
|
||||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Store the embedding with metadata using new function
|
|
||||||
metadata = {
|
|
||||||
"name": name,
|
|
||||||
"subHeading": sub_heading,
|
|
||||||
"categories": categories,
|
|
||||||
}
|
|
||||||
return await store_content_embedding(
|
|
||||||
content_type=ContentType.STORE_AGENT,
|
|
||||||
content_id=version_id,
|
|
||||||
embedding=embedding,
|
|
||||||
searchable_text=searchable_text,
|
|
||||||
metadata=metadata,
|
|
||||||
user_id=None, # Store agents are public
|
|
||||||
tx=tx,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_embedding(version_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Delete embedding for a listing version.
|
|
||||||
|
|
||||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
|
||||||
Note: This is usually handled automatically by CASCADE delete,
|
|
||||||
but provided for manual cleanup if needed.
|
|
||||||
"""
|
|
||||||
return await delete_content_embedding(ContentType.STORE_AGENT, version_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_content_embedding(
|
|
||||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Delete embedding for any content type.
|
|
||||||
|
|
||||||
New function for unified content embedding deletion.
|
|
||||||
Note: This is usually handled automatically by CASCADE delete,
|
|
||||||
but provided for manual cleanup if needed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content_type: The type of content (STORE_AGENT, LIBRARY_AGENT, etc.)
|
|
||||||
content_id: The unique identifier for the content
|
|
||||||
user_id: Optional user ID. For public content (STORE_AGENT, BLOCK), pass None.
|
|
||||||
For user-scoped content (LIBRARY_AGENT), pass the user's ID to avoid
|
|
||||||
deleting embeddings belonging to other users.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deletion succeeded, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
client = prisma.get_client()
|
|
||||||
|
|
||||||
await execute_raw_with_schema(
|
|
||||||
"""
|
|
||||||
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
|
|
||||||
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
|
||||||
AND "contentId" = $2
|
|
||||||
AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
|
||||||
""",
|
|
||||||
content_type,
|
|
||||||
content_id,
|
|
||||||
user_id,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
user_str = f" (user: {user_id})" if user_id else ""
|
|
||||||
logger.info(f"Deleted embedding for {content_type}:{content_id}{user_str}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete embedding for {content_type}:{content_id}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def get_embedding_stats() -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get statistics about embedding coverage.
|
|
||||||
|
|
||||||
Returns counts of:
|
|
||||||
- Total approved listing versions
|
|
||||||
- Versions with embeddings
|
|
||||||
- Versions without embeddings
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Count approved versions
|
|
||||||
approved_result = await query_raw_with_schema(
|
|
||||||
"""
|
|
||||||
SELECT COUNT(*) as count
|
|
||||||
FROM {schema_prefix}"StoreListingVersion"
|
|
||||||
WHERE "submissionStatus" = 'APPROVED'
|
|
||||||
AND "isDeleted" = false
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
|
||||||
|
|
||||||
# Count versions with embeddings
|
|
||||||
embedded_result = await query_raw_with_schema(
|
|
||||||
"""
|
|
||||||
SELECT COUNT(*) as count
|
|
||||||
FROM {schema_prefix}"StoreListingVersion" slv
|
|
||||||
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
|
||||||
WHERE slv."submissionStatus" = 'APPROVED'
|
|
||||||
AND slv."isDeleted" = false
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_approved": total_approved,
|
|
||||||
"with_embeddings": with_embeddings,
|
|
||||||
"without_embeddings": total_approved - with_embeddings,
|
|
||||||
"coverage_percent": (
|
|
||||||
round(with_embeddings / total_approved * 100, 1)
|
|
||||||
if total_approved > 0
|
|
||||||
else 0
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get embedding stats: {e}")
|
|
||||||
return {
|
|
||||||
"total_approved": 0,
|
|
||||||
"with_embeddings": 0,
|
|
||||||
"without_embeddings": 0,
|
|
||||||
"coverage_percent": 0,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for approved listings that don't have them.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch_size: Number of embeddings to generate in one call
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with success/failure counts
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Find approved versions without embeddings
|
|
||||||
missing = await query_raw_with_schema(
|
|
||||||
"""
|
|
||||||
SELECT
|
|
||||||
slv.id,
|
|
||||||
slv.name,
|
|
||||||
slv.description,
|
|
||||||
slv."subHeading",
|
|
||||||
slv.categories
|
|
||||||
FROM {schema_prefix}"StoreListingVersion" slv
|
|
||||||
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
|
|
||||||
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
|
||||||
WHERE slv."submissionStatus" = 'APPROVED'
|
|
||||||
AND slv."isDeleted" = false
|
|
||||||
AND uce."contentId" IS NULL
|
|
||||||
LIMIT $1
|
|
||||||
""",
|
|
||||||
batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not missing:
|
|
||||||
return {
|
|
||||||
"processed": 0,
|
|
||||||
"success": 0,
|
|
||||||
"failed": 0,
|
|
||||||
"message": "No missing embeddings",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Process embeddings concurrently for better performance
|
|
||||||
embedding_tasks = [
|
|
||||||
ensure_embedding(
|
|
||||||
version_id=row["id"],
|
|
||||||
name=row["name"],
|
|
||||||
description=row["description"],
|
|
||||||
sub_heading=row["subHeading"],
|
|
||||||
categories=row["categories"] or [],
|
|
||||||
)
|
|
||||||
for row in missing
|
|
||||||
]
|
|
||||||
|
|
||||||
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
success = sum(1 for result in results if result is True)
|
|
||||||
failed = len(results) - success
|
|
||||||
|
|
||||||
return {
|
|
||||||
"processed": len(missing),
|
|
||||||
"success": success,
|
|
||||||
"failed": failed,
|
|
||||||
"message": f"Backfilled {success} embeddings, {failed} failed",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to backfill embeddings: {e}")
|
|
||||||
return {
|
|
||||||
"processed": 0,
|
|
||||||
"success": 0,
|
|
||||||
"failed": 0,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def embed_query(query: str) -> list[float] | None:
|
|
||||||
"""
|
|
||||||
Generate embedding for a search query.
|
|
||||||
|
|
||||||
Same as generate_embedding but with clearer intent.
|
|
||||||
"""
|
|
||||||
return await generate_embedding(query)
|
|
||||||
|
|
||||||
|
|
||||||
def embedding_to_vector_string(embedding: list[float]) -> str:
|
|
||||||
"""Convert embedding list to PostgreSQL vector string format."""
|
|
||||||
return "[" + ",".join(str(x) for x in embedding) + "]"
|
|
||||||
|
|
||||||
|
|
||||||
async def ensure_content_embedding(
|
|
||||||
content_type: ContentType,
|
|
||||||
content_id: str,
|
|
||||||
searchable_text: str,
|
|
||||||
metadata: dict | None = None,
|
|
||||||
user_id: str | None = None,
|
|
||||||
force: bool = False,
|
|
||||||
tx: prisma.Prisma | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Ensure an embedding exists for any content type.
|
|
||||||
|
|
||||||
Generic function for creating embeddings for store agents, blocks, docs, etc.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content_type: ContentType enum value (STORE_AGENT, BLOCK, etc.)
|
|
||||||
content_id: Unique identifier for the content
|
|
||||||
searchable_text: Combined text for embedding generation
|
|
||||||
metadata: Optional metadata to store with embedding
|
|
||||||
force: Force regeneration even if embedding exists
|
|
||||||
tx: Optional transaction client
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if embedding exists/was created, False on failure
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Check if embedding already exists
|
|
||||||
if not force:
|
|
||||||
existing = await get_content_embedding(content_type, content_id, user_id)
|
|
||||||
if existing and existing.get("embedding"):
|
|
||||||
logger.debug(
|
|
||||||
f"Embedding for {content_type}:{content_id} already exists"
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Generate new embedding
|
|
||||||
embedding = await generate_embedding(searchable_text)
|
|
||||||
if embedding is None:
|
|
||||||
logger.warning(
|
|
||||||
f"Could not generate embedding for {content_type}:{content_id}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Store the embedding
|
|
||||||
return await store_content_embedding(
|
|
||||||
content_type=content_type,
|
|
||||||
content_id=content_id,
|
|
||||||
embedding=embedding,
|
|
||||||
searchable_text=searchable_text,
|
|
||||||
metadata=metadata or {},
|
|
||||||
user_id=user_id,
|
|
||||||
tx=tx,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
|
|
||||||
return False
|
|
||||||
@@ -1,329 +0,0 @@
|
|||||||
"""
|
|
||||||
Integration tests for embeddings with schema handling.
|
|
||||||
|
|
||||||
These tests verify that embeddings operations work correctly across different database schemas.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from prisma.enums import ContentType
|
|
||||||
|
|
||||||
from backend.api.features.store import embeddings
|
|
||||||
|
|
||||||
# Schema prefix tests removed - functionality moved to db.raw_with_schema() helper
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_store_content_embedding_with_schema():
|
|
||||||
"""Test storing embeddings with proper schema handling."""
|
|
||||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
|
||||||
mock_schema.return_value = "platform"
|
|
||||||
|
|
||||||
with patch("prisma.get_client") as mock_get_client:
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
|
|
||||||
result = await embeddings.store_content_embedding(
|
|
||||||
content_type=ContentType.STORE_AGENT,
|
|
||||||
content_id="test-id",
|
|
||||||
embedding=[0.1] * 1536,
|
|
||||||
searchable_text="test text",
|
|
||||||
metadata={"test": "data"},
|
|
||||||
user_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the query was called
|
|
||||||
assert mock_client.execute_raw.called
|
|
||||||
|
|
||||||
# Get the SQL query that was executed
|
|
||||||
call_args = mock_client.execute_raw.call_args
|
|
||||||
sql_query = call_args[0][0]
|
|
||||||
|
|
||||||
# Verify schema prefix is in the query
|
|
||||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
|
||||||
|
|
||||||
# Verify result
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_get_content_embedding_with_schema():
|
|
||||||
"""Test retrieving embeddings with proper schema handling."""
|
|
||||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
|
||||||
mock_schema.return_value = "platform"
|
|
||||||
|
|
||||||
with patch("prisma.get_client") as mock_get_client:
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
mock_client.query_raw.return_value = [
|
|
||||||
{
|
|
||||||
"contentType": "STORE_AGENT",
|
|
||||||
"contentId": "test-id",
|
|
||||||
"userId": None,
|
|
||||||
"embedding": "[0.1, 0.2]",
|
|
||||||
"searchableText": "test",
|
|
||||||
"metadata": {},
|
|
||||||
"createdAt": "2024-01-01",
|
|
||||||
"updatedAt": "2024-01-01",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
|
|
||||||
result = await embeddings.get_content_embedding(
|
|
||||||
ContentType.STORE_AGENT,
|
|
||||||
"test-id",
|
|
||||||
user_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the query was called
|
|
||||||
assert mock_client.query_raw.called
|
|
||||||
|
|
||||||
# Get the SQL query that was executed
|
|
||||||
call_args = mock_client.query_raw.call_args
|
|
||||||
sql_query = call_args[0][0]
|
|
||||||
|
|
||||||
# Verify schema prefix is in the query
|
|
||||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
|
||||||
|
|
||||||
# Verify result
|
|
||||||
assert result is not None
|
|
||||||
assert result["contentId"] == "test-id"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_delete_content_embedding_with_schema():
|
|
||||||
"""Test deleting embeddings with proper schema handling."""
|
|
||||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
|
||||||
mock_schema.return_value = "platform"
|
|
||||||
|
|
||||||
with patch("prisma.get_client") as mock_get_client:
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
|
|
||||||
result = await embeddings.delete_content_embedding(
|
|
||||||
ContentType.STORE_AGENT,
|
|
||||||
"test-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the query was called
|
|
||||||
assert mock_client.execute_raw.called
|
|
||||||
|
|
||||||
# Get the SQL query that was executed
|
|
||||||
call_args = mock_client.execute_raw.call_args
|
|
||||||
sql_query = call_args[0][0]
|
|
||||||
|
|
||||||
# Verify schema prefix is in the query
|
|
||||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
|
||||||
|
|
||||||
# Verify result
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_get_embedding_stats_with_schema():
|
|
||||||
"""Test embedding statistics with proper schema handling."""
|
|
||||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
|
||||||
mock_schema.return_value = "platform"
|
|
||||||
|
|
||||||
with patch("prisma.get_client") as mock_get_client:
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
# Mock both query results
|
|
||||||
mock_client.query_raw.side_effect = [
|
|
||||||
[{"count": 100}], # total_approved
|
|
||||||
[{"count": 80}], # with_embeddings
|
|
||||||
]
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
|
|
||||||
result = await embeddings.get_embedding_stats()
|
|
||||||
|
|
||||||
# Verify both queries were called
|
|
||||||
assert mock_client.query_raw.call_count == 2
|
|
||||||
|
|
||||||
# Get both SQL queries
|
|
||||||
first_call = mock_client.query_raw.call_args_list[0]
|
|
||||||
second_call = mock_client.query_raw.call_args_list[1]
|
|
||||||
|
|
||||||
first_sql = first_call[0][0]
|
|
||||||
second_sql = second_call[0][0]
|
|
||||||
|
|
||||||
# Verify schema prefix in both queries
|
|
||||||
assert '"platform"."StoreListingVersion"' in first_sql
|
|
||||||
assert '"platform"."StoreListingVersion"' in second_sql
|
|
||||||
assert '"platform"."UnifiedContentEmbedding"' in second_sql
|
|
||||||
|
|
||||||
# Verify results
|
|
||||||
assert result["total_approved"] == 100
|
|
||||||
assert result["with_embeddings"] == 80
|
|
||||||
assert result["without_embeddings"] == 20
|
|
||||||
assert result["coverage_percent"] == 80.0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_backfill_missing_embeddings_with_schema():
|
|
||||||
"""Test backfilling embeddings with proper schema handling."""
|
|
||||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
|
||||||
mock_schema.return_value = "platform"
|
|
||||||
|
|
||||||
with patch("prisma.get_client") as mock_get_client:
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
# Mock missing embeddings query
|
|
||||||
mock_client.query_raw.return_value = [
|
|
||||||
{
|
|
||||||
"id": "version-1",
|
|
||||||
"name": "Test Agent",
|
|
||||||
"description": "Test description",
|
|
||||||
"subHeading": "Test heading",
|
|
||||||
"categories": ["test"],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.ensure_embedding"
|
|
||||||
) as mock_ensure:
|
|
||||||
mock_ensure.return_value = True
|
|
||||||
|
|
||||||
result = await embeddings.backfill_missing_embeddings(batch_size=10)
|
|
||||||
|
|
||||||
# Verify the query was called
|
|
||||||
assert mock_client.query_raw.called
|
|
||||||
|
|
||||||
# Get the SQL query
|
|
||||||
call_args = mock_client.query_raw.call_args
|
|
||||||
sql_query = call_args[0][0]
|
|
||||||
|
|
||||||
# Verify schema prefix in query
|
|
||||||
assert '"platform"."StoreListingVersion"' in sql_query
|
|
||||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
|
||||||
|
|
||||||
# Verify ensure_embedding was called
|
|
||||||
assert mock_ensure.called
|
|
||||||
|
|
||||||
# Verify results
|
|
||||||
assert result["processed"] == 1
|
|
||||||
assert result["success"] == 1
|
|
||||||
assert result["failed"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_ensure_content_embedding_with_schema():
|
|
||||||
"""Test ensuring embeddings exist with proper schema handling."""
|
|
||||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
|
||||||
mock_schema.return_value = "platform"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.get_content_embedding"
|
|
||||||
) as mock_get:
|
|
||||||
# Simulate no existing embedding
|
|
||||||
mock_get.return_value = None
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.generate_embedding"
|
|
||||||
) as mock_generate:
|
|
||||||
mock_generate.return_value = [0.1] * 1536
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.store_content_embedding"
|
|
||||||
) as mock_store:
|
|
||||||
mock_store.return_value = True
|
|
||||||
|
|
||||||
result = await embeddings.ensure_content_embedding(
|
|
||||||
content_type=ContentType.STORE_AGENT,
|
|
||||||
content_id="test-id",
|
|
||||||
searchable_text="test text",
|
|
||||||
metadata={"test": "data"},
|
|
||||||
user_id=None,
|
|
||||||
force=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the flow
|
|
||||||
assert mock_get.called
|
|
||||||
assert mock_generate.called
|
|
||||||
assert mock_store.called
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_backward_compatibility_store_embedding():
|
|
||||||
"""Test backward compatibility wrapper for store_embedding."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.store_content_embedding"
|
|
||||||
) as mock_store:
|
|
||||||
mock_store.return_value = True
|
|
||||||
|
|
||||||
result = await embeddings.store_embedding(
|
|
||||||
version_id="test-version-id",
|
|
||||||
embedding=[0.1] * 1536,
|
|
||||||
tx=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify it calls the new function with correct parameters
|
|
||||||
assert mock_store.called
|
|
||||||
call_args = mock_store.call_args
|
|
||||||
|
|
||||||
assert call_args[1]["content_type"] == ContentType.STORE_AGENT
|
|
||||||
assert call_args[1]["content_id"] == "test-version-id"
|
|
||||||
assert call_args[1]["user_id"] is None
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_backward_compatibility_get_embedding():
|
|
||||||
"""Test backward compatibility wrapper for get_embedding."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.get_content_embedding"
|
|
||||||
) as mock_get:
|
|
||||||
mock_get.return_value = {
|
|
||||||
"contentType": "STORE_AGENT",
|
|
||||||
"contentId": "test-version-id",
|
|
||||||
"embedding": "[0.1, 0.2]",
|
|
||||||
"createdAt": "2024-01-01",
|
|
||||||
"updatedAt": "2024-01-01",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await embeddings.get_embedding("test-version-id")
|
|
||||||
|
|
||||||
# Verify it calls the new function
|
|
||||||
assert mock_get.called
|
|
||||||
|
|
||||||
# Verify it transforms to old format
|
|
||||||
assert result is not None
|
|
||||||
assert result["storeListingVersionId"] == "test-version-id"
|
|
||||||
assert "embedding" in result
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_schema_handling_error_cases():
|
|
||||||
"""Test error handling in schema-aware operations."""
|
|
||||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
|
||||||
mock_schema.return_value = "platform"
|
|
||||||
|
|
||||||
with patch("prisma.get_client") as mock_get_client:
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
|
|
||||||
result = await embeddings.store_content_embedding(
|
|
||||||
content_type=ContentType.STORE_AGENT,
|
|
||||||
content_id="test-id",
|
|
||||||
embedding=[0.1] * 1536,
|
|
||||||
searchable_text="test",
|
|
||||||
metadata=None,
|
|
||||||
user_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should return False on error, not raise
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v", "-s"])
|
|
||||||
@@ -1,387 +0,0 @@
|
|||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import prisma
|
|
||||||
import pytest
|
|
||||||
from prisma import Prisma
|
|
||||||
from prisma.enums import ContentType
|
|
||||||
|
|
||||||
from backend.api.features.store import embeddings
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
async def setup_prisma():
|
|
||||||
"""Setup Prisma client for tests."""
|
|
||||||
try:
|
|
||||||
Prisma()
|
|
||||||
except prisma.errors.ClientAlreadyRegisteredError:
|
|
||||||
pass
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_build_searchable_text():
|
|
||||||
"""Test searchable text building from listing fields."""
|
|
||||||
result = embeddings.build_searchable_text(
|
|
||||||
name="AI Assistant",
|
|
||||||
description="A helpful AI assistant for productivity",
|
|
||||||
sub_heading="Boost your productivity",
|
|
||||||
categories=["AI", "Productivity"],
|
|
||||||
)
|
|
||||||
|
|
||||||
expected = "AI Assistant Boost your productivity A helpful AI assistant for productivity AI Productivity"
|
|
||||||
assert result == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_build_searchable_text_empty_fields():
|
|
||||||
"""Test searchable text building with empty fields."""
|
|
||||||
result = embeddings.build_searchable_text(
|
|
||||||
name="", description="Test description", sub_heading="", categories=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == "Test description"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_generate_embedding_success():
|
|
||||||
"""Test successful embedding generation."""
|
|
||||||
# Mock OpenAI response
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.data = [MagicMock()]
|
|
||||||
mock_response.data[0].embedding = [0.1, 0.2, 0.3] * 512 # 1536 dimensions
|
|
||||||
|
|
||||||
# Use AsyncMock for async embeddings.create method
|
|
||||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
# Patch at the point of use in embeddings.py
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.get_openai_client"
|
|
||||||
) as mock_get_client:
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
|
|
||||||
result = await embeddings.generate_embedding("test text")
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert len(result) == 1536
|
|
||||||
assert result[0] == 0.1
|
|
||||||
|
|
||||||
mock_client.embeddings.create.assert_called_once_with(
|
|
||||||
model="text-embedding-3-small", input="test text"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_generate_embedding_no_api_key():
|
|
||||||
"""Test embedding generation without API key."""
|
|
||||||
# Patch at the point of use in embeddings.py
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.get_openai_client"
|
|
||||||
) as mock_get_client:
|
|
||||||
mock_get_client.return_value = None
|
|
||||||
|
|
||||||
result = await embeddings.generate_embedding("test text")
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_generate_embedding_api_error():
|
|
||||||
"""Test embedding generation with API error."""
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_client.embeddings.create = AsyncMock(side_effect=Exception("API Error"))
|
|
||||||
|
|
||||||
# Patch at the point of use in embeddings.py
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.get_openai_client"
|
|
||||||
) as mock_get_client:
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
|
|
||||||
result = await embeddings.generate_embedding("test text")
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_generate_embedding_text_truncation():
|
|
||||||
"""Test that long text is properly truncated using tiktoken."""
|
|
||||||
from tiktoken import encoding_for_model
|
|
||||||
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.data = [MagicMock()]
|
|
||||||
mock_response.data[0].embedding = [0.1] * 1536
|
|
||||||
|
|
||||||
# Use AsyncMock for async embeddings.create method
|
|
||||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
# Patch at the point of use in embeddings.py
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.get_openai_client"
|
|
||||||
) as mock_get_client:
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
|
|
||||||
# Create text that will exceed 8191 tokens
|
|
||||||
# Use varied characters to ensure token-heavy text: each word is ~1 token
|
|
||||||
words = [f"word{i}" for i in range(10000)]
|
|
||||||
long_text = " ".join(words) # ~10000 tokens
|
|
||||||
|
|
||||||
await embeddings.generate_embedding(long_text)
|
|
||||||
|
|
||||||
# Verify text was truncated to 8191 tokens
|
|
||||||
call_args = mock_client.embeddings.create.call_args
|
|
||||||
truncated_text = call_args.kwargs["input"]
|
|
||||||
|
|
||||||
# Count actual tokens in truncated text
|
|
||||||
enc = encoding_for_model("text-embedding-3-small")
|
|
||||||
actual_tokens = len(enc.encode(truncated_text))
|
|
||||||
|
|
||||||
# Should be at or just under 8191 tokens
|
|
||||||
assert actual_tokens <= 8191
|
|
||||||
# Should be close to the limit (not over-truncated)
|
|
||||||
assert actual_tokens >= 8100
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_store_embedding_success(mocker):
|
|
||||||
"""Test successful embedding storage."""
|
|
||||||
mock_client = mocker.AsyncMock()
|
|
||||||
mock_client.execute_raw = mocker.AsyncMock()
|
|
||||||
|
|
||||||
embedding = [0.1, 0.2, 0.3]
|
|
||||||
|
|
||||||
result = await embeddings.store_embedding(
|
|
||||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
# execute_raw is called twice: once for SET search_path, once for INSERT
|
|
||||||
assert mock_client.execute_raw.call_count == 2
|
|
||||||
|
|
||||||
# First call: SET search_path
|
|
||||||
first_call_args = mock_client.execute_raw.call_args_list[0][0]
|
|
||||||
assert "SET search_path" in first_call_args[0]
|
|
||||||
|
|
||||||
# Second call: INSERT query with the actual data
|
|
||||||
second_call_args = mock_client.execute_raw.call_args_list[1][0]
|
|
||||||
assert "test-version-id" in second_call_args
|
|
||||||
assert "[0.1,0.2,0.3]" in second_call_args
|
|
||||||
assert None in second_call_args # userId should be None for store agents
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_store_embedding_database_error(mocker):
|
|
||||||
"""Test embedding storage with database error."""
|
|
||||||
mock_client = mocker.AsyncMock()
|
|
||||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
|
||||||
|
|
||||||
embedding = [0.1, 0.2, 0.3]
|
|
||||||
|
|
||||||
result = await embeddings.store_embedding(
|
|
||||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_get_embedding_success():
|
|
||||||
"""Test successful embedding retrieval."""
|
|
||||||
mock_result = [
|
|
||||||
{
|
|
||||||
"contentType": "STORE_AGENT",
|
|
||||||
"contentId": "test-version-id",
|
|
||||||
"userId": None,
|
|
||||||
"embedding": "[0.1,0.2,0.3]",
|
|
||||||
"searchableText": "Test text",
|
|
||||||
"metadata": {},
|
|
||||||
"createdAt": "2024-01-01T00:00:00Z",
|
|
||||||
"updatedAt": "2024-01-01T00:00:00Z",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
|
||||||
return_value=mock_result,
|
|
||||||
):
|
|
||||||
result = await embeddings.get_embedding("test-version-id")
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result["storeListingVersionId"] == "test-version-id"
|
|
||||||
assert result["embedding"] == "[0.1,0.2,0.3]"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_get_embedding_not_found():
|
|
||||||
"""Test embedding retrieval when not found."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
|
||||||
return_value=[],
|
|
||||||
):
|
|
||||||
result = await embeddings.get_embedding("test-version-id")
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
|
||||||
@patch("backend.api.features.store.embeddings.store_embedding")
|
|
||||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
|
||||||
async def test_ensure_embedding_already_exists(mock_get, mock_store, mock_generate):
|
|
||||||
"""Test ensure_embedding when embedding already exists."""
|
|
||||||
mock_get.return_value = {"embedding": "[0.1,0.2,0.3]"}
|
|
||||||
|
|
||||||
result = await embeddings.ensure_embedding(
|
|
||||||
version_id="test-id",
|
|
||||||
name="Test",
|
|
||||||
description="Test description",
|
|
||||||
sub_heading="Test heading",
|
|
||||||
categories=["test"],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
mock_generate.assert_not_called()
|
|
||||||
mock_store.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
|
||||||
@patch("backend.api.features.store.embeddings.store_content_embedding")
|
|
||||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
|
||||||
async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
|
|
||||||
"""Test ensure_embedding creating new embedding."""
|
|
||||||
mock_get.return_value = None
|
|
||||||
mock_generate.return_value = [0.1, 0.2, 0.3]
|
|
||||||
mock_store.return_value = True
|
|
||||||
|
|
||||||
result = await embeddings.ensure_embedding(
|
|
||||||
version_id="test-id",
|
|
||||||
name="Test",
|
|
||||||
description="Test description",
|
|
||||||
sub_heading="Test heading",
|
|
||||||
categories=["test"],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
mock_generate.assert_called_once_with("Test Test heading Test description test")
|
|
||||||
mock_store.assert_called_once_with(
|
|
||||||
content_type=ContentType.STORE_AGENT,
|
|
||||||
content_id="test-id",
|
|
||||||
embedding=[0.1, 0.2, 0.3],
|
|
||||||
searchable_text="Test Test heading Test description test",
|
|
||||||
metadata={"name": "Test", "subHeading": "Test heading", "categories": ["test"]},
|
|
||||||
user_id=None,
|
|
||||||
tx=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
|
||||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
|
||||||
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
|
||||||
"""Test ensure_embedding when generation fails."""
|
|
||||||
mock_get.return_value = None
|
|
||||||
mock_generate.return_value = None
|
|
||||||
|
|
||||||
result = await embeddings.ensure_embedding(
|
|
||||||
version_id="test-id",
|
|
||||||
name="Test",
|
|
||||||
description="Test description",
|
|
||||||
sub_heading="Test heading",
|
|
||||||
categories=["test"],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_get_embedding_stats():
|
|
||||||
"""Test embedding statistics retrieval."""
|
|
||||||
# Mock approved count query and embedded count query
|
|
||||||
mock_approved_result = [{"count": 100}]
|
|
||||||
mock_embedded_result = [{"count": 75}]
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
|
||||||
side_effect=[mock_approved_result, mock_embedded_result],
|
|
||||||
):
|
|
||||||
result = await embeddings.get_embedding_stats()
|
|
||||||
|
|
||||||
assert result["total_approved"] == 100
|
|
||||||
assert result["with_embeddings"] == 75
|
|
||||||
assert result["without_embeddings"] == 25
|
|
||||||
assert result["coverage_percent"] == 75.0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@patch("backend.api.features.store.embeddings.ensure_embedding")
|
|
||||||
async def test_backfill_missing_embeddings_success(mock_ensure):
|
|
||||||
"""Test backfill with successful embedding generation."""
|
|
||||||
# Mock missing embeddings query
|
|
||||||
mock_missing = [
|
|
||||||
{
|
|
||||||
"id": "version-1",
|
|
||||||
"name": "Agent 1",
|
|
||||||
"description": "Description 1",
|
|
||||||
"subHeading": "Heading 1",
|
|
||||||
"categories": ["AI"],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "version-2",
|
|
||||||
"name": "Agent 2",
|
|
||||||
"description": "Description 2",
|
|
||||||
"subHeading": "Heading 2",
|
|
||||||
"categories": ["Productivity"],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Mock ensure_embedding to succeed for first, fail for second
|
|
||||||
mock_ensure.side_effect = [True, False]
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
|
||||||
return_value=mock_missing,
|
|
||||||
):
|
|
||||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
|
||||||
|
|
||||||
assert result["processed"] == 2
|
|
||||||
assert result["success"] == 1
|
|
||||||
assert result["failed"] == 1
|
|
||||||
assert mock_ensure.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_backfill_missing_embeddings_no_missing():
|
|
||||||
"""Test backfill when no embeddings are missing."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
|
||||||
return_value=[],
|
|
||||||
):
|
|
||||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
|
||||||
|
|
||||||
assert result["processed"] == 0
|
|
||||||
assert result["success"] == 0
|
|
||||||
assert result["failed"] == 0
|
|
||||||
assert result["message"] == "No missing embeddings"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_embedding_to_vector_string():
|
|
||||||
"""Test embedding to PostgreSQL vector string conversion."""
|
|
||||||
embedding = [0.1, 0.2, 0.3, -0.4]
|
|
||||||
result = embeddings.embedding_to_vector_string(embedding)
|
|
||||||
assert result == "[0.1,0.2,0.3,-0.4]"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_embed_query():
|
|
||||||
"""Test embed_query function (alias for generate_embedding)."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.embeddings.generate_embedding"
|
|
||||||
) as mock_generate:
|
|
||||||
mock_generate.return_value = [0.1, 0.2, 0.3]
|
|
||||||
|
|
||||||
result = await embeddings.embed_query("test query")
|
|
||||||
|
|
||||||
assert result == [0.1, 0.2, 0.3]
|
|
||||||
mock_generate.assert_called_once_with("test query")
|
|
||||||
@@ -1,393 +0,0 @@
|
|||||||
"""
|
|
||||||
Hybrid Search for Store Agents
|
|
||||||
|
|
||||||
Combines semantic (embedding) search with lexical (tsvector) search
|
|
||||||
for improved relevance in marketplace agent discovery.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from backend.api.features.store.embeddings import (
|
|
||||||
embed_query,
|
|
||||||
embedding_to_vector_string,
|
|
||||||
)
|
|
||||||
from backend.data.db import query_raw_with_schema
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HybridSearchWeights:
|
|
||||||
"""Weights for combining search signals."""
|
|
||||||
|
|
||||||
semantic: float = 0.30 # Embedding cosine similarity
|
|
||||||
lexical: float = 0.30 # tsvector ts_rank_cd score
|
|
||||||
category: float = 0.20 # Category match boost
|
|
||||||
recency: float = 0.10 # Newer agents ranked higher
|
|
||||||
popularity: float = 0.10 # Agent usage/runs (PageRank-like)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
"""Validate weights are non-negative and sum to approximately 1.0."""
|
|
||||||
total = (
|
|
||||||
self.semantic
|
|
||||||
+ self.lexical
|
|
||||||
+ self.category
|
|
||||||
+ self.recency
|
|
||||||
+ self.popularity
|
|
||||||
)
|
|
||||||
|
|
||||||
if any(
|
|
||||||
w < 0
|
|
||||||
for w in [
|
|
||||||
self.semantic,
|
|
||||||
self.lexical,
|
|
||||||
self.category,
|
|
||||||
self.recency,
|
|
||||||
self.popularity,
|
|
||||||
]
|
|
||||||
):
|
|
||||||
raise ValueError("All weights must be non-negative")
|
|
||||||
|
|
||||||
if not (0.99 <= total <= 1.01):
|
|
||||||
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_WEIGHTS = HybridSearchWeights()
|
|
||||||
|
|
||||||
# Minimum relevance score threshold - agents below this are filtered out
|
|
||||||
# With weights (0.30 semantic + 0.30 lexical + 0.20 category + 0.10 recency + 0.10 popularity):
|
|
||||||
# - 0.20 means at least ~60% semantic match OR strong lexical match required
|
|
||||||
# - Ensures only genuinely relevant results are returned
|
|
||||||
# - Recency/popularity alone (0.10 each) won't pass the threshold
|
|
||||||
DEFAULT_MIN_SCORE = 0.20
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HybridSearchResult:
|
|
||||||
"""A single search result with score breakdown."""
|
|
||||||
|
|
||||||
slug: str
|
|
||||||
agent_name: str
|
|
||||||
agent_image: str
|
|
||||||
creator_username: str
|
|
||||||
creator_avatar: str
|
|
||||||
sub_heading: str
|
|
||||||
description: str
|
|
||||||
runs: int
|
|
||||||
rating: float
|
|
||||||
categories: list[str]
|
|
||||||
featured: bool
|
|
||||||
is_available: bool
|
|
||||||
updated_at: datetime
|
|
||||||
|
|
||||||
# Score breakdown (for debugging/tuning)
|
|
||||||
combined_score: float
|
|
||||||
semantic_score: float = 0.0
|
|
||||||
lexical_score: float = 0.0
|
|
||||||
category_score: float = 0.0
|
|
||||||
recency_score: float = 0.0
|
|
||||||
popularity_score: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_search(
|
|
||||||
query: str,
|
|
||||||
featured: bool = False,
|
|
||||||
creators: list[str] | None = None,
|
|
||||||
category: str | None = None,
|
|
||||||
sorted_by: (
|
|
||||||
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
|
|
||||||
) = None,
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 20,
|
|
||||||
weights: HybridSearchWeights | None = None,
|
|
||||||
min_score: float | None = None,
|
|
||||||
) -> tuple[list[dict[str, Any]], int]:
|
|
||||||
"""
|
|
||||||
Perform hybrid search combining semantic and lexical signals.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Search query string
|
|
||||||
featured: Filter for featured agents only
|
|
||||||
creators: Filter by creator usernames
|
|
||||||
category: Filter by category
|
|
||||||
sorted_by: Sort order (relevance uses hybrid scoring)
|
|
||||||
page: Page number (1-indexed)
|
|
||||||
page_size: Results per page
|
|
||||||
weights: Custom weights for search signals
|
|
||||||
min_score: Minimum relevance score threshold (0-1). Results below
|
|
||||||
this score are filtered out. Defaults to DEFAULT_MIN_SCORE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (results list, total count). Returns empty list if no
|
|
||||||
results meet the minimum relevance threshold.
|
|
||||||
"""
|
|
||||||
# Validate inputs
|
|
||||||
query = query.strip()
|
|
||||||
if not query:
|
|
||||||
return [], 0 # Empty query returns no results
|
|
||||||
|
|
||||||
if page < 1:
|
|
||||||
page = 1
|
|
||||||
if page_size < 1:
|
|
||||||
page_size = 1
|
|
||||||
if page_size > 100: # Cap at reasonable limit to prevent performance issues
|
|
||||||
page_size = 100
|
|
||||||
|
|
||||||
if weights is None:
|
|
||||||
weights = DEFAULT_WEIGHTS
|
|
||||||
if min_score is None:
|
|
||||||
min_score = DEFAULT_MIN_SCORE
|
|
||||||
|
|
||||||
offset = (page - 1) * page_size
|
|
||||||
|
|
||||||
# Generate query embedding
|
|
||||||
query_embedding = await embed_query(query)
|
|
||||||
|
|
||||||
# Build WHERE clause conditions
|
|
||||||
where_parts: list[str] = ["sa.is_available = true"]
|
|
||||||
params: list[Any] = []
|
|
||||||
param_index = 1
|
|
||||||
|
|
||||||
# Add search query for lexical matching
|
|
||||||
params.append(query)
|
|
||||||
query_param = f"${param_index}"
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
# Add lowercased query for category matching
|
|
||||||
params.append(query.lower())
|
|
||||||
query_lower_param = f"${param_index}"
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
if featured:
|
|
||||||
where_parts.append("sa.featured = true")
|
|
||||||
|
|
||||||
if creators:
|
|
||||||
where_parts.append(f"sa.creator_username = ANY(${param_index})")
|
|
||||||
params.append(creators)
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
if category:
|
|
||||||
where_parts.append(f"${param_index} = ANY(sa.categories)")
|
|
||||||
params.append(category)
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
# Safe: where_parts only contains hardcoded strings with $N parameter placeholders
|
|
||||||
# No user input is concatenated directly into the SQL string
|
|
||||||
where_clause = " AND ".join(where_parts)
|
|
||||||
|
|
||||||
# Embedding is required for hybrid search - fail fast if unavailable
|
|
||||||
if query_embedding is None or not query_embedding:
|
|
||||||
# Log detailed error server-side
|
|
||||||
logger.error(
|
|
||||||
"Failed to generate query embedding. "
|
|
||||||
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
|
||||||
)
|
|
||||||
# Raise generic error to client
|
|
||||||
raise ValueError("Search service temporarily unavailable")
|
|
||||||
|
|
||||||
# Add embedding parameter
|
|
||||||
embedding_str = embedding_to_vector_string(query_embedding)
|
|
||||||
params.append(embedding_str)
|
|
||||||
embedding_param = f"${param_index}"
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
# Add weight parameters for SQL calculation
|
|
||||||
params.append(weights.semantic)
|
|
||||||
weight_semantic_param = f"${param_index}"
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
params.append(weights.lexical)
|
|
||||||
weight_lexical_param = f"${param_index}"
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
params.append(weights.category)
|
|
||||||
weight_category_param = f"${param_index}"
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
params.append(weights.recency)
|
|
||||||
weight_recency_param = f"${param_index}"
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
params.append(weights.popularity)
|
|
||||||
weight_popularity_param = f"${param_index}"
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
# Add min_score parameter
|
|
||||||
params.append(min_score)
|
|
||||||
min_score_param = f"${param_index}"
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
# Optimized hybrid search query:
|
|
||||||
# 1. Direct join to UnifiedContentEmbedding via contentId=storeListingVersionId (no redundant JOINs)
|
|
||||||
# 2. UNION approach (deduplicates agents matching both branches)
|
|
||||||
# 3. COUNT(*) OVER() to get total count in single query
|
|
||||||
# 4. Optimized category matching with EXISTS + unnest
|
|
||||||
# 5. Pre-calculated max values for lexical and popularity normalization
|
|
||||||
# 6. Simplified recency calculation with linear decay
|
|
||||||
# 7. Logarithmic popularity scaling to prevent viral agents from dominating
|
|
||||||
sql_query = f"""
|
|
||||||
WITH candidates AS (
|
|
||||||
-- Lexical matches (uses GIN index on search column)
|
|
||||||
SELECT sa."storeListingVersionId"
|
|
||||||
FROM {{schema_prefix}}"StoreAgent" sa
|
|
||||||
WHERE {where_clause}
|
|
||||||
AND sa.search @@ plainto_tsquery('english', {query_param})
|
|
||||||
|
|
||||||
UNION
|
|
||||||
|
|
||||||
-- Semantic matches (uses HNSW index on embedding with KNN)
|
|
||||||
SELECT "storeListingVersionId"
|
|
||||||
FROM (
|
|
||||||
SELECT sa."storeListingVersionId", uce.embedding
|
|
||||||
FROM {{schema_prefix}}"StoreAgent" sa
|
|
||||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
|
||||||
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
|
||||||
WHERE {where_clause}
|
|
||||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
|
||||||
LIMIT 200
|
|
||||||
) semantic_results
|
|
||||||
),
|
|
||||||
search_scores AS (
|
|
||||||
SELECT
|
|
||||||
sa.slug,
|
|
||||||
sa.agent_name,
|
|
||||||
sa.agent_image,
|
|
||||||
sa.creator_username,
|
|
||||||
sa.creator_avatar,
|
|
||||||
sa.sub_heading,
|
|
||||||
sa.description,
|
|
||||||
sa.runs,
|
|
||||||
sa.rating,
|
|
||||||
sa.categories,
|
|
||||||
sa.featured,
|
|
||||||
sa.is_available,
|
|
||||||
sa.updated_at,
|
|
||||||
-- Semantic score: cosine similarity (1 - distance)
|
|
||||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
|
||||||
-- Lexical score: ts_rank_cd (will be normalized later)
|
|
||||||
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
|
||||||
-- Category match: optimized with unnest for better performance
|
|
||||||
CASE
|
|
||||||
WHEN EXISTS (
|
|
||||||
SELECT 1 FROM unnest(sa.categories) cat
|
|
||||||
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
|
|
||||||
)
|
|
||||||
THEN 1.0
|
|
||||||
ELSE 0.0
|
|
||||||
END as category_score,
|
|
||||||
-- Recency score: linear decay over 90 days (simpler than exponential)
|
|
||||||
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score,
|
|
||||||
-- Popularity raw: agent runs count (will be normalized with log scaling)
|
|
||||||
sa.runs as popularity_raw
|
|
||||||
FROM candidates c
|
|
||||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
|
||||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
|
||||||
LEFT JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
|
||||||
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
|
||||||
),
|
|
||||||
max_lexical AS (
|
|
||||||
SELECT MAX(lexical_raw) as max_val FROM search_scores
|
|
||||||
),
|
|
||||||
max_popularity AS (
|
|
||||||
SELECT MAX(popularity_raw) as max_val FROM search_scores
|
|
||||||
),
|
|
||||||
normalized AS (
|
|
||||||
SELECT
|
|
||||||
ss.*,
|
|
||||||
-- Normalize lexical score by pre-calculated max
|
|
||||||
CASE
|
|
||||||
WHEN ml.max_val > 0
|
|
||||||
THEN ss.lexical_raw / ml.max_val
|
|
||||||
ELSE 0
|
|
||||||
END as lexical_score,
|
|
||||||
-- Normalize popularity with logarithmic scaling to prevent viral agents from dominating
|
|
||||||
-- LOG(1 + runs) / LOG(1 + max_runs) ensures score is 0-1 range
|
|
||||||
CASE
|
|
||||||
WHEN mp.max_val > 0 AND ss.popularity_raw > 0
|
|
||||||
THEN LN(1 + ss.popularity_raw) / LN(1 + mp.max_val)
|
|
||||||
ELSE 0
|
|
||||||
END as popularity_score
|
|
||||||
FROM search_scores ss
|
|
||||||
CROSS JOIN max_lexical ml
|
|
||||||
CROSS JOIN max_popularity mp
|
|
||||||
),
|
|
||||||
scored AS (
|
|
||||||
SELECT
|
|
||||||
slug,
|
|
||||||
agent_name,
|
|
||||||
agent_image,
|
|
||||||
creator_username,
|
|
||||||
creator_avatar,
|
|
||||||
sub_heading,
|
|
||||||
description,
|
|
||||||
runs,
|
|
||||||
rating,
|
|
||||||
categories,
|
|
||||||
featured,
|
|
||||||
is_available,
|
|
||||||
updated_at,
|
|
||||||
semantic_score,
|
|
||||||
lexical_score,
|
|
||||||
category_score,
|
|
||||||
recency_score,
|
|
||||||
popularity_score,
|
|
||||||
(
|
|
||||||
{weight_semantic_param} * semantic_score +
|
|
||||||
{weight_lexical_param} * lexical_score +
|
|
||||||
{weight_category_param} * category_score +
|
|
||||||
{weight_recency_param} * recency_score +
|
|
||||||
{weight_popularity_param} * popularity_score
|
|
||||||
) as combined_score
|
|
||||||
FROM normalized
|
|
||||||
),
|
|
||||||
filtered AS (
|
|
||||||
SELECT
|
|
||||||
*,
|
|
||||||
COUNT(*) OVER () as total_count
|
|
||||||
FROM scored
|
|
||||||
WHERE combined_score >= {min_score_param}
|
|
||||||
)
|
|
||||||
SELECT * FROM filtered
|
|
||||||
ORDER BY combined_score DESC
|
|
||||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Add pagination params
|
|
||||||
params.extend([page_size, offset])
|
|
||||||
|
|
||||||
# Execute search query - includes total_count via window function
|
|
||||||
results = await query_raw_with_schema(
|
|
||||||
sql_query, *params, set_public_search_path=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract total count from first result (all rows have same count)
|
|
||||||
total = results[0]["total_count"] if results else 0
|
|
||||||
|
|
||||||
# Remove total_count from results before returning
|
|
||||||
for result in results:
|
|
||||||
result.pop("total_count", None)
|
|
||||||
|
|
||||||
# Log without sensitive query content
|
|
||||||
logger.info(f"Hybrid search: {len(results)} results, {total} total")
|
|
||||||
|
|
||||||
return results, total
|
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_search_simple(
|
|
||||||
query: str,
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 20,
|
|
||||||
) -> tuple[list[dict[str, Any]], int]:
|
|
||||||
"""
|
|
||||||
Simplified hybrid search for common use cases.
|
|
||||||
|
|
||||||
Uses default weights and no filters.
|
|
||||||
"""
|
|
||||||
return await hybrid_search(
|
|
||||||
query=query,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
)
|
|
||||||
@@ -1,334 +0,0 @@
|
|||||||
"""
|
|
||||||
Integration tests for hybrid search with schema handling.
|
|
||||||
|
|
||||||
These tests verify that hybrid search works correctly across different database schemas.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.api.features.store.hybrid_search import HybridSearchWeights, hybrid_search
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_hybrid_search_with_schema_handling():
|
|
||||||
"""Test that hybrid search correctly handles database schema prefixes."""
|
|
||||||
# Test with a mock query to ensure schema handling works
|
|
||||||
query = "test agent"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
|
||||||
) as mock_query:
|
|
||||||
# Mock the query result
|
|
||||||
mock_query.return_value = [
|
|
||||||
{
|
|
||||||
"slug": "test/agent",
|
|
||||||
"agent_name": "Test Agent",
|
|
||||||
"agent_image": "test.png",
|
|
||||||
"creator_username": "test",
|
|
||||||
"creator_avatar": "avatar.png",
|
|
||||||
"sub_heading": "Test sub-heading",
|
|
||||||
"description": "Test description",
|
|
||||||
"runs": 10,
|
|
||||||
"rating": 4.5,
|
|
||||||
"categories": ["test"],
|
|
||||||
"featured": False,
|
|
||||||
"is_available": True,
|
|
||||||
"updated_at": "2024-01-01T00:00:00Z",
|
|
||||||
"combined_score": 0.8,
|
|
||||||
"semantic_score": 0.7,
|
|
||||||
"lexical_score": 0.6,
|
|
||||||
"category_score": 0.5,
|
|
||||||
"recency_score": 0.4,
|
|
||||||
"total_count": 1,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.embed_query"
|
|
||||||
) as mock_embed:
|
|
||||||
mock_embed.return_value = [0.1] * 1536 # Mock embedding
|
|
||||||
|
|
||||||
results, total = await hybrid_search(
|
|
||||||
query=query,
|
|
||||||
page=1,
|
|
||||||
page_size=20,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the query was called
|
|
||||||
assert mock_query.called
|
|
||||||
# Verify the SQL template uses schema_prefix placeholder
|
|
||||||
call_args = mock_query.call_args
|
|
||||||
sql_template = call_args[0][0]
|
|
||||||
assert "{schema_prefix}" in sql_template
|
|
||||||
|
|
||||||
# Verify results
|
|
||||||
assert len(results) == 1
|
|
||||||
assert total == 1
|
|
||||||
assert results[0]["slug"] == "test/agent"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_hybrid_search_with_public_schema():
|
|
||||||
"""Test hybrid search when using public schema (no prefix needed)."""
|
|
||||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
|
||||||
mock_schema.return_value = "public"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
|
||||||
) as mock_query:
|
|
||||||
mock_query.return_value = []
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.embed_query"
|
|
||||||
) as mock_embed:
|
|
||||||
mock_embed.return_value = [0.1] * 1536
|
|
||||||
|
|
||||||
results, total = await hybrid_search(
|
|
||||||
query="test",
|
|
||||||
page=1,
|
|
||||||
page_size=20,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the mock was set up correctly
|
|
||||||
assert mock_schema.return_value == "public"
|
|
||||||
|
|
||||||
# Results should work even with empty results
|
|
||||||
assert results == []
|
|
||||||
assert total == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_hybrid_search_with_custom_schema():
|
|
||||||
"""Test hybrid search when using custom schema (e.g., 'platform')."""
|
|
||||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
|
||||||
mock_schema.return_value = "platform"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
|
||||||
) as mock_query:
|
|
||||||
mock_query.return_value = []
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.embed_query"
|
|
||||||
) as mock_embed:
|
|
||||||
mock_embed.return_value = [0.1] * 1536
|
|
||||||
|
|
||||||
results, total = await hybrid_search(
|
|
||||||
query="test",
|
|
||||||
page=1,
|
|
||||||
page_size=20,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the mock was set up correctly
|
|
||||||
assert mock_schema.return_value == "platform"
|
|
||||||
|
|
||||||
assert results == []
|
|
||||||
assert total == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_hybrid_search_without_embeddings():
|
|
||||||
"""Test hybrid search fails fast when embeddings are unavailable."""
|
|
||||||
# Patch where the function is used, not where it's defined
|
|
||||||
with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed:
|
|
||||||
# Simulate embedding failure
|
|
||||||
mock_embed.return_value = None
|
|
||||||
|
|
||||||
# Should raise ValueError with helpful message
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
await hybrid_search(
|
|
||||||
query="test",
|
|
||||||
page=1,
|
|
||||||
page_size=20,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify error message is generic (doesn't leak implementation details)
|
|
||||||
assert "Search service temporarily unavailable" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_hybrid_search_with_filters():
|
|
||||||
"""Test hybrid search with various filters."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
|
||||||
) as mock_query:
|
|
||||||
mock_query.return_value = []
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.embed_query"
|
|
||||||
) as mock_embed:
|
|
||||||
mock_embed.return_value = [0.1] * 1536
|
|
||||||
|
|
||||||
# Test with featured filter
|
|
||||||
results, total = await hybrid_search(
|
|
||||||
query="test",
|
|
||||||
featured=True,
|
|
||||||
creators=["user1", "user2"],
|
|
||||||
category="productivity",
|
|
||||||
page=1,
|
|
||||||
page_size=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify filters were applied in the query
|
|
||||||
call_args = mock_query.call_args
|
|
||||||
params = call_args[0][1:] # Skip SQL template
|
|
||||||
|
|
||||||
# Should have query, query_lower, creators array, category
|
|
||||||
assert len(params) >= 4
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_hybrid_search_weights():
|
|
||||||
"""Test hybrid search with custom weights."""
|
|
||||||
custom_weights = HybridSearchWeights(
|
|
||||||
semantic=0.5,
|
|
||||||
lexical=0.3,
|
|
||||||
category=0.1,
|
|
||||||
recency=0.1,
|
|
||||||
popularity=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
|
||||||
) as mock_query:
|
|
||||||
mock_query.return_value = []
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.embed_query"
|
|
||||||
) as mock_embed:
|
|
||||||
mock_embed.return_value = [0.1] * 1536
|
|
||||||
|
|
||||||
results, total = await hybrid_search(
|
|
||||||
query="test",
|
|
||||||
weights=custom_weights,
|
|
||||||
page=1,
|
|
||||||
page_size=20,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify custom weights were used in the query
|
|
||||||
call_args = mock_query.call_args
|
|
||||||
sql_template = call_args[0][0]
|
|
||||||
params = call_args[0][1:] # Get all parameters passed
|
|
||||||
|
|
||||||
# Check that SQL uses parameterized weights (not f-string interpolation)
|
|
||||||
assert "$" in sql_template # Verify parameterization is used
|
|
||||||
|
|
||||||
# Check that custom weights are in the params
|
|
||||||
assert 0.5 in params # semantic weight
|
|
||||||
assert 0.3 in params # lexical weight
|
|
||||||
assert 0.1 in params # category and recency weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_hybrid_search_min_score_filtering():
|
|
||||||
"""Test hybrid search minimum score threshold."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
|
||||||
) as mock_query:
|
|
||||||
# Return results with varying scores
|
|
||||||
mock_query.return_value = [
|
|
||||||
{
|
|
||||||
"slug": "high-score/agent",
|
|
||||||
"agent_name": "High Score Agent",
|
|
||||||
"combined_score": 0.8,
|
|
||||||
"total_count": 1,
|
|
||||||
# ... other fields
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.embed_query"
|
|
||||||
) as mock_embed:
|
|
||||||
mock_embed.return_value = [0.1] * 1536
|
|
||||||
|
|
||||||
# Test with custom min_score
|
|
||||||
results, total = await hybrid_search(
|
|
||||||
query="test",
|
|
||||||
min_score=0.5, # High threshold
|
|
||||||
page=1,
|
|
||||||
page_size=20,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify min_score was applied in query
|
|
||||||
call_args = mock_query.call_args
|
|
||||||
sql_template = call_args[0][0]
|
|
||||||
params = call_args[0][1:] # Get all parameters
|
|
||||||
|
|
||||||
# Check that SQL uses parameterized min_score
|
|
||||||
assert "combined_score >=" in sql_template
|
|
||||||
assert "$" in sql_template # Verify parameterization
|
|
||||||
|
|
||||||
# Check that custom min_score is in the params
|
|
||||||
assert 0.5 in params
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_hybrid_search_pagination():
|
|
||||||
"""Test hybrid search pagination."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
|
||||||
) as mock_query:
|
|
||||||
mock_query.return_value = []
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.embed_query"
|
|
||||||
) as mock_embed:
|
|
||||||
mock_embed.return_value = [0.1] * 1536
|
|
||||||
|
|
||||||
# Test page 2 with page_size 10
|
|
||||||
results, total = await hybrid_search(
|
|
||||||
query="test",
|
|
||||||
page=2,
|
|
||||||
page_size=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify pagination parameters
|
|
||||||
call_args = mock_query.call_args
|
|
||||||
params = call_args[0]
|
|
||||||
|
|
||||||
# Last two params should be LIMIT and OFFSET
|
|
||||||
limit = params[-2]
|
|
||||||
offset = params[-1]
|
|
||||||
|
|
||||||
assert limit == 10 # page_size
|
|
||||||
assert offset == 10 # (page - 1) * page_size = (2 - 1) * 10
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_hybrid_search_error_handling():
|
|
||||||
"""Test hybrid search error handling."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
|
||||||
) as mock_query:
|
|
||||||
# Simulate database error
|
|
||||||
mock_query.side_effect = Exception("Database connection error")
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.hybrid_search.embed_query"
|
|
||||||
) as mock_embed:
|
|
||||||
mock_embed.return_value = [0.1] * 1536
|
|
||||||
|
|
||||||
# Should raise exception
|
|
||||||
with pytest.raises(Exception) as exc_info:
|
|
||||||
await hybrid_search(
|
|
||||||
query="test",
|
|
||||||
page=1,
|
|
||||||
page_size=20,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "Database connection error" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v", "-s"])
|
|
||||||
@@ -110,7 +110,6 @@ class Profile(pydantic.BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class StoreSubmission(pydantic.BaseModel):
|
class StoreSubmission(pydantic.BaseModel):
|
||||||
listing_id: str
|
|
||||||
agent_id: str
|
agent_id: str
|
||||||
agent_version: int
|
agent_version: int
|
||||||
name: str
|
name: str
|
||||||
@@ -165,12 +164,8 @@ class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||||
agent_id: str = pydantic.Field(
|
agent_id: str
|
||||||
..., min_length=1, description="Agent ID cannot be empty"
|
agent_version: int
|
||||||
)
|
|
||||||
agent_version: int = pydantic.Field(
|
|
||||||
..., gt=0, description="Agent version must be greater than 0"
|
|
||||||
)
|
|
||||||
slug: str
|
slug: str
|
||||||
name: str
|
name: str
|
||||||
sub_heading: str
|
sub_heading: str
|
||||||
|
|||||||
@@ -138,7 +138,6 @@ def test_creator_details():
|
|||||||
|
|
||||||
def test_store_submission():
|
def test_store_submission():
|
||||||
submission = store_model.StoreSubmission(
|
submission = store_model.StoreSubmission(
|
||||||
listing_id="listing123",
|
|
||||||
agent_id="agent123",
|
agent_id="agent123",
|
||||||
agent_version=1,
|
agent_version=1,
|
||||||
sub_heading="Test subheading",
|
sub_heading="Test subheading",
|
||||||
@@ -160,7 +159,6 @@ def test_store_submissions_response():
|
|||||||
response = store_model.StoreSubmissionsResponse(
|
response = store_model.StoreSubmissionsResponse(
|
||||||
submissions=[
|
submissions=[
|
||||||
store_model.StoreSubmission(
|
store_model.StoreSubmission(
|
||||||
listing_id="listing123",
|
|
||||||
agent_id="agent123",
|
agent_id="agent123",
|
||||||
agent_version=1,
|
agent_version=1,
|
||||||
sub_heading="Test subheading",
|
sub_heading="Test subheading",
|
||||||
|
|||||||
@@ -521,7 +521,6 @@ def test_get_submissions_success(
|
|||||||
mocked_value = store_model.StoreSubmissionsResponse(
|
mocked_value = store_model.StoreSubmissionsResponse(
|
||||||
submissions=[
|
submissions=[
|
||||||
store_model.StoreSubmission(
|
store_model.StoreSubmission(
|
||||||
listing_id="test-listing-id",
|
|
||||||
name="Test Agent",
|
name="Test Agent",
|
||||||
description="Test agent description",
|
description="Test agent description",
|
||||||
image_urls=["test.jpg"],
|
image_urls=["test.jpg"],
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ from backend.data.onboarding import (
|
|||||||
complete_re_run_agent,
|
complete_re_run_agent,
|
||||||
get_recommended_agents,
|
get_recommended_agents,
|
||||||
get_user_onboarding,
|
get_user_onboarding,
|
||||||
|
increment_runs,
|
||||||
onboarding_enabled,
|
onboarding_enabled,
|
||||||
reset_user_onboarding,
|
reset_user_onboarding,
|
||||||
update_user_onboarding,
|
update_user_onboarding,
|
||||||
@@ -974,6 +975,7 @@ async def execute_graph(
|
|||||||
# Record successful graph execution
|
# Record successful graph execution
|
||||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||||
record_graph_operation(operation="execute", status="success")
|
record_graph_operation(operation="execute", status="success")
|
||||||
|
await increment_runs(user_id)
|
||||||
await complete_re_run_agent(user_id, graph_id)
|
await complete_re_run_agent(user_id, graph_id)
|
||||||
if source == "library":
|
if source == "library":
|
||||||
await complete_onboarding_step(
|
await complete_onboarding_step(
|
||||||
|
|||||||
@@ -6,9 +6,6 @@ import hashlib
|
|||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from prisma.types import Serializable
|
|
||||||
|
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
BaseWebhooksManager,
|
BaseWebhooksManager,
|
||||||
@@ -87,9 +84,7 @@ class AirtableWebhookManager(BaseWebhooksManager):
|
|||||||
# update webhook config
|
# update webhook config
|
||||||
await update_webhook(
|
await update_webhook(
|
||||||
webhook.id,
|
webhook.id,
|
||||||
config=cast(
|
config={"base_id": base_id, "cursor": response.cursor},
|
||||||
dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
event_type = "notification"
|
event_type = "notification"
|
||||||
|
|||||||
@@ -1,184 +0,0 @@
|
|||||||
"""
|
|
||||||
Shared helpers for Human-In-The-Loop (HITL) review functionality.
|
|
||||||
Used by both the dedicated HumanInTheLoopBlock and blocks that require human review.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from prisma.enums import ReviewStatus
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
|
||||||
from backend.data.human_review import ReviewResult
|
|
||||||
from backend.executor.manager import async_update_node_execution_status
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ReviewDecision(BaseModel):
|
|
||||||
"""Result of a review decision."""
|
|
||||||
|
|
||||||
should_proceed: bool
|
|
||||||
message: str
|
|
||||||
review_result: ReviewResult
|
|
||||||
|
|
||||||
|
|
||||||
class HITLReviewHelper:
|
|
||||||
"""Helper class for Human-In-The-Loop review operations."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
|
|
||||||
"""Create or retrieve a human review from the database."""
|
|
||||||
return await get_database_manager_async_client().get_or_create_human_review(
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def update_node_execution_status(**kwargs) -> None:
|
|
||||||
"""Update the execution status of a node."""
|
|
||||||
await async_update_node_execution_status(
|
|
||||||
db_client=get_database_manager_async_client(), **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def update_review_processed_status(
|
|
||||||
node_exec_id: str, processed: bool
|
|
||||||
) -> None:
|
|
||||||
"""Update the processed status of a review."""
|
|
||||||
return await get_database_manager_async_client().update_review_processed_status(
|
|
||||||
node_exec_id, processed
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _handle_review_request(
|
|
||||||
input_data: Any,
|
|
||||||
user_id: str,
|
|
||||||
node_exec_id: str,
|
|
||||||
graph_exec_id: str,
|
|
||||||
graph_id: str,
|
|
||||||
graph_version: int,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
block_name: str = "Block",
|
|
||||||
editable: bool = False,
|
|
||||||
) -> Optional[ReviewResult]:
|
|
||||||
"""
|
|
||||||
Handle a review request for a block that requires human review.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_data: The input data to be reviewed
|
|
||||||
user_id: ID of the user requesting the review
|
|
||||||
node_exec_id: ID of the node execution
|
|
||||||
graph_exec_id: ID of the graph execution
|
|
||||||
graph_id: ID of the graph
|
|
||||||
graph_version: Version of the graph
|
|
||||||
execution_context: Current execution context
|
|
||||||
block_name: Name of the block requesting review
|
|
||||||
editable: Whether the reviewer can edit the data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReviewResult if review is complete, None if waiting for human input
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If review creation or status update fails
|
|
||||||
"""
|
|
||||||
# Skip review if safe mode is disabled - return auto-approved result
|
|
||||||
if not execution_context.safe_mode:
|
|
||||||
logger.info(
|
|
||||||
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
|
|
||||||
)
|
|
||||||
return ReviewResult(
|
|
||||||
data=input_data,
|
|
||||||
status=ReviewStatus.APPROVED,
|
|
||||||
message="Auto-approved (safe mode disabled)",
|
|
||||||
processed=True,
|
|
||||||
node_exec_id=node_exec_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await HITLReviewHelper.get_or_create_human_review(
|
|
||||||
user_id=user_id,
|
|
||||||
node_exec_id=node_exec_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
graph_id=graph_id,
|
|
||||||
graph_version=graph_version,
|
|
||||||
input_data=input_data,
|
|
||||||
message=f"Review required for {block_name} execution",
|
|
||||||
editable=editable,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
logger.info(
|
|
||||||
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
|
||||||
)
|
|
||||||
await HITLReviewHelper.update_node_execution_status(
|
|
||||||
exec_id=node_exec_id,
|
|
||||||
status=ExecutionStatus.REVIEW,
|
|
||||||
)
|
|
||||||
return None # Signal that execution should pause
|
|
||||||
|
|
||||||
# Mark review as processed if not already done
|
|
||||||
if not result.processed:
|
|
||||||
await HITLReviewHelper.update_review_processed_status(
|
|
||||||
node_exec_id=node_exec_id, processed=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def handle_review_decision(
|
|
||||||
input_data: Any,
|
|
||||||
user_id: str,
|
|
||||||
node_exec_id: str,
|
|
||||||
graph_exec_id: str,
|
|
||||||
graph_id: str,
|
|
||||||
graph_version: int,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
block_name: str = "Block",
|
|
||||||
editable: bool = False,
|
|
||||||
) -> Optional[ReviewDecision]:
|
|
||||||
"""
|
|
||||||
Handle a review request and return the decision in a single call.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_data: The input data to be reviewed
|
|
||||||
user_id: ID of the user requesting the review
|
|
||||||
node_exec_id: ID of the node execution
|
|
||||||
graph_exec_id: ID of the graph execution
|
|
||||||
graph_id: ID of the graph
|
|
||||||
graph_version: Version of the graph
|
|
||||||
execution_context: Current execution context
|
|
||||||
block_name: Name of the block requesting review
|
|
||||||
editable: Whether the reviewer can edit the data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ReviewDecision if review is complete (approved/rejected),
|
|
||||||
None if execution should pause (awaiting review)
|
|
||||||
"""
|
|
||||||
review_result = await HITLReviewHelper._handle_review_request(
|
|
||||||
input_data=input_data,
|
|
||||||
user_id=user_id,
|
|
||||||
node_exec_id=node_exec_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
graph_id=graph_id,
|
|
||||||
graph_version=graph_version,
|
|
||||||
execution_context=execution_context,
|
|
||||||
block_name=block_name,
|
|
||||||
editable=editable,
|
|
||||||
)
|
|
||||||
|
|
||||||
if review_result is None:
|
|
||||||
# Still awaiting review - return None to pause execution
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Review is complete, determine outcome
|
|
||||||
should_proceed = review_result.status == ReviewStatus.APPROVED
|
|
||||||
message = review_result.message or (
|
|
||||||
"Execution approved by reviewer"
|
|
||||||
if should_proceed
|
|
||||||
else "Execution rejected by reviewer"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ReviewDecision(
|
|
||||||
should_proceed=should_proceed, message=message, review_result=review_result
|
|
||||||
)
|
|
||||||
@@ -3,7 +3,6 @@ from typing import Any
|
|||||||
|
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
|
|
||||||
from backend.blocks.helpers.review import HITLReviewHelper
|
|
||||||
from backend.data.block import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
@@ -12,9 +11,11 @@ from backend.data.block import (
|
|||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||||
from backend.data.human_review import ReviewResult
|
from backend.data.human_review import ReviewResult
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
|
from backend.executor.manager import async_update_node_execution_status
|
||||||
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -71,26 +72,32 @@ class HumanInTheLoopBlock(Block):
|
|||||||
("approved_data", {"name": "John Doe", "age": 30}),
|
("approved_data", {"name": "John Doe", "age": 30}),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"handle_review_decision": lambda **kwargs: type(
|
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
|
||||||
"ReviewDecision",
|
data={"name": "John Doe", "age": 30},
|
||||||
(),
|
status=ReviewStatus.APPROVED,
|
||||||
{
|
message="",
|
||||||
"should_proceed": True,
|
processed=False,
|
||||||
"message": "Test approval message",
|
node_exec_id="test-node-exec-id",
|
||||||
"review_result": ReviewResult(
|
),
|
||||||
data={"name": "John Doe", "age": 30},
|
"update_node_execution_status": lambda *_args, **_kwargs: None,
|
||||||
status=ReviewStatus.APPROVED,
|
"update_review_processed_status": lambda *_args, **_kwargs: None,
|
||||||
message="",
|
|
||||||
processed=False,
|
|
||||||
node_exec_id="test-node-exec-id",
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)(),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_review_decision(self, **kwargs):
|
async def get_or_create_human_review(self, **kwargs):
|
||||||
return await HITLReviewHelper.handle_review_decision(**kwargs)
|
return await get_database_manager_async_client().get_or_create_human_review(
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_node_execution_status(self, **kwargs):
|
||||||
|
return await async_update_node_execution_status(
|
||||||
|
db_client=get_database_manager_async_client(), **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_review_processed_status(self, node_exec_id: str, processed: bool):
|
||||||
|
return await get_database_manager_async_client().update_review_processed_status(
|
||||||
|
node_exec_id, processed
|
||||||
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
@@ -102,7 +109,7 @@ class HumanInTheLoopBlock(Block):
|
|||||||
graph_id: str,
|
graph_id: str,
|
||||||
graph_version: int,
|
graph_version: int,
|
||||||
execution_context: ExecutionContext,
|
execution_context: ExecutionContext,
|
||||||
**_kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
if not execution_context.safe_mode:
|
if not execution_context.safe_mode:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -112,28 +119,48 @@ class HumanInTheLoopBlock(Block):
|
|||||||
yield "review_message", "Auto-approved (safe mode disabled)"
|
yield "review_message", "Auto-approved (safe mode disabled)"
|
||||||
return
|
return
|
||||||
|
|
||||||
decision = await self.handle_review_decision(
|
try:
|
||||||
input_data=input_data.data,
|
result = await self.get_or_create_human_review(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=graph_exec_id,
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
execution_context=execution_context,
|
input_data=input_data.data,
|
||||||
block_name=self.name,
|
message=input_data.name,
|
||||||
editable=input_data.editable,
|
editable=input_data.editable,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in HITL block for node {node_exec_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
if decision is None:
|
if result is None:
|
||||||
return
|
logger.info(
|
||||||
|
f"HITL block pausing execution for node {node_exec_id} - awaiting human review"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await self.update_node_execution_status(
|
||||||
|
exec_id=node_exec_id,
|
||||||
|
status=ExecutionStatus.REVIEW,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to update node status for HITL block {node_exec_id}: {str(e)}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
status = decision.review_result.status
|
if not result.processed:
|
||||||
if status == ReviewStatus.APPROVED:
|
await self.update_review_processed_status(
|
||||||
yield "approved_data", decision.review_result.data
|
node_exec_id=node_exec_id, processed=True
|
||||||
elif status == ReviewStatus.REJECTED:
|
)
|
||||||
yield "rejected_data", decision.review_result.data
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unexpected review status: {status}")
|
|
||||||
|
|
||||||
if decision.message:
|
if result.status == ReviewStatus.APPROVED:
|
||||||
yield "review_message", decision.message
|
yield "approved_data", result.data
|
||||||
|
if result.message:
|
||||||
|
yield "review_message", result.message
|
||||||
|
|
||||||
|
elif result.status == ReviewStatus.REJECTED:
|
||||||
|
yield "rejected_data", result.data
|
||||||
|
if result.message:
|
||||||
|
yield "review_message", result.message
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -18,7 +18,6 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.request import DEFAULT_USER_AGENT
|
|
||||||
|
|
||||||
|
|
||||||
class GetWikipediaSummaryBlock(Block, GetRequest):
|
class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||||
@@ -40,27 +39,17 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
|
|||||||
output_schema=GetWikipediaSummaryBlock.Output,
|
output_schema=GetWikipediaSummaryBlock.Output,
|
||||||
test_input={"topic": "Artificial Intelligence"},
|
test_input={"topic": "Artificial Intelligence"},
|
||||||
test_output=("summary", "summary content"),
|
test_output=("summary", "summary content"),
|
||||||
test_mock={
|
test_mock={"get_request": lambda url, json: {"extract": "summary content"}},
|
||||||
"get_request": lambda url, headers, json: {"extract": "summary content"}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||||
topic = input_data.topic
|
topic = input_data.topic
|
||||||
# URL-encode the topic to handle spaces and special characters
|
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||||
encoded_topic = quote(topic, safe="")
|
|
||||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{encoded_topic}"
|
|
||||||
|
|
||||||
# Set headers per Wikimedia robot policy (https://w.wiki/4wJS)
|
|
||||||
# - User-Agent: Required, must identify the bot
|
|
||||||
# - Accept-Encoding: gzip recommended to reduce bandwidth
|
|
||||||
headers = {
|
|
||||||
"User-Agent": DEFAULT_USER_AGENT,
|
|
||||||
"Accept-Encoding": "gzip, deflate",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
# Note: User-Agent is now automatically set by the request library
|
||||||
|
# to comply with Wikimedia's robot policy (https://w.wiki/4wJS)
|
||||||
try:
|
try:
|
||||||
response = await self.get_request(url, headers=headers, json=True)
|
response = await self.get_request(url, json=True)
|
||||||
if "extract" not in response:
|
if "extract" not in response:
|
||||||
raise ValueError(f"Unable to parse Wikipedia response: {response}")
|
raise ValueError(f"Unable to parse Wikipedia response: {response}")
|
||||||
yield "summary", response["extract"]
|
yield "summary", response["extract"]
|
||||||
|
|||||||
@@ -391,12 +391,8 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
"""
|
"""
|
||||||
block = sink_node.block
|
block = sink_node.block
|
||||||
|
|
||||||
# Use custom name from node metadata if set, otherwise fall back to block.name
|
|
||||||
custom_name = sink_node.metadata.get("customized_name")
|
|
||||||
tool_name = custom_name if custom_name else block.name
|
|
||||||
|
|
||||||
tool_function: dict[str, Any] = {
|
tool_function: dict[str, Any] = {
|
||||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
"name": SmartDecisionMakerBlock.cleanup(block.name),
|
||||||
"description": block.description,
|
"description": block.description,
|
||||||
}
|
}
|
||||||
sink_block_input_schema = block.input_schema
|
sink_block_input_schema = block.input_schema
|
||||||
@@ -493,24 +489,14 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use custom name from node metadata if set, otherwise fall back to graph name
|
|
||||||
custom_name = sink_node.metadata.get("customized_name")
|
|
||||||
tool_name = custom_name if custom_name else sink_graph_meta.name
|
|
||||||
|
|
||||||
tool_function: dict[str, Any] = {
|
tool_function: dict[str, Any] = {
|
||||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
|
||||||
"description": sink_graph_meta.description,
|
"description": sink_graph_meta.description,
|
||||||
}
|
}
|
||||||
|
|
||||||
properties = {}
|
properties = {}
|
||||||
field_mapping = {}
|
|
||||||
|
|
||||||
for link in links:
|
for link in links:
|
||||||
field_name = link.sink_name
|
|
||||||
|
|
||||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
|
||||||
field_mapping[clean_field_name] = field_name
|
|
||||||
|
|
||||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||||
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
|
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
|
||||||
link.sink_name, {}
|
link.sink_name, {}
|
||||||
@@ -520,7 +506,7 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
if "description" in sink_block_properties
|
if "description" in sink_block_properties
|
||||||
else f"The {link.sink_name} of the tool"
|
else f"The {link.sink_name} of the tool"
|
||||||
)
|
)
|
||||||
properties[clean_field_name] = {
|
properties[link.sink_name] = {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": description,
|
"description": description,
|
||||||
"default": json.dumps(sink_block_properties.get("default", None)),
|
"default": json.dumps(sink_block_properties.get("default", None)),
|
||||||
@@ -533,7 +519,7 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
"strict": True,
|
"strict": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
tool_function["_field_mapping"] = field_mapping
|
# Store node info for later use in output processing
|
||||||
tool_function["_sink_node_id"] = sink_node.id
|
tool_function["_sink_node_id"] = sink_node.id
|
||||||
|
|
||||||
return {"type": "function", "function": tool_function}
|
return {"type": "function", "function": tool_function}
|
||||||
@@ -989,28 +975,10 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
graph_version: int,
|
graph_version: int,
|
||||||
execution_context: ExecutionContext,
|
execution_context: ExecutionContext,
|
||||||
execution_processor: "ExecutionProcessor",
|
execution_processor: "ExecutionProcessor",
|
||||||
nodes_to_skip: set[str] | None = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
tool_functions = await self._create_tool_node_signatures(node_id)
|
tool_functions = await self._create_tool_node_signatures(node_id)
|
||||||
original_tool_count = len(tool_functions)
|
|
||||||
|
|
||||||
# Filter out tools for nodes that should be skipped (e.g., missing optional credentials)
|
|
||||||
if nodes_to_skip:
|
|
||||||
tool_functions = [
|
|
||||||
tf
|
|
||||||
for tf in tool_functions
|
|
||||||
if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip
|
|
||||||
]
|
|
||||||
|
|
||||||
# Only raise error if we had tools but they were all filtered out
|
|
||||||
if original_tool_count > 0 and not tool_functions:
|
|
||||||
raise ValueError(
|
|
||||||
"No available tools to execute - all downstream nodes are unavailable "
|
|
||||||
"(possibly due to missing optional credentials)"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "tool_functions", json.dumps(tool_functions)
|
yield "tool_functions", json.dumps(tool_functions)
|
||||||
|
|
||||||
conversation_history = input_data.conversation_history or []
|
conversation_history = input_data.conversation_history or []
|
||||||
@@ -1161,9 +1129,8 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
|
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
|
||||||
arg_value = tool_args.get(clean_arg_name)
|
arg_value = tool_args.get(clean_arg_name)
|
||||||
|
|
||||||
# Use original_field_name directly (not sanitized) to match link sink_name
|
sanitized_arg_name = self.cleanup(original_field_name)
|
||||||
# The field_mapping already translates from LLM's cleaned names to original names
|
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_arg_name}"
|
||||||
emit_key = f"tools_^_{sink_node_id}_~_{original_field_name}"
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||||
|
|||||||
@@ -1057,153 +1057,3 @@ async def test_smart_decision_maker_traditional_mode_default():
|
|||||||
) # Should yield individual tool parameters
|
) # Should yield individual tool parameters
|
||||||
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||||
assert "conversations" in outputs
|
assert "conversations" in outputs
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
|
||||||
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
from backend.blocks.basic import StoreValueBlock
|
|
||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
|
||||||
from backend.data.graph import Link, Node
|
|
||||||
|
|
||||||
# Create a mock node with customized_name in metadata
|
|
||||||
mock_node = MagicMock(spec=Node)
|
|
||||||
mock_node.id = "test-node-id"
|
|
||||||
mock_node.block_id = StoreValueBlock().id
|
|
||||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
|
||||||
mock_node.block = StoreValueBlock()
|
|
||||||
|
|
||||||
# Create a mock link
|
|
||||||
mock_link = MagicMock(spec=Link)
|
|
||||||
mock_link.sink_name = "input"
|
|
||||||
|
|
||||||
# Call the function directly
|
|
||||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
|
||||||
mock_node, [mock_link]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the tool name uses the customized name (cleaned up)
|
|
||||||
assert result["type"] == "function"
|
|
||||||
assert result["function"]["name"] == "my_custom_tool_name" # Cleaned version
|
|
||||||
assert result["function"]["_sink_node_id"] == "test-node-id"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_smart_decision_maker_falls_back_to_block_name():
|
|
||||||
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
from backend.blocks.basic import StoreValueBlock
|
|
||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
|
||||||
from backend.data.graph import Link, Node
|
|
||||||
|
|
||||||
# Create a mock node without customized_name
|
|
||||||
mock_node = MagicMock(spec=Node)
|
|
||||||
mock_node.id = "test-node-id"
|
|
||||||
mock_node.block_id = StoreValueBlock().id
|
|
||||||
mock_node.metadata = {} # No customized_name
|
|
||||||
mock_node.block = StoreValueBlock()
|
|
||||||
|
|
||||||
# Create a mock link
|
|
||||||
mock_link = MagicMock(spec=Link)
|
|
||||||
mock_link.sink_name = "input"
|
|
||||||
|
|
||||||
# Call the function directly
|
|
||||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
|
||||||
mock_node, [mock_link]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the tool name uses the block's default name
|
|
||||||
assert result["type"] == "function"
|
|
||||||
assert result["function"]["name"] == "storevalueblock" # Default block name cleaned
|
|
||||||
assert result["function"]["_sink_node_id"] == "test-node-id"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_smart_decision_maker_uses_customized_name_for_agents():
|
|
||||||
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
|
||||||
from backend.data.graph import Link, Node
|
|
||||||
|
|
||||||
# Create a mock node with customized_name in metadata
|
|
||||||
mock_node = MagicMock(spec=Node)
|
|
||||||
mock_node.id = "test-agent-node-id"
|
|
||||||
mock_node.metadata = {"customized_name": "My Custom Agent"}
|
|
||||||
mock_node.input_default = {
|
|
||||||
"graph_id": "test-graph-id",
|
|
||||||
"graph_version": 1,
|
|
||||||
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create a mock link
|
|
||||||
mock_link = MagicMock(spec=Link)
|
|
||||||
mock_link.sink_name = "test_input"
|
|
||||||
|
|
||||||
# Mock the database client
|
|
||||||
mock_graph_meta = MagicMock()
|
|
||||||
mock_graph_meta.name = "Original Agent Name"
|
|
||||||
mock_graph_meta.description = "Agent description"
|
|
||||||
|
|
||||||
mock_db_client = AsyncMock()
|
|
||||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
|
||||||
return_value=mock_db_client,
|
|
||||||
):
|
|
||||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
|
||||||
mock_node, [mock_link]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the tool name uses the customized name (cleaned up)
|
|
||||||
assert result["type"] == "function"
|
|
||||||
assert result["function"]["name"] == "my_custom_agent" # Cleaned version
|
|
||||||
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
|
||||||
"""Test that agent node falls back to graph name when no customized_name."""
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
|
||||||
from backend.data.graph import Link, Node
|
|
||||||
|
|
||||||
# Create a mock node without customized_name
|
|
||||||
mock_node = MagicMock(spec=Node)
|
|
||||||
mock_node.id = "test-agent-node-id"
|
|
||||||
mock_node.metadata = {} # No customized_name
|
|
||||||
mock_node.input_default = {
|
|
||||||
"graph_id": "test-graph-id",
|
|
||||||
"graph_version": 1,
|
|
||||||
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create a mock link
|
|
||||||
mock_link = MagicMock(spec=Link)
|
|
||||||
mock_link.sink_name = "test_input"
|
|
||||||
|
|
||||||
# Mock the database client
|
|
||||||
mock_graph_meta = MagicMock()
|
|
||||||
mock_graph_meta.name = "Original Agent Name"
|
|
||||||
mock_graph_meta.description = "Agent description"
|
|
||||||
|
|
||||||
mock_db_client = AsyncMock()
|
|
||||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
|
||||||
return_value=mock_db_client,
|
|
||||||
):
|
|
||||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
|
||||||
mock_node, [mock_link]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the tool name uses the graph's default name
|
|
||||||
assert result["type"] == "function"
|
|
||||||
assert result["function"]["name"] == "original_agent_name" # Graph name cleaned
|
|
||||||
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
|||||||
mock_node.block = CreateDictionaryBlock()
|
mock_node.block = CreateDictionaryBlock()
|
||||||
mock_node.block_id = CreateDictionaryBlock().id
|
mock_node.block_id = CreateDictionaryBlock().id
|
||||||
mock_node.input_default = {}
|
mock_node.input_default = {}
|
||||||
mock_node.metadata = {}
|
|
||||||
|
|
||||||
# Create mock links with dynamic dictionary fields
|
# Create mock links with dynamic dictionary fields
|
||||||
mock_links = [
|
mock_links = [
|
||||||
@@ -78,7 +77,6 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
|
|||||||
mock_node.block = AddToListBlock()
|
mock_node.block = AddToListBlock()
|
||||||
mock_node.block_id = AddToListBlock().id
|
mock_node.block_id = AddToListBlock().id
|
||||||
mock_node.input_default = {}
|
mock_node.input_default = {}
|
||||||
mock_node.metadata = {}
|
|
||||||
|
|
||||||
# Create mock links with dynamic list fields
|
# Create mock links with dynamic list fields
|
||||||
mock_links = [
|
mock_links = [
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ async def test_create_block_function_signature_with_dict_fields():
|
|||||||
mock_node.block = CreateDictionaryBlock()
|
mock_node.block = CreateDictionaryBlock()
|
||||||
mock_node.block_id = CreateDictionaryBlock().id
|
mock_node.block_id = CreateDictionaryBlock().id
|
||||||
mock_node.input_default = {}
|
mock_node.input_default = {}
|
||||||
mock_node.metadata = {}
|
|
||||||
|
|
||||||
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
|
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
|
||||||
mock_links = [
|
mock_links = [
|
||||||
@@ -107,7 +106,6 @@ async def test_create_block_function_signature_with_list_fields():
|
|||||||
mock_node.block = AddToListBlock()
|
mock_node.block = AddToListBlock()
|
||||||
mock_node.block_id = AddToListBlock().id
|
mock_node.block_id = AddToListBlock().id
|
||||||
mock_node.input_default = {}
|
mock_node.input_default = {}
|
||||||
mock_node.metadata = {}
|
|
||||||
|
|
||||||
# Create mock links with dynamic list fields
|
# Create mock links with dynamic list fields
|
||||||
mock_links = [
|
mock_links = [
|
||||||
@@ -161,7 +159,6 @@ async def test_create_block_function_signature_with_object_fields():
|
|||||||
mock_node.block = MatchTextPatternBlock()
|
mock_node.block = MatchTextPatternBlock()
|
||||||
mock_node.block_id = MatchTextPatternBlock().id
|
mock_node.block_id = MatchTextPatternBlock().id
|
||||||
mock_node.input_default = {}
|
mock_node.input_default = {}
|
||||||
mock_node.metadata = {}
|
|
||||||
|
|
||||||
# Create mock links with dynamic object fields
|
# Create mock links with dynamic object fields
|
||||||
mock_links = [
|
mock_links = [
|
||||||
@@ -211,13 +208,11 @@ async def test_create_tool_node_signatures():
|
|||||||
mock_dict_node.block = CreateDictionaryBlock()
|
mock_dict_node.block = CreateDictionaryBlock()
|
||||||
mock_dict_node.block_id = CreateDictionaryBlock().id
|
mock_dict_node.block_id = CreateDictionaryBlock().id
|
||||||
mock_dict_node.input_default = {}
|
mock_dict_node.input_default = {}
|
||||||
mock_dict_node.metadata = {}
|
|
||||||
|
|
||||||
mock_list_node = Mock()
|
mock_list_node = Mock()
|
||||||
mock_list_node.block = AddToListBlock()
|
mock_list_node.block = AddToListBlock()
|
||||||
mock_list_node.block_id = AddToListBlock().id
|
mock_list_node.block_id = AddToListBlock().id
|
||||||
mock_list_node.input_default = {}
|
mock_list_node.input_default = {}
|
||||||
mock_list_node.metadata = {}
|
|
||||||
|
|
||||||
# Mock links with dynamic fields
|
# Mock links with dynamic fields
|
||||||
dict_link1 = Mock(
|
dict_link1 = Mock(
|
||||||
@@ -428,7 +423,6 @@ async def test_mixed_regular_and_dynamic_fields():
|
|||||||
mock_node.block.name = "TestBlock"
|
mock_node.block.name = "TestBlock"
|
||||||
mock_node.block.description = "A test block"
|
mock_node.block.description = "A test block"
|
||||||
mock_node.block.input_schema = Mock()
|
mock_node.block.input_schema = Mock()
|
||||||
mock_node.metadata = {}
|
|
||||||
|
|
||||||
# Mock the get_field_schema to return a proper schema for regular fields
|
# Mock the get_field_schema to return a proper schema for regular fields
|
||||||
def get_field_schema(field_name):
|
def get_field_schema(field_name):
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from .blog import WordPressCreatePostBlock, WordPressGetAllPostsBlock
|
from .blog import WordPressCreatePostBlock
|
||||||
|
|
||||||
__all__ = ["WordPressCreatePostBlock", "WordPressGetAllPostsBlock"]
|
__all__ = ["WordPressCreatePostBlock"]
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ async def oauth_exchange_code_for_tokens(
|
|||||||
grant_type="authorization_code",
|
grant_type="authorization_code",
|
||||||
).model_dump(exclude_none=True)
|
).model_dump(exclude_none=True)
|
||||||
|
|
||||||
response = await Requests(raise_for_status=False).post(
|
response = await Requests().post(
|
||||||
f"{WORDPRESS_BASE_URL}oauth2/token",
|
f"{WORDPRESS_BASE_URL}oauth2/token",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=data,
|
data=data,
|
||||||
@@ -205,7 +205,7 @@ async def oauth_refresh_tokens(
|
|||||||
grant_type="refresh_token",
|
grant_type="refresh_token",
|
||||||
).model_dump(exclude_none=True)
|
).model_dump(exclude_none=True)
|
||||||
|
|
||||||
response = await Requests(raise_for_status=False).post(
|
response = await Requests().post(
|
||||||
f"{WORDPRESS_BASE_URL}oauth2/token",
|
f"{WORDPRESS_BASE_URL}oauth2/token",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=data,
|
data=data,
|
||||||
@@ -252,7 +252,7 @@ async def validate_token(
|
|||||||
"token": token,
|
"token": token,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await Requests(raise_for_status=False).get(
|
response = await Requests().get(
|
||||||
f"{WORDPRESS_BASE_URL}oauth2/token-info",
|
f"{WORDPRESS_BASE_URL}oauth2/token-info",
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
@@ -296,7 +296,7 @@ async def make_api_request(
|
|||||||
|
|
||||||
url = f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}"
|
url = f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}"
|
||||||
|
|
||||||
request_method = getattr(Requests(raise_for_status=False), method.lower())
|
request_method = getattr(Requests(), method.lower())
|
||||||
response = await request_method(
|
response = await request_method(
|
||||||
url,
|
url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@@ -476,7 +476,6 @@ async def create_post(
|
|||||||
data["tags"] = ",".join(str(t) for t in data["tags"])
|
data["tags"] = ",".join(str(t) for t in data["tags"])
|
||||||
|
|
||||||
# Make the API request
|
# Make the API request
|
||||||
site = normalize_site(site)
|
|
||||||
endpoint = f"/rest/v1.1/sites/{site}/posts/new"
|
endpoint = f"/rest/v1.1/sites/{site}/posts/new"
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
@@ -484,7 +483,7 @@ async def create_post(
|
|||||||
"Content-Type": "application/x-www-form-urlencoded",
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await Requests(raise_for_status=False).post(
|
response = await Requests().post(
|
||||||
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
|
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=data,
|
data=data,
|
||||||
@@ -500,132 +499,3 @@ async def create_post(
|
|||||||
)
|
)
|
||||||
error_message = error_data.get("message", response.text)
|
error_message = error_data.get("message", response.text)
|
||||||
raise ValueError(f"Failed to create post: {response.status} - {error_message}")
|
raise ValueError(f"Failed to create post: {response.status} - {error_message}")
|
||||||
|
|
||||||
|
|
||||||
class Post(BaseModel):
|
|
||||||
"""Response model for individual posts in a posts list response.
|
|
||||||
|
|
||||||
This is a simplified version compared to PostResponse, as the list endpoint
|
|
||||||
returns less detailed information than the create/get single post endpoints.
|
|
||||||
"""
|
|
||||||
|
|
||||||
ID: int
|
|
||||||
site_ID: int
|
|
||||||
author: PostAuthor
|
|
||||||
date: datetime
|
|
||||||
modified: datetime
|
|
||||||
title: str
|
|
||||||
URL: str
|
|
||||||
short_URL: str
|
|
||||||
content: str | None = None
|
|
||||||
excerpt: str | None = None
|
|
||||||
slug: str
|
|
||||||
guid: str
|
|
||||||
status: str
|
|
||||||
sticky: bool
|
|
||||||
password: str | None = ""
|
|
||||||
parent: Union[Dict[str, Any], bool, None] = None
|
|
||||||
type: str
|
|
||||||
discussion: Dict[str, Union[str, bool, int]] | None = None
|
|
||||||
likes_enabled: bool | None = None
|
|
||||||
sharing_enabled: bool | None = None
|
|
||||||
like_count: int | None = None
|
|
||||||
i_like: bool | None = None
|
|
||||||
is_reblogged: bool | None = None
|
|
||||||
is_following: bool | None = None
|
|
||||||
global_ID: str | None = None
|
|
||||||
featured_image: str | None = None
|
|
||||||
post_thumbnail: Dict[str, Any] | None = None
|
|
||||||
format: str | None = None
|
|
||||||
geo: Union[Dict[str, Any], bool, None] = None
|
|
||||||
menu_order: int | None = None
|
|
||||||
page_template: str | None = None
|
|
||||||
publicize_URLs: List[str] | None = None
|
|
||||||
terms: Dict[str, Dict[str, Any]] | None = None
|
|
||||||
tags: Dict[str, Dict[str, Any]] | None = None
|
|
||||||
categories: Dict[str, Dict[str, Any]] | None = None
|
|
||||||
attachments: Dict[str, Dict[str, Any]] | None = None
|
|
||||||
attachment_count: int | None = None
|
|
||||||
metadata: List[Dict[str, Any]] | None = None
|
|
||||||
meta: Dict[str, Any] | None = None
|
|
||||||
capabilities: Dict[str, bool] | None = None
|
|
||||||
revisions: List[int] | None = None
|
|
||||||
other_URLs: Dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class PostsResponse(BaseModel):
|
|
||||||
"""Response model for WordPress posts list."""
|
|
||||||
|
|
||||||
found: int
|
|
||||||
posts: List[Post]
|
|
||||||
meta: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_site(site: str) -> str:
|
|
||||||
"""
|
|
||||||
Normalize a site identifier by stripping protocol and trailing slashes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
site: Site URL, domain, or ID (e.g., "https://myblog.wordpress.com/", "myblog.wordpress.com", "123456789")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Normalized site identifier (domain or ID only)
|
|
||||||
"""
|
|
||||||
site = site.strip()
|
|
||||||
if site.startswith("https://"):
|
|
||||||
site = site[8:]
|
|
||||||
elif site.startswith("http://"):
|
|
||||||
site = site[7:]
|
|
||||||
return site.rstrip("/")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_posts(
|
|
||||||
credentials: Credentials,
|
|
||||||
site: str,
|
|
||||||
status: PostStatus | None = None,
|
|
||||||
number: int = 100,
|
|
||||||
offset: int = 0,
|
|
||||||
) -> PostsResponse:
|
|
||||||
"""
|
|
||||||
Get posts from a WordPress site.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: OAuth credentials
|
|
||||||
site: Site ID or domain (e.g., "myblog.wordpress.com" or "123456789")
|
|
||||||
status: Filter by post status using PostStatus enum, or None for all
|
|
||||||
number: Number of posts to retrieve (max 100)
|
|
||||||
offset: Number of posts to skip (for pagination)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PostsResponse with the list of posts
|
|
||||||
"""
|
|
||||||
site = normalize_site(site)
|
|
||||||
endpoint = f"/rest/v1.1/sites/{site}/posts"
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Authorization": credentials.auth_header(),
|
|
||||||
}
|
|
||||||
|
|
||||||
params: Dict[str, Any] = {
|
|
||||||
"number": max(1, min(number, 100)), # 1–100 posts per request
|
|
||||||
"offset": offset,
|
|
||||||
}
|
|
||||||
|
|
||||||
if status:
|
|
||||||
params["status"] = status.value
|
|
||||||
response = await Requests(raise_for_status=False).get(
|
|
||||||
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
|
|
||||||
headers=headers,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.ok:
|
|
||||||
return PostsResponse.model_validate(response.json())
|
|
||||||
|
|
||||||
error_data = (
|
|
||||||
response.json()
|
|
||||||
if response.headers.get("content-type", "").startswith("application/json")
|
|
||||||
else {}
|
|
||||||
)
|
|
||||||
error_message = error_data.get("message", response.text)
|
|
||||||
raise ValueError(f"Failed to get posts: {response.status} - {error_message}")
|
|
||||||
|
|||||||
@@ -9,15 +9,7 @@ from backend.sdk import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._api import (
|
from ._api import CreatePostRequest, PostResponse, PostStatus, create_post
|
||||||
CreatePostRequest,
|
|
||||||
Post,
|
|
||||||
PostResponse,
|
|
||||||
PostsResponse,
|
|
||||||
PostStatus,
|
|
||||||
create_post,
|
|
||||||
get_posts,
|
|
||||||
)
|
|
||||||
from ._config import wordpress
|
from ._config import wordpress
|
||||||
|
|
||||||
|
|
||||||
@@ -57,15 +49,8 @@ class WordPressCreatePostBlock(Block):
|
|||||||
media_urls: list[str] = SchemaField(
|
media_urls: list[str] = SchemaField(
|
||||||
description="URLs of images to sideload and attach to the post", default=[]
|
description="URLs of images to sideload and attach to the post", default=[]
|
||||||
)
|
)
|
||||||
publish_as_draft: bool = SchemaField(
|
|
||||||
description="If True, publishes the post as a draft. If False, publishes it publicly.",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
site: str = SchemaField(
|
|
||||||
description="The site ID or domain (pass-through for chaining with other blocks)"
|
|
||||||
)
|
|
||||||
post_id: int = SchemaField(description="The ID of the created post")
|
post_id: int = SchemaField(description="The ID of the created post")
|
||||||
post_url: str = SchemaField(description="The full URL of the created post")
|
post_url: str = SchemaField(description="The full URL of the created post")
|
||||||
short_url: str = SchemaField(description="The shortened wp.me URL")
|
short_url: str = SchemaField(description="The shortened wp.me URL")
|
||||||
@@ -93,9 +78,7 @@ class WordPressCreatePostBlock(Block):
|
|||||||
tags=input_data.tags,
|
tags=input_data.tags,
|
||||||
featured_image=input_data.featured_image,
|
featured_image=input_data.featured_image,
|
||||||
media_urls=input_data.media_urls,
|
media_urls=input_data.media_urls,
|
||||||
status=(
|
status=PostStatus.PUBLISH,
|
||||||
PostStatus.DRAFT if input_data.publish_as_draft else PostStatus.PUBLISH
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
post_response: PostResponse = await create_post(
|
post_response: PostResponse = await create_post(
|
||||||
@@ -104,69 +87,7 @@ class WordPressCreatePostBlock(Block):
|
|||||||
post_data=post_request,
|
post_data=post_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "site", input_data.site
|
|
||||||
yield "post_id", post_response.ID
|
yield "post_id", post_response.ID
|
||||||
yield "post_url", post_response.URL
|
yield "post_url", post_response.URL
|
||||||
yield "short_url", post_response.short_URL
|
yield "short_url", post_response.short_URL
|
||||||
yield "post_data", post_response.model_dump()
|
yield "post_data", post_response.model_dump()
|
||||||
|
|
||||||
|
|
||||||
class WordPressGetAllPostsBlock(Block):
|
|
||||||
"""
|
|
||||||
Fetches all posts from a WordPress.com site or Jetpack-enabled site.
|
|
||||||
Supports filtering by status and pagination.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: CredentialsMetaInput = wordpress.credentials_field()
|
|
||||||
site: str = SchemaField(
|
|
||||||
description="Site ID or domain (e.g., 'myblog.wordpress.com' or '123456789')"
|
|
||||||
)
|
|
||||||
status: PostStatus | None = SchemaField(
|
|
||||||
description="Filter by post status, or None for all",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
number: int = SchemaField(
|
|
||||||
description="Number of posts to retrieve (max 100 per request)", default=20
|
|
||||||
)
|
|
||||||
offset: int = SchemaField(
|
|
||||||
description="Number of posts to skip (for pagination)", default=0
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
site: str = SchemaField(
|
|
||||||
description="The site ID or domain (pass-through for chaining with other blocks)"
|
|
||||||
)
|
|
||||||
found: int = SchemaField(description="Total number of posts found")
|
|
||||||
posts: list[Post] = SchemaField(
|
|
||||||
description="List of post objects with their details"
|
|
||||||
)
|
|
||||||
post: Post = SchemaField(
|
|
||||||
description="Individual post object (yielded for each post)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="97728fa7-7f6f-4789-ba0c-f2c114119536",
|
|
||||||
description="Fetch all posts from WordPress.com or Jetpack sites",
|
|
||||||
categories={BlockCategory.SOCIAL},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self, input_data: Input, *, credentials: Credentials, **kwargs
|
|
||||||
) -> BlockOutput:
|
|
||||||
posts_response: PostsResponse = await get_posts(
|
|
||||||
credentials=credentials,
|
|
||||||
site=input_data.site,
|
|
||||||
status=input_data.status,
|
|
||||||
number=input_data.number,
|
|
||||||
offset=input_data.offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "site", input_data.site
|
|
||||||
yield "found", posts_response.found
|
|
||||||
yield "posts", posts_response.posts
|
|
||||||
for post in posts_response.posts:
|
|
||||||
yield "post", post
|
|
||||||
|
|||||||
@@ -50,8 +50,6 @@ from .model import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
|
|
||||||
from .graph import Link
|
from .graph import Link
|
||||||
|
|
||||||
app_config = Config()
|
app_config = Config()
|
||||||
@@ -474,7 +472,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
self.block_type = block_type
|
self.block_type = block_type
|
||||||
self.webhook_config = webhook_config
|
self.webhook_config = webhook_config
|
||||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||||
self.requires_human_review: bool = False
|
|
||||||
|
|
||||||
if self.webhook_config:
|
if self.webhook_config:
|
||||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||||
@@ -617,77 +614,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
block_id=self.id,
|
block_id=self.id,
|
||||||
) from ex
|
) from ex
|
||||||
|
|
||||||
async def is_block_exec_need_review(
|
|
||||||
self,
|
|
||||||
input_data: BlockInput,
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
node_exec_id: str,
|
|
||||||
graph_exec_id: str,
|
|
||||||
graph_id: str,
|
|
||||||
graph_version: int,
|
|
||||||
execution_context: "ExecutionContext",
|
|
||||||
**kwargs,
|
|
||||||
) -> tuple[bool, BlockInput]:
|
|
||||||
"""
|
|
||||||
Check if this block execution needs human review and handle the review process.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (should_pause, input_data_to_use)
|
|
||||||
- should_pause: True if execution should be paused for review
|
|
||||||
- input_data_to_use: The input data to use (may be modified by reviewer)
|
|
||||||
"""
|
|
||||||
# Skip review if not required or safe mode is disabled
|
|
||||||
if not self.requires_human_review or not execution_context.safe_mode:
|
|
||||||
return False, input_data
|
|
||||||
|
|
||||||
from backend.blocks.helpers.review import HITLReviewHelper
|
|
||||||
|
|
||||||
# Handle the review request and get decision
|
|
||||||
decision = await HITLReviewHelper.handle_review_decision(
|
|
||||||
input_data=input_data,
|
|
||||||
user_id=user_id,
|
|
||||||
node_exec_id=node_exec_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
graph_id=graph_id,
|
|
||||||
graph_version=graph_version,
|
|
||||||
execution_context=execution_context,
|
|
||||||
block_name=self.name,
|
|
||||||
editable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if decision is None:
|
|
||||||
# We're awaiting review - pause execution
|
|
||||||
return True, input_data
|
|
||||||
|
|
||||||
if not decision.should_proceed:
|
|
||||||
# Review was rejected, raise an error to stop execution
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Block execution rejected by reviewer: {decision.message}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=self.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Review was approved - use the potentially modified data
|
|
||||||
# ReviewResult.data must be a dict for block inputs
|
|
||||||
reviewed_data = decision.review_result.data
|
|
||||||
if not isinstance(reviewed_data, dict):
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=self.id,
|
|
||||||
)
|
|
||||||
return False, reviewed_data
|
|
||||||
|
|
||||||
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||||
# Check for review requirement and get potentially modified input data
|
|
||||||
should_pause, input_data = await self.is_block_exec_need_review(
|
|
||||||
input_data, **kwargs
|
|
||||||
)
|
|
||||||
if should_pause:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Validate the input data (original or reviewer-modified) once
|
|
||||||
if error := self.input_schema.validate_data(input_data):
|
if error := self.input_schema.validate_data(input_data):
|
||||||
raise BlockInputError(
|
raise BlockInputError(
|
||||||
message=f"Unable to execute block with invalid input data: {error}",
|
message=f"Unable to execute block with invalid input data: {error}",
|
||||||
@@ -695,7 +622,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
block_id=self.id,
|
block_id=self.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the validated input data
|
|
||||||
async for output_name, output_data in self.run(
|
async for output_name, output_data in self.run(
|
||||||
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
@@ -38,20 +38,6 @@ POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
|
|||||||
if POOL_TIMEOUT:
|
if POOL_TIMEOUT:
|
||||||
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
|
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
|
||||||
|
|
||||||
# Add public schema to search_path for pgvector type access
|
|
||||||
# The vector extension is in public schema, but search_path is determined by schema parameter
|
|
||||||
# Extract the schema from DATABASE_URL or default to 'public' (matching get_database_schema())
|
|
||||||
parsed_url = urlparse(DATABASE_URL)
|
|
||||||
url_params = dict(parse_qsl(parsed_url.query))
|
|
||||||
db_schema = url_params.get("schema", "public")
|
|
||||||
# Build search_path, avoiding duplicates if db_schema is already 'public'
|
|
||||||
search_path_schemas = list(
|
|
||||||
dict.fromkeys([db_schema, "public"])
|
|
||||||
) # Preserves order, removes duplicates
|
|
||||||
search_path = ",".join(search_path_schemas)
|
|
||||||
# This allows using ::vector without schema qualification
|
|
||||||
DATABASE_URL = add_param(DATABASE_URL, "options", f"-c search_path={search_path}")
|
|
||||||
|
|
||||||
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
|
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
|
||||||
|
|
||||||
prisma = Prisma(
|
prisma = Prisma(
|
||||||
@@ -122,102 +108,21 @@ def get_database_schema() -> str:
|
|||||||
return query_params.get("schema", "public")
|
return query_params.get("schema", "public")
|
||||||
|
|
||||||
|
|
||||||
async def _raw_with_schema(
|
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||||
query_template: str,
|
"""Execute raw SQL query with proper schema handling."""
|
||||||
*args,
|
|
||||||
execute: bool = False,
|
|
||||||
client: Prisma | None = None,
|
|
||||||
set_public_search_path: bool = False,
|
|
||||||
) -> list[dict] | int:
|
|
||||||
"""Internal: Execute raw SQL with proper schema handling.
|
|
||||||
|
|
||||||
Use query_raw_with_schema() or execute_raw_with_schema() instead.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_template: SQL query with {schema_prefix} placeholder
|
|
||||||
*args: Query parameters
|
|
||||||
execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE.
|
|
||||||
client: Optional Prisma client for transactions (only used when execute=True).
|
|
||||||
set_public_search_path: If True, sets search_path to include public schema.
|
|
||||||
Needed for pgvector types and other public schema objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- list[dict] if execute=False (query results)
|
|
||||||
- int if execute=True (number of affected rows)
|
|
||||||
"""
|
|
||||||
schema = get_database_schema()
|
schema = get_database_schema()
|
||||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||||
|
|
||||||
import prisma as prisma_module
|
import prisma as prisma_module
|
||||||
|
|
||||||
db_client = client if client else prisma_module.get_client()
|
result = await prisma_module.get_client().query_raw(
|
||||||
|
formatted_query, *args # type: ignore
|
||||||
# Set search_path to include public schema if requested
|
)
|
||||||
# Prisma doesn't support the 'options' connection parameter, so we set it per-session
|
|
||||||
# This is idempotent and safe to call multiple times
|
|
||||||
if set_public_search_path:
|
|
||||||
await db_client.execute_raw(f"SET search_path = {schema}, public") # type: ignore
|
|
||||||
|
|
||||||
if execute:
|
|
||||||
result = await db_client.execute_raw(formatted_query, *args) # type: ignore
|
|
||||||
else:
|
|
||||||
result = await db_client.query_raw(formatted_query, *args) # type: ignore
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def query_raw_with_schema(
|
|
||||||
query_template: str, *args, set_public_search_path: bool = False
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Execute raw SQL SELECT query with proper schema handling.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_template: SQL query with {schema_prefix} placeholder
|
|
||||||
*args: Query parameters
|
|
||||||
set_public_search_path: If True, sets search_path to include public schema.
|
|
||||||
Needed for pgvector types and other public schema objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of result rows as dictionaries
|
|
||||||
|
|
||||||
Example:
|
|
||||||
results = await query_raw_with_schema(
|
|
||||||
'SELECT * FROM {schema_prefix}"User" WHERE id = $1',
|
|
||||||
user_id
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
return await _raw_with_schema(query_template, *args, execute=False, set_public_search_path=set_public_search_path) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_raw_with_schema(
|
|
||||||
query_template: str,
|
|
||||||
*args,
|
|
||||||
client: Prisma | None = None,
|
|
||||||
set_public_search_path: bool = False,
|
|
||||||
) -> int:
|
|
||||||
"""Execute raw SQL command (INSERT/UPDATE/DELETE) with proper schema handling.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_template: SQL query with {schema_prefix} placeholder
|
|
||||||
*args: Query parameters
|
|
||||||
client: Optional Prisma client for transactions
|
|
||||||
set_public_search_path: If True, sets search_path to include public schema.
|
|
||||||
Needed for pgvector types and other public schema objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of affected rows
|
|
||||||
|
|
||||||
Example:
|
|
||||||
await execute_raw_with_schema(
|
|
||||||
'INSERT INTO {schema_prefix}"User" (id, name) VALUES ($1, $2)',
|
|
||||||
user_id, name,
|
|
||||||
client=tx # Optional transaction client
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
return await _raw_with_schema(query_template, *args, execute=True, client=client, set_public_search_path=set_public_search_path) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDbModel(BaseModel):
|
class BaseDbModel(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
|
|
||||||
|
|||||||
@@ -383,7 +383,6 @@ class GraphExecutionWithNodes(GraphExecution):
|
|||||||
self,
|
self,
|
||||||
execution_context: ExecutionContext,
|
execution_context: ExecutionContext,
|
||||||
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||||
nodes_to_skip: Optional[set[str]] = None,
|
|
||||||
):
|
):
|
||||||
return GraphExecutionEntry(
|
return GraphExecutionEntry(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
@@ -391,7 +390,6 @@ class GraphExecutionWithNodes(GraphExecution):
|
|||||||
graph_version=self.graph_version or 0,
|
graph_version=self.graph_version or 0,
|
||||||
graph_exec_id=self.id,
|
graph_exec_id=self.id,
|
||||||
nodes_input_masks=compiled_nodes_input_masks,
|
nodes_input_masks=compiled_nodes_input_masks,
|
||||||
nodes_to_skip=nodes_to_skip or set(),
|
|
||||||
execution_context=execution_context,
|
execution_context=execution_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1147,8 +1145,6 @@ class GraphExecutionEntry(BaseModel):
|
|||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
nodes_input_masks: Optional[NodesInputMasks] = None
|
nodes_input_masks: Optional[NodesInputMasks] = None
|
||||||
nodes_to_skip: set[str] = Field(default_factory=set)
|
|
||||||
"""Node IDs that should be skipped due to optional credentials not being configured."""
|
|
||||||
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
|
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -94,15 +94,6 @@ class Node(BaseDbModel):
|
|||||||
input_links: list[Link] = []
|
input_links: list[Link] = []
|
||||||
output_links: list[Link] = []
|
output_links: list[Link] = []
|
||||||
|
|
||||||
@property
|
|
||||||
def credentials_optional(self) -> bool:
|
|
||||||
"""
|
|
||||||
Whether credentials are optional for this node.
|
|
||||||
When True and credentials are not configured, the node will be skipped
|
|
||||||
during execution rather than causing a validation error.
|
|
||||||
"""
|
|
||||||
return self.metadata.get("credentials_optional", False)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
|
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
|
||||||
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
|
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
|
||||||
@@ -244,10 +235,7 @@ class BaseGraph(BaseDbModel):
|
|||||||
return any(
|
return any(
|
||||||
node.block_id
|
node.block_id
|
||||||
for node in self.nodes
|
for node in self.nodes
|
||||||
if (
|
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||||
node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
|
||||||
or node.block.requires_human_review
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -338,35 +326,7 @@ class Graph(BaseGraph):
|
|||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def credentials_input_schema(self) -> dict[str, Any]:
|
def credentials_input_schema(self) -> dict[str, Any]:
|
||||||
schema = self._credentials_input_schema.jsonschema()
|
return self._credentials_input_schema.jsonschema()
|
||||||
|
|
||||||
# Determine which credential fields are required based on credentials_optional metadata
|
|
||||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
|
||||||
required_fields = []
|
|
||||||
|
|
||||||
# Build a map of node_id -> node for quick lookup
|
|
||||||
all_nodes = {node.id: node for node in self.nodes}
|
|
||||||
for sub_graph in self.sub_graphs:
|
|
||||||
for node in sub_graph.nodes:
|
|
||||||
all_nodes[node.id] = node
|
|
||||||
|
|
||||||
for field_key, (
|
|
||||||
_field_info,
|
|
||||||
node_field_pairs,
|
|
||||||
) in graph_credentials_inputs.items():
|
|
||||||
# A field is required if ANY node using it has credentials_optional=False
|
|
||||||
is_required = False
|
|
||||||
for node_id, _field_name in node_field_pairs:
|
|
||||||
node = all_nodes.get(node_id)
|
|
||||||
if node and not node.credentials_optional:
|
|
||||||
is_required = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if is_required:
|
|
||||||
required_fields.append(field_key)
|
|
||||||
|
|
||||||
schema["required"] = required_fields
|
|
||||||
return schema
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import fastapi.exceptions
|
import fastapi.exceptions
|
||||||
@@ -19,17 +18,6 @@ from backend.usecases.sample import create_test_user
|
|||||||
from backend.util.test import SpinTestServer
|
from backend.util.test import SpinTestServer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def mock_embedding_functions():
|
|
||||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.db.ensure_embedding",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=True,
|
|
||||||
):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot):
|
async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot):
|
||||||
"""
|
"""
|
||||||
@@ -408,58 +396,3 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
|||||||
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||||
)
|
)
|
||||||
assert got_graph is not None
|
assert got_graph is not None
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Tests for Optional Credentials Feature
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def test_node_credentials_optional_default():
|
|
||||||
"""Test that credentials_optional defaults to False when not set in metadata."""
|
|
||||||
node = Node(
|
|
||||||
id="test_node",
|
|
||||||
block_id=StoreValueBlock().id,
|
|
||||||
input_default={},
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
assert node.credentials_optional is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_node_credentials_optional_true():
|
|
||||||
"""Test that credentials_optional returns True when explicitly set."""
|
|
||||||
node = Node(
|
|
||||||
id="test_node",
|
|
||||||
block_id=StoreValueBlock().id,
|
|
||||||
input_default={},
|
|
||||||
metadata={"credentials_optional": True},
|
|
||||||
)
|
|
||||||
assert node.credentials_optional is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_node_credentials_optional_false():
|
|
||||||
"""Test that credentials_optional returns False when explicitly set to False."""
|
|
||||||
node = Node(
|
|
||||||
id="test_node",
|
|
||||||
block_id=StoreValueBlock().id,
|
|
||||||
input_default={},
|
|
||||||
metadata={"credentials_optional": False},
|
|
||||||
)
|
|
||||||
assert node.credentials_optional is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_node_credentials_optional_with_other_metadata():
|
|
||||||
"""Test that credentials_optional works correctly with other metadata present."""
|
|
||||||
node = Node(
|
|
||||||
id="test_node",
|
|
||||||
block_id=StoreValueBlock().id,
|
|
||||||
input_default={},
|
|
||||||
metadata={
|
|
||||||
"position": {"x": 100, "y": 200},
|
|
||||||
"customized_name": "My Custom Node",
|
|
||||||
"credentials_optional": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert node.credentials_optional is True
|
|
||||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
|
||||||
assert node.metadata["customized_name"] == "My Custom Node"
|
|
||||||
|
|||||||
@@ -334,7 +334,7 @@ async def _get_user_timezone(user_id: str) -> str:
|
|||||||
return get_user_timezone_or_utc(user.timezone if user else None)
|
return get_user_timezone_or_utc(user.timezone if user else None)
|
||||||
|
|
||||||
|
|
||||||
async def increment_onboarding_runs(user_id: str):
|
async def increment_runs(user_id: str):
|
||||||
"""
|
"""
|
||||||
Increment a user's run counters and trigger any onboarding milestones.
|
Increment a user's run counters and trigger any onboarding milestones.
|
||||||
"""
|
"""
|
||||||
|
|||||||
429
autogpt_platform/backend/backend/data/understanding.py
Normal file
429
autogpt_platform/backend/backend/data/understanding.py
Normal file
@@ -0,0 +1,429 @@
|
|||||||
|
"""Data models and access layer for user business understanding."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
from prisma.models import UserBusinessUnderstanding
|
||||||
|
from prisma.types import (
|
||||||
|
UserBusinessUnderstandingCreateInput,
|
||||||
|
UserBusinessUnderstandingUpdateInput,
|
||||||
|
)
|
||||||
|
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
|
from backend.util.json import SafeJson
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Cache configuration
|
||||||
|
CACHE_KEY_PREFIX = "understanding"
|
||||||
|
CACHE_TTL_SECONDS = 48 * 60 * 60 # 48 hours
|
||||||
|
|
||||||
|
|
||||||
|
def _cache_key(user_id: str) -> str:
|
||||||
|
"""Generate cache key for user business understanding."""
|
||||||
|
return f"{CACHE_KEY_PREFIX}:{user_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def _json_to_list(value: Any) -> list[str]:
|
||||||
|
"""Convert Json field to list[str], handling None."""
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
if isinstance(value, list):
|
||||||
|
return cast(list[str], value)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||||
|
"""Input model for updating business understanding - all fields optional for incremental updates."""
|
||||||
|
|
||||||
|
# User info
|
||||||
|
user_name: Optional[str] = pydantic.Field(None, description="The user's name")
|
||||||
|
job_title: Optional[str] = pydantic.Field(None, description="The user's job title")
|
||||||
|
|
||||||
|
# Business basics
|
||||||
|
business_name: Optional[str] = pydantic.Field(
|
||||||
|
None, description="Name of the user's business"
|
||||||
|
)
|
||||||
|
industry: Optional[str] = pydantic.Field(None, description="Industry or sector")
|
||||||
|
business_size: Optional[str] = pydantic.Field(
|
||||||
|
None, description="Company size (e.g., '1-10', '11-50')"
|
||||||
|
)
|
||||||
|
user_role: Optional[str] = pydantic.Field(
|
||||||
|
None,
|
||||||
|
description="User's role in the organization (e.g., 'decision maker', 'implementer')",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Processes & activities
|
||||||
|
key_workflows: Optional[list[str]] = pydantic.Field(
|
||||||
|
None, description="Key business workflows"
|
||||||
|
)
|
||||||
|
daily_activities: Optional[list[str]] = pydantic.Field(
|
||||||
|
None, description="Daily activities performed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pain points & goals
|
||||||
|
pain_points: Optional[list[str]] = pydantic.Field(
|
||||||
|
None, description="Current pain points"
|
||||||
|
)
|
||||||
|
bottlenecks: Optional[list[str]] = pydantic.Field(
|
||||||
|
None, description="Process bottlenecks"
|
||||||
|
)
|
||||||
|
manual_tasks: Optional[list[str]] = pydantic.Field(
|
||||||
|
None, description="Manual/repetitive tasks"
|
||||||
|
)
|
||||||
|
automation_goals: Optional[list[str]] = pydantic.Field(
|
||||||
|
None, description="Desired automation goals"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Current tools
|
||||||
|
current_software: Optional[list[str]] = pydantic.Field(
|
||||||
|
None, description="Software/tools currently used"
|
||||||
|
)
|
||||||
|
existing_automation: Optional[list[str]] = pydantic.Field(
|
||||||
|
None, description="Existing automations"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Additional context
|
||||||
|
additional_notes: Optional[str] = pydantic.Field(
|
||||||
|
None, description="Any additional context"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BusinessUnderstanding(pydantic.BaseModel):
|
||||||
|
"""Full business understanding model returned from database."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
# User info
|
||||||
|
user_name: Optional[str] = None
|
||||||
|
job_title: Optional[str] = None
|
||||||
|
|
||||||
|
# Business basics
|
||||||
|
business_name: Optional[str] = None
|
||||||
|
industry: Optional[str] = None
|
||||||
|
business_size: Optional[str] = None
|
||||||
|
user_role: Optional[str] = None
|
||||||
|
|
||||||
|
# Processes & activities
|
||||||
|
key_workflows: list[str] = pydantic.Field(default_factory=list)
|
||||||
|
daily_activities: list[str] = pydantic.Field(default_factory=list)
|
||||||
|
|
||||||
|
# Pain points & goals
|
||||||
|
pain_points: list[str] = pydantic.Field(default_factory=list)
|
||||||
|
bottlenecks: list[str] = pydantic.Field(default_factory=list)
|
||||||
|
manual_tasks: list[str] = pydantic.Field(default_factory=list)
|
||||||
|
automation_goals: list[str] = pydantic.Field(default_factory=list)
|
||||||
|
|
||||||
|
# Current tools
|
||||||
|
current_software: list[str] = pydantic.Field(default_factory=list)
|
||||||
|
existing_automation: list[str] = pydantic.Field(default_factory=list)
|
||||||
|
|
||||||
|
# Additional context
|
||||||
|
additional_notes: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db(cls, db_record: UserBusinessUnderstanding) -> "BusinessUnderstanding":
|
||||||
|
"""Convert database record to Pydantic model."""
|
||||||
|
return cls(
|
||||||
|
id=db_record.id,
|
||||||
|
user_id=db_record.userId,
|
||||||
|
created_at=db_record.createdAt,
|
||||||
|
updated_at=db_record.updatedAt,
|
||||||
|
user_name=db_record.userName,
|
||||||
|
job_title=db_record.jobTitle,
|
||||||
|
business_name=db_record.businessName,
|
||||||
|
industry=db_record.industry,
|
||||||
|
business_size=db_record.businessSize,
|
||||||
|
user_role=db_record.userRole,
|
||||||
|
key_workflows=_json_to_list(db_record.keyWorkflows),
|
||||||
|
daily_activities=_json_to_list(db_record.dailyActivities),
|
||||||
|
pain_points=_json_to_list(db_record.painPoints),
|
||||||
|
bottlenecks=_json_to_list(db_record.bottlenecks),
|
||||||
|
manual_tasks=_json_to_list(db_record.manualTasks),
|
||||||
|
automation_goals=_json_to_list(db_record.automationGoals),
|
||||||
|
current_software=_json_to_list(db_record.currentSoftware),
|
||||||
|
existing_automation=_json_to_list(db_record.existingAutomation),
|
||||||
|
additional_notes=db_record.additionalNotes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_lists(existing: list | None, new: list | None) -> list | None:
|
||||||
|
"""Merge two lists, removing duplicates while preserving order."""
|
||||||
|
if new is None:
|
||||||
|
return existing
|
||||||
|
if existing is None:
|
||||||
|
return new
|
||||||
|
# Preserve order, add new items that don't exist
|
||||||
|
merged = list(existing)
|
||||||
|
for item in new:
|
||||||
|
if item not in merged:
|
||||||
|
merged.append(item)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
|
||||||
|
"""Get business understanding from Redis cache."""
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
cached_data = await redis.get(_cache_key(user_id))
|
||||||
|
if cached_data:
|
||||||
|
return BusinessUnderstanding.model_validate_json(cached_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get understanding from cache: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _set_cache(user_id: str, understanding: BusinessUnderstanding) -> None:
|
||||||
|
"""Set business understanding in Redis cache with TTL."""
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
await redis.setex(
|
||||||
|
_cache_key(user_id),
|
||||||
|
CACHE_TTL_SECONDS,
|
||||||
|
understanding.model_dump_json(),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to set understanding in cache: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _delete_cache(user_id: str) -> None:
|
||||||
|
"""Delete business understanding from Redis cache."""
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
await redis.delete(_cache_key(user_id))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete understanding from cache: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_business_understanding(
|
||||||
|
user_id: str,
|
||||||
|
) -> Optional[BusinessUnderstanding]:
|
||||||
|
"""Get the business understanding for a user.
|
||||||
|
|
||||||
|
Checks cache first, falls back to database if not cached.
|
||||||
|
Results are cached for 48 hours.
|
||||||
|
"""
|
||||||
|
# Try cache first
|
||||||
|
cached = await _get_from_cache(user_id)
|
||||||
|
if cached:
|
||||||
|
logger.debug(f"Business understanding cache hit for user {user_id}")
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Cache miss - load from database
|
||||||
|
logger.debug(f"Business understanding cache miss for user {user_id}")
|
||||||
|
record = await UserBusinessUnderstanding.prisma().find_unique(
|
||||||
|
where={"userId": user_id}
|
||||||
|
)
|
||||||
|
if record is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
understanding = BusinessUnderstanding.from_db(record)
|
||||||
|
|
||||||
|
# Store in cache for next time
|
||||||
|
await _set_cache(user_id, understanding)
|
||||||
|
|
||||||
|
return understanding
|
||||||
|
|
||||||
|
|
||||||
|
async def upsert_business_understanding(
|
||||||
|
user_id: str,
|
||||||
|
data: BusinessUnderstandingInput,
|
||||||
|
) -> BusinessUnderstanding:
|
||||||
|
"""
|
||||||
|
Create or update business understanding with incremental merge strategy.
|
||||||
|
|
||||||
|
- String fields: new value overwrites if provided (not None)
|
||||||
|
- List fields: new items are appended to existing (deduplicated)
|
||||||
|
"""
|
||||||
|
# Get existing record for merge
|
||||||
|
existing = await UserBusinessUnderstanding.prisma().find_unique(
|
||||||
|
where={"userId": user_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build update data with merge strategy
|
||||||
|
update_data: UserBusinessUnderstandingUpdateInput = {}
|
||||||
|
create_data: dict[str, Any] = {"userId": user_id}
|
||||||
|
|
||||||
|
# String fields - overwrite if provided
|
||||||
|
if data.user_name is not None:
|
||||||
|
update_data["userName"] = data.user_name
|
||||||
|
create_data["userName"] = data.user_name
|
||||||
|
if data.job_title is not None:
|
||||||
|
update_data["jobTitle"] = data.job_title
|
||||||
|
create_data["jobTitle"] = data.job_title
|
||||||
|
if data.business_name is not None:
|
||||||
|
update_data["businessName"] = data.business_name
|
||||||
|
create_data["businessName"] = data.business_name
|
||||||
|
if data.industry is not None:
|
||||||
|
update_data["industry"] = data.industry
|
||||||
|
create_data["industry"] = data.industry
|
||||||
|
if data.business_size is not None:
|
||||||
|
update_data["businessSize"] = data.business_size
|
||||||
|
create_data["businessSize"] = data.business_size
|
||||||
|
if data.user_role is not None:
|
||||||
|
update_data["userRole"] = data.user_role
|
||||||
|
create_data["userRole"] = data.user_role
|
||||||
|
if data.additional_notes is not None:
|
||||||
|
update_data["additionalNotes"] = data.additional_notes
|
||||||
|
create_data["additionalNotes"] = data.additional_notes
|
||||||
|
|
||||||
|
# List fields - merge with existing
|
||||||
|
if data.key_workflows is not None:
|
||||||
|
existing_list = _json_to_list(existing.keyWorkflows) if existing else None
|
||||||
|
merged = _merge_lists(existing_list, data.key_workflows)
|
||||||
|
update_data["keyWorkflows"] = SafeJson(merged)
|
||||||
|
create_data["keyWorkflows"] = SafeJson(merged)
|
||||||
|
|
||||||
|
if data.daily_activities is not None:
|
||||||
|
existing_list = _json_to_list(existing.dailyActivities) if existing else None
|
||||||
|
merged = _merge_lists(existing_list, data.daily_activities)
|
||||||
|
update_data["dailyActivities"] = SafeJson(merged)
|
||||||
|
create_data["dailyActivities"] = SafeJson(merged)
|
||||||
|
|
||||||
|
if data.pain_points is not None:
|
||||||
|
existing_list = _json_to_list(existing.painPoints) if existing else None
|
||||||
|
merged = _merge_lists(existing_list, data.pain_points)
|
||||||
|
update_data["painPoints"] = SafeJson(merged)
|
||||||
|
create_data["painPoints"] = SafeJson(merged)
|
||||||
|
|
||||||
|
if data.bottlenecks is not None:
|
||||||
|
existing_list = _json_to_list(existing.bottlenecks) if existing else None
|
||||||
|
merged = _merge_lists(existing_list, data.bottlenecks)
|
||||||
|
update_data["bottlenecks"] = SafeJson(merged)
|
||||||
|
create_data["bottlenecks"] = SafeJson(merged)
|
||||||
|
|
||||||
|
if data.manual_tasks is not None:
|
||||||
|
existing_list = _json_to_list(existing.manualTasks) if existing else None
|
||||||
|
merged = _merge_lists(existing_list, data.manual_tasks)
|
||||||
|
update_data["manualTasks"] = SafeJson(merged)
|
||||||
|
create_data["manualTasks"] = SafeJson(merged)
|
||||||
|
|
||||||
|
if data.automation_goals is not None:
|
||||||
|
existing_list = _json_to_list(existing.automationGoals) if existing else None
|
||||||
|
merged = _merge_lists(existing_list, data.automation_goals)
|
||||||
|
update_data["automationGoals"] = SafeJson(merged)
|
||||||
|
create_data["automationGoals"] = SafeJson(merged)
|
||||||
|
|
||||||
|
if data.current_software is not None:
|
||||||
|
existing_list = _json_to_list(existing.currentSoftware) if existing else None
|
||||||
|
merged = _merge_lists(existing_list, data.current_software)
|
||||||
|
update_data["currentSoftware"] = SafeJson(merged)
|
||||||
|
create_data["currentSoftware"] = SafeJson(merged)
|
||||||
|
|
||||||
|
if data.existing_automation is not None:
|
||||||
|
existing_list = _json_to_list(existing.existingAutomation) if existing else None
|
||||||
|
merged = _merge_lists(existing_list, data.existing_automation)
|
||||||
|
update_data["existingAutomation"] = SafeJson(merged)
|
||||||
|
create_data["existingAutomation"] = SafeJson(merged)
|
||||||
|
|
||||||
|
# Upsert
|
||||||
|
record = await UserBusinessUnderstanding.prisma().upsert(
|
||||||
|
where={"userId": user_id},
|
||||||
|
data={
|
||||||
|
"create": UserBusinessUnderstandingCreateInput(**create_data),
|
||||||
|
"update": update_data,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
understanding = BusinessUnderstanding.from_db(record)
|
||||||
|
|
||||||
|
# Update cache with new understanding
|
||||||
|
await _set_cache(user_id, understanding)
|
||||||
|
|
||||||
|
return understanding
|
||||||
|
|
||||||
|
|
||||||
|
async def clear_business_understanding(user_id: str) -> bool:
|
||||||
|
"""Clear/delete business understanding for a user from both DB and cache."""
|
||||||
|
# Delete from cache first
|
||||||
|
await _delete_cache(user_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await UserBusinessUnderstanding.prisma().delete(where={"userId": user_id})
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
# Record might not exist
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def format_understanding_for_prompt(understanding: BusinessUnderstanding) -> str:
|
||||||
|
"""Format business understanding as text for system prompt injection."""
|
||||||
|
sections = []
|
||||||
|
|
||||||
|
# User info section
|
||||||
|
user_info = []
|
||||||
|
if understanding.user_name:
|
||||||
|
user_info.append(f"Name: {understanding.user_name}")
|
||||||
|
if understanding.job_title:
|
||||||
|
user_info.append(f"Job Title: {understanding.job_title}")
|
||||||
|
if user_info:
|
||||||
|
sections.append("## User\n" + "\n".join(user_info))
|
||||||
|
|
||||||
|
# Business section
|
||||||
|
business_info = []
|
||||||
|
if understanding.business_name:
|
||||||
|
business_info.append(f"Company: {understanding.business_name}")
|
||||||
|
if understanding.industry:
|
||||||
|
business_info.append(f"Industry: {understanding.industry}")
|
||||||
|
if understanding.business_size:
|
||||||
|
business_info.append(f"Size: {understanding.business_size}")
|
||||||
|
if understanding.user_role:
|
||||||
|
business_info.append(f"Role Context: {understanding.user_role}")
|
||||||
|
if business_info:
|
||||||
|
sections.append("## Business\n" + "\n".join(business_info))
|
||||||
|
|
||||||
|
# Processes section
|
||||||
|
processes = []
|
||||||
|
if understanding.key_workflows:
|
||||||
|
processes.append(f"Key Workflows: {', '.join(understanding.key_workflows)}")
|
||||||
|
if understanding.daily_activities:
|
||||||
|
processes.append(
|
||||||
|
f"Daily Activities: {', '.join(understanding.daily_activities)}"
|
||||||
|
)
|
||||||
|
if processes:
|
||||||
|
sections.append("## Processes\n" + "\n".join(processes))
|
||||||
|
|
||||||
|
# Pain points section
|
||||||
|
pain_points = []
|
||||||
|
if understanding.pain_points:
|
||||||
|
pain_points.append(f"Pain Points: {', '.join(understanding.pain_points)}")
|
||||||
|
if understanding.bottlenecks:
|
||||||
|
pain_points.append(f"Bottlenecks: {', '.join(understanding.bottlenecks)}")
|
||||||
|
if understanding.manual_tasks:
|
||||||
|
pain_points.append(f"Manual Tasks: {', '.join(understanding.manual_tasks)}")
|
||||||
|
if pain_points:
|
||||||
|
sections.append("## Pain Points\n" + "\n".join(pain_points))
|
||||||
|
|
||||||
|
# Goals section
|
||||||
|
if understanding.automation_goals:
|
||||||
|
sections.append(
|
||||||
|
"## Automation Goals\n"
|
||||||
|
+ "\n".join(f"- {goal}" for goal in understanding.automation_goals)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Current tools section
|
||||||
|
tools_info = []
|
||||||
|
if understanding.current_software:
|
||||||
|
tools_info.append(
|
||||||
|
f"Current Software: {', '.join(understanding.current_software)}"
|
||||||
|
)
|
||||||
|
if understanding.existing_automation:
|
||||||
|
tools_info.append(
|
||||||
|
f"Existing Automation: {', '.join(understanding.existing_automation)}"
|
||||||
|
)
|
||||||
|
if tools_info:
|
||||||
|
sections.append("## Current Tools\n" + "\n".join(tools_info))
|
||||||
|
|
||||||
|
# Additional notes
|
||||||
|
if understanding.additional_notes:
|
||||||
|
sections.append(f"## Additional Context\n{understanding.additional_notes}")
|
||||||
|
|
||||||
|
if not sections:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return "# User Business Context\n\n" + "\n\n".join(sections)
|
||||||
@@ -7,10 +7,6 @@ from backend.api.features.library.db import (
|
|||||||
list_library_agents,
|
list_library_agents,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||||
from backend.api.features.store.embeddings import (
|
|
||||||
backfill_missing_embeddings,
|
|
||||||
get_embedding_stats,
|
|
||||||
)
|
|
||||||
from backend.data import db
|
from backend.data import db
|
||||||
from backend.data.analytics import (
|
from backend.data.analytics import (
|
||||||
get_accuracy_trends_and_alerts,
|
get_accuracy_trends_and_alerts,
|
||||||
@@ -24,7 +20,6 @@ from backend.data.execution import (
|
|||||||
get_execution_kv_data,
|
get_execution_kv_data,
|
||||||
get_execution_outputs_by_node_exec_id,
|
get_execution_outputs_by_node_exec_id,
|
||||||
get_frequently_executed_graphs,
|
get_frequently_executed_graphs,
|
||||||
get_graph_execution,
|
|
||||||
get_graph_execution_meta,
|
get_graph_execution_meta,
|
||||||
get_graph_executions,
|
get_graph_executions,
|
||||||
get_graph_executions_count,
|
get_graph_executions_count,
|
||||||
@@ -62,7 +57,6 @@ from backend.data.notifications import (
|
|||||||
get_user_notification_oldest_message_in_batch,
|
get_user_notification_oldest_message_in_batch,
|
||||||
remove_notifications_from_batch,
|
remove_notifications_from_batch,
|
||||||
)
|
)
|
||||||
from backend.data.onboarding import increment_onboarding_runs
|
|
||||||
from backend.data.user import (
|
from backend.data.user import (
|
||||||
get_active_user_ids_in_timerange,
|
get_active_user_ids_in_timerange,
|
||||||
get_user_by_id,
|
get_user_by_id,
|
||||||
@@ -146,7 +140,6 @@ class DatabaseManager(AppService):
|
|||||||
get_child_graph_executions = _(get_child_graph_executions)
|
get_child_graph_executions = _(get_child_graph_executions)
|
||||||
get_graph_executions = _(get_graph_executions)
|
get_graph_executions = _(get_graph_executions)
|
||||||
get_graph_executions_count = _(get_graph_executions_count)
|
get_graph_executions_count = _(get_graph_executions_count)
|
||||||
get_graph_execution = _(get_graph_execution)
|
|
||||||
get_graph_execution_meta = _(get_graph_execution_meta)
|
get_graph_execution_meta = _(get_graph_execution_meta)
|
||||||
create_graph_execution = _(create_graph_execution)
|
create_graph_execution = _(create_graph_execution)
|
||||||
get_node_execution = _(get_node_execution)
|
get_node_execution = _(get_node_execution)
|
||||||
@@ -211,17 +204,10 @@ class DatabaseManager(AppService):
|
|||||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||||
|
|
||||||
# Onboarding
|
|
||||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
|
||||||
|
|
||||||
# Store
|
# Store
|
||||||
get_store_agents = _(get_store_agents)
|
get_store_agents = _(get_store_agents)
|
||||||
get_store_agent_details = _(get_store_agent_details)
|
get_store_agent_details = _(get_store_agent_details)
|
||||||
|
|
||||||
# Store Embeddings
|
|
||||||
get_embedding_stats = _(get_embedding_stats)
|
|
||||||
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
|
||||||
|
|
||||||
# Summary data - async
|
# Summary data - async
|
||||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||||
|
|
||||||
@@ -273,10 +259,6 @@ class DatabaseManagerClient(AppServiceClient):
|
|||||||
get_store_agents = _(d.get_store_agents)
|
get_store_agents = _(d.get_store_agents)
|
||||||
get_store_agent_details = _(d.get_store_agent_details)
|
get_store_agent_details = _(d.get_store_agent_details)
|
||||||
|
|
||||||
# Store Embeddings
|
|
||||||
get_embedding_stats = _(d.get_embedding_stats)
|
|
||||||
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||||
d = DatabaseManager
|
d = DatabaseManager
|
||||||
@@ -292,7 +274,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
get_graph = d.get_graph
|
get_graph = d.get_graph
|
||||||
get_graph_metadata = d.get_graph_metadata
|
get_graph_metadata = d.get_graph_metadata
|
||||||
get_graph_settings = d.get_graph_settings
|
get_graph_settings = d.get_graph_settings
|
||||||
get_graph_execution = d.get_graph_execution
|
|
||||||
get_graph_execution_meta = d.get_graph_execution_meta
|
get_graph_execution_meta = d.get_graph_execution_meta
|
||||||
get_node = d.get_node
|
get_node = d.get_node
|
||||||
get_node_execution = d.get_node_execution
|
get_node_execution = d.get_node_execution
|
||||||
@@ -337,9 +318,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
add_store_agent_to_library = d.add_store_agent_to_library
|
add_store_agent_to_library = d.add_store_agent_to_library
|
||||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||||
|
|
||||||
# Onboarding
|
|
||||||
increment_onboarding_runs = d.increment_onboarding_runs
|
|
||||||
|
|
||||||
# Store
|
# Store
|
||||||
get_store_agents = d.get_store_agents
|
get_store_agents = d.get_store_agents
|
||||||
get_store_agent_details = d.get_store_agent_details
|
get_store_agent_details = d.get_store_agent_details
|
||||||
|
|||||||
@@ -178,7 +178,6 @@ async def execute_node(
|
|||||||
execution_processor: "ExecutionProcessor",
|
execution_processor: "ExecutionProcessor",
|
||||||
execution_stats: NodeExecutionStats | None = None,
|
execution_stats: NodeExecutionStats | None = None,
|
||||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||||
nodes_to_skip: Optional[set[str]] = None,
|
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""
|
"""
|
||||||
Execute a node in the graph. This will trigger a block execution on a node,
|
Execute a node in the graph. This will trigger a block execution on a node,
|
||||||
@@ -246,7 +245,6 @@ async def execute_node(
|
|||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"execution_context": execution_context,
|
"execution_context": execution_context,
|
||||||
"execution_processor": execution_processor,
|
"execution_processor": execution_processor,
|
||||||
"nodes_to_skip": nodes_to_skip or set(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||||
@@ -544,7 +542,6 @@ class ExecutionProcessor:
|
|||||||
node_exec_progress: NodeExecutionProgress,
|
node_exec_progress: NodeExecutionProgress,
|
||||||
nodes_input_masks: Optional[NodesInputMasks],
|
nodes_input_masks: Optional[NodesInputMasks],
|
||||||
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
||||||
nodes_to_skip: Optional[set[str]] = None,
|
|
||||||
) -> NodeExecutionStats:
|
) -> NodeExecutionStats:
|
||||||
log_metadata = LogMetadata(
|
log_metadata = LogMetadata(
|
||||||
logger=_logger,
|
logger=_logger,
|
||||||
@@ -567,7 +564,6 @@ class ExecutionProcessor:
|
|||||||
db_client=db_client,
|
db_client=db_client,
|
||||||
log_metadata=log_metadata,
|
log_metadata=log_metadata,
|
||||||
nodes_input_masks=nodes_input_masks,
|
nodes_input_masks=nodes_input_masks,
|
||||||
nodes_to_skip=nodes_to_skip,
|
|
||||||
)
|
)
|
||||||
if isinstance(status, BaseException):
|
if isinstance(status, BaseException):
|
||||||
raise status
|
raise status
|
||||||
@@ -613,7 +609,6 @@ class ExecutionProcessor:
|
|||||||
db_client: "DatabaseManagerAsyncClient",
|
db_client: "DatabaseManagerAsyncClient",
|
||||||
log_metadata: LogMetadata,
|
log_metadata: LogMetadata,
|
||||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||||
nodes_to_skip: Optional[set[str]] = None,
|
|
||||||
) -> ExecutionStatus:
|
) -> ExecutionStatus:
|
||||||
status = ExecutionStatus.RUNNING
|
status = ExecutionStatus.RUNNING
|
||||||
|
|
||||||
@@ -650,7 +645,6 @@ class ExecutionProcessor:
|
|||||||
execution_processor=self,
|
execution_processor=self,
|
||||||
execution_stats=stats,
|
execution_stats=stats,
|
||||||
nodes_input_masks=nodes_input_masks,
|
nodes_input_masks=nodes_input_masks,
|
||||||
nodes_to_skip=nodes_to_skip,
|
|
||||||
):
|
):
|
||||||
await persist_output(output_name, output_data)
|
await persist_output(output_name, output_data)
|
||||||
|
|
||||||
@@ -962,21 +956,6 @@ class ExecutionProcessor:
|
|||||||
|
|
||||||
queued_node_exec = execution_queue.get()
|
queued_node_exec = execution_queue.get()
|
||||||
|
|
||||||
# Check if this node should be skipped due to optional credentials
|
|
||||||
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
|
|
||||||
log_metadata.info(
|
|
||||||
f"Skipping node execution {queued_node_exec.node_exec_id} "
|
|
||||||
f"for node {queued_node_exec.node_id} - optional credentials not configured"
|
|
||||||
)
|
|
||||||
# Mark the node as completed without executing
|
|
||||||
# No outputs will be produced, so downstream nodes won't trigger
|
|
||||||
update_node_execution_status(
|
|
||||||
db_client=db_client,
|
|
||||||
exec_id=queued_node_exec.node_exec_id,
|
|
||||||
status=ExecutionStatus.COMPLETED,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
log_metadata.debug(
|
log_metadata.debug(
|
||||||
f"Dispatching node execution {queued_node_exec.node_exec_id} "
|
f"Dispatching node execution {queued_node_exec.node_exec_id} "
|
||||||
f"for node {queued_node_exec.node_id}",
|
f"for node {queued_node_exec.node_id}",
|
||||||
@@ -1037,7 +1016,6 @@ class ExecutionProcessor:
|
|||||||
execution_stats,
|
execution_stats,
|
||||||
execution_stats_lock,
|
execution_stats_lock,
|
||||||
),
|
),
|
||||||
nodes_to_skip=graph_exec.nodes_to_skip,
|
|
||||||
),
|
),
|
||||||
self.node_execution_loop,
|
self.node_execution_loop,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import fastapi.responses
|
import fastapi.responses
|
||||||
import pytest
|
import pytest
|
||||||
@@ -20,17 +19,6 @@ from backend.util.test import SpinTestServer, wait_execution
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def mock_embedding_functions():
|
|
||||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.store.db.ensure_embedding",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=True,
|
|
||||||
):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
|
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
|
||||||
logger.info(f"Creating graph for user {u.id}")
|
logger.info(f"Creating graph for user {u.id}")
|
||||||
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)
|
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -28,6 +27,7 @@ from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
|||||||
from backend.data.block import BlockInput
|
from backend.data.block import BlockInput
|
||||||
from backend.data.execution import GraphExecutionWithNodes
|
from backend.data.execution import GraphExecutionWithNodes
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
|
from backend.data.onboarding import increment_runs
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
from backend.monitoring import (
|
from backend.monitoring import (
|
||||||
NotificationJobArgs,
|
NotificationJobArgs,
|
||||||
@@ -37,7 +37,7 @@ from backend.monitoring import (
|
|||||||
report_execution_accuracy_alerts,
|
report_execution_accuracy_alerts,
|
||||||
report_late_executions,
|
report_late_executions,
|
||||||
)
|
)
|
||||||
from backend.util.clients import get_database_manager_client, get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||||
from backend.util.exceptions import (
|
from backend.util.exceptions import (
|
||||||
GraphNotFoundError,
|
GraphNotFoundError,
|
||||||
@@ -156,6 +156,7 @@ async def _execute_graph(**kwargs):
|
|||||||
inputs=args.input_data,
|
inputs=args.input_data,
|
||||||
graph_credentials_inputs=args.input_credentials,
|
graph_credentials_inputs=args.input_credentials,
|
||||||
)
|
)
|
||||||
|
await increment_runs(args.user_id)
|
||||||
elapsed = asyncio.get_event_loop().time() - start_time
|
elapsed = asyncio.get_event_loop().time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||||
@@ -253,74 +254,6 @@ def execution_accuracy_alerts():
|
|||||||
return report_execution_accuracy_alerts()
|
return report_execution_accuracy_alerts()
|
||||||
|
|
||||||
|
|
||||||
def ensure_embeddings_coverage():
|
|
||||||
"""
|
|
||||||
Ensure approved store agents have embeddings for hybrid search.
|
|
||||||
|
|
||||||
Processes ALL missing embeddings in batches of 10 until 100% coverage.
|
|
||||||
Missing embeddings = agents invisible in hybrid search.
|
|
||||||
|
|
||||||
Schedule: Runs every 6 hours (balanced between coverage and API costs).
|
|
||||||
- Catches agents approved between scheduled runs
|
|
||||||
- Batch size 10: gradual processing to avoid rate limits
|
|
||||||
- Manual trigger available via execute_ensure_embeddings_coverage endpoint
|
|
||||||
"""
|
|
||||||
db_client = get_database_manager_client()
|
|
||||||
stats = db_client.get_embedding_stats()
|
|
||||||
|
|
||||||
# Check for error from get_embedding_stats() first
|
|
||||||
if "error" in stats:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to get embedding stats: {stats['error']} - skipping backfill"
|
|
||||||
)
|
|
||||||
return {"processed": 0, "success": 0, "failed": 0, "error": stats["error"]}
|
|
||||||
|
|
||||||
if stats["without_embeddings"] == 0:
|
|
||||||
logger.info("All approved agents have embeddings, skipping backfill")
|
|
||||||
return {"processed": 0, "success": 0, "failed": 0}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Found {stats['without_embeddings']} agents without embeddings "
|
|
||||||
f"({stats['coverage_percent']}% coverage) - processing all"
|
|
||||||
)
|
|
||||||
|
|
||||||
total_processed = 0
|
|
||||||
total_success = 0
|
|
||||||
total_failed = 0
|
|
||||||
|
|
||||||
# Process in batches until no more missing embeddings
|
|
||||||
while True:
|
|
||||||
result = db_client.backfill_missing_embeddings(batch_size=10)
|
|
||||||
|
|
||||||
total_processed += result["processed"]
|
|
||||||
total_success += result["success"]
|
|
||||||
total_failed += result["failed"]
|
|
||||||
|
|
||||||
if result["processed"] == 0:
|
|
||||||
# No more missing embeddings
|
|
||||||
break
|
|
||||||
|
|
||||||
if result["success"] == 0 and result["processed"] > 0:
|
|
||||||
# All attempts in this batch failed - stop to avoid infinite loop
|
|
||||||
logger.error(
|
|
||||||
f"All {result['processed']} embedding attempts failed - stopping backfill"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
# Small delay between batches to avoid rate limits
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Embedding backfill completed: {total_success}/{total_processed} succeeded, "
|
|
||||||
f"{total_failed} failed"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"processed": total_processed,
|
|
||||||
"success": total_success,
|
|
||||||
"failed": total_failed,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Monitoring functions are now imported from monitoring module
|
# Monitoring functions are now imported from monitoring module
|
||||||
|
|
||||||
|
|
||||||
@@ -542,19 +475,6 @@ class Scheduler(AppService):
|
|||||||
jobstore=Jobstores.EXECUTION.value,
|
jobstore=Jobstores.EXECUTION.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Embedding Coverage - Every 6 hours
|
|
||||||
# Ensures all approved agents have embeddings for hybrid search
|
|
||||||
# Critical: missing embeddings = agents invisible in search
|
|
||||||
self.scheduler.add_job(
|
|
||||||
ensure_embeddings_coverage,
|
|
||||||
id="ensure_embeddings_coverage",
|
|
||||||
trigger="interval",
|
|
||||||
hours=6,
|
|
||||||
replace_existing=True,
|
|
||||||
max_instances=1, # Prevent overlapping runs
|
|
||||||
jobstore=Jobstores.EXECUTION.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||||
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
|
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
|
||||||
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
||||||
@@ -712,11 +632,6 @@ class Scheduler(AppService):
|
|||||||
"""Manually trigger execution accuracy alert checking."""
|
"""Manually trigger execution accuracy alert checking."""
|
||||||
return execution_accuracy_alerts()
|
return execution_accuracy_alerts()
|
||||||
|
|
||||||
@expose
|
|
||||||
def execute_ensure_embeddings_coverage(self):
|
|
||||||
"""Manually trigger embedding backfill for approved store agents."""
|
|
||||||
return ensure_embeddings_coverage()
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerClient(AppServiceClient):
|
class SchedulerClient(AppServiceClient):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from pydantic import BaseModel, JsonValue, ValidationError
|
|||||||
|
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data import onboarding as onboarding_db
|
|
||||||
from backend.data import user as user_db
|
from backend.data import user as user_db
|
||||||
from backend.data.block import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
@@ -32,6 +31,7 @@ from backend.data.execution import (
|
|||||||
GraphExecutionStats,
|
GraphExecutionStats,
|
||||||
GraphExecutionWithNodes,
|
GraphExecutionWithNodes,
|
||||||
NodesInputMasks,
|
NodesInputMasks,
|
||||||
|
get_graph_execution,
|
||||||
)
|
)
|
||||||
from backend.data.graph import GraphModel, Node
|
from backend.data.graph import GraphModel, Node
|
||||||
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput
|
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput
|
||||||
@@ -239,19 +239,14 @@ async def _validate_node_input_credentials(
|
|||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||||
) -> tuple[dict[str, dict[str, str]], set[str]]:
|
) -> dict[str, dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Checks all credentials for all nodes of the graph and returns structured errors
|
Checks all credentials for all nodes of the graph and returns structured errors.
|
||||||
and a set of nodes that should be skipped due to optional missing credentials.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[
|
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
|
||||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
|
|
||||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
|
||||||
]
|
|
||||||
"""
|
"""
|
||||||
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
||||||
nodes_to_skip: set[str] = set()
|
|
||||||
|
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
block = node.block
|
block = node.block
|
||||||
@@ -261,46 +256,27 @@ async def _validate_node_input_credentials(
|
|||||||
if not credentials_fields:
|
if not credentials_fields:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Track if any credential field is missing for this node
|
|
||||||
has_missing_credentials = False
|
|
||||||
|
|
||||||
for field_name, credentials_meta_type in credentials_fields.items():
|
for field_name, credentials_meta_type in credentials_fields.items():
|
||||||
try:
|
try:
|
||||||
# Check nodes_input_masks first, then input_default
|
|
||||||
field_value = None
|
|
||||||
if (
|
if (
|
||||||
nodes_input_masks
|
nodes_input_masks
|
||||||
and (node_input_mask := nodes_input_masks.get(node.id))
|
and (node_input_mask := nodes_input_masks.get(node.id))
|
||||||
and field_name in node_input_mask
|
and field_name in node_input_mask
|
||||||
):
|
):
|
||||||
field_value = node_input_mask[field_name]
|
credentials_meta = credentials_meta_type.model_validate(
|
||||||
|
node_input_mask[field_name]
|
||||||
|
)
|
||||||
elif field_name in node.input_default:
|
elif field_name in node.input_default:
|
||||||
# For optional credentials, don't use input_default - treat as missing
|
credentials_meta = credentials_meta_type.model_validate(
|
||||||
# This prevents stale credential IDs from failing validation
|
node.input_default[field_name]
|
||||||
if node.credentials_optional:
|
)
|
||||||
field_value = None
|
else:
|
||||||
else:
|
# Missing credentials
|
||||||
field_value = node.input_default[field_name]
|
credential_errors[node.id][
|
||||||
|
field_name
|
||||||
# Check if credentials are missing (None, empty, or not present)
|
] = "These credentials are required"
|
||||||
if field_value is None or (
|
continue
|
||||||
isinstance(field_value, dict) and not field_value.get("id")
|
|
||||||
):
|
|
||||||
has_missing_credentials = True
|
|
||||||
# If node has credentials_optional flag, mark for skipping instead of error
|
|
||||||
if node.credentials_optional:
|
|
||||||
continue # Don't add error, will be marked for skip after loop
|
|
||||||
else:
|
|
||||||
credential_errors[node.id][
|
|
||||||
field_name
|
|
||||||
] = "These credentials are required"
|
|
||||||
continue
|
|
||||||
|
|
||||||
credentials_meta = credentials_meta_type.model_validate(field_value)
|
|
||||||
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
# Validation error means credentials were provided but invalid
|
|
||||||
# This should always be an error, even if optional
|
|
||||||
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
|
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -311,7 +287,6 @@ async def _validate_node_input_credentials(
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle any errors fetching credentials
|
# Handle any errors fetching credentials
|
||||||
# If credentials were explicitly configured but unavailable, it's an error
|
|
||||||
credential_errors[node.id][
|
credential_errors[node.id][
|
||||||
field_name
|
field_name
|
||||||
] = f"Credentials not available: {e}"
|
] = f"Credentials not available: {e}"
|
||||||
@@ -338,19 +313,7 @@ async def _validate_node_input_credentials(
|
|||||||
] = "Invalid credentials: type/provider mismatch"
|
] = "Invalid credentials: type/provider mismatch"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If node has optional credentials and any are missing, mark for skipping
|
return credential_errors
|
||||||
# But only if there are no other errors for this node
|
|
||||||
if (
|
|
||||||
has_missing_credentials
|
|
||||||
and node.credentials_optional
|
|
||||||
and node.id not in credential_errors
|
|
||||||
):
|
|
||||||
nodes_to_skip.add(node.id)
|
|
||||||
logger.info(
|
|
||||||
f"Node #{node.id} will be skipped: optional credentials not configured"
|
|
||||||
)
|
|
||||||
|
|
||||||
return credential_errors, nodes_to_skip
|
|
||||||
|
|
||||||
|
|
||||||
def make_node_credentials_input_map(
|
def make_node_credentials_input_map(
|
||||||
@@ -392,25 +355,21 @@ async def validate_graph_with_credentials(
|
|||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||||
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
|
) -> Mapping[str, Mapping[str, str]]:
|
||||||
"""
|
"""
|
||||||
Validate graph including credentials and return structured errors per node,
|
Validate graph including credentials and return structured errors per node.
|
||||||
along with a set of nodes that should be skipped due to optional missing credentials.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[
|
dict[node_id, dict[field_name, error_message]]: Validation errors per node
|
||||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
|
|
||||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
|
||||||
]
|
|
||||||
"""
|
"""
|
||||||
# Get input validation errors
|
# Get input validation errors
|
||||||
node_input_errors = GraphModel.validate_graph_get_errors(
|
node_input_errors = GraphModel.validate_graph_get_errors(
|
||||||
graph, for_run=True, nodes_input_masks=nodes_input_masks
|
graph, for_run=True, nodes_input_masks=nodes_input_masks
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get credential input/availability/validation errors and nodes to skip
|
# Get credential input/availability/validation errors
|
||||||
node_credential_input_errors, nodes_to_skip = (
|
node_credential_input_errors = await _validate_node_input_credentials(
|
||||||
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
|
graph, user_id, nodes_input_masks
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge credential errors with structural errors
|
# Merge credential errors with structural errors
|
||||||
@@ -419,7 +378,7 @@ async def validate_graph_with_credentials(
|
|||||||
node_input_errors[node_id] = {}
|
node_input_errors[node_id] = {}
|
||||||
node_input_errors[node_id].update(field_errors)
|
node_input_errors[node_id].update(field_errors)
|
||||||
|
|
||||||
return node_input_errors, nodes_to_skip
|
return node_input_errors
|
||||||
|
|
||||||
|
|
||||||
async def _construct_starting_node_execution_input(
|
async def _construct_starting_node_execution_input(
|
||||||
@@ -427,7 +386,7 @@ async def _construct_starting_node_execution_input(
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
graph_inputs: BlockInput,
|
graph_inputs: BlockInput,
|
||||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||||
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
|
) -> list[tuple[str, BlockInput]]:
|
||||||
"""
|
"""
|
||||||
Validates and prepares the input data for executing a graph.
|
Validates and prepares the input data for executing a graph.
|
||||||
This function checks the graph for starting nodes, validates the input data
|
This function checks the graph for starting nodes, validates the input data
|
||||||
@@ -441,14 +400,11 @@ async def _construct_starting_node_execution_input(
|
|||||||
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[
|
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
|
the corresponding input data for that node.
|
||||||
and the corresponding input data for that node.
|
|
||||||
set[str]: Node IDs that should be skipped (optional credentials not configured)
|
|
||||||
]
|
|
||||||
"""
|
"""
|
||||||
# Use new validation function that includes credentials
|
# Use new validation function that includes credentials
|
||||||
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
|
validation_errors = await validate_graph_with_credentials(
|
||||||
graph, user_id, nodes_input_masks
|
graph, user_id, nodes_input_masks
|
||||||
)
|
)
|
||||||
n_error_nodes = len(validation_errors)
|
n_error_nodes = len(validation_errors)
|
||||||
@@ -489,7 +445,7 @@ async def _construct_starting_node_execution_input(
|
|||||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||||
)
|
)
|
||||||
|
|
||||||
return nodes_input, nodes_to_skip
|
return nodes_input
|
||||||
|
|
||||||
|
|
||||||
async def validate_and_construct_node_execution_input(
|
async def validate_and_construct_node_execution_input(
|
||||||
@@ -500,7 +456,7 @@ async def validate_and_construct_node_execution_input(
|
|||||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||||
is_sub_graph: bool = False,
|
is_sub_graph: bool = False,
|
||||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
|
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
|
||||||
"""
|
"""
|
||||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||||
This centralizes the logic used by both scheduler validation and actual execution.
|
This centralizes the logic used by both scheduler validation and actual execution.
|
||||||
@@ -517,7 +473,6 @@ async def validate_and_construct_node_execution_input(
|
|||||||
GraphModel: Full graph object for the given `graph_id`.
|
GraphModel: Full graph object for the given `graph_id`.
|
||||||
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
||||||
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
||||||
set[str]: Node IDs that should be skipped (optional credentials not configured).
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError: If the graph is not found.
|
NotFoundError: If the graph is not found.
|
||||||
@@ -559,16 +514,14 @@ async def validate_and_construct_node_execution_input(
|
|||||||
nodes_input_masks or {},
|
nodes_input_masks or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
starting_nodes_input, nodes_to_skip = (
|
starting_nodes_input = await _construct_starting_node_execution_input(
|
||||||
await _construct_starting_node_execution_input(
|
graph=graph,
|
||||||
graph=graph,
|
user_id=user_id,
|
||||||
user_id=user_id,
|
graph_inputs=graph_inputs,
|
||||||
graph_inputs=graph_inputs,
|
nodes_input_masks=nodes_input_masks,
|
||||||
nodes_input_masks=nodes_input_masks,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
|
return graph, starting_nodes_input, nodes_input_masks
|
||||||
|
|
||||||
|
|
||||||
def _merge_nodes_input_masks(
|
def _merge_nodes_input_masks(
|
||||||
@@ -809,14 +762,13 @@ async def add_graph_execution(
|
|||||||
edb = execution_db
|
edb = execution_db
|
||||||
udb = user_db
|
udb = user_db
|
||||||
gdb = graph_db
|
gdb = graph_db
|
||||||
odb = onboarding_db
|
|
||||||
else:
|
else:
|
||||||
edb = udb = gdb = odb = get_database_manager_async_client()
|
edb = udb = gdb = get_database_manager_async_client()
|
||||||
|
|
||||||
# Get or create the graph execution
|
# Get or create the graph execution
|
||||||
if graph_exec_id:
|
if graph_exec_id:
|
||||||
# Resume existing execution
|
# Resume existing execution
|
||||||
graph_exec = await edb.get_graph_execution(
|
graph_exec = await get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=graph_exec_id,
|
execution_id=graph_exec_id,
|
||||||
include_node_executions=True,
|
include_node_executions=True,
|
||||||
@@ -827,9 +779,6 @@ async def add_graph_execution(
|
|||||||
|
|
||||||
# Use existing execution's compiled input masks
|
# Use existing execution's compiled input masks
|
||||||
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
|
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
|
||||||
# For resumed executions, nodes_to_skip was already determined at creation time
|
|
||||||
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
|
|
||||||
nodes_to_skip: set[str] = set()
|
|
||||||
|
|
||||||
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
|
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
|
||||||
else:
|
else:
|
||||||
@@ -838,7 +787,7 @@ async def add_graph_execution(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create new execution
|
# Create new execution
|
||||||
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
|
graph, starting_nodes_input, compiled_nodes_input_masks = (
|
||||||
await validate_and_construct_node_execution_input(
|
await validate_and_construct_node_execution_input(
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -887,12 +836,10 @@ async def add_graph_execution(
|
|||||||
try:
|
try:
|
||||||
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
||||||
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
||||||
nodes_to_skip=nodes_to_skip,
|
|
||||||
execution_context=execution_context,
|
execution_context=execution_context,
|
||||||
)
|
)
|
||||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||||
|
|
||||||
# Publish to execution queue for executor to pick up
|
|
||||||
exec_queue = await get_async_execution_queue()
|
exec_queue = await get_async_execution_queue()
|
||||||
await exec_queue.publish_message(
|
await exec_queue.publish_message(
|
||||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||||
@@ -901,12 +848,14 @@ async def add_graph_execution(
|
|||||||
)
|
)
|
||||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||||
|
|
||||||
# Update execution status to QUEUED
|
|
||||||
graph_exec.status = ExecutionStatus.QUEUED
|
graph_exec.status = ExecutionStatus.QUEUED
|
||||||
await edb.update_graph_execution_stats(
|
await edb.update_graph_execution_stats(
|
||||||
graph_exec_id=graph_exec.id,
|
graph_exec_id=graph_exec.id,
|
||||||
status=graph_exec.status,
|
status=graph_exec.status,
|
||||||
)
|
)
|
||||||
|
await get_async_execution_event_bus().publish(graph_exec)
|
||||||
|
|
||||||
|
return graph_exec
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
err = str(e) or type(e).__name__
|
err = str(e) or type(e).__name__
|
||||||
if not graph_exec:
|
if not graph_exec:
|
||||||
@@ -927,24 +876,6 @@ async def add_graph_execution(
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
try:
|
|
||||||
await get_async_execution_event_bus().publish(graph_exec)
|
|
||||||
logger.info(f"Published update for execution #{graph_exec.id} to event bus")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to publish execution event for graph exec #{graph_exec.id}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await odb.increment_onboarding_runs(user_id)
|
|
||||||
logger.info(
|
|
||||||
f"Incremented user #{user_id} onboarding runs for exec #{graph_exec.id}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to increment onboarding runs for user #{user_id}: {e}")
|
|
||||||
|
|
||||||
return graph_exec
|
|
||||||
|
|
||||||
|
|
||||||
# ============ Execution Output Helpers ============ #
|
# ============ Execution Output Helpers ============ #
|
||||||
|
|
||||||
|
|||||||
@@ -367,13 +367,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup mock returns
|
# Setup mock returns
|
||||||
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
|
|
||||||
nodes_to_skip: set[str] = set()
|
|
||||||
mock_validate.return_value = (
|
mock_validate.return_value = (
|
||||||
mock_graph,
|
mock_graph,
|
||||||
starting_nodes_input,
|
starting_nodes_input,
|
||||||
compiled_nodes_input_masks,
|
compiled_nodes_input_masks,
|
||||||
nodes_to_skip,
|
|
||||||
)
|
)
|
||||||
mock_prisma.is_connected.return_value = True
|
mock_prisma.is_connected.return_value = True
|
||||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||||
@@ -459,212 +456,3 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
|||||||
# Both executions should succeed (though they create different objects)
|
# Both executions should succeed (though they create different objects)
|
||||||
assert result1 == mock_graph_exec
|
assert result1 == mock_graph_exec
|
||||||
assert result2 == mock_graph_exec_2
|
assert result2 == mock_graph_exec_2
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Tests for Optional Credentials Feature
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Test that _validate_node_input_credentials returns nodes_to_skip set
|
|
||||||
for nodes with credentials_optional=True and missing credentials.
|
|
||||||
"""
|
|
||||||
from backend.executor.utils import _validate_node_input_credentials
|
|
||||||
|
|
||||||
# Create a mock node with credentials_optional=True
|
|
||||||
mock_node = mocker.MagicMock()
|
|
||||||
mock_node.id = "node-with-optional-creds"
|
|
||||||
mock_node.credentials_optional = True
|
|
||||||
mock_node.input_default = {} # No credentials configured
|
|
||||||
|
|
||||||
# Create a mock block with credentials field
|
|
||||||
mock_block = mocker.MagicMock()
|
|
||||||
mock_credentials_field_type = mocker.MagicMock()
|
|
||||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
|
||||||
"credentials": mock_credentials_field_type
|
|
||||||
}
|
|
||||||
mock_node.block = mock_block
|
|
||||||
|
|
||||||
# Create mock graph
|
|
||||||
mock_graph = mocker.MagicMock()
|
|
||||||
mock_graph.nodes = [mock_node]
|
|
||||||
|
|
||||||
# Call the function
|
|
||||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
|
||||||
graph=mock_graph,
|
|
||||||
user_id="test-user-id",
|
|
||||||
nodes_input_masks=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Node should be in nodes_to_skip, not in errors
|
|
||||||
assert mock_node.id in nodes_to_skip
|
|
||||||
assert mock_node.id not in errors
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_validate_node_input_credentials_required_missing_creds_error(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Test that _validate_node_input_credentials returns errors
|
|
||||||
for nodes with credentials_optional=False and missing credentials.
|
|
||||||
"""
|
|
||||||
from backend.executor.utils import _validate_node_input_credentials
|
|
||||||
|
|
||||||
# Create a mock node with credentials_optional=False (required)
|
|
||||||
mock_node = mocker.MagicMock()
|
|
||||||
mock_node.id = "node-with-required-creds"
|
|
||||||
mock_node.credentials_optional = False
|
|
||||||
mock_node.input_default = {} # No credentials configured
|
|
||||||
|
|
||||||
# Create a mock block with credentials field
|
|
||||||
mock_block = mocker.MagicMock()
|
|
||||||
mock_credentials_field_type = mocker.MagicMock()
|
|
||||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
|
||||||
"credentials": mock_credentials_field_type
|
|
||||||
}
|
|
||||||
mock_node.block = mock_block
|
|
||||||
|
|
||||||
# Create mock graph
|
|
||||||
mock_graph = mocker.MagicMock()
|
|
||||||
mock_graph.nodes = [mock_node]
|
|
||||||
|
|
||||||
# Call the function
|
|
||||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
|
||||||
graph=mock_graph,
|
|
||||||
user_id="test-user-id",
|
|
||||||
nodes_input_masks=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Node should be in errors, not in nodes_to_skip
|
|
||||||
assert mock_node.id in errors
|
|
||||||
assert "credentials" in errors[mock_node.id]
|
|
||||||
assert "required" in errors[mock_node.id]["credentials"].lower()
|
|
||||||
assert mock_node.id not in nodes_to_skip
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_validate_graph_with_credentials_returns_nodes_to_skip(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Test that validate_graph_with_credentials returns nodes_to_skip set
|
|
||||||
from _validate_node_input_credentials.
|
|
||||||
"""
|
|
||||||
from backend.executor.utils import validate_graph_with_credentials
|
|
||||||
|
|
||||||
# Mock _validate_node_input_credentials to return specific values
|
|
||||||
mock_validate = mocker.patch(
|
|
||||||
"backend.executor.utils._validate_node_input_credentials"
|
|
||||||
)
|
|
||||||
expected_errors = {"node1": {"field": "error"}}
|
|
||||||
expected_nodes_to_skip = {"node2", "node3"}
|
|
||||||
mock_validate.return_value = (expected_errors, expected_nodes_to_skip)
|
|
||||||
|
|
||||||
# Mock GraphModel with validate_graph_get_errors method
|
|
||||||
mock_graph = mocker.MagicMock()
|
|
||||||
mock_graph.validate_graph_get_errors.return_value = {}
|
|
||||||
|
|
||||||
# Call the function
|
|
||||||
errors, nodes_to_skip = await validate_graph_with_credentials(
|
|
||||||
graph=mock_graph,
|
|
||||||
user_id="test-user-id",
|
|
||||||
nodes_input_masks=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify nodes_to_skip is passed through
|
|
||||||
assert nodes_to_skip == expected_nodes_to_skip
|
|
||||||
assert "node1" in errors
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
|
||||||
"""
|
|
||||||
Test that add_graph_execution properly passes nodes_to_skip
|
|
||||||
to the graph execution entry.
|
|
||||||
"""
|
|
||||||
from backend.data.execution import GraphExecutionWithNodes
|
|
||||||
from backend.executor.utils import add_graph_execution
|
|
||||||
|
|
||||||
# Mock data
|
|
||||||
graph_id = "test-graph-id"
|
|
||||||
user_id = "test-user-id"
|
|
||||||
inputs = {"test_input": "test_value"}
|
|
||||||
graph_version = 1
|
|
||||||
|
|
||||||
# Mock the graph object
|
|
||||||
mock_graph = mocker.MagicMock()
|
|
||||||
mock_graph.version = graph_version
|
|
||||||
|
|
||||||
# Starting nodes and masks
|
|
||||||
starting_nodes_input = [("node1", {"input1": "value1"})]
|
|
||||||
compiled_nodes_input_masks = {}
|
|
||||||
nodes_to_skip = {"skipped-node-1", "skipped-node-2"}
|
|
||||||
|
|
||||||
# Mock the graph execution object
|
|
||||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
|
||||||
mock_graph_exec.id = "execution-id-123"
|
|
||||||
mock_graph_exec.node_executions = []
|
|
||||||
|
|
||||||
# Track what's passed to to_graph_execution_entry
|
|
||||||
captured_kwargs = {}
|
|
||||||
|
|
||||||
def capture_to_entry(**kwargs):
|
|
||||||
captured_kwargs.update(kwargs)
|
|
||||||
return mocker.MagicMock()
|
|
||||||
|
|
||||||
mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry
|
|
||||||
|
|
||||||
# Setup mocks
|
|
||||||
mock_validate = mocker.patch(
|
|
||||||
"backend.executor.utils.validate_and_construct_node_execution_input"
|
|
||||||
)
|
|
||||||
mock_edb = mocker.patch("backend.executor.utils.execution_db")
|
|
||||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
|
||||||
mock_udb = mocker.patch("backend.executor.utils.user_db")
|
|
||||||
mock_gdb = mocker.patch("backend.executor.utils.graph_db")
|
|
||||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
|
||||||
mock_get_event_bus = mocker.patch(
|
|
||||||
"backend.executor.utils.get_async_execution_event_bus"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup returns - include nodes_to_skip in the tuple
|
|
||||||
mock_validate.return_value = (
|
|
||||||
mock_graph,
|
|
||||||
starting_nodes_input,
|
|
||||||
compiled_nodes_input_masks,
|
|
||||||
nodes_to_skip, # This should be passed through
|
|
||||||
)
|
|
||||||
mock_prisma.is_connected.return_value = True
|
|
||||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
|
||||||
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
|
|
||||||
return_value=mock_graph_exec
|
|
||||||
)
|
|
||||||
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
|
|
||||||
|
|
||||||
mock_user = mocker.MagicMock()
|
|
||||||
mock_user.timezone = "UTC"
|
|
||||||
mock_settings = mocker.MagicMock()
|
|
||||||
mock_settings.human_in_the_loop_safe_mode = True
|
|
||||||
|
|
||||||
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
|
|
||||||
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
|
|
||||||
mock_get_queue.return_value = mocker.AsyncMock()
|
|
||||||
mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock())
|
|
||||||
|
|
||||||
# Call the function
|
|
||||||
await add_graph_execution(
|
|
||||||
graph_id=graph_id,
|
|
||||||
user_id=user_id,
|
|
||||||
inputs=inputs,
|
|
||||||
graph_version=graph_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
|
||||||
assert "nodes_to_skip" in captured_kwargs
|
|
||||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from .discord import DiscordOAuthHandler
|
|||||||
from .github import GitHubOAuthHandler
|
from .github import GitHubOAuthHandler
|
||||||
from .google import GoogleOAuthHandler
|
from .google import GoogleOAuthHandler
|
||||||
from .notion import NotionOAuthHandler
|
from .notion import NotionOAuthHandler
|
||||||
from .reddit import RedditOAuthHandler
|
|
||||||
from .twitter import TwitterOAuthHandler
|
from .twitter import TwitterOAuthHandler
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -21,7 +20,6 @@ _ORIGINAL_HANDLERS = [
|
|||||||
GitHubOAuthHandler,
|
GitHubOAuthHandler,
|
||||||
GoogleOAuthHandler,
|
GoogleOAuthHandler,
|
||||||
NotionOAuthHandler,
|
NotionOAuthHandler,
|
||||||
RedditOAuthHandler,
|
|
||||||
TwitterOAuthHandler,
|
TwitterOAuthHandler,
|
||||||
TodoistOAuthHandler,
|
TodoistOAuthHandler,
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,208 +0,0 @@
|
|||||||
import time
|
|
||||||
import urllib.parse
|
|
||||||
from typing import ClassVar, Optional
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.request import Requests
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
|
|
||||||
|
|
||||||
class RedditOAuthHandler(BaseOAuthHandler):
|
|
||||||
"""
|
|
||||||
Reddit OAuth 2.0 handler.
|
|
||||||
|
|
||||||
Based on the documentation at:
|
|
||||||
- https://github.com/reddit-archive/reddit/wiki/OAuth2
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Reddit requires `duration=permanent` to get refresh tokens
|
|
||||||
- Access tokens expire after 1 hour (3600 seconds)
|
|
||||||
- Reddit requires HTTP Basic Auth for token requests
|
|
||||||
- Reddit requires a unique User-Agent header
|
|
||||||
"""
|
|
||||||
|
|
||||||
PROVIDER_NAME = ProviderName.REDDIT
|
|
||||||
DEFAULT_SCOPES: ClassVar[list[str]] = [
|
|
||||||
"identity", # Get username, verify auth
|
|
||||||
"read", # Access posts and comments
|
|
||||||
"submit", # Submit new posts and comments
|
|
||||||
"edit", # Edit own posts and comments
|
|
||||||
"history", # Access user's post history
|
|
||||||
"privatemessages", # Access inbox and send private messages
|
|
||||||
"flair", # Access and set flair on posts/subreddits
|
|
||||||
]
|
|
||||||
|
|
||||||
AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize"
|
|
||||||
TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
|
|
||||||
USERNAME_URL = "https://oauth.reddit.com/api/v1/me"
|
|
||||||
REVOKE_URL = "https://www.reddit.com/api/v1/revoke_token"
|
|
||||||
|
|
||||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
|
||||||
self.client_id = client_id
|
|
||||||
self.client_secret = client_secret
|
|
||||||
self.redirect_uri = redirect_uri
|
|
||||||
|
|
||||||
def get_login_url(
|
|
||||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
|
||||||
) -> str:
|
|
||||||
"""Generate Reddit OAuth 2.0 authorization URL"""
|
|
||||||
scopes = self.handle_default_scopes(scopes)
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"response_type": "code",
|
|
||||||
"client_id": self.client_id,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"scope": " ".join(scopes),
|
|
||||||
"state": state,
|
|
||||||
"duration": "permanent", # Required for refresh tokens
|
|
||||||
}
|
|
||||||
|
|
||||||
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
|
|
||||||
|
|
||||||
async def exchange_code_for_tokens(
|
|
||||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
"""Exchange authorization code for access tokens"""
|
|
||||||
scopes = self.handle_default_scopes(scopes)
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/x-www-form-urlencoded",
|
|
||||||
"User-Agent": settings.config.reddit_user_agent,
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"code": code,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Reddit requires HTTP Basic Auth for token requests
|
|
||||||
auth = (self.client_id, self.client_secret)
|
|
||||||
|
|
||||||
response = await Requests().post(
|
|
||||||
self.TOKEN_URL, headers=headers, data=data, auth=auth
|
|
||||||
)
|
|
||||||
|
|
||||||
if not response.ok:
|
|
||||||
error_text = response.text()
|
|
||||||
raise ValueError(
|
|
||||||
f"Reddit token exchange failed: {response.status} - {error_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
tokens = response.json()
|
|
||||||
|
|
||||||
if "error" in tokens:
|
|
||||||
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
|
|
||||||
|
|
||||||
username = await self._get_username(tokens["access_token"])
|
|
||||||
|
|
||||||
return OAuth2Credentials(
|
|
||||||
provider=self.PROVIDER_NAME,
|
|
||||||
title=None,
|
|
||||||
username=username,
|
|
||||||
access_token=tokens["access_token"],
|
|
||||||
refresh_token=tokens.get("refresh_token"),
|
|
||||||
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
|
|
||||||
refresh_token_expires_at=None, # Reddit refresh tokens don't expire
|
|
||||||
scopes=scopes,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_username(self, access_token: str) -> str:
|
|
||||||
"""Get the username from the access token"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {access_token}",
|
|
||||||
"User-Agent": settings.config.reddit_user_agent,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await Requests().get(self.USERNAME_URL, headers=headers)
|
|
||||||
|
|
||||||
if not response.ok:
|
|
||||||
raise ValueError(f"Failed to get Reddit username: {response.status}")
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
return data.get("name", "unknown")
|
|
||||||
|
|
||||||
async def _refresh_tokens(
|
|
||||||
self, credentials: OAuth2Credentials
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
"""Refresh access tokens using refresh token"""
|
|
||||||
if not credentials.refresh_token:
|
|
||||||
raise ValueError("No refresh token available")
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/x-www-form-urlencoded",
|
|
||||||
"User-Agent": settings.config.reddit_user_agent,
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"grant_type": "refresh_token",
|
|
||||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
|
||||||
}
|
|
||||||
|
|
||||||
auth = (self.client_id, self.client_secret)
|
|
||||||
|
|
||||||
response = await Requests().post(
|
|
||||||
self.TOKEN_URL, headers=headers, data=data, auth=auth
|
|
||||||
)
|
|
||||||
|
|
||||||
if not response.ok:
|
|
||||||
error_text = response.text()
|
|
||||||
raise ValueError(
|
|
||||||
f"Reddit token refresh failed: {response.status} - {error_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
tokens = response.json()
|
|
||||||
|
|
||||||
if "error" in tokens:
|
|
||||||
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
|
|
||||||
|
|
||||||
username = await self._get_username(tokens["access_token"])
|
|
||||||
|
|
||||||
# Reddit may or may not return a new refresh token
|
|
||||||
new_refresh_token = tokens.get("refresh_token")
|
|
||||||
if new_refresh_token:
|
|
||||||
refresh_token: SecretStr | None = SecretStr(new_refresh_token)
|
|
||||||
elif credentials.refresh_token:
|
|
||||||
# Keep the existing refresh token
|
|
||||||
refresh_token = credentials.refresh_token
|
|
||||||
else:
|
|
||||||
refresh_token = None
|
|
||||||
|
|
||||||
return OAuth2Credentials(
|
|
||||||
id=credentials.id,
|
|
||||||
provider=self.PROVIDER_NAME,
|
|
||||||
title=credentials.title,
|
|
||||||
username=username,
|
|
||||||
access_token=tokens["access_token"],
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=credentials.scopes,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
|
||||||
"""Revoke the access token"""
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/x-www-form-urlencoded",
|
|
||||||
"User-Agent": settings.config.reddit_user_agent,
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"token": credentials.access_token.get_secret_value(),
|
|
||||||
"token_type_hint": "access_token",
|
|
||||||
}
|
|
||||||
|
|
||||||
auth = (self.client_id, self.client_secret)
|
|
||||||
|
|
||||||
response = await Requests().post(
|
|
||||||
self.REVOKE_URL, headers=headers, data=data, auth=auth
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reddit returns 204 No Content on successful revocation
|
|
||||||
return response.ok
|
|
||||||
@@ -10,7 +10,6 @@ from backend.util.settings import Settings
|
|||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from openai import AsyncOpenAI
|
|
||||||
from supabase import AClient, Client
|
from supabase import AClient, Client
|
||||||
|
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
@@ -140,24 +139,6 @@ async def get_async_supabase() -> "AClient":
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ============ OpenAI Client ============ #
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600)
|
|
||||||
def get_openai_client() -> "AsyncOpenAI | None":
|
|
||||||
"""
|
|
||||||
Get a process-cached async OpenAI client for embeddings.
|
|
||||||
|
|
||||||
Returns None if API key is not configured.
|
|
||||||
"""
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
api_key = settings.secrets.openai_internal_api_key
|
|
||||||
if not api_key:
|
|
||||||
return None
|
|
||||||
return AsyncOpenAI(api_key=api_key)
|
|
||||||
|
|
||||||
|
|
||||||
# ============ Notification Queue Helpers ============ #
|
# ============ Notification Queue Helpers ============ #
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -264,7 +264,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
reddit_user_agent: str = Field(
|
reddit_user_agent: str = Field(
|
||||||
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
default="AutoGPT:1.0 (by /u/autogpt)",
|
||||||
description="The user agent for the Reddit API",
|
description="The user agent for the Reddit API",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,227 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Generate a lightweight stub for prisma/types.py that collapses all exported
|
|
||||||
symbols to Any. This prevents Pyright from spending time/budget on Prisma's
|
|
||||||
query DSL types while keeping runtime behavior unchanged.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
poetry run gen-prisma-stub
|
|
||||||
|
|
||||||
This script automatically finds the prisma package location and generates
|
|
||||||
the types.pyi stub file in the same directory as types.py.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import ast
|
|
||||||
import importlib.util
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Iterable, Set
|
|
||||||
|
|
||||||
|
|
||||||
def _iter_assigned_names(target: ast.expr) -> Iterable[str]:
|
|
||||||
"""Extract names from assignment targets (handles tuple unpacking)."""
|
|
||||||
if isinstance(target, ast.Name):
|
|
||||||
yield target.id
|
|
||||||
elif isinstance(target, (ast.Tuple, ast.List)):
|
|
||||||
for elt in target.elts:
|
|
||||||
yield from _iter_assigned_names(elt)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_private(name: str) -> bool:
|
|
||||||
"""Check if a name is private (starts with _ but not __)."""
|
|
||||||
return name.startswith("_") and not name.startswith("__")
|
|
||||||
|
|
||||||
|
|
||||||
def _is_safe_type_alias(node: ast.Assign) -> bool:
|
|
||||||
"""Check if an assignment is a safe type alias that shouldn't be stubbed.
|
|
||||||
|
|
||||||
Safe types are:
|
|
||||||
- Literal types (don't cause type budget issues)
|
|
||||||
- Simple type references (SortMode, SortOrder, etc.)
|
|
||||||
- TypeVar definitions
|
|
||||||
"""
|
|
||||||
if not node.value:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if it's a Subscript (like Literal[...], Union[...], TypeVar[...])
|
|
||||||
if isinstance(node.value, ast.Subscript):
|
|
||||||
# Get the base type name
|
|
||||||
if isinstance(node.value.value, ast.Name):
|
|
||||||
base_name = node.value.value.id
|
|
||||||
# Literal types are safe
|
|
||||||
if base_name == "Literal":
|
|
||||||
return True
|
|
||||||
# TypeVar is safe
|
|
||||||
if base_name == "TypeVar":
|
|
||||||
return True
|
|
||||||
elif isinstance(node.value.value, ast.Attribute):
|
|
||||||
# Handle typing_extensions.Literal etc.
|
|
||||||
if node.value.value.attr == "Literal":
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check if it's a simple Name reference (like SortMode = _types.SortMode)
|
|
||||||
if isinstance(node.value, ast.Attribute):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check if it's a Call (like TypeVar(...))
|
|
||||||
if isinstance(node.value, ast.Call):
|
|
||||||
if isinstance(node.value.func, ast.Name):
|
|
||||||
if node.value.func.id == "TypeVar":
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def collect_top_level_symbols(
|
|
||||||
tree: ast.Module, source_lines: list[str]
|
|
||||||
) -> tuple[Set[str], Set[str], list[str], Set[str]]:
|
|
||||||
"""Collect all top-level symbols from an AST module.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (class_names, function_names, safe_variable_sources, unsafe_variable_names)
|
|
||||||
safe_variable_sources contains the actual source code lines for safe variables
|
|
||||||
"""
|
|
||||||
classes: Set[str] = set()
|
|
||||||
functions: Set[str] = set()
|
|
||||||
safe_variable_sources: list[str] = []
|
|
||||||
unsafe_variables: Set[str] = set()
|
|
||||||
|
|
||||||
for node in tree.body:
|
|
||||||
if isinstance(node, ast.ClassDef):
|
|
||||||
if not _is_private(node.name):
|
|
||||||
classes.add(node.name)
|
|
||||||
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
||||||
if not _is_private(node.name):
|
|
||||||
functions.add(node.name)
|
|
||||||
elif isinstance(node, ast.Assign):
|
|
||||||
is_safe = _is_safe_type_alias(node)
|
|
||||||
names = []
|
|
||||||
for t in node.targets:
|
|
||||||
for n in _iter_assigned_names(t):
|
|
||||||
if not _is_private(n):
|
|
||||||
names.append(n)
|
|
||||||
if names:
|
|
||||||
if is_safe:
|
|
||||||
# Extract the source code for this assignment
|
|
||||||
start_line = node.lineno - 1 # 0-indexed
|
|
||||||
end_line = node.end_lineno if node.end_lineno else node.lineno
|
|
||||||
source = "\n".join(source_lines[start_line:end_line])
|
|
||||||
safe_variable_sources.append(source)
|
|
||||||
else:
|
|
||||||
unsafe_variables.update(names)
|
|
||||||
elif isinstance(node, ast.AnnAssign) and node.target:
|
|
||||||
# Annotated assignments are always stubbed
|
|
||||||
for n in _iter_assigned_names(node.target):
|
|
||||||
if not _is_private(n):
|
|
||||||
unsafe_variables.add(n)
|
|
||||||
|
|
||||||
return classes, functions, safe_variable_sources, unsafe_variables
|
|
||||||
|
|
||||||
|
|
||||||
def find_prisma_types_path() -> Path:
|
|
||||||
"""Find the prisma types.py file in the installed package."""
|
|
||||||
spec = importlib.util.find_spec("prisma")
|
|
||||||
if spec is None or spec.origin is None:
|
|
||||||
raise RuntimeError("Could not find prisma package. Is it installed?")
|
|
||||||
|
|
||||||
prisma_dir = Path(spec.origin).parent
|
|
||||||
types_path = prisma_dir / "types.py"
|
|
||||||
|
|
||||||
if not types_path.exists():
|
|
||||||
raise RuntimeError(f"prisma/types.py not found at {types_path}")
|
|
||||||
|
|
||||||
return types_path
|
|
||||||
|
|
||||||
|
|
||||||
def generate_stub(src_path: Path, stub_path: Path) -> int:
|
|
||||||
"""Generate the .pyi stub file from the source types.py."""
|
|
||||||
code = src_path.read_text(encoding="utf-8", errors="ignore")
|
|
||||||
source_lines = code.splitlines()
|
|
||||||
tree = ast.parse(code, filename=str(src_path))
|
|
||||||
classes, functions, safe_variable_sources, unsafe_variables = (
|
|
||||||
collect_top_level_symbols(tree, source_lines)
|
|
||||||
)
|
|
||||||
|
|
||||||
header = """\
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Auto-generated stub file - DO NOT EDIT
|
|
||||||
# Generated by gen_prisma_types_stub.py
|
|
||||||
#
|
|
||||||
# This stub intentionally collapses complex Prisma query DSL types to Any.
|
|
||||||
# Prisma's generated types can explode Pyright's type inference budgets
|
|
||||||
# on large schemas. We collapse them to Any so the rest of the codebase
|
|
||||||
# can remain strongly typed while keeping runtime behavior unchanged.
|
|
||||||
#
|
|
||||||
# Safe types (Literal, TypeVar, simple references) are preserved from the
|
|
||||||
# original types.py to maintain proper type checking where possible.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
from typing import Any
|
|
||||||
from typing_extensions import Literal
|
|
||||||
|
|
||||||
# Re-export commonly used typing constructs that may be imported from this module
|
|
||||||
from typing import TYPE_CHECKING, TypeVar, Generic, Union, Optional, List, Dict
|
|
||||||
|
|
||||||
# Base type alias for stubbed Prisma types - allows any dict structure
|
|
||||||
_PrismaDict = dict[str, Any]
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
lines = [header]
|
|
||||||
|
|
||||||
# Include safe variable definitions (Literal types, TypeVars, etc.)
|
|
||||||
lines.append("# Safe type definitions preserved from original types.py")
|
|
||||||
for source in safe_variable_sources:
|
|
||||||
lines.append(source)
|
|
||||||
lines.append("")
|
|
||||||
|
|
||||||
# Stub all classes and unsafe variables uniformly as dict[str, Any] aliases
|
|
||||||
# This allows:
|
|
||||||
# 1. Use in type annotations: x: SomeType
|
|
||||||
# 2. Constructor calls: SomeType(...)
|
|
||||||
# 3. Dict literal assignments: x: SomeType = {...}
|
|
||||||
lines.append(
|
|
||||||
"# Stubbed types (collapsed to dict[str, Any] to prevent type budget exhaustion)"
|
|
||||||
)
|
|
||||||
all_stubbed = sorted(classes | unsafe_variables)
|
|
||||||
for name in all_stubbed:
|
|
||||||
lines.append(f"{name} = _PrismaDict")
|
|
||||||
|
|
||||||
lines.append("")
|
|
||||||
|
|
||||||
# Stub functions
|
|
||||||
for name in sorted(functions):
|
|
||||||
lines.append(f"def {name}(*args: Any, **kwargs: Any) -> Any: ...")
|
|
||||||
|
|
||||||
lines.append("")
|
|
||||||
|
|
||||||
stub_path.write_text("\n".join(lines), encoding="utf-8")
|
|
||||||
return (
|
|
||||||
len(classes)
|
|
||||||
+ len(functions)
|
|
||||||
+ len(safe_variable_sources)
|
|
||||||
+ len(unsafe_variables)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
"""Main entry point."""
|
|
||||||
try:
|
|
||||||
types_path = find_prisma_types_path()
|
|
||||||
stub_path = types_path.with_suffix(".pyi")
|
|
||||||
|
|
||||||
print(f"Found prisma types.py at: {types_path}")
|
|
||||||
print(f"Generating stub at: {stub_path}")
|
|
||||||
|
|
||||||
num_symbols = generate_stub(types_path, stub_path)
|
|
||||||
print(f"Generated {stub_path.name} with {num_symbols} Any-typed symbols")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -25,9 +25,6 @@ def run(*command: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def lint():
|
def lint():
|
||||||
# Generate Prisma types stub before running pyright to prevent type budget exhaustion
|
|
||||||
run("gen-prisma-stub")
|
|
||||||
|
|
||||||
lint_step_args: list[list[str]] = [
|
lint_step_args: list[list[str]] = [
|
||||||
["ruff", "check", *TARGET_DIRS, "--exit-zero"],
|
["ruff", "check", *TARGET_DIRS, "--exit-zero"],
|
||||||
["ruff", "format", "--diff", "--check", LIBS_DIR],
|
["ruff", "format", "--diff", "--check", LIBS_DIR],
|
||||||
@@ -52,6 +49,4 @@ def format():
|
|||||||
run("ruff", "format", LIBS_DIR)
|
run("ruff", "format", LIBS_DIR)
|
||||||
run("isort", "--profile", "black", BACKEND_DIR)
|
run("isort", "--profile", "black", BACKEND_DIR)
|
||||||
run("black", BACKEND_DIR)
|
run("black", BACKEND_DIR)
|
||||||
# Generate Prisma types stub before running pyright to prevent type budget exhaustion
|
|
||||||
run("gen-prisma-stub")
|
|
||||||
run("pyright", *TARGET_DIRS)
|
run("pyright", *TARGET_DIRS)
|
||||||
|
|||||||
@@ -0,0 +1,81 @@
|
|||||||
|
-- DropIndex
|
||||||
|
DROP INDEX "StoreListingVersion_storeListingId_version_key";
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "UserBusinessUnderstanding" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"userId" TEXT NOT NULL,
|
||||||
|
"userName" TEXT,
|
||||||
|
"jobTitle" TEXT,
|
||||||
|
"businessName" TEXT,
|
||||||
|
"industry" TEXT,
|
||||||
|
"businessSize" TEXT,
|
||||||
|
"userRole" TEXT,
|
||||||
|
"keyWorkflows" JSONB,
|
||||||
|
"dailyActivities" JSONB,
|
||||||
|
"painPoints" JSONB,
|
||||||
|
"bottlenecks" JSONB,
|
||||||
|
"manualTasks" JSONB,
|
||||||
|
"automationGoals" JSONB,
|
||||||
|
"currentSoftware" JSONB,
|
||||||
|
"existingAutomation" JSONB,
|
||||||
|
"additionalNotes" TEXT,
|
||||||
|
|
||||||
|
CONSTRAINT "UserBusinessUnderstanding_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "ChatSession" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"userId" TEXT,
|
||||||
|
"title" TEXT,
|
||||||
|
"credentials" JSONB NOT NULL DEFAULT '{}',
|
||||||
|
"successfulAgentRuns" JSONB NOT NULL DEFAULT '{}',
|
||||||
|
"successfulAgentSchedules" JSONB NOT NULL DEFAULT '{}',
|
||||||
|
"totalPromptTokens" INTEGER NOT NULL DEFAULT 0,
|
||||||
|
"totalCompletionTokens" INTEGER NOT NULL DEFAULT 0,
|
||||||
|
|
||||||
|
CONSTRAINT "ChatSession_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "ChatMessage" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"sessionId" TEXT NOT NULL,
|
||||||
|
"role" TEXT NOT NULL,
|
||||||
|
"content" TEXT,
|
||||||
|
"name" TEXT,
|
||||||
|
"toolCallId" TEXT,
|
||||||
|
"refusal" TEXT,
|
||||||
|
"toolCalls" JSONB,
|
||||||
|
"functionCall" JSONB,
|
||||||
|
"sequence" INTEGER NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "ChatMessage_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "UserBusinessUnderstanding_userId_key" ON "UserBusinessUnderstanding"("userId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "UserBusinessUnderstanding_userId_idx" ON "UserBusinessUnderstanding"("userId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "ChatSession_userId_updatedAt_idx" ON "ChatSession"("userId", "updatedAt");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "ChatMessage_sessionId_sequence_idx" ON "ChatMessage"("sessionId", "sequence");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "ChatMessage_sessionId_sequence_key" ON "ChatMessage"("sessionId", "sequence");
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "UserBusinessUnderstanding" ADD CONSTRAINT "UserBusinessUnderstanding_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "ChatMessage" ADD CONSTRAINT "ChatMessage_sessionId_fkey" FOREIGN KEY ("sessionId") REFERENCES "ChatSession"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
-- CreateExtension
|
|
||||||
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
|
||||||
-- Create in public schema so vector type is available across all schemas
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "vector" WITH SCHEMA "public";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
-- CreateEnum
|
|
||||||
CREATE TYPE "ContentType" AS ENUM ('STORE_AGENT', 'BLOCK', 'INTEGRATION', 'DOCUMENTATION', 'LIBRARY_AGENT');
|
|
||||||
|
|
||||||
-- CreateTable
|
|
||||||
CREATE TABLE "UnifiedContentEmbedding" (
|
|
||||||
"id" TEXT NOT NULL,
|
|
||||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
|
||||||
"contentType" "ContentType" NOT NULL,
|
|
||||||
"contentId" TEXT NOT NULL,
|
|
||||||
"userId" TEXT,
|
|
||||||
"embedding" public.vector(1536) NOT NULL,
|
|
||||||
"searchableText" TEXT NOT NULL,
|
|
||||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
|
||||||
|
|
||||||
CONSTRAINT "UnifiedContentEmbedding_pkey" PRIMARY KEY ("id")
|
|
||||||
);
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "UnifiedContentEmbedding_contentType_idx" ON "UnifiedContentEmbedding"("contentType");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "UnifiedContentEmbedding_userId_idx" ON "UnifiedContentEmbedding"("userId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "UnifiedContentEmbedding_contentType_userId_idx" ON "UnifiedContentEmbedding"("contentType", "userId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
-- NULLS NOT DISTINCT ensures only one public (NULL userId) embedding per contentType+contentId
|
|
||||||
-- Requires PostgreSQL 15+. Supabase uses PostgreSQL 15+.
|
|
||||||
CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" ON "UnifiedContentEmbedding"("contentType", "contentId", "userId") NULLS NOT DISTINCT;
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
-- HNSW index for fast vector similarity search on embeddings
|
|
||||||
-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py
|
|
||||||
CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops);
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
-- Acknowledge Supabase-managed extensions to prevent drift warnings
|
|
||||||
-- These extensions are pre-installed by Supabase in specific schemas
|
|
||||||
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
|
|
||||||
|
|
||||||
-- Create schemas (safe in both CI and Supabase)
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "extensions";
|
|
||||||
|
|
||||||
-- Extensions that exist in both CI and Supabase
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pgcrypto extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'uuid-ossp extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
-- Supabase-specific extensions (skip gracefully in CI)
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pg_net extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pgjwt extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "graphql";
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pg_graphql extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "pgsodium";
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pgsodium extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "vault";
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'supabase_vault extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
|
|
||||||
-- Return to platform
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "platform";
|
|
||||||
@@ -117,7 +117,6 @@ lint = "linter:lint"
|
|||||||
test = "run_tests:test"
|
test = "run_tests:test"
|
||||||
load-store-agents = "test.load_store_agents:run"
|
load-store-agents = "test.load_store_agents:run"
|
||||||
export-api-schema = "backend.cli.generate_openapi_json:main"
|
export-api-schema = "backend.cli.generate_openapi_json:main"
|
||||||
gen-prisma-stub = "gen_prisma_types_stub:main"
|
|
||||||
oauth-tool = "backend.cli.oauth_tool:cli"
|
oauth-tool = "backend.cli.oauth_tool:cli"
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
@@ -135,9 +134,6 @@ ignore_patterns = []
|
|||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
asyncio_default_fixture_loop_scope = "session"
|
asyncio_default_fixture_loop_scope = "session"
|
||||||
# Disable syrupy plugin to avoid conflict with pytest-snapshot
|
|
||||||
# Both provide --snapshot-update argument causing ArgumentError
|
|
||||||
addopts = "-p no:syrupy"
|
|
||||||
filterwarnings = [
|
filterwarnings = [
|
||||||
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
|
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
|
||||||
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",
|
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
datasource db {
|
datasource db {
|
||||||
provider = "postgresql"
|
provider = "postgresql"
|
||||||
url = env("DATABASE_URL")
|
url = env("DATABASE_URL")
|
||||||
directUrl = env("DIRECT_URL")
|
directUrl = env("DIRECT_URL")
|
||||||
extensions = [pgvector(map: "vector")]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
generator client {
|
generator client {
|
||||||
provider = "prisma-client-py"
|
provider = "prisma-client-py"
|
||||||
recursive_type_depth = -1
|
recursive_type_depth = -1
|
||||||
interface = "asyncio"
|
interface = "asyncio"
|
||||||
previewFeatures = ["views", "fullTextSearch", "postgresqlExtensions"]
|
previewFeatures = ["views", "fullTextSearch"]
|
||||||
partial_type_generator = "backend/data/partial_types.py"
|
partial_type_generator = "backend/data/partial_types.py"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,6 +53,7 @@ model User {
|
|||||||
|
|
||||||
Profile Profile[]
|
Profile Profile[]
|
||||||
UserOnboarding UserOnboarding?
|
UserOnboarding UserOnboarding?
|
||||||
|
BusinessUnderstanding UserBusinessUnderstanding?
|
||||||
BuilderSearchHistory BuilderSearchHistory[]
|
BuilderSearchHistory BuilderSearchHistory[]
|
||||||
StoreListings StoreListing[]
|
StoreListings StoreListing[]
|
||||||
StoreListingReviews StoreListingReview[]
|
StoreListingReviews StoreListingReview[]
|
||||||
@@ -122,6 +122,43 @@ model UserOnboarding {
|
|||||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model UserBusinessUnderstanding {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @default(now()) @updatedAt
|
||||||
|
|
||||||
|
userId String @unique
|
||||||
|
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
// User info
|
||||||
|
userName String?
|
||||||
|
jobTitle String?
|
||||||
|
|
||||||
|
// Business basics (string columns)
|
||||||
|
businessName String?
|
||||||
|
industry String?
|
||||||
|
businessSize String? // "1-10", "11-50", "51-200", "201-1000", "1000+"
|
||||||
|
userRole String? // Role in organization context (e.g., "decision maker", "implementer")
|
||||||
|
|
||||||
|
// Processes & activities (JSON arrays)
|
||||||
|
keyWorkflows Json?
|
||||||
|
dailyActivities Json?
|
||||||
|
|
||||||
|
// Pain points & goals (JSON arrays)
|
||||||
|
painPoints Json?
|
||||||
|
bottlenecks Json?
|
||||||
|
manualTasks Json?
|
||||||
|
automationGoals Json?
|
||||||
|
|
||||||
|
// Current tools (JSON arrays)
|
||||||
|
currentSoftware Json?
|
||||||
|
existingAutomation Json?
|
||||||
|
|
||||||
|
additionalNotes String?
|
||||||
|
|
||||||
|
@@index([userId])
|
||||||
|
}
|
||||||
|
|
||||||
model BuilderSearchHistory {
|
model BuilderSearchHistory {
|
||||||
id String @id @default(uuid())
|
id String @id @default(uuid())
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
@@ -135,6 +172,59 @@ model BuilderSearchHistory {
|
|||||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
//////////////// CHAT SESSION TABLES ///////////////////
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
model ChatSession {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @default(now()) @updatedAt
|
||||||
|
|
||||||
|
userId String?
|
||||||
|
|
||||||
|
// Session metadata
|
||||||
|
title String?
|
||||||
|
credentials Json @default("{}") // Map of provider -> credential metadata
|
||||||
|
|
||||||
|
// Rate limiting counters (stored as JSON maps)
|
||||||
|
successfulAgentRuns Json @default("{}") // Map of graph_id -> count
|
||||||
|
successfulAgentSchedules Json @default("{}") // Map of graph_id -> count
|
||||||
|
|
||||||
|
// Usage tracking
|
||||||
|
totalPromptTokens Int @default(0)
|
||||||
|
totalCompletionTokens Int @default(0)
|
||||||
|
|
||||||
|
Messages ChatMessage[]
|
||||||
|
|
||||||
|
@@index([userId, updatedAt])
|
||||||
|
}
|
||||||
|
|
||||||
|
model ChatMessage {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
|
||||||
|
sessionId String
|
||||||
|
Session ChatSession @relation(fields: [sessionId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
// Message content
|
||||||
|
role String // "user", "assistant", "system", "tool", "function"
|
||||||
|
content String?
|
||||||
|
name String?
|
||||||
|
toolCallId String?
|
||||||
|
refusal String?
|
||||||
|
toolCalls Json? // List of tool calls for assistant messages
|
||||||
|
functionCall Json? // Deprecated but kept for compatibility
|
||||||
|
|
||||||
|
// Ordering within session
|
||||||
|
sequence Int
|
||||||
|
|
||||||
|
@@unique([sessionId, sequence])
|
||||||
|
@@index([sessionId, sequence])
|
||||||
|
}
|
||||||
|
|
||||||
// This model describes the Agent Graph/Flow (Multi Agent System).
|
// This model describes the Agent Graph/Flow (Multi Agent System).
|
||||||
model AgentGraph {
|
model AgentGraph {
|
||||||
id String @default(uuid())
|
id String @default(uuid())
|
||||||
@@ -728,19 +818,20 @@ view StoreAgent {
|
|||||||
agent_output_demo String?
|
agent_output_demo String?
|
||||||
agent_image String[]
|
agent_image String[]
|
||||||
|
|
||||||
featured Boolean @default(false)
|
featured Boolean @default(false)
|
||||||
creator_username String?
|
creator_username String?
|
||||||
creator_avatar String?
|
creator_avatar String?
|
||||||
sub_heading String
|
sub_heading String
|
||||||
description String
|
description String
|
||||||
categories String[]
|
categories String[]
|
||||||
|
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
||||||
runs Int
|
runs Int
|
||||||
rating Float
|
rating Float
|
||||||
versions String[]
|
versions String[]
|
||||||
agentGraphVersions String[]
|
agentGraphVersions String[]
|
||||||
agentGraphId String
|
agentGraphId String
|
||||||
is_available Boolean @default(true)
|
is_available Boolean @default(true)
|
||||||
useForOnboarding Boolean @default(false)
|
useForOnboarding Boolean @default(false)
|
||||||
|
|
||||||
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
||||||
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
||||||
@@ -899,52 +990,12 @@ model StoreListingVersion {
|
|||||||
// Reviews for this specific version
|
// Reviews for this specific version
|
||||||
Reviews StoreListingReview[]
|
Reviews StoreListingReview[]
|
||||||
|
|
||||||
// Note: Embeddings now stored in UnifiedContentEmbedding table
|
|
||||||
// Use contentType=STORE_AGENT and contentId=storeListingVersionId
|
|
||||||
|
|
||||||
@@unique([storeListingId, version])
|
|
||||||
@@index([storeListingId, submissionStatus, isAvailable])
|
@@index([storeListingId, submissionStatus, isAvailable])
|
||||||
@@index([submissionStatus])
|
@@index([submissionStatus])
|
||||||
@@index([reviewerId])
|
@@index([reviewerId])
|
||||||
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
|
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
|
||||||
}
|
}
|
||||||
|
|
||||||
// Content type enum for unified search across store agents, blocks, docs
|
|
||||||
// Note: BLOCK/INTEGRATION are file-based (Python classes), not DB records
|
|
||||||
// DOCUMENTATION are file-based (.md files), not DB records
|
|
||||||
// Only STORE_AGENT and LIBRARY_AGENT are stored in database
|
|
||||||
enum ContentType {
|
|
||||||
STORE_AGENT // Database: StoreListingVersion
|
|
||||||
BLOCK // File-based: Python classes in /backend/blocks/
|
|
||||||
INTEGRATION // File-based: Python classes (blocks with credentials)
|
|
||||||
DOCUMENTATION // File-based: .md/.mdx files
|
|
||||||
LIBRARY_AGENT // Database: User's personal agents
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unified embeddings table for all searchable content types
|
|
||||||
// Supports both public content (userId=null) and user-specific content (userId=userID)
|
|
||||||
model UnifiedContentEmbedding {
|
|
||||||
id String @id @default(uuid())
|
|
||||||
createdAt DateTime @default(now())
|
|
||||||
updatedAt DateTime @updatedAt
|
|
||||||
|
|
||||||
// Content identification
|
|
||||||
contentType ContentType
|
|
||||||
contentId String // DB ID (storeListingVersionId) or file identifier (block.id, file_path)
|
|
||||||
userId String? // NULL for public content (store, blocks, docs), userId for private content (library agents)
|
|
||||||
|
|
||||||
// Search data
|
|
||||||
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
|
|
||||||
searchableText String // Combined text for search and fallback
|
|
||||||
metadata Json @default("{}") // Content-specific metadata
|
|
||||||
|
|
||||||
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
|
|
||||||
@@index([contentType])
|
|
||||||
@@index([userId])
|
|
||||||
@@index([contentType, userId])
|
|
||||||
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
|
|
||||||
}
|
|
||||||
|
|
||||||
model StoreListingReview {
|
model StoreListingReview {
|
||||||
id String @id @default(uuid())
|
id String @id @default(uuid())
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
"created_at": "2025-09-04T13:37:00",
|
"created_at": "2025-09-04T13:37:00",
|
||||||
"credentials_input_schema": {
|
"credentials_input_schema": {
|
||||||
"properties": {},
|
"properties": {},
|
||||||
"required": [],
|
|
||||||
"title": "TestGraphCredentialsInputSchema",
|
"title": "TestGraphCredentialsInputSchema",
|
||||||
"type": "object"
|
"type": "object"
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
{
|
{
|
||||||
"credentials_input_schema": {
|
"credentials_input_schema": {
|
||||||
"properties": {},
|
"properties": {},
|
||||||
"required": [],
|
|
||||||
"title": "TestGraphCredentialsInputSchema",
|
"title": "TestGraphCredentialsInputSchema",
|
||||||
"type": "object"
|
"type": "object"
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
"id": "test-agent-1",
|
"id": "test-agent-1",
|
||||||
"graph_id": "test-agent-1",
|
"graph_id": "test-agent-1",
|
||||||
"graph_version": 1,
|
"graph_version": 1,
|
||||||
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
|
||||||
"image_url": null,
|
"image_url": null,
|
||||||
"creator_name": "Test Creator",
|
"creator_name": "Test Creator",
|
||||||
"creator_image_url": "",
|
"creator_image_url": "",
|
||||||
@@ -42,7 +41,6 @@
|
|||||||
"id": "test-agent-2",
|
"id": "test-agent-2",
|
||||||
"graph_id": "test-agent-2",
|
"graph_id": "test-agent-2",
|
||||||
"graph_version": 1,
|
"graph_version": 1,
|
||||||
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
|
||||||
"image_url": null,
|
"image_url": null,
|
||||||
"creator_name": "Test Creator",
|
"creator_name": "Test Creator",
|
||||||
"creator_image_url": "",
|
"creator_image_url": "",
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
{
|
{
|
||||||
"submissions": [
|
"submissions": [
|
||||||
{
|
{
|
||||||
"listing_id": "test-listing-id",
|
|
||||||
"agent_id": "test-agent-id",
|
"agent_id": "test-agent-id",
|
||||||
"agent_version": 1,
|
"agent_version": 1,
|
||||||
"name": "Test Agent",
|
"name": "Test Agent",
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ services:
|
|||||||
context: ../
|
context: ../
|
||||||
dockerfile: autogpt_platform/backend/Dockerfile
|
dockerfile: autogpt_platform/backend/Dockerfile
|
||||||
target: migrate
|
target: migrate
|
||||||
command: ["sh", "-c", "poetry run prisma generate && poetry run gen-prisma-stub && poetry run prisma migrate deploy"]
|
command: ["sh", "-c", "poetry run prisma generate && poetry run prisma migrate deploy"]
|
||||||
develop:
|
develop:
|
||||||
watch:
|
watch:
|
||||||
- path: ./
|
- path: ./
|
||||||
|
|||||||
@@ -92,6 +92,7 @@
|
|||||||
"react-currency-input-field": "4.0.3",
|
"react-currency-input-field": "4.0.3",
|
||||||
"react-day-picker": "9.11.1",
|
"react-day-picker": "9.11.1",
|
||||||
"react-dom": "18.3.1",
|
"react-dom": "18.3.1",
|
||||||
|
"react-drag-drop-files": "2.4.0",
|
||||||
"react-hook-form": "7.66.0",
|
"react-hook-form": "7.66.0",
|
||||||
"react-icons": "5.5.0",
|
"react-icons": "5.5.0",
|
||||||
"react-markdown": "9.0.3",
|
"react-markdown": "9.0.3",
|
||||||
|
|||||||
112
autogpt_platform/frontend/pnpm-lock.yaml
generated
112
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -200,6 +200,9 @@ importers:
|
|||||||
react-dom:
|
react-dom:
|
||||||
specifier: 18.3.1
|
specifier: 18.3.1
|
||||||
version: 18.3.1(react@18.3.1)
|
version: 18.3.1(react@18.3.1)
|
||||||
|
react-drag-drop-files:
|
||||||
|
specifier: 2.4.0
|
||||||
|
version: 2.4.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
react-hook-form:
|
react-hook-form:
|
||||||
specifier: 7.66.0
|
specifier: 7.66.0
|
||||||
version: 7.66.0(react@18.3.1)
|
version: 7.66.0(react@18.3.1)
|
||||||
@@ -1001,6 +1004,9 @@ packages:
|
|||||||
'@emotion/memoize@0.8.1':
|
'@emotion/memoize@0.8.1':
|
||||||
resolution: {integrity: sha512-W2P2c/VRW1/1tLox0mVUalvnWXxavmv/Oum2aPsRcoDJuob75FC3Y8FbpfLwUegRcxINtGUMPq0tFCvYNTBXNA==}
|
resolution: {integrity: sha512-W2P2c/VRW1/1tLox0mVUalvnWXxavmv/Oum2aPsRcoDJuob75FC3Y8FbpfLwUegRcxINtGUMPq0tFCvYNTBXNA==}
|
||||||
|
|
||||||
|
'@emotion/unitless@0.8.1':
|
||||||
|
resolution: {integrity: sha512-KOEGMu6dmJZtpadb476IsZBclKvILjopjUii3V+7MnXIQCYh8W3NgNcgwo21n9LXZX6EDIKvqfjYxXebDwxKmQ==}
|
||||||
|
|
||||||
'@epic-web/invariant@1.0.0':
|
'@epic-web/invariant@1.0.0':
|
||||||
resolution: {integrity: sha512-lrTPqgvfFQtR/eY/qkIzp98OGdNJu0m5ji3q/nJI8v3SXkRKEnWiOxMmbvcSoAIzv/cGiuvRy57k4suKQSAdwA==}
|
resolution: {integrity: sha512-lrTPqgvfFQtR/eY/qkIzp98OGdNJu0m5ji3q/nJI8v3SXkRKEnWiOxMmbvcSoAIzv/cGiuvRy57k4suKQSAdwA==}
|
||||||
|
|
||||||
@@ -3116,6 +3122,9 @@ packages:
|
|||||||
'@types/statuses@2.0.6':
|
'@types/statuses@2.0.6':
|
||||||
resolution: {integrity: sha512-xMAgYwceFhRA2zY+XbEA7mxYbA093wdiW8Vu6gZPGWy9cmOyU9XesH1tNcEWsKFd5Vzrqx5T3D38PWx1FIIXkA==}
|
resolution: {integrity: sha512-xMAgYwceFhRA2zY+XbEA7mxYbA093wdiW8Vu6gZPGWy9cmOyU9XesH1tNcEWsKFd5Vzrqx5T3D38PWx1FIIXkA==}
|
||||||
|
|
||||||
|
'@types/stylis@4.2.7':
|
||||||
|
resolution: {integrity: sha512-VgDNokpBoKF+wrdvhAAfS55OMQpL6QRglwTwNC3kIgBrzZxA4WsFj+2eLfEA/uMUDzBcEhYmjSbwQakn/i3ajA==}
|
||||||
|
|
||||||
'@types/tedious@4.0.14':
|
'@types/tedious@4.0.14':
|
||||||
resolution: {integrity: sha512-KHPsfX/FoVbUGbyYvk1q9MMQHLPeRZhRJZdO45Q4YjvFkv4hMNghCWTvy7rdKessBsmtz4euWCWAB6/tVpI1Iw==}
|
resolution: {integrity: sha512-KHPsfX/FoVbUGbyYvk1q9MMQHLPeRZhRJZdO45Q4YjvFkv4hMNghCWTvy7rdKessBsmtz4euWCWAB6/tVpI1Iw==}
|
||||||
|
|
||||||
@@ -3772,6 +3781,9 @@ packages:
|
|||||||
resolution: {integrity: sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA==}
|
resolution: {integrity: sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA==}
|
||||||
engines: {node: '>= 6'}
|
engines: {node: '>= 6'}
|
||||||
|
|
||||||
|
camelize@1.0.1:
|
||||||
|
resolution: {integrity: sha512-dU+Tx2fsypxTgtLoE36npi3UqcjSSMNYfkqgmoEhtZrraP5VWq0K7FkWVTYa8eMPtnU/G2txVsfdCJTn9uzpuQ==}
|
||||||
|
|
||||||
caniuse-lite@1.0.30001762:
|
caniuse-lite@1.0.30001762:
|
||||||
resolution: {integrity: sha512-PxZwGNvH7Ak8WX5iXzoK1KPZttBXNPuaOvI2ZYU7NrlM+d9Ov+TUvlLOBNGzVXAntMSMMlJPd+jY6ovrVjSmUw==}
|
resolution: {integrity: sha512-PxZwGNvH7Ak8WX5iXzoK1KPZttBXNPuaOvI2ZYU7NrlM+d9Ov+TUvlLOBNGzVXAntMSMMlJPd+jY6ovrVjSmUw==}
|
||||||
|
|
||||||
@@ -3985,6 +3997,10 @@ packages:
|
|||||||
resolution: {integrity: sha512-r4ESw/IlusD17lgQi1O20Fa3qNnsckR126TdUuBgAu7GBYSIPvdNyONd3Zrxh0xCwA4+6w/TDArBPsMvhur+KQ==}
|
resolution: {integrity: sha512-r4ESw/IlusD17lgQi1O20Fa3qNnsckR126TdUuBgAu7GBYSIPvdNyONd3Zrxh0xCwA4+6w/TDArBPsMvhur+KQ==}
|
||||||
engines: {node: '>= 0.10'}
|
engines: {node: '>= 0.10'}
|
||||||
|
|
||||||
|
css-color-keywords@1.0.0:
|
||||||
|
resolution: {integrity: sha512-FyyrDHZKEjXDpNJYvVsV960FiqQyXc/LlYmsxl2BcdMb2WPx0OGRVgTg55rPSyLSNMqP52R9r8geSp7apN3Ofg==}
|
||||||
|
engines: {node: '>=4'}
|
||||||
|
|
||||||
css-loader@6.11.0:
|
css-loader@6.11.0:
|
||||||
resolution: {integrity: sha512-CTJ+AEQJjq5NzLga5pE39qdiSV56F8ywCIsqNIRF0r7BDgWsN25aazToqAFg7ZrtA/U016xudB3ffgweORxX7g==}
|
resolution: {integrity: sha512-CTJ+AEQJjq5NzLga5pE39qdiSV56F8ywCIsqNIRF0r7BDgWsN25aazToqAFg7ZrtA/U016xudB3ffgweORxX7g==}
|
||||||
engines: {node: '>= 12.13.0'}
|
engines: {node: '>= 12.13.0'}
|
||||||
@@ -4000,6 +4016,9 @@ packages:
|
|||||||
css-select@4.3.0:
|
css-select@4.3.0:
|
||||||
resolution: {integrity: sha512-wPpOYtnsVontu2mODhA19JrqWxNsfdatRKd64kmpRbQgh1KtItko5sTnEpPdpSaJszTOhEMlF/RPz28qj4HqhQ==}
|
resolution: {integrity: sha512-wPpOYtnsVontu2mODhA19JrqWxNsfdatRKd64kmpRbQgh1KtItko5sTnEpPdpSaJszTOhEMlF/RPz28qj4HqhQ==}
|
||||||
|
|
||||||
|
css-to-react-native@3.2.0:
|
||||||
|
resolution: {integrity: sha512-e8RKaLXMOFii+02mOlqwjbD00KSEKqblnpO9e++1aXS1fPQOpS1YoqdVHBqPjHNoxeF2mimzVqawm2KCbEdtHQ==}
|
||||||
|
|
||||||
css-what@6.2.2:
|
css-what@6.2.2:
|
||||||
resolution: {integrity: sha512-u/O3vwbptzhMs3L1fQE82ZSLHQQfto5gyZzwteVIEyeaY5Fc7R4dapF/BvRoSYFeqfBk4m0V1Vafq5Pjv25wvA==}
|
resolution: {integrity: sha512-u/O3vwbptzhMs3L1fQE82ZSLHQQfto5gyZzwteVIEyeaY5Fc7R4dapF/BvRoSYFeqfBk4m0V1Vafq5Pjv25wvA==}
|
||||||
engines: {node: '>= 6'}
|
engines: {node: '>= 6'}
|
||||||
@@ -6112,6 +6131,10 @@ packages:
|
|||||||
resolution: {integrity: sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==}
|
resolution: {integrity: sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==}
|
||||||
engines: {node: ^10 || ^12 || >=14}
|
engines: {node: ^10 || ^12 || >=14}
|
||||||
|
|
||||||
|
postcss@8.4.49:
|
||||||
|
resolution: {integrity: sha512-OCVPnIObs4N29kxTjzLfUryOkvZEq+pf8jTF0lg8E7uETuWHA+v7j3c/xJmiqpX450191LlmZfUKkXxkTry7nA==}
|
||||||
|
engines: {node: ^10 || ^12 || >=14}
|
||||||
|
|
||||||
postcss@8.5.6:
|
postcss@8.5.6:
|
||||||
resolution: {integrity: sha512-3Ybi1tAuwAP9s0r1UQ2J4n5Y0G05bJkpUIO0/bI9MhwmD70S5aTWbXGBwxHrelT+XM1k6dM0pk+SwNkpTRN7Pg==}
|
resolution: {integrity: sha512-3Ybi1tAuwAP9s0r1UQ2J4n5Y0G05bJkpUIO0/bI9MhwmD70S5aTWbXGBwxHrelT+XM1k6dM0pk+SwNkpTRN7Pg==}
|
||||||
engines: {node: ^10 || ^12 || >=14}
|
engines: {node: ^10 || ^12 || >=14}
|
||||||
@@ -6283,6 +6306,12 @@ packages:
|
|||||||
peerDependencies:
|
peerDependencies:
|
||||||
react: ^18.3.1
|
react: ^18.3.1
|
||||||
|
|
||||||
|
react-drag-drop-files@2.4.0:
|
||||||
|
resolution: {integrity: sha512-MGPV3HVVnwXEXq3gQfLtSU3jz5j5jrabvGedokpiSEMoONrDHgYl/NpIOlfsqGQ4zBv1bzzv7qbKURZNOX32PA==}
|
||||||
|
peerDependencies:
|
||||||
|
react: ^18.0.0
|
||||||
|
react-dom: ^18.0.0
|
||||||
|
|
||||||
react-hook-form@7.66.0:
|
react-hook-form@7.66.0:
|
||||||
resolution: {integrity: sha512-xXBqsWGKrY46ZqaHDo+ZUYiMUgi8suYu5kdrS20EG8KiL7VRQitEbNjm+UcrDYrNi1YLyfpmAeGjCZYXLT9YBw==}
|
resolution: {integrity: sha512-xXBqsWGKrY46ZqaHDo+ZUYiMUgi8suYu5kdrS20EG8KiL7VRQitEbNjm+UcrDYrNi1YLyfpmAeGjCZYXLT9YBw==}
|
||||||
engines: {node: '>=18.0.0'}
|
engines: {node: '>=18.0.0'}
|
||||||
@@ -6649,6 +6678,9 @@ packages:
|
|||||||
engines: {node: '>= 0.10'}
|
engines: {node: '>= 0.10'}
|
||||||
hasBin: true
|
hasBin: true
|
||||||
|
|
||||||
|
shallowequal@1.1.0:
|
||||||
|
resolution: {integrity: sha512-y0m1JoUZSlPAjXVtPPW70aZWfIL/dSP7AFkRnniLCrK/8MDKog3TySTBmckD+RObVxH0v4Tox67+F14PdED2oQ==}
|
||||||
|
|
||||||
sharp@0.34.5:
|
sharp@0.34.5:
|
||||||
resolution: {integrity: sha512-Ou9I5Ft9WNcCbXrU9cMgPBcCK8LiwLqcbywW3t4oDV37n1pzpuNLsYiAV8eODnjbtQlSDwZ2cUEeQz4E54Hltg==}
|
resolution: {integrity: sha512-Ou9I5Ft9WNcCbXrU9cMgPBcCK8LiwLqcbywW3t4oDV37n1pzpuNLsYiAV8eODnjbtQlSDwZ2cUEeQz4E54Hltg==}
|
||||||
engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0}
|
engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0}
|
||||||
@@ -6862,6 +6894,13 @@ packages:
|
|||||||
style-to-object@1.0.14:
|
style-to-object@1.0.14:
|
||||||
resolution: {integrity: sha512-LIN7rULI0jBscWQYaSswptyderlarFkjQ+t79nzty8tcIAceVomEVlLzH5VP4Cmsv6MtKhs7qaAiwlcp+Mgaxw==}
|
resolution: {integrity: sha512-LIN7rULI0jBscWQYaSswptyderlarFkjQ+t79nzty8tcIAceVomEVlLzH5VP4Cmsv6MtKhs7qaAiwlcp+Mgaxw==}
|
||||||
|
|
||||||
|
styled-components@6.2.0:
|
||||||
|
resolution: {integrity: sha512-ryFCkETE++8jlrBmC+BoGPUN96ld1/Yp0s7t5bcXDobrs4XoXroY1tN+JbFi09hV6a5h3MzbcVi8/BGDP0eCgQ==}
|
||||||
|
engines: {node: '>= 16'}
|
||||||
|
peerDependencies:
|
||||||
|
react: '>= 16.8.0'
|
||||||
|
react-dom: '>= 16.8.0'
|
||||||
|
|
||||||
styled-jsx@5.1.6:
|
styled-jsx@5.1.6:
|
||||||
resolution: {integrity: sha512-qSVyDTeMotdvQYoHWLNGwRFJHC+i+ZvdBRYosOFgC+Wg1vx4frN2/RG/NA7SYqqvKNLf39P2LSRA2pu6n0XYZA==}
|
resolution: {integrity: sha512-qSVyDTeMotdvQYoHWLNGwRFJHC+i+ZvdBRYosOFgC+Wg1vx4frN2/RG/NA7SYqqvKNLf39P2LSRA2pu6n0XYZA==}
|
||||||
engines: {node: '>= 12.0.0'}
|
engines: {node: '>= 12.0.0'}
|
||||||
@@ -6888,6 +6927,9 @@ packages:
|
|||||||
babel-plugin-macros:
|
babel-plugin-macros:
|
||||||
optional: true
|
optional: true
|
||||||
|
|
||||||
|
stylis@4.3.6:
|
||||||
|
resolution: {integrity: sha512-yQ3rwFWRfwNUY7H5vpU0wfdkNSnvnJinhF9830Swlaxl03zsOjCfmX0ugac+3LtK0lYSgwL/KXc8oYL3mG4YFQ==}
|
||||||
|
|
||||||
sucrase@3.35.1:
|
sucrase@3.35.1:
|
||||||
resolution: {integrity: sha512-DhuTmvZWux4H1UOnWMB3sk0sbaCVOoQZjv8u1rDoTV0HTdGem9hkAZtl4JZy8P2z4Bg0nT+YMeOFyVr4zcG5Tw==}
|
resolution: {integrity: sha512-DhuTmvZWux4H1UOnWMB3sk0sbaCVOoQZjv8u1rDoTV0HTdGem9hkAZtl4JZy8P2z4Bg0nT+YMeOFyVr4zcG5Tw==}
|
||||||
engines: {node: '>=16 || 14 >=14.17'}
|
engines: {node: '>=16 || 14 >=14.17'}
|
||||||
@@ -7054,6 +7096,9 @@ packages:
|
|||||||
tslib@1.14.1:
|
tslib@1.14.1:
|
||||||
resolution: {integrity: sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==}
|
resolution: {integrity: sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==}
|
||||||
|
|
||||||
|
tslib@2.6.2:
|
||||||
|
resolution: {integrity: sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==}
|
||||||
|
|
||||||
tslib@2.8.1:
|
tslib@2.8.1:
|
||||||
resolution: {integrity: sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==}
|
resolution: {integrity: sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==}
|
||||||
|
|
||||||
@@ -8290,10 +8335,10 @@ snapshots:
|
|||||||
'@emotion/is-prop-valid@1.2.2':
|
'@emotion/is-prop-valid@1.2.2':
|
||||||
dependencies:
|
dependencies:
|
||||||
'@emotion/memoize': 0.8.1
|
'@emotion/memoize': 0.8.1
|
||||||
optional: true
|
|
||||||
|
|
||||||
'@emotion/memoize@0.8.1':
|
'@emotion/memoize@0.8.1': {}
|
||||||
optional: true
|
|
||||||
|
'@emotion/unitless@0.8.1': {}
|
||||||
|
|
||||||
'@epic-web/invariant@1.0.0': {}
|
'@epic-web/invariant@1.0.0': {}
|
||||||
|
|
||||||
@@ -10689,6 +10734,8 @@ snapshots:
|
|||||||
|
|
||||||
'@types/statuses@2.0.6': {}
|
'@types/statuses@2.0.6': {}
|
||||||
|
|
||||||
|
'@types/stylis@4.2.7': {}
|
||||||
|
|
||||||
'@types/tedious@4.0.14':
|
'@types/tedious@4.0.14':
|
||||||
dependencies:
|
dependencies:
|
||||||
'@types/node': 24.10.0
|
'@types/node': 24.10.0
|
||||||
@@ -11385,6 +11432,8 @@ snapshots:
|
|||||||
|
|
||||||
camelcase-css@2.0.1: {}
|
camelcase-css@2.0.1: {}
|
||||||
|
|
||||||
|
camelize@1.0.1: {}
|
||||||
|
|
||||||
caniuse-lite@1.0.30001762: {}
|
caniuse-lite@1.0.30001762: {}
|
||||||
|
|
||||||
case-sensitive-paths-webpack-plugin@2.4.0: {}
|
case-sensitive-paths-webpack-plugin@2.4.0: {}
|
||||||
@@ -11596,6 +11645,8 @@ snapshots:
|
|||||||
randombytes: 2.1.0
|
randombytes: 2.1.0
|
||||||
randomfill: 1.0.4
|
randomfill: 1.0.4
|
||||||
|
|
||||||
|
css-color-keywords@1.0.0: {}
|
||||||
|
|
||||||
css-loader@6.11.0(webpack@5.104.1(esbuild@0.25.12)):
|
css-loader@6.11.0(webpack@5.104.1(esbuild@0.25.12)):
|
||||||
dependencies:
|
dependencies:
|
||||||
icss-utils: 5.1.0(postcss@8.5.6)
|
icss-utils: 5.1.0(postcss@8.5.6)
|
||||||
@@ -11617,6 +11668,12 @@ snapshots:
|
|||||||
domutils: 2.8.0
|
domutils: 2.8.0
|
||||||
nth-check: 2.1.1
|
nth-check: 2.1.1
|
||||||
|
|
||||||
|
css-to-react-native@3.2.0:
|
||||||
|
dependencies:
|
||||||
|
camelize: 1.0.1
|
||||||
|
css-color-keywords: 1.0.0
|
||||||
|
postcss-value-parser: 4.2.0
|
||||||
|
|
||||||
css-what@6.2.2: {}
|
css-what@6.2.2: {}
|
||||||
|
|
||||||
css.escape@1.5.1: {}
|
css.escape@1.5.1: {}
|
||||||
@@ -12070,8 +12127,8 @@ snapshots:
|
|||||||
'@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3)
|
'@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3)
|
||||||
eslint: 8.57.1
|
eslint: 8.57.1
|
||||||
eslint-import-resolver-node: 0.3.9
|
eslint-import-resolver-node: 0.3.9
|
||||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
|
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1)
|
||||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
|
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
|
||||||
eslint-plugin-jsx-a11y: 6.10.2(eslint@8.57.1)
|
eslint-plugin-jsx-a11y: 6.10.2(eslint@8.57.1)
|
||||||
eslint-plugin-react: 7.37.5(eslint@8.57.1)
|
eslint-plugin-react: 7.37.5(eslint@8.57.1)
|
||||||
eslint-plugin-react-hooks: 5.2.0(eslint@8.57.1)
|
eslint-plugin-react-hooks: 5.2.0(eslint@8.57.1)
|
||||||
@@ -12090,7 +12147,7 @@ snapshots:
|
|||||||
transitivePeerDependencies:
|
transitivePeerDependencies:
|
||||||
- supports-color
|
- supports-color
|
||||||
|
|
||||||
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1):
|
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1):
|
||||||
dependencies:
|
dependencies:
|
||||||
'@nolyfill/is-core-module': 1.0.39
|
'@nolyfill/is-core-module': 1.0.39
|
||||||
debug: 4.4.3
|
debug: 4.4.3
|
||||||
@@ -12101,22 +12158,22 @@ snapshots:
|
|||||||
tinyglobby: 0.2.15
|
tinyglobby: 0.2.15
|
||||||
unrs-resolver: 1.11.1
|
unrs-resolver: 1.11.1
|
||||||
optionalDependencies:
|
optionalDependencies:
|
||||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
|
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
|
||||||
transitivePeerDependencies:
|
transitivePeerDependencies:
|
||||||
- supports-color
|
- supports-color
|
||||||
|
|
||||||
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
|
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
|
||||||
dependencies:
|
dependencies:
|
||||||
debug: 3.2.7
|
debug: 3.2.7
|
||||||
optionalDependencies:
|
optionalDependencies:
|
||||||
'@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3)
|
'@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3)
|
||||||
eslint: 8.57.1
|
eslint: 8.57.1
|
||||||
eslint-import-resolver-node: 0.3.9
|
eslint-import-resolver-node: 0.3.9
|
||||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
|
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1)
|
||||||
transitivePeerDependencies:
|
transitivePeerDependencies:
|
||||||
- supports-color
|
- supports-color
|
||||||
|
|
||||||
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
|
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
|
||||||
dependencies:
|
dependencies:
|
||||||
'@rtsao/scc': 1.1.0
|
'@rtsao/scc': 1.1.0
|
||||||
array-includes: 3.1.9
|
array-includes: 3.1.9
|
||||||
@@ -12127,7 +12184,7 @@ snapshots:
|
|||||||
doctrine: 2.1.0
|
doctrine: 2.1.0
|
||||||
eslint: 8.57.1
|
eslint: 8.57.1
|
||||||
eslint-import-resolver-node: 0.3.9
|
eslint-import-resolver-node: 0.3.9
|
||||||
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
|
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
|
||||||
hasown: 2.0.2
|
hasown: 2.0.2
|
||||||
is-core-module: 2.16.1
|
is-core-module: 2.16.1
|
||||||
is-glob: 4.0.3
|
is-glob: 4.0.3
|
||||||
@@ -14202,6 +14259,12 @@ snapshots:
|
|||||||
picocolors: 1.1.1
|
picocolors: 1.1.1
|
||||||
source-map-js: 1.2.1
|
source-map-js: 1.2.1
|
||||||
|
|
||||||
|
postcss@8.4.49:
|
||||||
|
dependencies:
|
||||||
|
nanoid: 3.3.11
|
||||||
|
picocolors: 1.1.1
|
||||||
|
source-map-js: 1.2.1
|
||||||
|
|
||||||
postcss@8.5.6:
|
postcss@8.5.6:
|
||||||
dependencies:
|
dependencies:
|
||||||
nanoid: 3.3.11
|
nanoid: 3.3.11
|
||||||
@@ -14323,6 +14386,13 @@ snapshots:
|
|||||||
react: 18.3.1
|
react: 18.3.1
|
||||||
scheduler: 0.23.2
|
scheduler: 0.23.2
|
||||||
|
|
||||||
|
react-drag-drop-files@2.4.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||||
|
dependencies:
|
||||||
|
prop-types: 15.8.1
|
||||||
|
react: 18.3.1
|
||||||
|
react-dom: 18.3.1(react@18.3.1)
|
||||||
|
styled-components: 6.2.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
|
|
||||||
react-hook-form@7.66.0(react@18.3.1):
|
react-hook-form@7.66.0(react@18.3.1):
|
||||||
dependencies:
|
dependencies:
|
||||||
react: 18.3.1
|
react: 18.3.1
|
||||||
@@ -14816,6 +14886,8 @@ snapshots:
|
|||||||
safe-buffer: 5.2.1
|
safe-buffer: 5.2.1
|
||||||
to-buffer: 1.2.2
|
to-buffer: 1.2.2
|
||||||
|
|
||||||
|
shallowequal@1.1.0: {}
|
||||||
|
|
||||||
sharp@0.34.5:
|
sharp@0.34.5:
|
||||||
dependencies:
|
dependencies:
|
||||||
'@img/colour': 1.0.0
|
'@img/colour': 1.0.0
|
||||||
@@ -15106,6 +15178,20 @@ snapshots:
|
|||||||
dependencies:
|
dependencies:
|
||||||
inline-style-parser: 0.2.7
|
inline-style-parser: 0.2.7
|
||||||
|
|
||||||
|
styled-components@6.2.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||||
|
dependencies:
|
||||||
|
'@emotion/is-prop-valid': 1.2.2
|
||||||
|
'@emotion/unitless': 0.8.1
|
||||||
|
'@types/stylis': 4.2.7
|
||||||
|
css-to-react-native: 3.2.0
|
||||||
|
csstype: 3.2.3
|
||||||
|
postcss: 8.4.49
|
||||||
|
react: 18.3.1
|
||||||
|
react-dom: 18.3.1(react@18.3.1)
|
||||||
|
shallowequal: 1.1.0
|
||||||
|
stylis: 4.3.6
|
||||||
|
tslib: 2.6.2
|
||||||
|
|
||||||
styled-jsx@5.1.6(@babel/core@7.28.5)(react@18.3.1):
|
styled-jsx@5.1.6(@babel/core@7.28.5)(react@18.3.1):
|
||||||
dependencies:
|
dependencies:
|
||||||
client-only: 0.0.1
|
client-only: 0.0.1
|
||||||
@@ -15120,6 +15206,8 @@ snapshots:
|
|||||||
optionalDependencies:
|
optionalDependencies:
|
||||||
'@babel/core': 7.28.5
|
'@babel/core': 7.28.5
|
||||||
|
|
||||||
|
stylis@4.3.6: {}
|
||||||
|
|
||||||
sucrase@3.35.1:
|
sucrase@3.35.1:
|
||||||
dependencies:
|
dependencies:
|
||||||
'@jridgewell/gen-mapping': 0.3.13
|
'@jridgewell/gen-mapping': 0.3.13
|
||||||
@@ -15302,6 +15390,8 @@ snapshots:
|
|||||||
|
|
||||||
tslib@1.14.1: {}
|
tslib@1.14.1: {}
|
||||||
|
|
||||||
|
tslib@2.6.2: {}
|
||||||
|
|
||||||
tslib@2.8.1: {}
|
tslib@2.8.1: {}
|
||||||
|
|
||||||
tty-browserify@0.0.1: {}
|
tty-browserify@0.0.1: {}
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 2.6 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 16 KiB |
@@ -66,7 +66,6 @@ export const RunInputDialog = ({
|
|||||||
formContext={{
|
formContext={{
|
||||||
showHandles: false,
|
showHandles: false,
|
||||||
size: "large",
|
size: "large",
|
||||||
showOptionalToggle: false,
|
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
@@ -81,16 +80,18 @@ export const RunInputDialog = ({
|
|||||||
Inputs
|
Inputs
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
<FormRenderer
|
<div className="px-2">
|
||||||
jsonSchema={inputSchema as RJSFSchema}
|
<FormRenderer
|
||||||
handleChange={(v) => handleInputChange(v.formData)}
|
jsonSchema={inputSchema as RJSFSchema}
|
||||||
uiSchema={uiSchema}
|
handleChange={(v) => handleInputChange(v.formData)}
|
||||||
initialValues={{}}
|
uiSchema={uiSchema}
|
||||||
formContext={{
|
initialValues={{}}
|
||||||
showHandles: false,
|
formContext={{
|
||||||
size: "large",
|
showHandles: false,
|
||||||
}}
|
size: "large",
|
||||||
/>
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ export const useRunInputDialog = ({
|
|||||||
if (isCredentialFieldSchema(fieldSchema)) {
|
if (isCredentialFieldSchema(fieldSchema)) {
|
||||||
dynamicUiSchema[fieldName] = {
|
dynamicUiSchema[fieldName] = {
|
||||||
...dynamicUiSchema[fieldName],
|
...dynamicUiSchema[fieldName],
|
||||||
"ui:field": "custom/credential_field",
|
"ui:field": "credentials",
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -76,18 +76,12 @@ export const useRunInputDialog = ({
|
|||||||
}, [credentialsSchema]);
|
}, [credentialsSchema]);
|
||||||
|
|
||||||
const handleManualRun = async () => {
|
const handleManualRun = async () => {
|
||||||
// Filter out incomplete credentials (those without a valid id)
|
|
||||||
// RJSF auto-populates const values (provider, type) but not id field
|
|
||||||
const validCredentials = Object.fromEntries(
|
|
||||||
Object.entries(credentialValues).filter(([_, cred]) => cred && cred.id),
|
|
||||||
);
|
|
||||||
|
|
||||||
await executeGraph({
|
await executeGraph({
|
||||||
graphId: flowID ?? "",
|
graphId: flowID ?? "",
|
||||||
graphVersion: flowVersion || null,
|
graphVersion: flowVersion || null,
|
||||||
data: {
|
data: {
|
||||||
inputs: inputValues,
|
inputs: inputValues,
|
||||||
credentials_inputs: validCredentials,
|
credentials_inputs: credentialValues,
|
||||||
source: "builder",
|
source: "builder",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import { useGetV2GetSpecificBlocks } from "@/app/api/__generated__/endpoints/def
|
|||||||
import {
|
import {
|
||||||
useGetV1GetExecutionDetails,
|
useGetV1GetExecutionDetails,
|
||||||
useGetV1GetSpecificGraph,
|
useGetV1GetSpecificGraph,
|
||||||
useGetV1ListUserGraphs,
|
|
||||||
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||||
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||||
@@ -18,7 +17,6 @@ import { useReactFlow } from "@xyflow/react";
|
|||||||
import { useControlPanelStore } from "../../../stores/controlPanelStore";
|
import { useControlPanelStore } from "../../../stores/controlPanelStore";
|
||||||
import { useHistoryStore } from "../../../stores/historyStore";
|
import { useHistoryStore } from "../../../stores/historyStore";
|
||||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||||
import { okData } from "@/app/api/helpers";
|
|
||||||
|
|
||||||
export const useFlow = () => {
|
export const useFlow = () => {
|
||||||
const [isLocked, setIsLocked] = useState(false);
|
const [isLocked, setIsLocked] = useState(false);
|
||||||
@@ -38,9 +36,6 @@ export const useFlow = () => {
|
|||||||
const setGraphExecutionStatus = useGraphStore(
|
const setGraphExecutionStatus = useGraphStore(
|
||||||
useShallow((state) => state.setGraphExecutionStatus),
|
useShallow((state) => state.setGraphExecutionStatus),
|
||||||
);
|
);
|
||||||
const setAvailableSubGraphs = useGraphStore(
|
|
||||||
useShallow((state) => state.setAvailableSubGraphs),
|
|
||||||
);
|
|
||||||
const updateEdgeBeads = useEdgeStore(
|
const updateEdgeBeads = useEdgeStore(
|
||||||
useShallow((state) => state.updateEdgeBeads),
|
useShallow((state) => state.updateEdgeBeads),
|
||||||
);
|
);
|
||||||
@@ -67,11 +62,6 @@ export const useFlow = () => {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
// Fetch all available graphs for sub-agent update detection
|
|
||||||
const { data: availableGraphs } = useGetV1ListUserGraphs({
|
|
||||||
query: { select: okData },
|
|
||||||
});
|
|
||||||
|
|
||||||
const { data: graph, isLoading: isGraphLoading } = useGetV1GetSpecificGraph(
|
const { data: graph, isLoading: isGraphLoading } = useGetV1GetSpecificGraph(
|
||||||
flowID ?? "",
|
flowID ?? "",
|
||||||
flowVersion !== null ? { version: flowVersion } : {},
|
flowVersion !== null ? { version: flowVersion } : {},
|
||||||
@@ -126,18 +116,10 @@ export const useFlow = () => {
|
|||||||
}
|
}
|
||||||
}, [graph]);
|
}, [graph]);
|
||||||
|
|
||||||
// Update available sub-graphs in store for sub-agent update detection
|
|
||||||
useEffect(() => {
|
|
||||||
if (availableGraphs) {
|
|
||||||
setAvailableSubGraphs(availableGraphs);
|
|
||||||
}
|
|
||||||
}, [availableGraphs, setAvailableSubGraphs]);
|
|
||||||
|
|
||||||
// adding nodes
|
// adding nodes
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (customNodes.length > 0) {
|
if (customNodes.length > 0) {
|
||||||
useNodeStore.getState().setNodes([]);
|
useNodeStore.getState().setNodes([]);
|
||||||
useNodeStore.getState().clearResolutionState();
|
|
||||||
addNodes(customNodes);
|
addNodes(customNodes);
|
||||||
|
|
||||||
// Sync hardcoded values with handle IDs.
|
// Sync hardcoded values with handle IDs.
|
||||||
@@ -221,7 +203,6 @@ export const useFlow = () => {
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
return () => {
|
return () => {
|
||||||
useNodeStore.getState().setNodes([]);
|
useNodeStore.getState().setNodes([]);
|
||||||
useNodeStore.getState().clearResolutionState();
|
|
||||||
useEdgeStore.getState().setEdges([]);
|
useEdgeStore.getState().setEdges([]);
|
||||||
useGraphStore.getState().reset();
|
useGraphStore.getState().reset();
|
||||||
useEdgeStore.getState().resetEdgeBeads();
|
useEdgeStore.getState().resetEdgeBeads();
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import {
|
|||||||
getBezierPath,
|
getBezierPath,
|
||||||
} from "@xyflow/react";
|
} from "@xyflow/react";
|
||||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
|
||||||
import { XIcon } from "@phosphor-icons/react";
|
import { XIcon } from "@phosphor-icons/react";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { NodeExecutionResult } from "@/lib/autogpt-server-api";
|
import { NodeExecutionResult } from "@/lib/autogpt-server-api";
|
||||||
@@ -36,8 +35,6 @@ const CustomEdge = ({
|
|||||||
selected,
|
selected,
|
||||||
}: EdgeProps<CustomEdge>) => {
|
}: EdgeProps<CustomEdge>) => {
|
||||||
const removeConnection = useEdgeStore((state) => state.removeEdge);
|
const removeConnection = useEdgeStore((state) => state.removeEdge);
|
||||||
// Subscribe to the brokenEdgeIDs map and check if this edge is broken across any node
|
|
||||||
const isBroken = useNodeStore((state) => state.isEdgeBroken(id));
|
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
const [edgePath, labelX, labelY] = getBezierPath({
|
const [edgePath, labelX, labelY] = getBezierPath({
|
||||||
@@ -53,12 +50,6 @@ const CustomEdge = ({
|
|||||||
const beadUp = data?.beadUp ?? 0;
|
const beadUp = data?.beadUp ?? 0;
|
||||||
const beadDown = data?.beadDown ?? 0;
|
const beadDown = data?.beadDown ?? 0;
|
||||||
|
|
||||||
const handleRemoveEdge = () => {
|
|
||||||
removeConnection(id);
|
|
||||||
// Note: broken edge tracking is cleaned up automatically by useSubAgentUpdateState
|
|
||||||
// when it detects the edge no longer exists
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<BaseEdge
|
<BaseEdge
|
||||||
@@ -66,11 +57,9 @@ const CustomEdge = ({
|
|||||||
markerEnd={markerEnd}
|
markerEnd={markerEnd}
|
||||||
className={cn(
|
className={cn(
|
||||||
isStatic && "!stroke-[1.5px] [stroke-dasharray:6]",
|
isStatic && "!stroke-[1.5px] [stroke-dasharray:6]",
|
||||||
isBroken
|
selected
|
||||||
? "!stroke-red-500 !stroke-[2px] [stroke-dasharray:4]"
|
? "stroke-zinc-800"
|
||||||
: selected
|
: "stroke-zinc-500/50 hover:stroke-zinc-500",
|
||||||
? "stroke-zinc-800"
|
|
||||||
: "stroke-zinc-500/50 hover:stroke-zinc-500",
|
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
<JSBeads
|
<JSBeads
|
||||||
@@ -81,16 +70,12 @@ const CustomEdge = ({
|
|||||||
/>
|
/>
|
||||||
<EdgeLabelRenderer>
|
<EdgeLabelRenderer>
|
||||||
<Button
|
<Button
|
||||||
onClick={handleRemoveEdge}
|
onClick={() => removeConnection(id)}
|
||||||
className={cn(
|
className={cn(
|
||||||
"absolute h-fit min-w-0 p-1 transition-opacity",
|
"absolute h-fit min-w-0 p-1 transition-opacity",
|
||||||
isBroken
|
isHovered ? "opacity-100" : "opacity-0",
|
||||||
? "bg-red-500 opacity-100 hover:bg-red-600"
|
|
||||||
: isHovered
|
|
||||||
? "opacity-100"
|
|
||||||
: "opacity-0",
|
|
||||||
)}
|
)}
|
||||||
variant={isBroken ? "primary" : "secondary"}
|
variant="secondary"
|
||||||
style={{
|
style={{
|
||||||
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
||||||
pointerEvents: "all",
|
pointerEvents: "all",
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import { Handle, Position } from "@xyflow/react";
|
|||||||
import { useEdgeStore } from "../../../stores/edgeStore";
|
import { useEdgeStore } from "../../../stores/edgeStore";
|
||||||
import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers";
|
import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { useNodeStore } from "../../../stores/nodeStore";
|
|
||||||
|
|
||||||
const InputNodeHandle = ({
|
const InputNodeHandle = ({
|
||||||
handleId,
|
handleId,
|
||||||
@@ -16,9 +15,6 @@ const InputNodeHandle = ({
|
|||||||
const isInputConnected = useEdgeStore((state) =>
|
const isInputConnected = useEdgeStore((state) =>
|
||||||
state.isInputConnected(nodeId ?? "", cleanedHandleId),
|
state.isInputConnected(nodeId ?? "", cleanedHandleId),
|
||||||
);
|
);
|
||||||
const isInputBroken = useNodeStore((state) =>
|
|
||||||
state.isInputBroken(nodeId, cleanedHandleId),
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Handle
|
<Handle
|
||||||
@@ -31,10 +27,7 @@ const InputNodeHandle = ({
|
|||||||
<CircleIcon
|
<CircleIcon
|
||||||
size={16}
|
size={16}
|
||||||
weight={isInputConnected ? "fill" : "duotone"}
|
weight={isInputConnected ? "fill" : "duotone"}
|
||||||
className={cn(
|
className={"text-gray-400 opacity-100"}
|
||||||
"text-gray-400 opacity-100",
|
|
||||||
isInputBroken && "text-red-500",
|
|
||||||
)}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</Handle>
|
</Handle>
|
||||||
@@ -45,17 +38,14 @@ const OutputNodeHandle = ({
|
|||||||
field_name,
|
field_name,
|
||||||
nodeId,
|
nodeId,
|
||||||
hexColor,
|
hexColor,
|
||||||
isBroken,
|
|
||||||
}: {
|
}: {
|
||||||
field_name: string;
|
field_name: string;
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
hexColor: string;
|
hexColor: string;
|
||||||
isBroken: boolean;
|
|
||||||
}) => {
|
}) => {
|
||||||
const isOutputConnected = useEdgeStore((state) =>
|
const isOutputConnected = useEdgeStore((state) =>
|
||||||
state.isOutputConnected(nodeId, field_name),
|
state.isOutputConnected(nodeId, field_name),
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Handle
|
<Handle
|
||||||
type={"source"}
|
type={"source"}
|
||||||
@@ -68,10 +58,7 @@ const OutputNodeHandle = ({
|
|||||||
size={16}
|
size={16}
|
||||||
weight={"duotone"}
|
weight={"duotone"}
|
||||||
color={isOutputConnected ? hexColor : "gray"}
|
color={isOutputConnected ? hexColor : "gray"}
|
||||||
className={cn(
|
className={cn("text-gray-400 opacity-100")}
|
||||||
"text-gray-400 opacity-100",
|
|
||||||
isBroken && "text-red-500",
|
|
||||||
)}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</Handle>
|
</Handle>
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput";
|
|||||||
import { NodeRightClickMenu } from "./components/NodeRightClickMenu";
|
import { NodeRightClickMenu } from "./components/NodeRightClickMenu";
|
||||||
import { StickyNoteBlock } from "./components/StickyNoteBlock";
|
import { StickyNoteBlock } from "./components/StickyNoteBlock";
|
||||||
import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
|
import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
|
||||||
import { SubAgentUpdateFeature } from "./components/SubAgentUpdate/SubAgentUpdateFeature";
|
|
||||||
import { useCustomNode } from "./useCustomNode";
|
|
||||||
|
|
||||||
export type CustomNodeData = {
|
export type CustomNodeData = {
|
||||||
hardcodedValues: {
|
hardcodedValues: {
|
||||||
@@ -47,10 +45,6 @@ export type CustomNode = XYNode<CustomNodeData, "custom">;
|
|||||||
|
|
||||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||||
({ data, id: nodeId, selected }) => {
|
({ data, id: nodeId, selected }) => {
|
||||||
const { inputSchema, outputSchema } = useCustomNode({ data, nodeId });
|
|
||||||
|
|
||||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
|
||||||
|
|
||||||
if (data.uiType === BlockUIType.NOTE) {
|
if (data.uiType === BlockUIType.NOTE) {
|
||||||
return (
|
return (
|
||||||
<StickyNoteBlock data={data} selected={selected} nodeId={nodeId} />
|
<StickyNoteBlock data={data} selected={selected} nodeId={nodeId} />
|
||||||
@@ -69,6 +63,16 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
|||||||
|
|
||||||
const isAyrshare = data.uiType === BlockUIType.AYRSHARE;
|
const isAyrshare = data.uiType === BlockUIType.AYRSHARE;
|
||||||
|
|
||||||
|
const inputSchema =
|
||||||
|
data.uiType === BlockUIType.AGENT
|
||||||
|
? (data.hardcodedValues.input_schema ?? {})
|
||||||
|
: data.inputSchema;
|
||||||
|
|
||||||
|
const outputSchema =
|
||||||
|
data.uiType === BlockUIType.AGENT
|
||||||
|
? (data.hardcodedValues.output_schema ?? {})
|
||||||
|
: data.outputSchema;
|
||||||
|
|
||||||
const hasConfigErrors =
|
const hasConfigErrors =
|
||||||
data.errors &&
|
data.errors &&
|
||||||
Object.values(data.errors).some(
|
Object.values(data.errors).some(
|
||||||
@@ -83,11 +87,12 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
|||||||
|
|
||||||
const hasErrors = hasConfigErrors || hasOutputError;
|
const hasErrors = hasConfigErrors || hasOutputError;
|
||||||
|
|
||||||
|
// Currently all blockTypes design are similar - that's why i am using the same component for all of them
|
||||||
|
// If in future - if we need some drastic change in some blockTypes design - we can create separate components for them
|
||||||
const node = (
|
const node = (
|
||||||
<NodeContainer selected={selected} nodeId={nodeId} hasErrors={hasErrors}>
|
<NodeContainer selected={selected} nodeId={nodeId} hasErrors={hasErrors}>
|
||||||
<div className="rounded-xlarge bg-white">
|
<div className="rounded-xlarge bg-white">
|
||||||
<NodeHeader data={data} nodeId={nodeId} />
|
<NodeHeader data={data} nodeId={nodeId} />
|
||||||
{isAgent && <SubAgentUpdateFeature nodeID={nodeId} nodeData={data} />}
|
|
||||||
{isWebhook && <WebhookDisclaimer nodeId={nodeId} />}
|
{isWebhook && <WebhookDisclaimer nodeId={nodeId} />}
|
||||||
{isAyrshare && <AyrshareConnectButton />}
|
{isAyrshare && <AyrshareConnectButton />}
|
||||||
<FormCreator
|
<FormCreator
|
||||||
|
|||||||
@@ -68,10 +68,7 @@ export const NodeHeader = ({ data, nodeId }: Props) => {
|
|||||||
<Tooltip>
|
<Tooltip>
|
||||||
<TooltipTrigger asChild>
|
<TooltipTrigger asChild>
|
||||||
<div>
|
<div>
|
||||||
<Text
|
<Text variant="large-semibold" className="line-clamp-1">
|
||||||
variant="large-semibold"
|
|
||||||
className="line-clamp-1 hover:cursor-text"
|
|
||||||
>
|
|
||||||
{beautifyString(title).replace("Block", "").trim()}
|
{beautifyString(title).replace("Block", "").trim()}
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ export const NodeDataViewer: FC<NodeDataViewerProps> = ({
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex justify-end pt-4">
|
<div className="flex justify-end pt-4">
|
||||||
{outputItems.length > 1 && (
|
{outputItems.length > 0 && (
|
||||||
<OutputActions
|
<OutputActions
|
||||||
items={outputItems.map((item) => ({
|
items={outputItems.map((item) => ({
|
||||||
value: item.value,
|
value: item.value,
|
||||||
|
|||||||
@@ -1,118 +0,0 @@
|
|||||||
import React from "react";
|
|
||||||
import { ArrowUpIcon, WarningIcon } from "@phosphor-icons/react";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import {
|
|
||||||
Tooltip,
|
|
||||||
TooltipContent,
|
|
||||||
TooltipTrigger,
|
|
||||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
|
||||||
import { cn, beautifyString } from "@/lib/utils";
|
|
||||||
import { CustomNodeData } from "../../CustomNode";
|
|
||||||
import { useSubAgentUpdateState } from "./useSubAgentUpdateState";
|
|
||||||
import { IncompatibleUpdateDialog } from "./components/IncompatibleUpdateDialog";
|
|
||||||
import { ResolutionModeBar } from "./components/ResolutionModeBar";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Inline component for the update bar that can be placed after the header.
|
|
||||||
* Use this inside the node content where you want the bar to appear.
|
|
||||||
*/
|
|
||||||
type SubAgentUpdateFeatureProps = {
|
|
||||||
nodeID: string;
|
|
||||||
nodeData: CustomNodeData;
|
|
||||||
};
|
|
||||||
|
|
||||||
export function SubAgentUpdateFeature({
|
|
||||||
nodeID,
|
|
||||||
nodeData,
|
|
||||||
}: SubAgentUpdateFeatureProps) {
|
|
||||||
const {
|
|
||||||
updateInfo,
|
|
||||||
isInResolutionMode,
|
|
||||||
handleUpdateClick,
|
|
||||||
showIncompatibilityDialog,
|
|
||||||
setShowIncompatibilityDialog,
|
|
||||||
handleConfirmIncompatibleUpdate,
|
|
||||||
} = useSubAgentUpdateState({ nodeID: nodeID, nodeData: nodeData });
|
|
||||||
|
|
||||||
const agentName = nodeData.title || "Agent";
|
|
||||||
|
|
||||||
if (!updateInfo.hasUpdate && !isInResolutionMode) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<>
|
|
||||||
{isInResolutionMode ? (
|
|
||||||
<ResolutionModeBar incompatibilities={updateInfo.incompatibilities} />
|
|
||||||
) : (
|
|
||||||
<SubAgentUpdateAvailableBar
|
|
||||||
currentVersion={updateInfo.currentVersion}
|
|
||||||
latestVersion={updateInfo.latestVersion}
|
|
||||||
isCompatible={updateInfo.isCompatible}
|
|
||||||
onUpdate={handleUpdateClick}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{/* Incompatibility dialog - rendered here since this component owns the state */}
|
|
||||||
{updateInfo.incompatibilities && (
|
|
||||||
<IncompatibleUpdateDialog
|
|
||||||
isOpen={showIncompatibilityDialog}
|
|
||||||
onClose={() => setShowIncompatibilityDialog(false)}
|
|
||||||
onConfirm={handleConfirmIncompatibleUpdate}
|
|
||||||
currentVersion={updateInfo.currentVersion}
|
|
||||||
latestVersion={updateInfo.latestVersion}
|
|
||||||
agentName={beautifyString(agentName)}
|
|
||||||
incompatibilities={updateInfo.incompatibilities}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
type SubAgentUpdateAvailableBarProps = {
|
|
||||||
currentVersion: number;
|
|
||||||
latestVersion: number;
|
|
||||||
isCompatible: boolean;
|
|
||||||
onUpdate: () => void;
|
|
||||||
};
|
|
||||||
|
|
||||||
function SubAgentUpdateAvailableBar({
|
|
||||||
currentVersion,
|
|
||||||
latestVersion,
|
|
||||||
isCompatible,
|
|
||||||
onUpdate,
|
|
||||||
}: SubAgentUpdateAvailableBarProps): React.ReactElement {
|
|
||||||
return (
|
|
||||||
<div className="flex items-center justify-between gap-2 rounded-t-xl bg-blue-50 px-3 py-2 dark:bg-blue-900/30">
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<ArrowUpIcon className="h-4 w-4 text-blue-600 dark:text-blue-400" />
|
|
||||||
<span className="text-sm text-blue-700 dark:text-blue-300">
|
|
||||||
Update available (v{currentVersion} → v{latestVersion})
|
|
||||||
</span>
|
|
||||||
{!isCompatible && (
|
|
||||||
<Tooltip>
|
|
||||||
<TooltipTrigger asChild>
|
|
||||||
<WarningIcon className="h-4 w-4 text-amber-500" />
|
|
||||||
</TooltipTrigger>
|
|
||||||
<TooltipContent className="max-w-xs">
|
|
||||||
<p className="font-medium">Incompatible changes detected</p>
|
|
||||||
<p className="text-xs text-gray-400">
|
|
||||||
Click Update to see details
|
|
||||||
</p>
|
|
||||||
</TooltipContent>
|
|
||||||
</Tooltip>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
<Button
|
|
||||||
size="small"
|
|
||||||
variant={isCompatible ? "primary" : "outline"}
|
|
||||||
onClick={onUpdate}
|
|
||||||
className={cn(
|
|
||||||
"h-7 text-xs",
|
|
||||||
!isCompatible && "border-amber-500 text-amber-600 hover:bg-amber-50",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
Update
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,274 +0,0 @@
|
|||||||
import React from "react";
|
|
||||||
import {
|
|
||||||
WarningIcon,
|
|
||||||
XCircleIcon,
|
|
||||||
PlusCircleIcon,
|
|
||||||
} from "@phosphor-icons/react";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
|
||||||
import { beautifyString } from "@/lib/utils";
|
|
||||||
import { IncompatibilityInfo } from "@/app/(platform)/build/hooks/useSubAgentUpdate/types";
|
|
||||||
|
|
||||||
type IncompatibleUpdateDialogProps = {
|
|
||||||
isOpen: boolean;
|
|
||||||
onClose: () => void;
|
|
||||||
onConfirm: () => void;
|
|
||||||
currentVersion: number;
|
|
||||||
latestVersion: number;
|
|
||||||
agentName: string;
|
|
||||||
incompatibilities: IncompatibilityInfo;
|
|
||||||
};
|
|
||||||
|
|
||||||
export function IncompatibleUpdateDialog({
|
|
||||||
isOpen,
|
|
||||||
onClose,
|
|
||||||
onConfirm,
|
|
||||||
currentVersion,
|
|
||||||
latestVersion,
|
|
||||||
agentName,
|
|
||||||
incompatibilities,
|
|
||||||
}: IncompatibleUpdateDialogProps) {
|
|
||||||
const hasMissingInputs = incompatibilities.missingInputs.length > 0;
|
|
||||||
const hasMissingOutputs = incompatibilities.missingOutputs.length > 0;
|
|
||||||
const hasNewInputs = incompatibilities.newInputs.length > 0;
|
|
||||||
const hasNewOutputs = incompatibilities.newOutputs.length > 0;
|
|
||||||
const hasNewRequired = incompatibilities.newRequiredInputs.length > 0;
|
|
||||||
const hasTypeMismatches = incompatibilities.inputTypeMismatches.length > 0;
|
|
||||||
|
|
||||||
const hasInputChanges = hasMissingInputs || hasNewInputs;
|
|
||||||
const hasOutputChanges = hasMissingOutputs || hasNewOutputs;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Dialog
|
|
||||||
title={
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<WarningIcon className="h-5 w-5 text-amber-500" weight="fill" />
|
|
||||||
Incompatible Update
|
|
||||||
</div>
|
|
||||||
}
|
|
||||||
controlled={{
|
|
||||||
isOpen,
|
|
||||||
set: async (open) => {
|
|
||||||
if (!open) onClose();
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
onClose={onClose}
|
|
||||||
styling={{ maxWidth: "32rem" }}
|
|
||||||
>
|
|
||||||
<Dialog.Content>
|
|
||||||
<div className="space-y-4">
|
|
||||||
<p className="text-sm text-gray-600 dark:text-gray-400">
|
|
||||||
Updating <strong>{beautifyString(agentName)}</strong> from v
|
|
||||||
{currentVersion} to v{latestVersion} will break some connections.
|
|
||||||
</p>
|
|
||||||
|
|
||||||
{/* Input changes - two column layout */}
|
|
||||||
{hasInputChanges && (
|
|
||||||
<TwoColumnSection
|
|
||||||
title="Input Changes"
|
|
||||||
leftIcon={
|
|
||||||
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
|
||||||
}
|
|
||||||
leftTitle="Removed"
|
|
||||||
leftItems={incompatibilities.missingInputs}
|
|
||||||
rightIcon={
|
|
||||||
<PlusCircleIcon
|
|
||||||
className="h-4 w-4 text-green-500"
|
|
||||||
weight="fill"
|
|
||||||
/>
|
|
||||||
}
|
|
||||||
rightTitle="Added"
|
|
||||||
rightItems={incompatibilities.newInputs}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Output changes - two column layout */}
|
|
||||||
{hasOutputChanges && (
|
|
||||||
<TwoColumnSection
|
|
||||||
title="Output Changes"
|
|
||||||
leftIcon={
|
|
||||||
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
|
||||||
}
|
|
||||||
leftTitle="Removed"
|
|
||||||
leftItems={incompatibilities.missingOutputs}
|
|
||||||
rightIcon={
|
|
||||||
<PlusCircleIcon
|
|
||||||
className="h-4 w-4 text-green-500"
|
|
||||||
weight="fill"
|
|
||||||
/>
|
|
||||||
}
|
|
||||||
rightTitle="Added"
|
|
||||||
rightItems={incompatibilities.newOutputs}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{hasTypeMismatches && (
|
|
||||||
<SingleColumnSection
|
|
||||||
icon={
|
|
||||||
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
|
||||||
}
|
|
||||||
title="Type Changed"
|
|
||||||
description="These connected inputs have a different type:"
|
|
||||||
items={incompatibilities.inputTypeMismatches.map(
|
|
||||||
(m) => `${m.name} (${m.oldType} → ${m.newType})`,
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{hasNewRequired && (
|
|
||||||
<SingleColumnSection
|
|
||||||
icon={
|
|
||||||
<PlusCircleIcon
|
|
||||||
className="h-4 w-4 text-amber-500"
|
|
||||||
weight="fill"
|
|
||||||
/>
|
|
||||||
}
|
|
||||||
title="New Required Inputs"
|
|
||||||
description="These inputs are now required:"
|
|
||||||
items={incompatibilities.newRequiredInputs}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<Alert variant="warning">
|
|
||||||
<AlertDescription>
|
|
||||||
If you proceed, you'll need to remove the broken connections
|
|
||||||
before you can save or run your agent.
|
|
||||||
</AlertDescription>
|
|
||||||
</Alert>
|
|
||||||
|
|
||||||
<Dialog.Footer>
|
|
||||||
<Button variant="ghost" size="small" onClick={onClose}>
|
|
||||||
Cancel
|
|
||||||
</Button>
|
|
||||||
<Button
|
|
||||||
variant="primary"
|
|
||||||
size="small"
|
|
||||||
onClick={onConfirm}
|
|
||||||
className="border-amber-700 bg-amber-600 hover:bg-amber-700"
|
|
||||||
>
|
|
||||||
Update Anyway
|
|
||||||
</Button>
|
|
||||||
</Dialog.Footer>
|
|
||||||
</div>
|
|
||||||
</Dialog.Content>
|
|
||||||
</Dialog>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
type TwoColumnSectionProps = {
|
|
||||||
title: string;
|
|
||||||
leftIcon: React.ReactNode;
|
|
||||||
leftTitle: string;
|
|
||||||
leftItems: string[];
|
|
||||||
rightIcon: React.ReactNode;
|
|
||||||
rightTitle: string;
|
|
||||||
rightItems: string[];
|
|
||||||
};
|
|
||||||
|
|
||||||
function TwoColumnSection({
|
|
||||||
title,
|
|
||||||
leftIcon,
|
|
||||||
leftTitle,
|
|
||||||
leftItems,
|
|
||||||
rightIcon,
|
|
||||||
rightTitle,
|
|
||||||
rightItems,
|
|
||||||
}: TwoColumnSectionProps) {
|
|
||||||
return (
|
|
||||||
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
|
||||||
<span className="font-medium">{title}</span>
|
|
||||||
<div className="mt-2 grid grid-cols-2 items-start gap-4">
|
|
||||||
{/* Left column - Breaking changes */}
|
|
||||||
<div className="min-w-0">
|
|
||||||
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
|
||||||
{leftIcon}
|
|
||||||
<span>{leftTitle}</span>
|
|
||||||
</div>
|
|
||||||
<ul className="mt-1.5 space-y-1">
|
|
||||||
{leftItems.length > 0 ? (
|
|
||||||
leftItems.map((item) => (
|
|
||||||
<li
|
|
||||||
key={item}
|
|
||||||
className="text-sm text-gray-700 dark:text-gray-300"
|
|
||||||
>
|
|
||||||
<code className="rounded bg-red-50 px-1 py-0.5 font-mono text-xs text-red-700 dark:bg-red-900/30 dark:text-red-300">
|
|
||||||
{item}
|
|
||||||
</code>
|
|
||||||
</li>
|
|
||||||
))
|
|
||||||
) : (
|
|
||||||
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
|
||||||
None
|
|
||||||
</li>
|
|
||||||
)}
|
|
||||||
</ul>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Right column - Possible solutions */}
|
|
||||||
<div className="min-w-0">
|
|
||||||
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
|
||||||
{rightIcon}
|
|
||||||
<span>{rightTitle}</span>
|
|
||||||
</div>
|
|
||||||
<ul className="mt-1.5 space-y-1">
|
|
||||||
{rightItems.length > 0 ? (
|
|
||||||
rightItems.map((item) => (
|
|
||||||
<li
|
|
||||||
key={item}
|
|
||||||
className="text-sm text-gray-700 dark:text-gray-300"
|
|
||||||
>
|
|
||||||
<code className="rounded bg-green-50 px-1 py-0.5 font-mono text-xs text-green-700 dark:bg-green-900/30 dark:text-green-300">
|
|
||||||
{item}
|
|
||||||
</code>
|
|
||||||
</li>
|
|
||||||
))
|
|
||||||
) : (
|
|
||||||
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
|
||||||
None
|
|
||||||
</li>
|
|
||||||
)}
|
|
||||||
</ul>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
type SingleColumnSectionProps = {
|
|
||||||
icon: React.ReactNode;
|
|
||||||
title: string;
|
|
||||||
description: string;
|
|
||||||
items: string[];
|
|
||||||
};
|
|
||||||
|
|
||||||
function SingleColumnSection({
|
|
||||||
icon,
|
|
||||||
title,
|
|
||||||
description,
|
|
||||||
items,
|
|
||||||
}: SingleColumnSectionProps) {
|
|
||||||
return (
|
|
||||||
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
{icon}
|
|
||||||
<span className="font-medium">{title}</span>
|
|
||||||
</div>
|
|
||||||
<p className="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
|
||||||
{description}
|
|
||||||
</p>
|
|
||||||
<ul className="mt-2 space-y-1">
|
|
||||||
{items.map((item) => (
|
|
||||||
<li
|
|
||||||
key={item}
|
|
||||||
className="ml-4 list-disc text-sm text-gray-700 dark:text-gray-300"
|
|
||||||
>
|
|
||||||
<code className="rounded bg-gray-100 px-1 py-0.5 font-mono text-xs dark:bg-gray-800">
|
|
||||||
{item}
|
|
||||||
</code>
|
|
||||||
</li>
|
|
||||||
))}
|
|
||||||
</ul>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
import React from "react";
|
|
||||||
import { InfoIcon, WarningIcon } from "@phosphor-icons/react";
|
|
||||||
import {
|
|
||||||
Tooltip,
|
|
||||||
TooltipContent,
|
|
||||||
TooltipTrigger,
|
|
||||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
|
||||||
import { IncompatibilityInfo } from "@/app/(platform)/build/hooks/useSubAgentUpdate/types";
|
|
||||||
|
|
||||||
type ResolutionModeBarProps = {
|
|
||||||
incompatibilities: IncompatibilityInfo | null;
|
|
||||||
};
|
|
||||||
|
|
||||||
export function ResolutionModeBar({
|
|
||||||
incompatibilities,
|
|
||||||
}: ResolutionModeBarProps): React.ReactElement {
|
|
||||||
const renderIncompatibilities = () => {
|
|
||||||
if (!incompatibilities) return <span>No incompatibilities</span>;
|
|
||||||
|
|
||||||
const sections: React.ReactNode[] = [];
|
|
||||||
|
|
||||||
if (incompatibilities.missingInputs.length > 0) {
|
|
||||||
sections.push(
|
|
||||||
<div key="missing-inputs" className="mb-1">
|
|
||||||
<span className="font-semibold">Missing inputs: </span>
|
|
||||||
{incompatibilities.missingInputs.map((name, i) => (
|
|
||||||
<React.Fragment key={name}>
|
|
||||||
<code className="font-mono">{name}</code>
|
|
||||||
{i < incompatibilities.missingInputs.length - 1 && ", "}
|
|
||||||
</React.Fragment>
|
|
||||||
))}
|
|
||||||
</div>,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if (incompatibilities.missingOutputs.length > 0) {
|
|
||||||
sections.push(
|
|
||||||
<div key="missing-outputs" className="mb-1">
|
|
||||||
<span className="font-semibold">Missing outputs: </span>
|
|
||||||
{incompatibilities.missingOutputs.map((name, i) => (
|
|
||||||
<React.Fragment key={name}>
|
|
||||||
<code className="font-mono">{name}</code>
|
|
||||||
{i < incompatibilities.missingOutputs.length - 1 && ", "}
|
|
||||||
</React.Fragment>
|
|
||||||
))}
|
|
||||||
</div>,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if (incompatibilities.newRequiredInputs.length > 0) {
|
|
||||||
sections.push(
|
|
||||||
<div key="new-required" className="mb-1">
|
|
||||||
<span className="font-semibold">New required inputs: </span>
|
|
||||||
{incompatibilities.newRequiredInputs.map((name, i) => (
|
|
||||||
<React.Fragment key={name}>
|
|
||||||
<code className="font-mono">{name}</code>
|
|
||||||
{i < incompatibilities.newRequiredInputs.length - 1 && ", "}
|
|
||||||
</React.Fragment>
|
|
||||||
))}
|
|
||||||
</div>,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if (incompatibilities.inputTypeMismatches.length > 0) {
|
|
||||||
sections.push(
|
|
||||||
<div key="type-mismatches" className="mb-1">
|
|
||||||
<span className="font-semibold">Type changed: </span>
|
|
||||||
{incompatibilities.inputTypeMismatches.map((m, i) => (
|
|
||||||
<React.Fragment key={m.name}>
|
|
||||||
<code className="font-mono">{m.name}</code>
|
|
||||||
<span className="text-gray-400">
|
|
||||||
{" "}
|
|
||||||
({m.oldType} → {m.newType})
|
|
||||||
</span>
|
|
||||||
{i < incompatibilities.inputTypeMismatches.length - 1 && ", "}
|
|
||||||
</React.Fragment>
|
|
||||||
))}
|
|
||||||
</div>,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return <>{sections}</>;
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="flex items-center justify-between gap-2 rounded-t-xl bg-amber-50 px-3 py-2 dark:bg-amber-900/30">
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<WarningIcon className="h-4 w-4 text-amber-600 dark:text-amber-400" />
|
|
||||||
<span className="text-sm text-amber-700 dark:text-amber-300">
|
|
||||||
Remove incompatible connections
|
|
||||||
</span>
|
|
||||||
<Tooltip>
|
|
||||||
<TooltipTrigger asChild>
|
|
||||||
<InfoIcon className="h-4 w-4 cursor-help text-amber-500" />
|
|
||||||
</TooltipTrigger>
|
|
||||||
<TooltipContent className="max-w-sm">
|
|
||||||
<p className="mb-2 font-semibold">Incompatible changes:</p>
|
|
||||||
<div className="text-xs">{renderIncompatibilities()}</div>
|
|
||||||
<p className="mt-2 text-xs text-gray-400">
|
|
||||||
{(incompatibilities?.newRequiredInputs.length ?? 0) > 0
|
|
||||||
? "Replace / delete"
|
|
||||||
: "Delete"}{" "}
|
|
||||||
the red connections to continue
|
|
||||||
</p>
|
|
||||||
</TooltipContent>
|
|
||||||
</Tooltip>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,194 +0,0 @@
|
|||||||
import { useState, useCallback, useEffect } from "react";
|
|
||||||
import { useShallow } from "zustand/react/shallow";
|
|
||||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
|
||||||
import {
|
|
||||||
useNodeStore,
|
|
||||||
NodeResolutionData,
|
|
||||||
} from "@/app/(platform)/build/stores/nodeStore";
|
|
||||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
|
||||||
import {
|
|
||||||
useSubAgentUpdate,
|
|
||||||
createUpdatedAgentNodeInputs,
|
|
||||||
getBrokenEdgeIDs,
|
|
||||||
} from "@/app/(platform)/build/hooks/useSubAgentUpdate";
|
|
||||||
import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api";
|
|
||||||
import { CustomNodeData } from "../../CustomNode";
|
|
||||||
|
|
||||||
// Stable empty set to avoid creating new references in selectors
|
|
||||||
const EMPTY_SET: Set<string> = new Set();
|
|
||||||
|
|
||||||
type UseSubAgentUpdateParams = {
|
|
||||||
nodeID: string;
|
|
||||||
nodeData: CustomNodeData;
|
|
||||||
};
|
|
||||||
|
|
||||||
export function useSubAgentUpdateState({
|
|
||||||
nodeID,
|
|
||||||
nodeData,
|
|
||||||
}: UseSubAgentUpdateParams) {
|
|
||||||
const [showIncompatibilityDialog, setShowIncompatibilityDialog] =
|
|
||||||
useState(false);
|
|
||||||
|
|
||||||
// Get store actions
|
|
||||||
const updateNodeData = useNodeStore(
|
|
||||||
useShallow((state) => state.updateNodeData),
|
|
||||||
);
|
|
||||||
const setNodeResolutionMode = useNodeStore(
|
|
||||||
useShallow((state) => state.setNodeResolutionMode),
|
|
||||||
);
|
|
||||||
const isNodeInResolutionMode = useNodeStore(
|
|
||||||
useShallow((state) => state.isNodeInResolutionMode),
|
|
||||||
);
|
|
||||||
const setBrokenEdgeIDs = useNodeStore(
|
|
||||||
useShallow((state) => state.setBrokenEdgeIDs),
|
|
||||||
);
|
|
||||||
// Get this node's broken edge IDs from the per-node map
|
|
||||||
// Use EMPTY_SET as fallback to maintain referential stability
|
|
||||||
const brokenEdgeIDs = useNodeStore(
|
|
||||||
(state) => state.brokenEdgeIDs.get(nodeID) || EMPTY_SET,
|
|
||||||
);
|
|
||||||
const getNodeResolutionData = useNodeStore(
|
|
||||||
useShallow((state) => state.getNodeResolutionData),
|
|
||||||
);
|
|
||||||
const connectedEdges = useEdgeStore(
|
|
||||||
useShallow((state) => state.getNodeEdges(nodeID)),
|
|
||||||
);
|
|
||||||
const availableSubGraphs = useGraphStore(
|
|
||||||
useShallow((state) => state.availableSubGraphs),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Extract agent-specific data
|
|
||||||
const graphID = nodeData.hardcodedValues?.graph_id as string | undefined;
|
|
||||||
const graphVersion = nodeData.hardcodedValues?.graph_version as
|
|
||||||
| number
|
|
||||||
| undefined;
|
|
||||||
const currentInputSchema = nodeData.hardcodedValues?.input_schema as
|
|
||||||
| GraphInputSchema
|
|
||||||
| undefined;
|
|
||||||
const currentOutputSchema = nodeData.hardcodedValues?.output_schema as
|
|
||||||
| GraphOutputSchema
|
|
||||||
| undefined;
|
|
||||||
|
|
||||||
// Use the sub-agent update hook
|
|
||||||
const updateInfo = useSubAgentUpdate(
|
|
||||||
nodeID,
|
|
||||||
graphID,
|
|
||||||
graphVersion,
|
|
||||||
currentInputSchema,
|
|
||||||
currentOutputSchema,
|
|
||||||
connectedEdges,
|
|
||||||
availableSubGraphs,
|
|
||||||
);
|
|
||||||
|
|
||||||
const isInResolutionMode = isNodeInResolutionMode(nodeID);
|
|
||||||
|
|
||||||
// Handle update button click
|
|
||||||
const handleUpdateClick = useCallback(() => {
|
|
||||||
if (!updateInfo.hasUpdate || !updateInfo.latestGraph) return;
|
|
||||||
|
|
||||||
if (updateInfo.isCompatible) {
|
|
||||||
// Compatible update - apply directly
|
|
||||||
const newHardcodedValues = createUpdatedAgentNodeInputs(
|
|
||||||
nodeData.hardcodedValues,
|
|
||||||
updateInfo.latestGraph,
|
|
||||||
);
|
|
||||||
updateNodeData(nodeID, { hardcodedValues: newHardcodedValues });
|
|
||||||
} else {
|
|
||||||
// Incompatible update - show dialog
|
|
||||||
setShowIncompatibilityDialog(true);
|
|
||||||
}
|
|
||||||
}, [
|
|
||||||
updateInfo.hasUpdate,
|
|
||||||
updateInfo.latestGraph,
|
|
||||||
updateInfo.isCompatible,
|
|
||||||
nodeData.hardcodedValues,
|
|
||||||
updateNodeData,
|
|
||||||
nodeID,
|
|
||||||
]);
|
|
||||||
|
|
||||||
// Handle confirming an incompatible update
|
|
||||||
function handleConfirmIncompatibleUpdate() {
|
|
||||||
if (!updateInfo.latestGraph || !updateInfo.incompatibilities) return;
|
|
||||||
|
|
||||||
const latestGraph = updateInfo.latestGraph;
|
|
||||||
|
|
||||||
// Get the new schemas from the latest graph version
|
|
||||||
const newInputSchema =
|
|
||||||
(latestGraph.input_schema as Record<string, unknown>) || {};
|
|
||||||
const newOutputSchema =
|
|
||||||
(latestGraph.output_schema as Record<string, unknown>) || {};
|
|
||||||
|
|
||||||
// Create the updated hardcoded values but DON'T apply them yet
|
|
||||||
// We'll apply them when resolution is complete
|
|
||||||
const pendingHardcodedValues = createUpdatedAgentNodeInputs(
|
|
||||||
nodeData.hardcodedValues,
|
|
||||||
latestGraph,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Get broken edge IDs and store them for this node
|
|
||||||
const brokenIds = getBrokenEdgeIDs(
|
|
||||||
connectedEdges,
|
|
||||||
updateInfo.incompatibilities,
|
|
||||||
nodeID,
|
|
||||||
);
|
|
||||||
setBrokenEdgeIDs(nodeID, brokenIds);
|
|
||||||
|
|
||||||
// Enter resolution mode with both old and new schemas
|
|
||||||
// DON'T apply the update yet - keep old schema so connections remain visible
|
|
||||||
const resolutionData: NodeResolutionData = {
|
|
||||||
incompatibilities: updateInfo.incompatibilities,
|
|
||||||
pendingUpdate: {
|
|
||||||
input_schema: newInputSchema,
|
|
||||||
output_schema: newOutputSchema,
|
|
||||||
},
|
|
||||||
currentSchema: {
|
|
||||||
input_schema: (currentInputSchema as Record<string, unknown>) || {},
|
|
||||||
output_schema: (currentOutputSchema as Record<string, unknown>) || {},
|
|
||||||
},
|
|
||||||
pendingHardcodedValues,
|
|
||||||
};
|
|
||||||
setNodeResolutionMode(nodeID, true, resolutionData);
|
|
||||||
|
|
||||||
setShowIncompatibilityDialog(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if resolution is complete (all broken edges removed)
|
|
||||||
const resolutionData = getNodeResolutionData(nodeID);
|
|
||||||
|
|
||||||
// Auto-check resolution on edge changes
|
|
||||||
useEffect(() => {
|
|
||||||
if (!isInResolutionMode) return;
|
|
||||||
|
|
||||||
// Check if any broken edges still exist
|
|
||||||
const remainingBroken = Array.from(brokenEdgeIDs).filter((edgeId) =>
|
|
||||||
connectedEdges.some((e) => e.id === edgeId),
|
|
||||||
);
|
|
||||||
|
|
||||||
if (remainingBroken.length === 0) {
|
|
||||||
// Resolution complete - now apply the pending update
|
|
||||||
if (resolutionData?.pendingHardcodedValues) {
|
|
||||||
updateNodeData(nodeID, {
|
|
||||||
hardcodedValues: resolutionData.pendingHardcodedValues,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
// setNodeResolutionMode will clean up this node's broken edges automatically
|
|
||||||
setNodeResolutionMode(nodeID, false);
|
|
||||||
}
|
|
||||||
}, [
|
|
||||||
isInResolutionMode,
|
|
||||||
brokenEdgeIDs,
|
|
||||||
connectedEdges,
|
|
||||||
resolutionData,
|
|
||||||
nodeID,
|
|
||||||
]);
|
|
||||||
|
|
||||||
return {
|
|
||||||
updateInfo,
|
|
||||||
isInResolutionMode,
|
|
||||||
resolutionData,
|
|
||||||
showIncompatibilityDialog,
|
|
||||||
setShowIncompatibilityDialog,
|
|
||||||
handleUpdateClick,
|
|
||||||
handleConfirmIncompatibleUpdate,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,4 @@
|
|||||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||||
import { NodeResolutionData } from "@/app/(platform)/build/stores/nodeStore";
|
|
||||||
import { RJSFSchema } from "@rjsf/utils";
|
|
||||||
|
|
||||||
export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
||||||
INCOMPLETE: "ring-slate-300 bg-slate-300",
|
INCOMPLETE: "ring-slate-300 bg-slate-300",
|
||||||
@@ -11,48 +9,3 @@ export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
|||||||
TERMINATED: "ring-orange-300 bg-orange-300 ",
|
TERMINATED: "ring-orange-300 bg-orange-300 ",
|
||||||
FAILED: "ring-red-300 bg-red-300",
|
FAILED: "ring-red-300 bg-red-300",
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* Merges schemas during resolution mode to include removed inputs/outputs
|
|
||||||
* that still have connections, so users can see and delete them.
|
|
||||||
*/
|
|
||||||
export function mergeSchemaForResolution(
|
|
||||||
currentSchema: Record<string, unknown>,
|
|
||||||
newSchema: Record<string, unknown>,
|
|
||||||
resolutionData: NodeResolutionData,
|
|
||||||
type: "input" | "output",
|
|
||||||
): Record<string, unknown> {
|
|
||||||
const newProps = (newSchema.properties as RJSFSchema) || {};
|
|
||||||
const currentProps = (currentSchema.properties as RJSFSchema) || {};
|
|
||||||
const mergedProps = { ...newProps };
|
|
||||||
const incomp = resolutionData.incompatibilities;
|
|
||||||
|
|
||||||
if (type === "input") {
|
|
||||||
// Add back missing inputs that have connections
|
|
||||||
incomp.missingInputs.forEach((inputName: string) => {
|
|
||||||
if (currentProps[inputName]) {
|
|
||||||
mergedProps[inputName] = currentProps[inputName];
|
|
||||||
}
|
|
||||||
});
|
|
||||||
// Add back inputs with type mismatches (keep old type so connection works visually)
|
|
||||||
incomp.inputTypeMismatches.forEach(
|
|
||||||
(mismatch: { name: string; oldType: string; newType: string }) => {
|
|
||||||
if (currentProps[mismatch.name]) {
|
|
||||||
mergedProps[mismatch.name] = currentProps[mismatch.name];
|
|
||||||
}
|
|
||||||
},
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
// Add back missing outputs that have connections
|
|
||||||
incomp.missingOutputs.forEach((outputName: string) => {
|
|
||||||
if (currentProps[outputName]) {
|
|
||||||
mergedProps[outputName] = currentProps[outputName];
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
...newSchema,
|
|
||||||
properties: mergedProps,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
|
||||||
import { CustomNodeData } from "./CustomNode";
|
|
||||||
import { BlockUIType } from "../../../types";
|
|
||||||
import { useMemo } from "react";
|
|
||||||
import { mergeSchemaForResolution } from "./helpers";
|
|
||||||
|
|
||||||
export const useCustomNode = ({
|
|
||||||
data,
|
|
||||||
nodeId,
|
|
||||||
}: {
|
|
||||||
data: CustomNodeData;
|
|
||||||
nodeId: string;
|
|
||||||
}) => {
|
|
||||||
const isInResolutionMode = useNodeStore((state) =>
|
|
||||||
state.nodesInResolutionMode.has(nodeId),
|
|
||||||
);
|
|
||||||
const resolutionData = useNodeStore((state) =>
|
|
||||||
state.nodeResolutionData.get(nodeId),
|
|
||||||
);
|
|
||||||
|
|
||||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
|
||||||
|
|
||||||
const currentInputSchema = isAgent
|
|
||||||
? (data.hardcodedValues.input_schema ?? {})
|
|
||||||
: data.inputSchema;
|
|
||||||
const currentOutputSchema = isAgent
|
|
||||||
? (data.hardcodedValues.output_schema ?? {})
|
|
||||||
: data.outputSchema;
|
|
||||||
|
|
||||||
const inputSchema = useMemo(() => {
|
|
||||||
if (isAgent && isInResolutionMode && resolutionData) {
|
|
||||||
return mergeSchemaForResolution(
|
|
||||||
resolutionData.currentSchema.input_schema,
|
|
||||||
resolutionData.pendingUpdate.input_schema,
|
|
||||||
resolutionData,
|
|
||||||
"input",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return currentInputSchema;
|
|
||||||
}, [isAgent, isInResolutionMode, resolutionData, currentInputSchema]);
|
|
||||||
|
|
||||||
const outputSchema = useMemo(() => {
|
|
||||||
if (isAgent && isInResolutionMode && resolutionData) {
|
|
||||||
return mergeSchemaForResolution(
|
|
||||||
resolutionData.currentSchema.output_schema,
|
|
||||||
resolutionData.pendingUpdate.output_schema,
|
|
||||||
resolutionData,
|
|
||||||
"output",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return currentOutputSchema;
|
|
||||||
}, [isAgent, isInResolutionMode, resolutionData, currentOutputSchema]);
|
|
||||||
|
|
||||||
return {
|
|
||||||
inputSchema,
|
|
||||||
outputSchema,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user