Compare commits

..

17 Commits

Author SHA1 Message Date
Bentlybro
88532185bc format.... 2026-01-14 12:50:54 +00:00
Bentlybro
3a56343013 Add error field to ClaudeCodeBlock schema
Introduces an 'error' field to the ClaudeCodeBlock schema to store error messages when code execution fails.
2026-01-14 12:41:24 +00:00
Bentlybro
b02b5e0708 Handle timestamp command failure in ClaudeCodeBlock
Adds error handling for the timestamp command in ClaudeCodeBlock by raising a RuntimeError if the command fails, ensuring issues are surfaced when the timestamp cannot be captured.
2026-01-14 12:35:53 +00:00
Bentlybro
1239cb7a4d Fix race condition in timestamp capture for Claude Code
Adjusts the timestamp capture to use '1 second ago' instead of the current time to prevent race conditions with file creation during Claude Code execution.
2026-01-14 12:27:13 +00:00
Bentlybro
6bc05de917 Add sandbox_id field to ClaudeCodeBlock test cases
Introduces a 'sandbox_id' field with a value of None in the test cases for ClaudeCodeBlock, reflecting the use of dispose_sandbox=True in test_input.
2026-01-14 12:16:36 +00:00
Bentlybro
da04bcbea5 Filter extracted files by execution timestamp
Update ClaudeCodeBlock to only extract text files created or modified during the current execution by capturing a start timestamp and filtering files using the 'find' command with -newermt. This prevents returning files from previous runs and ensures more accurate file outputs.
2026-01-14 12:04:28 +00:00
Bentlybro
34c1644bd6 Improve sandbox_id handling and file extraction in ClaudeCodeBlock
Updated the Output schema to allow sandbox_id to be optional and clarified its behavior when the sandbox is disposed. Enhanced file extraction to clarify that all text files in the working directory are returned, not just those created or modified during execution. Adjusted logic to always yield sandbox_id (None if disposed) and improved documentation and comments for clarity.
2026-01-14 11:48:02 +00:00
Bentlybro
6220396cd7 Refactor imports and update credential descriptions
Moved json and uuid imports to the top of the file and updated credential field descriptions to use markdown links for better readability.
2026-01-14 11:32:11 +00:00
Bentlybro
75a92fc3f3 Always yield files in ClaudeCodeBlock output
Updated the ClaudeCodeBlock to always yield the 'files' key with an empty list if no files are present, ensuring consistent output that matches the Output schema.
2026-01-13 21:55:53 +00:00
Bentlybro
2c6353a6a2 Clarify timeout behavior and fix path slicing
Updated the timeout field description to specify that the timeout only applies when creating a new sandbox, not when reconnecting. Also fixed whitespace in relative path slicing for improved code clarity.
2026-01-13 21:49:28 +00:00
Bentlybro
21b70ae9ae Improve shell command safety and error handling
Use shlex.quote to safely escape shell arguments such as working_directory and session IDs, preventing shell injection vulnerabilities. Add explicit error handling for failed Claude command executions. Refine file listing to allow hidden files except for .git, and update comments for clarity.
2026-01-13 21:40:56 +00:00
Bentlybro
b007e02364 Fix credential title assignment and improve session validation
Corrects the assignment of the 'title' field in test credential inputs to use the actual title instead of type. Adds validation to require 'sandbox_id' when resuming a session with 'session_id', and improves error handling for setup command failures in ClaudeCodeBlock.
2026-01-13 21:21:54 +00:00
Bentlybro
bbc289894a Move ClaudeCodeBlock to separate module
Extracted the ClaudeCodeBlock implementation from code_executor.py into a new claude_code.py module for better separation of concerns and maintainability. Removed Anthropic test credentials from code_executor.py as they are now defined in claude_code.py.
2026-01-13 21:06:11 +00:00
Bentlybro
639ee5f073 Add conversation history support to ClaudeCodeBlock
Introduces a conversation_history input and output to enable restoring context when continuing a session in a new sandbox. Updates the code execution flow to maintain and return the full conversation history, and marks several schema fields as advanced for improved UI clarity.
2026-01-13 16:40:17 +00:00
Bentlybro
898781134d Add session and sandbox continuation to ClaudeCodeBlock
Introduces session_id and sandbox_id fields to support resuming previous conversations and reconnecting to existing sandboxes. Updates input/output schemas, command execution logic, and documentation to enable session continuation and sandbox reuse for Claude code execution.
2026-01-13 15:10:46 +00:00
Bentlybro
7a28db1649 Refactor ClaudeCodeBlock to return output files
Updated ClaudeCodeBlock to extract and return a list of files created or modified during code execution, replacing stdout and stderr logs in the output schema. Added FileOutput model, file extraction logic, and updated test mocks and method signatures accordingly.
2026-01-13 14:41:10 +00:00
Bentlybro
a916ea0f8f Add Claude Code block for Anthropic code execution
Introduces a new ClaudeCodeBlock that enables execution of coding tasks using Anthropic's Claude Code in an E2B sandbox. The block supports configuration of credentials, prompt, timeout, setup commands, and working directory, and handles sandbox lifecycle and output collection.
2026-01-13 14:28:13 +00:00
252 changed files with 2722 additions and 20330 deletions

View File

@@ -1,9 +1,6 @@
# Ignore everything by default, selectively add things to context
*
# Documentation (for embeddings/search)
!docs/
# Platform - Libs
!autogpt_platform/autogpt_libs/autogpt_libs/
!autogpt_platform/autogpt_libs/pyproject.toml

View File

@@ -176,7 +176,7 @@ jobs:
}
- name: Run Database Migrations
run: poetry run prisma migrate deploy
run: poetry run prisma migrate dev --name updates
env:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}

View File

@@ -11,7 +11,6 @@ on:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
merge_group:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }}
@@ -152,14 +151,6 @@ jobs:
run: |
cp ../.env.default ../.env
- name: Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -235,25 +226,13 @@ jobs:
- name: Run Playwright tests
run: pnpm test:no-build
continue-on-error: false
- name: Upload Playwright report
if: always()
- name: Upload Playwright artifacts
if: failure()
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: playwright-report
if-no-files-found: ignore
retention-days: 3
- name: Upload Playwright test results
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-test-results
path: test-results
if-no-files-found: ignore
retention-days: 3
- name: Print Final Docker Compose logs
if: always()

View File

@@ -6,10 +6,9 @@ start-core:
# Stop core services
stop-core:
docker compose stop
docker compose stop deps
reset-db:
docker compose stop db
rm -rf db/docker/volumes/db/data
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
@@ -61,4 +60,4 @@ help:
@echo " run-backend - Run the backend FastAPI server"
@echo " run-frontend - Run the frontend Next.js development server"
@echo " test-data - Run the test data creator"
@echo " load-store-agents - Load store agents from agents/ folder into test database"
@echo " load-store-agents - Load store agents from agents/ folder into test database"

View File

@@ -58,13 +58,6 @@ V0_API_KEY=
OPEN_ROUTER_API_KEY=
NVIDIA_API_KEY=
# Langfuse Prompt Management
# Used for managing the CoPilot system prompt externally
# Get credentials from https://cloud.langfuse.com or your self-hosted instance
LANGFUSE_PUBLIC_KEY=
LANGFUSE_SECRET_KEY=
LANGFUSE_HOST=https://cloud.langfuse.com
# OAuth Credentials
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
# e.g. http://localhost:3000/auth/integrations/oauth_callback

View File

@@ -18,4 +18,3 @@ load-tests/results/
load-tests/*.json
load-tests/*.log
load-tests/node_modules/*
migrations/*/rollback*.sql

View File

@@ -100,7 +100,6 @@ COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migration
FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend
COPY docs /app/docs
RUN poetry install --no-ansi --only-root
ENV PORT=8000

View File

@@ -1,57 +1,21 @@
"""
External API Application
This module defines the main FastAPI application for the external API,
which mounts the v1 and v2 sub-applications.
"""
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from backend.api.middleware.security import SecurityHeadersMiddleware
from backend.monitoring.instrumentation import instrument_fastapi
from .v1.app import v1_app
from .v2.app import v2_app
DESCRIPTION = """
The external API provides programmatic access to the AutoGPT Platform for building
integrations, automations, and custom applications.
### API Versions
| Version | End of Life | Path | Documentation |
|---------------------|-------------|------------------------|---------------|
| **v2** | | `/external-api/v2/...` | [v2 docs](v2/docs) |
| **v1** (deprecated) | 2025-05-01 | `/external-api/v1/...` | [v1 docs](v1/docs) |
**Recommendation**: New integrations should use v2.
For authentication details and usage examples, see the
[API Integration Guide](https://docs.agpt.co/platform/integrating/api-guide/).
"""
from .v1.routes import v1_router
external_api = FastAPI(
title="AutoGPT Platform API",
summary="External API for AutoGPT Platform integrations",
description=DESCRIPTION,
version="2.0.0",
title="AutoGPT External API",
description="External API for AutoGPT integrations",
docs_url="/docs",
redoc_url="/redoc",
version="1.0",
)
external_api.add_middleware(SecurityHeadersMiddleware)
external_api.include_router(v1_router, prefix="/v1")
@external_api.get("/", include_in_schema=False)
async def root_redirect() -> RedirectResponse:
"""Redirect root to API documentation."""
return RedirectResponse(url="/docs")
# Mount versioned sub-applications
# Each sub-app has its own /docs page at /v1/docs and /v2/docs
external_api.mount("/v1", v1_app)
external_api.mount("/v2", v2_app)
# Add Prometheus instrumentation to the main app
# Add Prometheus instrumentation
instrument_fastapi(
external_api,
service_name="external-api",

View File

@@ -1,39 +0,0 @@
"""
V1 External API Application
This module defines the FastAPI application for the v1 external API.
"""
from fastapi import FastAPI
from backend.api.middleware.security import SecurityHeadersMiddleware
from .routes import v1_router
DESCRIPTION = """
The v1 API provides access to core AutoGPT functionality for external integrations.
For authentication details and usage examples, see the
[API Integration Guide](https://docs.agpt.co/platform/integrating/api-guide/).
"""
v1_app = FastAPI(
title="AutoGPT Platform API",
summary="External API for AutoGPT Platform integrations (v1)",
description=DESCRIPTION,
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
openapi_tags=[
{"name": "user", "description": "User information"},
{"name": "blocks", "description": "Block operations"},
{"name": "graphs", "description": "Graph execution"},
{"name": "store", "description": "Marketplace agents and creators"},
{"name": "integrations", "description": "OAuth credential management"},
{"name": "tools", "description": "AI assistant tools"},
],
)
v1_app.add_middleware(SecurityHeadersMiddleware)
v1_app.include_router(v1_router)

View File

@@ -70,7 +70,7 @@ class RunAgentRequest(BaseModel):
)
def _create_ephemeral_session(user_id: str) -> ChatSession:
def _create_ephemeral_session(user_id: str | None) -> ChatSession:
"""Create an ephemeral session for stateless API requests."""
return ChatSession.new(user_id)

View File

@@ -1,9 +0,0 @@
"""
V2 External API
This module provides the v2 external API for programmatic access to the AutoGPT Platform.
"""
from .routes import v2_router
__all__ = ["v2_router"]

View File

@@ -1,82 +0,0 @@
"""
V2 External API Application
This module defines the FastAPI application for the v2 external API.
"""
from fastapi import FastAPI
from backend.api.middleware.security import SecurityHeadersMiddleware
from .routes import v2_router
DESCRIPTION = """
The v2 API provides comprehensive access to the AutoGPT Platform for building
integrations, automations, and custom applications.
### Key Improvements over v1
- **Consistent naming**: Uses `graph_id`/`graph_version` consistently
- **Better pagination**: All list endpoints support pagination
- **Comprehensive coverage**: Access to library, runs, schedules, credits, and more
- **Human-in-the-loop**: Review and approve agent decisions via the API
For authentication details and usage examples, see the
[API Integration Guide](https://docs.agpt.co/platform/integrating/api-guide/).
### Pagination
List endpoints return paginated responses. Use `page` and `page_size` query
parameters to navigate results. Maximum page size is 100 items.
"""
v2_app = FastAPI(
title="AutoGPT Platform External API",
summary="External API for AutoGPT Platform integrations (v2)",
description=DESCRIPTION,
version="2.0.0",
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
openapi_tags=[
{
"name": "graphs",
"description": "Create, update, and manage agent graphs",
},
{
"name": "schedules",
"description": "Manage scheduled graph executions",
},
{
"name": "blocks",
"description": "Discover available building blocks",
},
{
"name": "marketplace",
"description": "Browse agents and creators, manage submissions",
},
{
"name": "library",
"description": "Access your agent library and execute agents",
},
{
"name": "runs",
"description": "Monitor execution runs and human-in-the-loop reviews",
},
{
"name": "credits",
"description": "Check balance and view transaction history",
},
{
"name": "integrations",
"description": "Manage OAuth credentials for external services",
},
{
"name": "files",
"description": "Upload files for agent input",
},
],
)
v2_app.add_middleware(SecurityHeadersMiddleware)
v2_app.include_router(v2_router)

View File

@@ -1,140 +0,0 @@
"""
V2 External API - Blocks Endpoints
Provides read-only access to available building blocks.
"""
import logging
from typing import Any
from fastapi import APIRouter, Response, Security
from fastapi.concurrency import run_in_threadpool
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.api.external.middleware import require_permission
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.block import get_blocks
from backend.util.cache import cached
from backend.util.json import dumps
logger = logging.getLogger(__name__)
blocks_router = APIRouter()
# ============================================================================
# Models
# ============================================================================
class BlockCost(BaseModel):
"""Cost information for a block."""
cost_type: str = Field(description="Type of cost (e.g., 'per_call', 'per_token')")
cost_filter: dict[str, Any] = Field(
default_factory=dict, description="Conditions for this cost"
)
cost_amount: int = Field(description="Cost amount in credits")
class Block(BaseModel):
"""A building block that can be used in graphs."""
id: str
name: str
description: str
categories: list[str] = Field(default_factory=list)
input_schema: dict[str, Any]
output_schema: dict[str, Any]
costs: list[BlockCost] = Field(default_factory=list)
disabled: bool = Field(default=False)
class BlocksListResponse(BaseModel):
"""Response for listing blocks."""
blocks: list[Block]
total_count: int
# ============================================================================
# Internal Functions
# ============================================================================
def _compute_blocks_sync() -> str:
"""
Synchronous function to compute blocks data.
This does the heavy lifting: instantiate 226+ blocks, compute costs, serialize.
"""
from backend.data.credit import get_block_cost
block_classes = get_blocks()
result = []
for block_class in block_classes.values():
block_instance = block_class()
if not block_instance.disabled:
costs = get_block_cost(block_instance)
# Convert BlockCost BaseModel objects to dictionaries
costs_dict = [
cost.model_dump() if isinstance(cost, BaseModel) else cost
for cost in costs
]
result.append({**block_instance.to_dict(), "costs": costs_dict})
return dumps(result)
@cached(ttl_seconds=3600)
async def _get_cached_blocks() -> str:
"""
Async cached function with thundering herd protection.
On cache miss: runs heavy work in thread pool
On cache hit: returns cached string immediately
"""
return await run_in_threadpool(_compute_blocks_sync)
# ============================================================================
# Endpoints
# ============================================================================
@blocks_router.get(
path="",
summary="List available blocks",
responses={
200: {
"description": "List of available building blocks",
"content": {
"application/json": {
"schema": {
"items": {"additionalProperties": True, "type": "object"},
"type": "array",
}
}
},
}
},
)
async def list_blocks(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_BLOCK)
),
) -> Response:
"""
List all available building blocks that can be used in graphs.
Each block represents a specific capability (e.g., HTTP request, text processing,
AI completion, etc.) that can be connected in a graph to create an agent.
The response includes input/output schemas for each block, as well as
cost information for blocks that consume credits.
"""
content = await _get_cached_blocks()
return Response(
content=content,
media_type="application/json",
)

View File

@@ -1,36 +0,0 @@
"""
Common utilities for V2 External API
"""
from typing import TypeVar
from pydantic import BaseModel, Field
# Constants for pagination
MAX_PAGE_SIZE = 100
DEFAULT_PAGE_SIZE = 20
class PaginationParams(BaseModel):
"""Common pagination parameters."""
page: int = Field(default=1, ge=1, description="Page number (1-indexed)")
page_size: int = Field(
default=DEFAULT_PAGE_SIZE,
ge=1,
le=MAX_PAGE_SIZE,
description=f"Number of items per page (max {MAX_PAGE_SIZE})",
)
T = TypeVar("T")
class PaginatedResponse(BaseModel):
"""Generic paginated response wrapper."""
items: list
total_count: int = Field(description="Total number of items across all pages")
page: int = Field(description="Current page number (1-indexed)")
page_size: int = Field(description="Number of items per page")
total_pages: int = Field(description="Total number of pages")

View File

@@ -1,141 +0,0 @@
"""
V2 External API - Credits Endpoints
Provides access to credit balance and transaction history.
"""
import logging
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Query, Security
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.api.external.middleware import require_permission
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.credit import get_user_credit_model
from .common import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
logger = logging.getLogger(__name__)
credits_router = APIRouter()
# ============================================================================
# Models
# ============================================================================
class CreditBalance(BaseModel):
"""User's credit balance."""
balance: int = Field(description="Current credit balance")
class CreditTransaction(BaseModel):
"""A credit transaction."""
transaction_key: str
amount: int = Field(description="Transaction amount (positive or negative)")
type: str = Field(description="One of: TOP_UP, USAGE, GRANT, REFUND")
transaction_time: datetime
running_balance: Optional[int] = Field(
default=None, description="Balance after this transaction"
)
description: Optional[str] = None
class CreditTransactionsResponse(BaseModel):
"""Response for listing credit transactions."""
transactions: list[CreditTransaction]
total_count: int
page: int
page_size: int
total_pages: int
# ============================================================================
# Endpoints
# ============================================================================
@credits_router.get(
path="",
summary="Get credit balance",
response_model=CreditBalance,
)
async def get_balance(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_CREDITS)
),
) -> CreditBalance:
"""
Get the current credit balance for the authenticated user.
"""
user_credit_model = await get_user_credit_model(auth.user_id)
balance = await user_credit_model.get_credits(auth.user_id)
return CreditBalance(balance=balance)
@credits_router.get(
path="/transactions",
summary="Get transaction history",
response_model=CreditTransactionsResponse,
)
async def get_transactions(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_CREDITS)
),
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = Query(
default=DEFAULT_PAGE_SIZE,
ge=1,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
transaction_type: Optional[str] = Query(
default=None,
description="Filter by transaction type (TOP_UP, USAGE, GRANT, REFUND)",
),
) -> CreditTransactionsResponse:
"""
Get credit transaction history for the authenticated user.
Returns transactions sorted by most recent first.
"""
user_credit_model = await get_user_credit_model(auth.user_id)
history = await user_credit_model.get_transaction_history(
user_id=auth.user_id,
transaction_count_limit=page_size,
transaction_type=transaction_type,
)
transactions = [
CreditTransaction(
transaction_key=t.transaction_key,
amount=t.amount,
type=t.transaction_type.value,
transaction_time=t.transaction_time,
running_balance=t.running_balance,
description=t.description,
)
for t in history.transactions
]
# Note: The current credit module doesn't support true pagination,
# so we're returning what we have
total_count = len(transactions)
total_pages = 1 # Without true pagination support
return CreditTransactionsResponse(
transactions=transactions,
total_count=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)

View File

@@ -1,132 +0,0 @@
"""
V2 External API - Files Endpoints
Provides file upload functionality for agent inputs.
"""
import base64
import logging
from fastapi import APIRouter, File, HTTPException, Query, Security, UploadFile
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.api.external.middleware import require_permission
from backend.data.auth.base import APIAuthorizationInfo
from backend.util.cloud_storage import get_cloud_storage_handler
from backend.util.settings import Settings
from backend.util.virus_scanner import scan_content_safe
logger = logging.getLogger(__name__)
settings = Settings()
files_router = APIRouter()
# ============================================================================
# Models
# ============================================================================
class UploadFileResponse(BaseModel):
"""Response after uploading a file."""
file_uri: str = Field(description="URI to reference the uploaded file in agents")
file_name: str
size: int = Field(description="File size in bytes")
content_type: str
expires_in_hours: int
# ============================================================================
# Endpoints
# ============================================================================
def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
"""Create standardized file size error response."""
return HTTPException(
status_code=400,
detail=f"File size ({size_bytes} bytes) exceeds the maximum allowed size of {max_size_mb}MB",
)
@files_router.post(
path="/upload",
summary="Upload a file",
response_model=UploadFileResponse,
)
async def upload_file(
file: UploadFile = File(...),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.UPLOAD_FILES)
),
provider: str = Query(
default="gcs", description="Storage provider (gcs, s3, azure)"
),
expiration_hours: int = Query(
default=24, ge=1, le=48, description="Hours until file expires (1-48)"
),
) -> UploadFileResponse:
"""
Upload a file to cloud storage for use with agents.
The returned `file_uri` can be used as input to agents that accept file inputs
(e.g., FileStoreBlock, AgentFileInputBlock).
Files are automatically scanned for viruses before storage.
"""
# Check file size limit
max_size_mb = settings.config.upload_file_size_limit_mb
max_size_bytes = max_size_mb * 1024 * 1024
# Try to get file size from headers first
if hasattr(file, "size") and file.size is not None and file.size > max_size_bytes:
raise _create_file_size_error(file.size, max_size_mb)
# Read file content
content = await file.read()
content_size = len(content)
# Double-check file size after reading
if content_size > max_size_bytes:
raise _create_file_size_error(content_size, max_size_mb)
# Extract file info
file_name = file.filename or "uploaded_file"
content_type = file.content_type or "application/octet-stream"
# Virus scan the content
await scan_content_safe(content, filename=file_name)
# Check if cloud storage is configured
cloud_storage = await get_cloud_storage_handler()
if not cloud_storage.config.gcs_bucket_name:
# Fallback to base64 data URI when GCS is not configured
base64_content = base64.b64encode(content).decode("utf-8")
data_uri = f"data:{content_type};base64,{base64_content}"
return UploadFileResponse(
file_uri=data_uri,
file_name=file_name,
size=content_size,
content_type=content_type,
expires_in_hours=expiration_hours,
)
# Store in cloud storage
storage_path = await cloud_storage.store_file(
content=content,
filename=file_name,
provider=provider,
expiration_hours=expiration_hours,
user_id=auth.user_id,
)
return UploadFileResponse(
file_uri=storage_path,
file_name=file_name,
size=content_size,
content_type=content_type,
expires_in_hours=expiration_hours,
)

View File

@@ -1,445 +0,0 @@
"""
V2 External API - Graphs Endpoints
Provides endpoints for managing agent graphs (CRUD operations).
"""
import logging
from fastapi import APIRouter, HTTPException, Query, Security
from prisma.enums import APIKeyPermission
from backend.api.external.middleware import require_permission
from backend.data import graph as graph_db
from backend.data.auth.base import APIAuthorizationInfo
from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate,
on_graph_deactivate,
)
from .common import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
from .models import (
CreateGraphRequest,
DeleteGraphResponse,
GraphDetails,
GraphLink,
GraphMeta,
GraphNode,
GraphSettings,
GraphsListResponse,
SetActiveVersionRequest,
)
logger = logging.getLogger(__name__)
graphs_router = APIRouter()
def _convert_graph_meta(graph: graph_db.GraphMeta) -> GraphMeta:
"""Convert internal GraphMeta to v2 API model."""
return GraphMeta(
id=graph.id,
version=graph.version,
is_active=graph.is_active,
name=graph.name,
description=graph.description,
created_at=graph.created_at,
input_schema=graph.input_schema,
output_schema=graph.output_schema,
)
def _convert_graph_details(graph: graph_db.GraphModel) -> GraphDetails:
"""Convert internal GraphModel to v2 API GraphDetails model."""
return GraphDetails(
id=graph.id,
version=graph.version,
is_active=graph.is_active,
name=graph.name,
description=graph.description,
created_at=graph.created_at,
input_schema=graph.input_schema,
output_schema=graph.output_schema,
nodes=[
GraphNode(
id=node.id,
block_id=node.block_id,
input_default=node.input_default,
metadata=node.metadata,
)
for node in graph.nodes
],
links=[
GraphLink(
id=link.id,
source_id=link.source_id,
sink_id=link.sink_id,
source_name=link.source_name,
sink_name=link.sink_name,
is_static=link.is_static,
)
for link in graph.links
],
credentials_input_schema=graph.credentials_input_schema,
)
@graphs_router.get(
path="",
summary="List user's graphs",
response_model=GraphsListResponse,
)
async def list_graphs(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_GRAPH)
),
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = Query(
default=DEFAULT_PAGE_SIZE,
ge=1,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
) -> GraphsListResponse:
"""
List all graphs owned by the authenticated user.
Returns a paginated list of graph metadata (not full graph details).
"""
graphs, pagination_info = await graph_db.list_graphs_paginated(
user_id=auth.user_id,
page=page,
page_size=page_size,
filter_by="active",
)
return GraphsListResponse(
graphs=[_convert_graph_meta(g) for g in graphs],
total_count=pagination_info.total_items,
page=pagination_info.current_page,
page_size=pagination_info.page_size,
total_pages=pagination_info.total_pages,
)
@graphs_router.post(
path="",
summary="Create a new graph",
response_model=GraphDetails,
)
async def create_graph(
create_graph_request: CreateGraphRequest,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_GRAPH)
),
) -> GraphDetails:
"""
Create a new agent graph.
The graph will be validated and assigned a new ID. It will automatically
be added to the user's library.
"""
# Import here to avoid circular imports
from backend.api.features.library import db as library_db
# Convert v2 API Graph model to internal Graph model
internal_graph = graph_db.Graph(
id=create_graph_request.graph.id or "",
version=create_graph_request.graph.version,
is_active=create_graph_request.graph.is_active,
name=create_graph_request.graph.name,
description=create_graph_request.graph.description,
nodes=[
graph_db.Node(
id=node.id,
block_id=node.block_id,
input_default=node.input_default,
metadata=node.metadata,
)
for node in create_graph_request.graph.nodes
],
links=[
graph_db.Link(
id=link.id,
source_id=link.source_id,
sink_id=link.sink_id,
source_name=link.source_name,
sink_name=link.sink_name,
is_static=link.is_static,
)
for link in create_graph_request.graph.links
],
)
graph = graph_db.make_graph_model(internal_graph, auth.user_id)
graph.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
graph.validate_graph(for_run=False)
await graph_db.create_graph(graph, user_id=auth.user_id)
await library_db.create_library_agent(graph, user_id=auth.user_id)
activated_graph = await on_graph_activate(graph, user_id=auth.user_id)
return _convert_graph_details(activated_graph)
@graphs_router.get(
path="/{graph_id}",
summary="Get graph details",
response_model=GraphDetails,
)
async def get_graph(
graph_id: str,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_GRAPH)
),
version: int | None = Query(
default=None,
description="Specific version to retrieve (default: active version)",
),
) -> GraphDetails:
"""
Get detailed information about a specific graph.
By default returns the active version. Use the `version` query parameter
to retrieve a specific version.
"""
graph = await graph_db.get_graph(
graph_id,
version,
user_id=auth.user_id,
include_subgraphs=True,
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return _convert_graph_details(graph)
@graphs_router.put(
path="/{graph_id}",
summary="Update graph (creates new version)",
response_model=GraphDetails,
)
async def update_graph(
graph_id: str,
graph_request: CreateGraphRequest,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_GRAPH)
),
) -> GraphDetails:
"""
Update a graph by creating a new version.
This does not modify existing versions - it creates a new version with
the provided content. The new version becomes the active version.
"""
# Import here to avoid circular imports
from backend.api.features.library import db as library_db
graph_data = graph_request.graph
if graph_data.id and graph_data.id != graph_id:
raise HTTPException(400, detail="Graph ID does not match ID in URI")
existing_versions = await graph_db.get_graph_all_versions(
graph_id, user_id=auth.user_id
)
if not existing_versions:
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
latest_version_number = max(g.version for g in existing_versions)
# Convert v2 API Graph model to internal Graph model
internal_graph = graph_db.Graph(
id=graph_id,
version=latest_version_number + 1,
is_active=graph_data.is_active,
name=graph_data.name,
description=graph_data.description,
nodes=[
graph_db.Node(
id=node.id,
block_id=node.block_id,
input_default=node.input_default,
metadata=node.metadata,
)
for node in graph_data.nodes
],
links=[
graph_db.Link(
id=link.id,
source_id=link.source_id,
sink_id=link.sink_id,
source_name=link.source_name,
sink_name=link.sink_name,
is_static=link.is_static,
)
for link in graph_data.links
],
)
current_active_version = next((v for v in existing_versions if v.is_active), None)
graph = graph_db.make_graph_model(internal_graph, auth.user_id)
graph.reassign_ids(user_id=auth.user_id, reassign_graph_id=False)
graph.validate_graph(for_run=False)
new_graph_version = await graph_db.create_graph(graph, user_id=auth.user_id)
if new_graph_version.is_active:
await library_db.update_agent_version_in_library(
auth.user_id, new_graph_version.id, new_graph_version.version
)
new_graph_version = await on_graph_activate(
new_graph_version, user_id=auth.user_id
)
await graph_db.set_graph_active_version(
graph_id=graph_id, version=new_graph_version.version, user_id=auth.user_id
)
if current_active_version:
await on_graph_deactivate(current_active_version, user_id=auth.user_id)
new_graph_version_with_subgraphs = await graph_db.get_graph(
graph_id,
new_graph_version.version,
user_id=auth.user_id,
include_subgraphs=True,
)
assert new_graph_version_with_subgraphs
return _convert_graph_details(new_graph_version_with_subgraphs)
@graphs_router.delete(
path="/{graph_id}",
summary="Delete graph permanently",
response_model=DeleteGraphResponse,
)
async def delete_graph(
graph_id: str,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_GRAPH)
),
) -> DeleteGraphResponse:
"""
Permanently delete a graph and all its versions.
This action cannot be undone. All associated executions will remain
but will reference a deleted graph.
"""
if active_version := await graph_db.get_graph(
graph_id=graph_id, version=None, user_id=auth.user_id
):
await on_graph_deactivate(active_version, user_id=auth.user_id)
version_count = await graph_db.delete_graph(graph_id, user_id=auth.user_id)
return DeleteGraphResponse(version_count=version_count)
@graphs_router.get(
path="/{graph_id}/versions",
summary="List all graph versions",
response_model=list[GraphDetails],
)
async def list_graph_versions(
graph_id: str,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_GRAPH)
),
) -> list[GraphDetails]:
"""
Get all versions of a specific graph.
Returns a list of all versions, with the active version marked.
"""
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=auth.user_id)
if not graphs:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return [_convert_graph_details(g) for g in graphs]
@graphs_router.put(
path="/{graph_id}/versions/active",
summary="Set active graph version",
)
async def set_active_version(
graph_id: str,
request_body: SetActiveVersionRequest,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_GRAPH)
),
) -> None:
"""
Set which version of a graph is the active version.
The active version is used when executing the graph without specifying
a version number.
"""
# Import here to avoid circular imports
from backend.api.features.library import db as library_db
new_active_version = request_body.active_graph_version
new_active_graph = await graph_db.get_graph(
graph_id, new_active_version, user_id=auth.user_id
)
if not new_active_graph:
raise HTTPException(404, f"Graph #{graph_id} v{new_active_version} not found")
current_active_graph = await graph_db.get_graph(
graph_id=graph_id,
version=None,
user_id=auth.user_id,
)
await on_graph_activate(new_active_graph, user_id=auth.user_id)
await graph_db.set_graph_active_version(
graph_id=graph_id,
version=new_active_version,
user_id=auth.user_id,
)
await library_db.update_agent_version_in_library(
auth.user_id, new_active_graph.id, new_active_graph.version
)
if current_active_graph and current_active_graph.version != new_active_version:
await on_graph_deactivate(current_active_graph, user_id=auth.user_id)
@graphs_router.patch(
path="/{graph_id}/settings",
summary="Update graph settings",
response_model=GraphSettings,
)
async def update_graph_settings(
graph_id: str,
settings: GraphSettings,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_GRAPH)
),
) -> GraphSettings:
"""
Update settings for a graph.
Currently supports:
- human_in_the_loop_safe_mode: Enable/disable safe mode for human-in-the-loop blocks
"""
# Import here to avoid circular imports
from backend.api.features.library import db as library_db
from backend.data.graph import GraphSettings as InternalGraphSettings
library_agent = await library_db.get_library_agent_by_graph_id(
graph_id=graph_id, user_id=auth.user_id
)
if not library_agent:
raise HTTPException(404, f"Graph #{graph_id} not found in user's library")
# Convert to internal model
internal_settings = InternalGraphSettings(
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode
)
updated_agent = await library_db.update_library_agent_settings(
user_id=auth.user_id,
agent_id=library_agent.id,
settings=internal_settings,
)
return GraphSettings(
human_in_the_loop_safe_mode=updated_agent.settings.human_in_the_loop_safe_mode
)

View File

@@ -1,271 +0,0 @@
"""
V2 External API - Integrations Endpoints
Provides access to user's integration credentials.
"""
import logging
from typing import Optional
from fastapi import APIRouter, HTTPException, Path, Security
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.api.external.middleware import require_permission
from backend.api.features.library import db as library_db
from backend.data import graph as graph_db
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.model import Credentials, OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
logger = logging.getLogger(__name__)
integrations_router = APIRouter()
creds_manager = IntegrationCredentialsManager()
# ============================================================================
# Models
# ============================================================================
class Credential(BaseModel):
"""A user's credential for an integration."""
id: str
provider: str = Field(description="Integration provider name")
title: Optional[str] = Field(
default=None, description="User-assigned title for this credential"
)
scopes: list[str] = Field(default_factory=list, description="Granted scopes")
class CredentialsListResponse(BaseModel):
"""Response for listing credentials."""
credentials: list[Credential]
class CredentialRequirement(BaseModel):
"""A credential requirement for a graph or agent."""
provider: str = Field(description="Required provider name")
required_scopes: list[str] = Field(
default_factory=list, description="Required scopes"
)
matching_credentials: list[Credential] = Field(
default_factory=list,
description="User's credentials that match this requirement",
)
class CredentialRequirementsResponse(BaseModel):
"""Response for listing credential requirements."""
requirements: list[CredentialRequirement]
# ============================================================================
# Conversion Functions
# ============================================================================
def _convert_credential(cred: Credentials) -> Credential:
"""Convert internal credential to v2 API model."""
scopes: list[str] = []
if isinstance(cred, OAuth2Credentials):
scopes = cred.scopes or []
return Credential(
id=cred.id,
provider=cred.provider,
title=cred.title,
scopes=scopes,
)
# ============================================================================
# Endpoints
# ============================================================================
@integrations_router.get(
path="/credentials",
summary="List all credentials",
response_model=CredentialsListResponse,
)
async def list_credentials(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> CredentialsListResponse:
"""
List all integration credentials for the authenticated user.
This returns all OAuth credentials the user has connected, across
all integration providers.
"""
credentials = await creds_manager.store.get_all_creds(auth.user_id)
return CredentialsListResponse(
credentials=[_convert_credential(c) for c in credentials]
)
@integrations_router.get(
path="/credentials/{provider}",
summary="List credentials by provider",
response_model=CredentialsListResponse,
)
async def list_credentials_by_provider(
provider: str = Path(description="Provider name (e.g., 'github', 'google')"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> CredentialsListResponse:
"""
List integration credentials for a specific provider.
"""
all_credentials = await creds_manager.store.get_all_creds(auth.user_id)
# Filter by provider
filtered = [c for c in all_credentials if c.provider.lower() == provider.lower()]
return CredentialsListResponse(
credentials=[_convert_credential(c) for c in filtered]
)
@integrations_router.get(
path="/graphs/{graph_id}/credentials",
summary="List credentials matching graph requirements",
response_model=CredentialRequirementsResponse,
)
async def list_graph_credential_requirements(
graph_id: str = Path(description="Graph ID"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> CredentialRequirementsResponse:
"""
List credential requirements for a graph and matching user credentials.
This helps identify which credentials the user needs to provide
when executing a graph.
"""
# Get the graph
graph = await graph_db.get_graph(
graph_id=graph_id,
version=None, # Active version
user_id=auth.user_id,
include_subgraphs=True,
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found")
# Get the credentials input schema which contains provider requirements
creds_schema = graph.credentials_input_schema
all_credentials = await creds_manager.store.get_all_creds(auth.user_id)
requirements = []
for field_name, field_schema in creds_schema.get("properties", {}).items():
# Extract provider from schema
# The schema structure varies, but typically has provider info
providers = []
if "anyOf" in field_schema:
for option in field_schema["anyOf"]:
if "provider" in option:
providers.append(option["provider"])
elif "provider" in field_schema:
providers.append(field_schema["provider"])
for provider in providers:
# Find matching credentials
matching = [
_convert_credential(c)
for c in all_credentials
if c.provider.lower() == provider.lower()
]
requirements.append(
CredentialRequirement(
provider=provider,
required_scopes=[], # Would need to extract from schema
matching_credentials=matching,
)
)
return CredentialRequirementsResponse(requirements=requirements)
@integrations_router.get(
path="/library/{agent_id}/credentials",
summary="List credentials matching library agent requirements",
response_model=CredentialRequirementsResponse,
)
async def list_library_agent_credential_requirements(
agent_id: str = Path(description="Library agent ID"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> CredentialRequirementsResponse:
"""
List credential requirements for a library agent and matching user credentials.
This helps identify which credentials the user needs to provide
when executing an agent from their library.
"""
# Get the library agent
try:
library_agent = await library_db.get_library_agent(
id=agent_id,
user_id=auth.user_id,
)
except Exception:
raise HTTPException(status_code=404, detail=f"Agent #{agent_id} not found")
# Get the underlying graph
graph = await graph_db.get_graph(
graph_id=library_agent.graph_id,
version=library_agent.graph_version,
user_id=auth.user_id,
include_subgraphs=True,
)
if not graph:
raise HTTPException(
status_code=404,
detail=f"Graph for agent #{agent_id} not found",
)
# Get the credentials input schema
creds_schema = graph.credentials_input_schema
all_credentials = await creds_manager.store.get_all_creds(auth.user_id)
requirements = []
for field_name, field_schema in creds_schema.get("properties", {}).items():
# Extract provider from schema
providers = []
if "anyOf" in field_schema:
for option in field_schema["anyOf"]:
if "provider" in option:
providers.append(option["provider"])
elif "provider" in field_schema:
providers.append(field_schema["provider"])
for provider in providers:
# Find matching credentials
matching = [
_convert_credential(c)
for c in all_credentials
if c.provider.lower() == provider.lower()
]
requirements.append(
CredentialRequirement(
provider=provider,
required_scopes=[],
matching_credentials=matching,
)
)
return CredentialRequirementsResponse(requirements=requirements)

View File

@@ -1,247 +0,0 @@
"""
V2 External API - Library Endpoints
Provides access to the user's agent library and agent execution.
"""
import logging
from fastapi import APIRouter, HTTPException, Path, Query, Security
from prisma.enums import APIKeyPermission
from backend.api.external.middleware import require_permission
from backend.api.features.library import db as library_db
from backend.api.features.library import model as library_model
from backend.data import execution as execution_db
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.credit import get_user_credit_model
from backend.executor import utils as execution_utils
from .common import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
from .models import (
ExecuteAgentRequest,
LibraryAgent,
LibraryAgentsResponse,
Run,
RunsListResponse,
)
logger = logging.getLogger(__name__)
library_router = APIRouter()
# ============================================================================
# Conversion Functions
# ============================================================================
def _convert_library_agent(agent: library_model.LibraryAgent) -> LibraryAgent:
"""Convert internal LibraryAgent to v2 API model."""
return LibraryAgent(
id=agent.id,
graph_id=agent.graph_id,
graph_version=agent.graph_version,
name=agent.name,
description=agent.description,
is_favorite=agent.is_favorite,
can_access_graph=agent.can_access_graph,
is_latest_version=agent.is_latest_version,
image_url=agent.image_url,
creator_name=agent.creator_name,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
created_at=agent.created_at,
updated_at=agent.updated_at,
)
def _convert_execution_to_run(exec: execution_db.GraphExecutionMeta) -> Run:
"""Convert internal execution to v2 API Run model."""
return Run(
id=exec.id,
graph_id=exec.graph_id,
graph_version=exec.graph_version,
status=exec.status.value,
started_at=exec.started_at,
ended_at=exec.ended_at,
inputs=exec.inputs,
cost=exec.stats.cost if exec.stats else 0,
duration=exec.stats.duration if exec.stats else 0,
node_count=exec.stats.node_exec_count if exec.stats else 0,
)
# ============================================================================
# Endpoints
# ============================================================================
@library_router.get(
path="/agents",
summary="List library agents",
response_model=LibraryAgentsResponse,
)
async def list_library_agents(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_LIBRARY)
),
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = Query(
default=DEFAULT_PAGE_SIZE,
ge=1,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
) -> LibraryAgentsResponse:
"""
List agents in the user's library.
The library contains agents the user has created or added from the marketplace.
"""
result = await library_db.list_library_agents(
user_id=auth.user_id,
page=page,
page_size=page_size,
)
return LibraryAgentsResponse(
agents=[_convert_library_agent(a) for a in result.agents],
total_count=result.pagination.total_items,
page=result.pagination.current_page,
page_size=result.pagination.page_size,
total_pages=result.pagination.total_pages,
)
@library_router.get(
path="/agents/favorites",
summary="List favorite agents",
response_model=LibraryAgentsResponse,
)
async def list_favorite_agents(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_LIBRARY)
),
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = Query(
default=DEFAULT_PAGE_SIZE,
ge=1,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
) -> LibraryAgentsResponse:
"""
List favorite agents in the user's library.
"""
result = await library_db.list_favorite_library_agents(
user_id=auth.user_id,
page=page,
page_size=page_size,
)
return LibraryAgentsResponse(
agents=[_convert_library_agent(a) for a in result.agents],
total_count=result.pagination.total_items,
page=result.pagination.current_page,
page_size=result.pagination.page_size,
total_pages=result.pagination.total_pages,
)
@library_router.post(
path="/agents/{agent_id}/runs",
summary="Execute an agent",
response_model=Run,
)
async def execute_agent(
request: ExecuteAgentRequest,
agent_id: str = Path(description="Library agent ID"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.RUN_AGENT)
),
) -> Run:
"""
Execute an agent from the library.
This creates a new run with the provided inputs. The run executes
asynchronously and you can poll the run status using GET /runs/{run_id}.
"""
# Check credit balance
user_credit_model = await get_user_credit_model(auth.user_id)
current_balance = await user_credit_model.get_credits(auth.user_id)
if current_balance <= 0:
raise HTTPException(
status_code=402,
detail="Insufficient balance to execute the agent. Please top up your account.",
)
# Get the library agent to find the graph ID and version
try:
library_agent = await library_db.get_library_agent(
id=agent_id,
user_id=auth.user_id,
)
except Exception:
raise HTTPException(status_code=404, detail=f"Agent #{agent_id} not found")
try:
result = await execution_utils.add_graph_execution(
graph_id=library_agent.graph_id,
user_id=auth.user_id,
inputs=request.inputs,
graph_version=library_agent.graph_version,
graph_credentials_inputs=request.credentials_inputs,
)
return _convert_execution_to_run(result)
except Exception as e:
logger.error(f"Failed to execute agent: {e}")
raise HTTPException(status_code=400, detail=str(e))
@library_router.get(
path="/agents/{agent_id}/runs",
summary="List runs for an agent",
response_model=RunsListResponse,
)
async def list_agent_runs(
agent_id: str = Path(description="Library agent ID"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_LIBRARY)
),
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = Query(
default=DEFAULT_PAGE_SIZE,
ge=1,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
) -> RunsListResponse:
"""
List execution runs for a specific agent.
"""
# Get the library agent to find the graph ID
try:
library_agent = await library_db.get_library_agent(
id=agent_id,
user_id=auth.user_id,
)
except Exception:
raise HTTPException(status_code=404, detail=f"Agent #{agent_id} not found")
result = await execution_db.get_graph_executions_paginated(
graph_id=library_agent.graph_id,
user_id=auth.user_id,
page=page,
page_size=page_size,
)
return RunsListResponse(
runs=[_convert_execution_to_run(e) for e in result.executions],
total_count=result.pagination.total_items,
page=result.pagination.current_page,
page_size=result.pagination.page_size,
total_pages=result.pagination.total_pages,
)

View File

@@ -1,510 +0,0 @@
"""
V2 External API - Marketplace Endpoints
Provides access to the agent marketplace (store).
"""
import logging
import urllib.parse
from datetime import datetime
from typing import Literal, Optional
from fastapi import APIRouter, HTTPException, Path, Query, Security
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.api.external.middleware import require_permission
from backend.api.features.store import cache as store_cache
from backend.api.features.store import db as store_db
from backend.api.features.store import model as store_model
from backend.data.auth.base import APIAuthorizationInfo
from .common import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
logger = logging.getLogger(__name__)
marketplace_router = APIRouter()
# ============================================================================
# Models
# ============================================================================
class MarketplaceAgent(BaseModel):
"""An agent available in the marketplace."""
slug: str
name: str
description: str
sub_heading: str
creator: str
creator_avatar: str
runs: int = Field(default=0, description="Number of times this agent has been run")
rating: float = Field(default=0.0, description="Average rating")
image_url: str = Field(default="")
class MarketplaceAgentDetails(BaseModel):
"""Detailed information about a marketplace agent."""
store_listing_version_id: str
slug: str
name: str
description: str
sub_heading: str
instructions: Optional[str] = None
creator: str
creator_avatar: str
categories: list[str] = Field(default_factory=list)
runs: int = Field(default=0)
rating: float = Field(default=0.0)
image_urls: list[str] = Field(default_factory=list)
video_url: str = Field(default="")
versions: list[str] = Field(default_factory=list, description="Available versions")
agent_graph_versions: list[str] = Field(default_factory=list)
agent_graph_id: str
last_updated: datetime
class MarketplaceAgentsResponse(BaseModel):
"""Response for listing marketplace agents."""
agents: list[MarketplaceAgent]
total_count: int
page: int
page_size: int
total_pages: int
class MarketplaceCreator(BaseModel):
"""A creator on the marketplace."""
name: str
username: str
description: str
avatar_url: str
num_agents: int
agent_rating: float
agent_runs: int
is_featured: bool = False
class MarketplaceCreatorDetails(BaseModel):
"""Detailed information about a marketplace creator."""
name: str
username: str
description: str
avatar_url: str
agent_rating: float
agent_runs: int
top_categories: list[str] = Field(default_factory=list)
links: list[str] = Field(default_factory=list)
class MarketplaceCreatorsResponse(BaseModel):
"""Response for listing marketplace creators."""
creators: list[MarketplaceCreator]
total_count: int
page: int
page_size: int
total_pages: int
class MarketplaceSubmission(BaseModel):
"""A marketplace submission."""
graph_id: str
graph_version: int
name: str
sub_heading: str
slug: str
description: str
instructions: Optional[str] = None
image_urls: list[str] = Field(default_factory=list)
date_submitted: datetime
status: str = Field(description="One of: DRAFT, PENDING, APPROVED, REJECTED")
runs: int = Field(default=0)
rating: float = Field(default=0.0)
store_listing_version_id: Optional[str] = None
version: Optional[int] = None
review_comments: Optional[str] = None
reviewed_at: Optional[datetime] = None
video_url: Optional[str] = None
categories: list[str] = Field(default_factory=list)
class SubmissionsListResponse(BaseModel):
"""Response for listing submissions."""
submissions: list[MarketplaceSubmission]
total_count: int
page: int
page_size: int
total_pages: int
class CreateSubmissionRequest(BaseModel):
"""Request to create a marketplace submission."""
graph_id: str = Field(description="ID of the graph to submit")
graph_version: int = Field(description="Version of the graph to submit")
name: str = Field(description="Display name for the agent")
slug: str = Field(description="URL-friendly identifier")
description: str = Field(description="Full description")
sub_heading: str = Field(description="Short tagline")
image_urls: list[str] = Field(default_factory=list)
video_url: Optional[str] = None
categories: list[str] = Field(default_factory=list)
# ============================================================================
# Conversion Functions
# ============================================================================
def _convert_store_agent(agent: store_model.StoreAgent) -> MarketplaceAgent:
"""Convert internal StoreAgent to v2 API model."""
return MarketplaceAgent(
slug=agent.slug,
name=agent.agent_name,
description=agent.description,
sub_heading=agent.sub_heading,
creator=agent.creator,
creator_avatar=agent.creator_avatar,
runs=agent.runs,
rating=agent.rating,
image_url=agent.agent_image,
)
def _convert_store_agent_details(
agent: store_model.StoreAgentDetails,
) -> MarketplaceAgentDetails:
"""Convert internal StoreAgentDetails to v2 API model."""
return MarketplaceAgentDetails(
store_listing_version_id=agent.store_listing_version_id,
slug=agent.slug,
name=agent.agent_name,
description=agent.description,
sub_heading=agent.sub_heading,
instructions=agent.instructions,
creator=agent.creator,
creator_avatar=agent.creator_avatar,
categories=agent.categories,
runs=agent.runs,
rating=agent.rating,
image_urls=agent.agent_image,
video_url=agent.agent_video,
versions=agent.versions,
agent_graph_versions=agent.agentGraphVersions,
agent_graph_id=agent.agentGraphId,
last_updated=agent.last_updated,
)
def _convert_creator(creator: store_model.Creator) -> MarketplaceCreator:
"""Convert internal Creator to v2 API model."""
return MarketplaceCreator(
name=creator.name,
username=creator.username,
description=creator.description,
avatar_url=creator.avatar_url,
num_agents=creator.num_agents,
agent_rating=creator.agent_rating,
agent_runs=creator.agent_runs,
is_featured=creator.is_featured,
)
def _convert_creator_details(
creator: store_model.CreatorDetails,
) -> MarketplaceCreatorDetails:
"""Convert internal CreatorDetails to v2 API model."""
return MarketplaceCreatorDetails(
name=creator.name,
username=creator.username,
description=creator.description,
avatar_url=creator.avatar_url,
agent_rating=creator.agent_rating,
agent_runs=creator.agent_runs,
top_categories=creator.top_categories,
links=creator.links,
)
def _convert_submission(sub: store_model.StoreSubmission) -> MarketplaceSubmission:
"""Convert internal StoreSubmission to v2 API model."""
return MarketplaceSubmission(
graph_id=sub.agent_id,
graph_version=sub.agent_version,
name=sub.name,
sub_heading=sub.sub_heading,
slug=sub.slug,
description=sub.description,
instructions=sub.instructions,
image_urls=sub.image_urls,
date_submitted=sub.date_submitted,
status=sub.status.value,
runs=sub.runs,
rating=sub.rating,
store_listing_version_id=sub.store_listing_version_id,
version=sub.version,
review_comments=sub.review_comments,
reviewed_at=sub.reviewed_at,
video_url=sub.video_url,
categories=sub.categories,
)
# ============================================================================
# Endpoints - Read (authenticated)
# ============================================================================
@marketplace_router.get(
path="/agents",
summary="List marketplace agents",
response_model=MarketplaceAgentsResponse,
)
async def list_agents(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_STORE)
),
featured: bool = Query(default=False, description="Filter to featured agents only"),
creator: Optional[str] = Query(
default=None, description="Filter by creator username"
),
sorted_by: Optional[Literal["rating", "runs", "name", "updated_at"]] = Query(
default=None, description="Sort field"
),
search_query: Optional[str] = Query(default=None, description="Search query"),
category: Optional[str] = Query(default=None, description="Filter by category"),
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = Query(
default=DEFAULT_PAGE_SIZE,
ge=1,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
) -> MarketplaceAgentsResponse:
"""
List agents available in the marketplace.
Supports filtering by featured status, creator, category, and search query.
Results can be sorted by rating, runs, name, or update time.
"""
result = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return MarketplaceAgentsResponse(
agents=[_convert_store_agent(a) for a in result.agents],
total_count=result.pagination.total_items,
page=result.pagination.current_page,
page_size=result.pagination.page_size,
total_pages=result.pagination.total_pages,
)
@marketplace_router.get(
path="/agents/{username}/{agent_name}",
summary="Get agent details",
response_model=MarketplaceAgentDetails,
)
async def get_agent_details(
username: str = Path(description="Creator username"),
agent_name: str = Path(description="Agent slug/name"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_STORE)
),
) -> MarketplaceAgentDetails:
"""
Get detailed information about a specific marketplace agent.
"""
username = urllib.parse.unquote(username).lower()
agent_name = urllib.parse.unquote(agent_name).lower()
agent = await store_cache._get_cached_agent_details(
username=username, agent_name=agent_name
)
return _convert_store_agent_details(agent)
@marketplace_router.get(
path="/creators",
summary="List marketplace creators",
response_model=MarketplaceCreatorsResponse,
)
async def list_creators(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_STORE)
),
featured: bool = Query(
default=False, description="Filter to featured creators only"
),
search_query: Optional[str] = Query(default=None, description="Search query"),
sorted_by: Optional[Literal["agent_rating", "agent_runs", "num_agents"]] = Query(
default=None, description="Sort field"
),
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = Query(
default=DEFAULT_PAGE_SIZE,
ge=1,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
) -> MarketplaceCreatorsResponse:
"""
List creators on the marketplace.
Supports filtering by featured status and search query.
Results can be sorted by rating, runs, or number of agents.
"""
result = await store_cache._get_cached_store_creators(
featured=featured,
search_query=search_query,
sorted_by=sorted_by,
page=page,
page_size=page_size,
)
return MarketplaceCreatorsResponse(
creators=[_convert_creator(c) for c in result.creators],
total_count=result.pagination.total_items,
page=result.pagination.current_page,
page_size=result.pagination.page_size,
total_pages=result.pagination.total_pages,
)
@marketplace_router.get(
path="/creators/{username}",
summary="Get creator details",
response_model=MarketplaceCreatorDetails,
)
async def get_creator_details(
username: str = Path(description="Creator username"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_STORE)
),
) -> MarketplaceCreatorDetails:
"""
Get detailed information about a specific marketplace creator.
"""
username = urllib.parse.unquote(username).lower()
creator = await store_cache._get_cached_creator_details(username=username)
return _convert_creator_details(creator)
# ============================================================================
# Endpoints - Submissions (CRUD)
# ============================================================================
@marketplace_router.get(
path="/submissions",
summary="List my submissions",
response_model=SubmissionsListResponse,
)
async def list_submissions(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_STORE)
),
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = Query(
default=DEFAULT_PAGE_SIZE,
ge=1,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
) -> SubmissionsListResponse:
"""
List your marketplace submissions.
Returns all submissions you've created, including drafts, pending,
approved, and rejected submissions.
"""
result = await store_db.get_store_submissions(
user_id=auth.user_id,
page=page,
page_size=page_size,
)
return SubmissionsListResponse(
submissions=[_convert_submission(s) for s in result.submissions],
total_count=result.pagination.total_items,
page=result.pagination.current_page,
page_size=result.pagination.page_size,
total_pages=result.pagination.total_pages,
)
@marketplace_router.post(
path="/submissions",
summary="Create a submission",
response_model=MarketplaceSubmission,
)
async def create_submission(
request: CreateSubmissionRequest,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_STORE)
),
) -> MarketplaceSubmission:
"""
Create a new marketplace submission.
This submits an agent for review to be published in the marketplace.
The submission will be in PENDING status until reviewed by the team.
"""
submission = await store_db.create_store_submission(
user_id=auth.user_id,
agent_id=request.graph_id,
agent_version=request.graph_version,
slug=request.slug,
name=request.name,
sub_heading=request.sub_heading,
description=request.description,
image_urls=request.image_urls,
video_url=request.video_url,
categories=request.categories,
)
return _convert_submission(submission)
@marketplace_router.delete(
path="/submissions/{submission_id}",
summary="Delete a submission",
)
async def delete_submission(
submission_id: str = Path(description="Submission ID"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_STORE)
),
) -> None:
"""
Delete a marketplace submission.
Only submissions in DRAFT status can be deleted.
"""
success = await store_db.delete_store_submission(
user_id=auth.user_id,
submission_id=submission_id,
)
if not success:
raise HTTPException(
status_code=404, detail=f"Submission #{submission_id} not found"
)

View File

@@ -1,552 +0,0 @@
"""
V2 External API - Request and Response Models
This module defines all request and response models for the v2 external API.
All models are self-contained and specific to the external API contract.
"""
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
# ============================================================================
# Common/Shared Models
# ============================================================================
class PaginatedResponse(BaseModel):
"""Base class for paginated responses."""
total_count: int = Field(description="Total number of items across all pages")
page: int = Field(description="Current page number (1-indexed)")
page_size: int = Field(description="Number of items per page")
total_pages: int = Field(description="Total number of pages")
# ============================================================================
# Graph Models
# ============================================================================
class GraphLink(BaseModel):
"""A link between two nodes in a graph."""
id: str
source_id: str = Field(description="ID of the source node")
sink_id: str = Field(description="ID of the target node")
source_name: str = Field(description="Output pin name on source node")
sink_name: str = Field(description="Input pin name on target node")
is_static: bool = Field(
default=False, description="Whether this link provides static data"
)
class GraphNode(BaseModel):
"""A node in an agent graph."""
id: str
block_id: str = Field(description="ID of the block type")
input_default: dict[str, Any] = Field(
default_factory=dict, description="Default input values"
)
metadata: dict[str, Any] = Field(
default_factory=dict, description="Node metadata (e.g., position)"
)
class Graph(BaseModel):
"""Graph definition for creating or updating an agent."""
id: Optional[str] = Field(default=None, description="Graph ID (assigned by server)")
version: int = Field(default=1, description="Graph version")
is_active: bool = Field(default=True, description="Whether this version is active")
name: str = Field(description="Graph name")
description: str = Field(default="", description="Graph description")
nodes: list[GraphNode] = Field(default_factory=list, description="List of nodes")
links: list[GraphLink] = Field(
default_factory=list, description="Links between nodes"
)
class GraphMeta(BaseModel):
"""Graph metadata (summary information)."""
id: str
version: int
is_active: bool
name: str
description: str
created_at: datetime
input_schema: dict[str, Any] = Field(description="Input schema for the graph")
output_schema: dict[str, Any] = Field(description="Output schema for the graph")
class GraphDetails(GraphMeta):
"""Full graph details including nodes and links."""
nodes: list[GraphNode]
links: list[GraphLink]
credentials_input_schema: dict[str, Any] = Field(
description="Schema for required credentials"
)
class GraphSettings(BaseModel):
"""Settings for a graph."""
human_in_the_loop_safe_mode: Optional[bool] = Field(
default=None, description="Enable safe mode for human-in-the-loop blocks"
)
class CreateGraphRequest(BaseModel):
"""Request to create a new graph."""
graph: Graph = Field(description="The graph definition")
class SetActiveVersionRequest(BaseModel):
"""Request to set the active graph version."""
active_graph_version: int = Field(description="Version number to set as active")
class GraphsListResponse(PaginatedResponse):
"""Response for listing graphs."""
graphs: list[GraphMeta]
class DeleteGraphResponse(BaseModel):
"""Response for deleting a graph."""
version_count: int = Field(description="Number of versions deleted")
# ============================================================================
# Schedule Models
# ============================================================================
class Schedule(BaseModel):
"""An execution schedule for a graph."""
id: str
name: str
graph_id: str
graph_version: int
cron: str = Field(description="Cron expression for the schedule")
input_data: dict[str, Any] = Field(
default_factory=dict, description="Input data for scheduled executions"
)
is_enabled: bool = Field(default=True, description="Whether schedule is enabled")
next_run_time: Optional[datetime] = Field(
default=None, description="Next scheduled run time"
)
class CreateScheduleRequest(BaseModel):
"""Request to create a schedule."""
name: str = Field(description="Display name for the schedule")
cron: str = Field(description="Cron expression (e.g., '0 9 * * *' for 9am daily)")
input_data: dict[str, Any] = Field(
default_factory=dict, description="Input data for scheduled executions"
)
credentials_inputs: dict[str, Any] = Field(
default_factory=dict, description="Credentials for the schedule"
)
graph_version: Optional[int] = Field(
default=None, description="Graph version (default: active version)"
)
timezone: Optional[str] = Field(
default=None,
description="Timezone for schedule (e.g., 'America/New_York')",
)
class SchedulesListResponse(PaginatedResponse):
"""Response for listing schedules."""
schedules: list[Schedule]
# ============================================================================
# Block Models
# ============================================================================
class BlockCost(BaseModel):
"""Cost information for a block."""
cost_type: str = Field(description="Type of cost (e.g., 'per_call', 'per_token')")
cost_filter: dict[str, Any] = Field(
default_factory=dict, description="Conditions for this cost"
)
cost_amount: int = Field(description="Cost amount in credits")
class Block(BaseModel):
"""A building block that can be used in graphs."""
id: str
name: str
description: str
categories: list[str] = Field(default_factory=list)
input_schema: dict[str, Any]
output_schema: dict[str, Any]
costs: list[BlockCost] = Field(default_factory=list)
class BlocksListResponse(BaseModel):
"""Response for listing blocks."""
blocks: list[Block]
# ============================================================================
# Marketplace Models
# ============================================================================
class MarketplaceAgent(BaseModel):
"""An agent available in the marketplace."""
slug: str
agent_name: str
agent_image: str
creator: str
creator_avatar: str
sub_heading: str
description: str
runs: int = Field(default=0, description="Number of times this agent has been run")
rating: float = Field(default=0.0, description="Average rating")
class MarketplaceAgentDetails(BaseModel):
"""Detailed information about a marketplace agent."""
store_listing_version_id: str
slug: str
agent_name: str
agent_video: str
agent_output_demo: str
agent_image: list[str]
creator: str
creator_avatar: str
sub_heading: str
description: str
instructions: Optional[str] = None
categories: list[str]
runs: int
rating: float
versions: list[str]
agent_graph_versions: list[str]
agent_graph_id: str
last_updated: datetime
recommended_schedule_cron: Optional[str] = None
class MarketplaceCreator(BaseModel):
"""A creator on the marketplace."""
name: str
username: str
description: str
avatar_url: str
num_agents: int
agent_rating: float
agent_runs: int
is_featured: bool = False
class MarketplaceAgentsResponse(PaginatedResponse):
"""Response for listing marketplace agents."""
agents: list[MarketplaceAgent]
class MarketplaceCreatorsResponse(PaginatedResponse):
"""Response for listing marketplace creators."""
creators: list[MarketplaceCreator]
# Submission models
class MarketplaceSubmission(BaseModel):
"""A marketplace submission."""
agent_id: str
agent_version: int
name: str
sub_heading: str
slug: str
description: str
instructions: Optional[str] = None
image_urls: list[str] = Field(default_factory=list)
date_submitted: datetime
status: str = Field(description="One of: DRAFT, PENDING, APPROVED, REJECTED")
runs: int
rating: float
store_listing_version_id: Optional[str] = None
version: Optional[int] = None
# Review fields
review_comments: Optional[str] = None
reviewed_at: Optional[datetime] = None
# Additional optional fields
video_url: Optional[str] = None
categories: list[str] = Field(default_factory=list)
class CreateSubmissionRequest(BaseModel):
"""Request to create a marketplace submission."""
agent_id: str = Field(description="ID of the graph to submit")
agent_version: int = Field(description="Version of the graph to submit")
name: str = Field(description="Display name for the agent")
slug: str = Field(description="URL-friendly identifier")
description: str = Field(description="Full description")
sub_heading: str = Field(description="Short tagline")
image_urls: list[str] = Field(default_factory=list)
video_url: Optional[str] = None
categories: list[str] = Field(default_factory=list)
class UpdateSubmissionRequest(BaseModel):
"""Request to update a marketplace submission."""
name: Optional[str] = None
description: Optional[str] = None
sub_heading: Optional[str] = None
image_urls: Optional[list[str]] = None
video_url: Optional[str] = None
categories: Optional[list[str]] = None
class SubmissionsListResponse(PaginatedResponse):
"""Response for listing submissions."""
submissions: list[MarketplaceSubmission]
# ============================================================================
# Library Models
# ============================================================================
class LibraryAgent(BaseModel):
"""An agent in the user's library."""
id: str
graph_id: str
graph_version: int
name: str
description: str
is_favorite: bool = False
can_access_graph: bool = False
is_latest_version: bool = False
image_url: Optional[str] = None
creator_name: str
input_schema: dict[str, Any] = Field(description="Input schema for the agent")
output_schema: dict[str, Any] = Field(description="Output schema for the agent")
created_at: datetime
updated_at: datetime
class LibraryAgentsResponse(PaginatedResponse):
"""Response for listing library agents."""
agents: list[LibraryAgent]
class ExecuteAgentRequest(BaseModel):
"""Request to execute an agent."""
inputs: dict[str, Any] = Field(
default_factory=dict, description="Input values for the agent"
)
credentials_inputs: dict[str, Any] = Field(
default_factory=dict, description="Credentials for the agent"
)
# ============================================================================
# Run Models
# ============================================================================
class Run(BaseModel):
"""An execution run."""
id: str
graph_id: str
graph_version: int
status: str = Field(
description="One of: INCOMPLETE, QUEUED, RUNNING, COMPLETED, TERMINATED, FAILED, REVIEW"
)
started_at: datetime
ended_at: Optional[datetime] = None
inputs: Optional[dict[str, Any]] = None
cost: int = Field(default=0, description="Cost in credits")
duration: float = Field(default=0, description="Duration in seconds")
node_count: int = Field(default=0, description="Number of nodes executed")
class RunDetails(Run):
"""Detailed information about a run including node executions."""
outputs: Optional[dict[str, list[Any]]] = None
node_executions: list[dict[str, Any]] = Field(
default_factory=list, description="Individual node execution results"
)
class RunsListResponse(PaginatedResponse):
"""Response for listing runs."""
runs: list[Run]
# ============================================================================
# Run Review Models (Human-in-the-loop)
# ============================================================================
class PendingReview(BaseModel):
"""A pending human-in-the-loop review."""
id: str # node_exec_id
run_id: str
graph_id: str
graph_version: int
payload: Any = Field(description="Data to be reviewed")
instructions: Optional[str] = Field(
default=None, description="Instructions for the reviewer"
)
editable: bool = Field(
default=True, description="Whether the reviewer can edit the data"
)
status: str = Field(description="One of: WAITING, APPROVED, REJECTED")
created_at: datetime
class PendingReviewsResponse(PaginatedResponse):
"""Response for listing pending reviews."""
reviews: list[PendingReview]
class ReviewDecision(BaseModel):
"""Decision for a single review item."""
node_exec_id: str = Field(description="Node execution ID (review ID)")
approved: bool = Field(description="Whether to approve the data")
edited_payload: Optional[Any] = Field(
default=None, description="Modified payload data (if editing)"
)
message: Optional[str] = Field(
default=None, description="Optional message from reviewer", max_length=2000
)
class SubmitReviewsRequest(BaseModel):
"""Request to submit review responses for all pending reviews of an execution."""
reviews: list[ReviewDecision] = Field(
description="All review decisions for the execution"
)
class SubmitReviewsResponse(BaseModel):
"""Response after submitting reviews."""
run_id: str
approved_count: int = Field(description="Number of reviews approved")
rejected_count: int = Field(description="Number of reviews rejected")
# ============================================================================
# Credit Models
# ============================================================================
class CreditBalance(BaseModel):
"""User's credit balance."""
balance: int = Field(description="Current credit balance")
class CreditTransaction(BaseModel):
"""A credit transaction."""
transaction_key: str
amount: int
transaction_type: str = Field(description="Transaction type")
transaction_time: datetime
running_balance: int
description: Optional[str] = None
class CreditTransactionsResponse(PaginatedResponse):
"""Response for listing credit transactions."""
transactions: list[CreditTransaction]
# ============================================================================
# Integration Models
# ============================================================================
class Credential(BaseModel):
"""A user's credential for an integration."""
id: str
provider: str = Field(description="Integration provider name")
title: Optional[str] = Field(
default=None, description="User-assigned title for this credential"
)
scopes: list[str] = Field(default_factory=list, description="Granted scopes")
class CredentialsListResponse(BaseModel):
"""Response for listing credentials."""
credentials: list[Credential]
class CredentialRequirement(BaseModel):
"""A credential requirement for a graph or agent."""
provider: str = Field(description="Required provider name")
required_scopes: list[str] = Field(
default_factory=list, description="Required scopes"
)
matching_credentials: list[Credential] = Field(
default_factory=list,
description="User's credentials that match this requirement",
)
class CredentialRequirementsResponse(BaseModel):
"""Response for listing credential requirements."""
requirements: list[CredentialRequirement]
# ============================================================================
# File Models
# ============================================================================
class UploadFileResponse(BaseModel):
"""Response after uploading a file."""
file_uri: str = Field(description="URI to reference the uploaded file")
file_name: str
size: int = Field(description="File size in bytes")
content_type: str
expires_in_hours: int

View File

@@ -1,35 +0,0 @@
"""
V2 External API Routes
This module defines the main v2 router that aggregates all v2 API endpoints.
"""
from fastapi import APIRouter
from .blocks import blocks_router
from .credits import credits_router
from .files import files_router
from .graphs import graphs_router
from .integrations import integrations_router
from .library import library_router
from .marketplace import marketplace_router
from .runs import runs_router
from .schedules import graph_schedules_router, schedules_router
v2_router = APIRouter()
# Include all sub-routers
v2_router.include_router(graphs_router, prefix="/graphs", tags=["graphs"])
v2_router.include_router(graph_schedules_router, prefix="/graphs", tags=["schedules"])
v2_router.include_router(schedules_router, prefix="/schedules", tags=["schedules"])
v2_router.include_router(blocks_router, prefix="/blocks", tags=["blocks"])
v2_router.include_router(
marketplace_router, prefix="/marketplace", tags=["marketplace"]
)
v2_router.include_router(library_router, prefix="/library", tags=["library"])
v2_router.include_router(runs_router, prefix="/runs", tags=["runs"])
v2_router.include_router(credits_router, prefix="/credits", tags=["credits"])
v2_router.include_router(
integrations_router, prefix="/integrations", tags=["integrations"]
)
v2_router.include_router(files_router, prefix="/files", tags=["files"])

View File

@@ -1,451 +0,0 @@
"""
V2 External API - Runs Endpoints
Provides access to execution runs and human-in-the-loop reviews.
"""
import logging
from datetime import datetime
from typing import Any, Optional
from fastapi import APIRouter, HTTPException, Path, Query, Security
from prisma.enums import APIKeyPermission, ReviewStatus
from pydantic import BaseModel, Field
from backend.api.external.middleware import require_permission
from backend.api.features.executions.review.model import (
PendingHumanReviewModel,
SafeJsonData,
)
from backend.data import execution as execution_db
from backend.data import human_review as review_db
from backend.data.auth.base import APIAuthorizationInfo
from backend.executor import utils as execution_utils
from .common import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
logger = logging.getLogger(__name__)
runs_router = APIRouter()
# ============================================================================
# Models
# ============================================================================
class Run(BaseModel):
"""An execution run."""
id: str
graph_id: str
graph_version: int
status: str = Field(
description="One of: INCOMPLETE, QUEUED, RUNNING, COMPLETED, TERMINATED, FAILED, REVIEW"
)
started_at: datetime
ended_at: Optional[datetime] = None
inputs: Optional[dict[str, Any]] = None
cost: int = Field(default=0, description="Cost in credits")
duration: float = Field(default=0, description="Duration in seconds")
node_count: int = Field(default=0, description="Number of nodes executed")
class RunDetails(Run):
"""Detailed information about a run including outputs and node executions."""
outputs: Optional[dict[str, list[Any]]] = None
node_executions: list[dict[str, Any]] = Field(
default_factory=list, description="Individual node execution results"
)
class RunsListResponse(BaseModel):
"""Response for listing runs."""
runs: list[Run]
total_count: int
page: int
page_size: int
total_pages: int
class PendingReview(BaseModel):
"""A pending human-in-the-loop review."""
id: str # node_exec_id
run_id: str
graph_id: str
graph_version: int
payload: SafeJsonData = Field(description="Data to be reviewed")
instructions: Optional[str] = Field(
default=None, description="Instructions for the reviewer"
)
editable: bool = Field(
default=True, description="Whether the reviewer can edit the data"
)
status: str = Field(description="One of: WAITING, APPROVED, REJECTED")
created_at: datetime
class PendingReviewsResponse(BaseModel):
"""Response for listing pending reviews."""
reviews: list[PendingReview]
total_count: int
page: int
page_size: int
total_pages: int
class ReviewDecision(BaseModel):
"""Decision for a single review item."""
node_exec_id: str = Field(description="Node execution ID (review ID)")
approved: bool = Field(description="Whether to approve the data")
edited_payload: Optional[SafeJsonData] = Field(
default=None, description="Modified payload data (if editing)"
)
message: Optional[str] = Field(
default=None, description="Optional message from reviewer", max_length=2000
)
class SubmitReviewsRequest(BaseModel):
"""Request to submit review responses for all pending reviews of an execution."""
reviews: list[ReviewDecision] = Field(
description="All review decisions for the execution"
)
class SubmitReviewsResponse(BaseModel):
"""Response after submitting reviews."""
run_id: str
approved_count: int = Field(description="Number of reviews approved")
rejected_count: int = Field(description="Number of reviews rejected")
# ============================================================================
# Conversion Functions
# ============================================================================
def _convert_execution_to_run(exec: execution_db.GraphExecutionMeta) -> Run:
"""Convert internal execution to v2 API Run model."""
return Run(
id=exec.id,
graph_id=exec.graph_id,
graph_version=exec.graph_version,
status=exec.status.value,
started_at=exec.started_at,
ended_at=exec.ended_at,
inputs=exec.inputs,
cost=exec.stats.cost if exec.stats else 0,
duration=exec.stats.duration if exec.stats else 0,
node_count=exec.stats.node_exec_count if exec.stats else 0,
)
def _convert_execution_to_run_details(
exec: execution_db.GraphExecutionWithNodes,
) -> RunDetails:
"""Convert internal execution with nodes to v2 API RunDetails model."""
return RunDetails(
id=exec.id,
graph_id=exec.graph_id,
graph_version=exec.graph_version,
status=exec.status.value,
started_at=exec.started_at,
ended_at=exec.ended_at,
inputs=exec.inputs,
outputs=exec.outputs,
cost=exec.stats.cost if exec.stats else 0,
duration=exec.stats.duration if exec.stats else 0,
node_count=exec.stats.node_exec_count if exec.stats else 0,
node_executions=[
{
"node_id": node.node_id,
"status": node.status.value,
"input_data": node.input_data,
"output_data": node.output_data,
"started_at": node.start_time,
"ended_at": node.end_time,
}
for node in exec.node_executions
],
)
def _convert_pending_review(review: PendingHumanReviewModel) -> PendingReview:
"""Convert internal PendingHumanReviewModel to v2 API PendingReview model."""
return PendingReview(
id=review.node_exec_id,
run_id=review.graph_exec_id,
graph_id=review.graph_id,
graph_version=review.graph_version,
payload=review.payload,
instructions=review.instructions,
editable=review.editable,
status=review.status.value,
created_at=review.created_at,
)
# ============================================================================
# Endpoints - Runs
# ============================================================================
@runs_router.get(
path="",
summary="List all runs",
response_model=RunsListResponse,
)
async def list_runs(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_RUN)
),
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = Query(
default=DEFAULT_PAGE_SIZE,
ge=1,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
) -> RunsListResponse:
"""
List all execution runs for the authenticated user.
Returns runs across all agents, sorted by most recent first.
"""
result = await execution_db.get_graph_executions_paginated(
user_id=auth.user_id,
page=page,
page_size=page_size,
)
return RunsListResponse(
runs=[_convert_execution_to_run(e) for e in result.executions],
total_count=result.pagination.total_items,
page=result.pagination.current_page,
page_size=result.pagination.page_size,
total_pages=result.pagination.total_pages,
)
@runs_router.get(
path="/{run_id}",
summary="Get run details",
response_model=RunDetails,
)
async def get_run(
run_id: str = Path(description="Run ID"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_RUN)
),
) -> RunDetails:
"""
Get detailed information about a specific run.
Includes outputs and individual node execution results.
"""
result = await execution_db.get_graph_execution(
user_id=auth.user_id,
execution_id=run_id,
include_node_executions=True,
)
if not result:
raise HTTPException(status_code=404, detail=f"Run #{run_id} not found")
return _convert_execution_to_run_details(result)
@runs_router.post(
path="/{run_id}/stop",
summary="Stop a run",
)
async def stop_run(
run_id: str = Path(description="Run ID"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_RUN)
),
) -> Run:
"""
Stop a running execution.
Only runs in QUEUED or RUNNING status can be stopped.
"""
# Verify the run exists and belongs to the user
exec = await execution_db.get_graph_execution(
user_id=auth.user_id,
execution_id=run_id,
)
if not exec:
raise HTTPException(status_code=404, detail=f"Run #{run_id} not found")
# Stop the execution
await execution_utils.stop_graph_execution(
graph_exec_id=run_id,
user_id=auth.user_id,
)
# Fetch updated execution
updated_exec = await execution_db.get_graph_execution(
user_id=auth.user_id,
execution_id=run_id,
)
if not updated_exec:
raise HTTPException(status_code=404, detail=f"Run #{run_id} not found")
return _convert_execution_to_run(updated_exec)
@runs_router.delete(
path="/{run_id}",
summary="Delete a run",
)
async def delete_run(
run_id: str = Path(description="Run ID"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_RUN)
),
) -> None:
"""
Delete an execution run.
This marks the run as deleted. The data may still be retained for
some time for recovery purposes.
"""
await execution_db.delete_graph_execution(
graph_exec_id=run_id,
user_id=auth.user_id,
)
# ============================================================================
# Endpoints - Reviews (Human-in-the-loop)
# ============================================================================
@runs_router.get(
path="/reviews",
summary="List all pending reviews",
response_model=PendingReviewsResponse,
)
async def list_pending_reviews(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_RUN_REVIEW)
),
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = Query(
default=DEFAULT_PAGE_SIZE,
ge=1,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
) -> PendingReviewsResponse:
"""
List all pending human-in-the-loop reviews.
These are blocks that require human approval or input before the
agent can continue execution.
"""
reviews = await review_db.get_pending_reviews_for_user(
user_id=auth.user_id,
page=page,
page_size=page_size,
)
# Note: get_pending_reviews_for_user returns list directly, not a paginated result
# We compute pagination info based on results
total_count = len(reviews)
total_pages = max(1, (total_count + page_size - 1) // page_size)
return PendingReviewsResponse(
reviews=[_convert_pending_review(r) for r in reviews],
total_count=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@runs_router.get(
path="/{run_id}/reviews",
summary="List reviews for a run",
response_model=list[PendingReview],
)
async def list_run_reviews(
run_id: str = Path(description="Run ID"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_RUN_REVIEW)
),
) -> list[PendingReview]:
"""
List all human-in-the-loop reviews for a specific run.
"""
reviews = await review_db.get_pending_reviews_for_execution(
graph_exec_id=run_id,
user_id=auth.user_id,
)
return [_convert_pending_review(r) for r in reviews]
@runs_router.post(
path="/{run_id}/reviews",
summary="Submit review responses for a run",
response_model=SubmitReviewsResponse,
)
async def submit_reviews(
request: SubmitReviewsRequest,
run_id: str = Path(description="Run ID"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_RUN_REVIEW)
),
) -> SubmitReviewsResponse:
"""
Submit responses to all pending human-in-the-loop reviews for a run.
All pending reviews for the execution must be included. Approving
a review will allow the agent to continue; rejecting will terminate
execution at that point.
"""
# Build review decisions dict for process_all_reviews_for_execution
review_decisions: dict[
str, tuple[ReviewStatus, SafeJsonData | None, str | None]
] = {}
for decision in request.reviews:
status = ReviewStatus.APPROVED if decision.approved else ReviewStatus.REJECTED
review_decisions[decision.node_exec_id] = (
status,
decision.edited_payload,
decision.message,
)
try:
results = await review_db.process_all_reviews_for_execution(
user_id=auth.user_id,
review_decisions=review_decisions,
)
approved_count = sum(
1 for r in results.values() if r.status == ReviewStatus.APPROVED
)
rejected_count = sum(
1 for r in results.values() if r.status == ReviewStatus.REJECTED
)
return SubmitReviewsResponse(
run_id=run_id,
approved_count=approved_count,
rejected_count=rejected_count,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -1,250 +0,0 @@
"""
V2 External API - Schedules Endpoints
Provides endpoints for managing execution schedules.
"""
import logging
from datetime import datetime
from typing import Any, Optional
from fastapi import APIRouter, HTTPException, Path, Query, Security
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.api.external.middleware import require_permission
from backend.data import graph as graph_db
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.user import get_user_by_id
from backend.executor import scheduler
from backend.util.clients import get_scheduler_client
from backend.util.timezone_utils import get_user_timezone_or_utc
from .common import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
logger = logging.getLogger(__name__)
schedules_router = APIRouter()
# ============================================================================
# Request/Response Models
# ============================================================================
class Schedule(BaseModel):
"""An execution schedule for a graph."""
id: str
name: str
graph_id: str
graph_version: int
cron: str = Field(description="Cron expression for the schedule")
input_data: dict[str, Any] = Field(
default_factory=dict, description="Input data for scheduled executions"
)
next_run_time: Optional[datetime] = Field(
default=None, description="Next scheduled run time"
)
is_enabled: bool = Field(default=True, description="Whether schedule is enabled")
class SchedulesListResponse(BaseModel):
"""Response for listing schedules."""
schedules: list[Schedule]
total_count: int
page: int
page_size: int
total_pages: int
class CreateScheduleRequest(BaseModel):
"""Request to create a schedule."""
name: str = Field(description="Display name for the schedule")
cron: str = Field(description="Cron expression (e.g., '0 9 * * *' for 9am daily)")
input_data: dict[str, Any] = Field(
default_factory=dict, description="Input data for scheduled executions"
)
credentials_inputs: dict[str, Any] = Field(
default_factory=dict, description="Credentials for the schedule"
)
graph_version: Optional[int] = Field(
default=None, description="Graph version (default: active version)"
)
timezone: Optional[str] = Field(
default=None,
description=(
"Timezone for schedule (e.g., 'America/New_York'). "
"Defaults to user's timezone."
),
)
def _convert_schedule(job: scheduler.GraphExecutionJobInfo) -> Schedule:
"""Convert internal schedule job info to v2 API model."""
# Parse the ISO format string to datetime
next_run = datetime.fromisoformat(job.next_run_time) if job.next_run_time else None
return Schedule(
id=job.id,
name=job.name or "",
graph_id=job.graph_id,
graph_version=job.graph_version,
cron=job.cron,
input_data=job.input_data,
next_run_time=next_run,
is_enabled=True, # All returned schedules are enabled
)
# ============================================================================
# Endpoints
# ============================================================================
@schedules_router.get(
path="",
summary="List all user schedules",
response_model=SchedulesListResponse,
)
async def list_all_schedules(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_SCHEDULE)
),
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = Query(
default=DEFAULT_PAGE_SIZE,
ge=1,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
) -> SchedulesListResponse:
"""
List all schedules for the authenticated user across all graphs.
"""
schedules = await get_scheduler_client().get_execution_schedules(
user_id=auth.user_id
)
converted = [_convert_schedule(s) for s in schedules]
# Manual pagination (scheduler doesn't support pagination natively)
total_count = len(converted)
total_pages = (total_count + page_size - 1) // page_size if total_count > 0 else 1
start = (page - 1) * page_size
end = start + page_size
paginated = converted[start:end]
return SchedulesListResponse(
schedules=paginated,
total_count=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@schedules_router.delete(
path="/{schedule_id}",
summary="Delete a schedule",
)
async def delete_schedule(
schedule_id: str = Path(description="Schedule ID to delete"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_SCHEDULE)
),
) -> None:
"""
Delete an execution schedule.
"""
try:
await get_scheduler_client().delete_schedule(
schedule_id=schedule_id,
user_id=auth.user_id,
)
except Exception as e:
if "not found" in str(e).lower():
raise HTTPException(
status_code=404, detail=f"Schedule #{schedule_id} not found"
)
raise
# ============================================================================
# Graph-specific Schedule Endpoints (nested under /graphs)
# These are included in the graphs router via include_router
# ============================================================================
graph_schedules_router = APIRouter()
@graph_schedules_router.get(
path="/{graph_id}/schedules",
summary="List schedules for a graph",
response_model=list[Schedule],
)
async def list_graph_schedules(
graph_id: str = Path(description="Graph ID"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_SCHEDULE)
),
) -> list[Schedule]:
"""
List all schedules for a specific graph.
"""
schedules = await get_scheduler_client().get_execution_schedules(
user_id=auth.user_id,
graph_id=graph_id,
)
return [_convert_schedule(s) for s in schedules]
@graph_schedules_router.post(
path="/{graph_id}/schedules",
summary="Create a schedule for a graph",
response_model=Schedule,
)
async def create_graph_schedule(
request: CreateScheduleRequest,
graph_id: str = Path(description="Graph ID"),
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_SCHEDULE)
),
) -> Schedule:
"""
Create a new execution schedule for a graph.
The schedule will execute the graph at times matching the cron expression,
using the provided input data.
"""
graph = await graph_db.get_graph(
graph_id=graph_id,
version=request.graph_version,
user_id=auth.user_id,
)
if not graph:
raise HTTPException(
status_code=404,
detail=f"Graph #{graph_id} v{request.graph_version} not found.",
)
# Determine timezone
if request.timezone:
user_timezone = request.timezone
else:
user = await get_user_by_id(auth.user_id)
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
result = await get_scheduler_client().add_execution_schedule(
user_id=auth.user_id,
graph_id=graph_id,
graph_version=graph.version,
name=request.name,
cron=request.cron,
input_data=request.input_data,
input_credentials=request.credentials_inputs,
user_timezone=user_timezone,
)
return _convert_schedule(result)

View File

@@ -1,6 +1,7 @@
"""Configuration management for chat system."""
import os
from pathlib import Path
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
@@ -11,11 +12,7 @@ class ChatConfig(BaseSettings):
# OpenAI API Configuration
model: str = Field(
default="anthropic/claude-opus-4.5", description="Default model to use"
)
title_model: str = Field(
default="openai/gpt-4o-mini",
description="Model to use for generating session titles (should be fast/cheap)",
default="qwen/qwen3-235b-a22b-2507", description="Default model to use"
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
@@ -26,6 +23,12 @@ class ChatConfig(BaseSettings):
# Session TTL Configuration - 12 hours
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# System Prompt Configuration
system_prompt_path: str = Field(
default="prompts/chat_system.md",
description="Path to system prompt file relative to chat module",
)
# Streaming Configuration
max_context_messages: int = Field(
default=50, ge=1, le=200, description="Maximum context messages"
@@ -38,13 +41,6 @@ class ChatConfig(BaseSettings):
default=3, description="Maximum number of agent schedules"
)
# Langfuse Prompt Management Configuration
# Note: Langfuse credentials are in Settings().secrets (settings.py)
langfuse_prompt_name: str = Field(
default="CoPilot Prompt",
description="Name of the prompt in Langfuse to fetch",
)
@field_validator("api_key", mode="before")
@classmethod
def get_api_key(cls, v):
@@ -76,11 +72,43 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",
"onboarding": "prompts/onboarding_system.md",
}
def get_system_prompt(self, **template_vars) -> str:
"""Load and render the system prompt from file.
Args:
**template_vars: Variables to substitute in the template
Returns:
Rendered system prompt string
"""
# Get the path relative to this module
module_dir = Path(__file__).parent
prompt_path = module_dir / self.system_prompt_path
# Check for .j2 extension first (Jinja2 template)
j2_path = Path(str(prompt_path) + ".j2")
if j2_path.exists():
try:
from jinja2 import Template
template = Template(j2_path.read_text())
return template.render(**template_vars)
except ImportError:
# Jinja2 not installed, fall back to reading as plain text
return j2_path.read_text()
# Check for markdown file
if prompt_path.exists():
content = prompt_path.read_text()
# Simple variable substitution if Jinja2 is not available
for key, value in template_vars.items():
placeholder = f"{{{key}}}"
content = content.replace(placeholder, str(value))
return content
raise FileNotFoundError(f"System prompt file not found: {prompt_path}")
class Config:
"""Pydantic config."""

View File

@@ -1,249 +0,0 @@
"""Database operations for chat sessions."""
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any, cast
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
ChatMessageCreateInput,
ChatSessionCreateInput,
ChatSessionUpdateInput,
ChatSessionWhereInput,
)
from backend.data.db import transaction
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
"""Get a chat session by ID from the database."""
session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id},
include={"Messages": True},
)
if session and session.Messages:
# Sort messages by sequence in Python - Prisma Python client doesn't support
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
session.Messages.sort(key=lambda m: m.sequence)
return session
async def create_chat_session(
session_id: str,
user_id: str,
) -> PrismaChatSession:
"""Create a new chat session in the database."""
data = ChatSessionCreateInput(
id=session_id,
userId=user_id,
credentials=SafeJson({}),
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
)
return await PrismaChatSession.prisma().create(
data=data,
include={"Messages": True},
)
async def update_chat_session(
session_id: str,
credentials: dict[str, Any] | None = None,
successful_agent_runs: dict[str, Any] | None = None,
successful_agent_schedules: dict[str, Any] | None = None,
total_prompt_tokens: int | None = None,
total_completion_tokens: int | None = None,
title: str | None = None,
) -> PrismaChatSession | None:
"""Update a chat session's metadata."""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
if credentials is not None:
data["credentials"] = SafeJson(credentials)
if successful_agent_runs is not None:
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
if successful_agent_schedules is not None:
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
if total_prompt_tokens is not None:
data["totalPromptTokens"] = total_prompt_tokens
if total_completion_tokens is not None:
data["totalCompletionTokens"] = total_completion_tokens
if title is not None:
data["title"] = title
session = await PrismaChatSession.prisma().update(
where={"id": session_id},
data=data,
include={"Messages": True},
)
if session and session.Messages:
# Sort in Python - Prisma Python doesn't support order_by in include clauses
session.Messages.sort(key=lambda m: m.sequence)
return session
async def add_chat_message(
session_id: str,
role: str,
sequence: int,
content: str | None = None,
name: str | None = None,
tool_call_id: str | None = None,
refusal: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
function_call: dict[str, Any] | None = None,
) -> PrismaChatMessage:
"""Add a message to a chat session."""
# Build input dict dynamically rather than using ChatMessageCreateInput directly
# because Prisma's TypedDict validation rejects optional fields set to None.
# We only include fields that have values, then cast at the end.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": role,
"sequence": sequence,
}
# Add optional string fields
if content is not None:
data["content"] = content
if name is not None:
data["name"] = name
if tool_call_id is not None:
data["toolCallId"] = tool_call_id
if refusal is not None:
data["refusal"] = refusal
# Add optional JSON fields only when they have values
if tool_calls is not None:
data["toolCalls"] = SafeJson(tool_calls)
if function_call is not None:
data["functionCall"] = SafeJson(function_call)
# Run message create and session timestamp update in parallel for lower latency
_, message = await asyncio.gather(
PrismaChatSession.prisma().update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
),
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
)
return message
async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
start_sequence: int,
) -> list[PrismaChatMessage]:
"""Add multiple messages to a chat session in a batch.
Uses a transaction for atomicity - if any message creation fails,
the entire batch is rolled back.
"""
if not messages:
return []
created_messages = []
async with transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
return created_messages
async def get_user_chat_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> list[PrismaChatSession]:
"""Get chat sessions for a user, ordered by most recent."""
return await PrismaChatSession.prisma().find_many(
where={"userId": user_id},
order={"updatedAt": "desc"},
take=limit,
skip=offset,
)
async def get_user_session_count(user_id: str) -> int:
"""Get the total number of chat sessions for a user."""
return await PrismaChatSession.prisma().count(where={"userId": user_id})
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
"""Delete a chat session and all its messages.
Args:
session_id: The session ID to delete.
user_id: If provided, validates that the session belongs to this user
before deletion. This prevents unauthorized deletion of other
users' sessions.
Returns:
True if deleted successfully, False otherwise.
"""
try:
# Build typed where clause with optional user_id validation
where_clause: ChatSessionWhereInput = {"id": session_id}
if user_id is not None:
where_clause["userId"] = user_id
result = await PrismaChatSession.prisma().delete_many(where=where_clause)
if result == 0:
logger.warning(
f"No session deleted for {session_id} "
f"(user_id validation: {user_id is not None})"
)
return False
return True
except Exception as e:
logger.error(f"Failed to delete chat session {session_id}: {e}")
return False
async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count

View File

@@ -1,9 +1,6 @@
import asyncio
import logging
import uuid
from datetime import UTC, datetime
from typing import Any
from weakref import WeakValueDictionary
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
@@ -19,63 +16,17 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from pydantic import BaseModel
from backend.data.redis_client import get_redis_async
from backend.util import json
from backend.util.exceptions import DatabaseError, RedisError
from backend.util.exceptions import RedisError
from . import db as chat_db
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
"""Parse a JSON field that may be stored as string or already parsed."""
if value is None:
return default
if isinstance(value, str):
return json.loads(value)
return value
# Redis cache key prefix for chat sessions
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
def _get_session_cache_key(session_id: str) -> str:
"""Get the Redis cache key for a chat session."""
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
# Session-level locks to prevent race conditions during concurrent upserts.
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock
class ChatMessage(BaseModel):
role: str
content: str | None = None
@@ -94,8 +45,7 @@ class Usage(BaseModel):
class ChatSession(BaseModel):
session_id: str
user_id: str
title: str | None = None
user_id: str | None
messages: list[ChatMessage]
usage: list[Usage]
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
@@ -105,11 +55,10 @@ class ChatSession(BaseModel):
successful_agent_schedules: dict[str, int] = {}
@staticmethod
def new(user_id: str) -> "ChatSession":
def new(user_id: str | None) -> "ChatSession":
return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
title=None,
messages=[],
usage=[],
credentials={},
@@ -117,61 +66,6 @@ class ChatSession(BaseModel):
updated_at=datetime.now(UTC),
)
@staticmethod
def from_db(
prisma_session: PrismaChatSession,
prisma_messages: list[PrismaChatMessage] | None = None,
) -> "ChatSession":
"""Convert Prisma models to Pydantic ChatSession."""
messages = []
if prisma_messages:
for msg in prisma_messages:
messages.append(
ChatMessage(
role=msg.role,
content=msg.content,
name=msg.name,
tool_call_id=msg.toolCallId,
refusal=msg.refusal,
tool_calls=_parse_json_field(msg.toolCalls),
function_call=_parse_json_field(msg.functionCall),
)
)
# Parse JSON fields from Prisma
credentials = _parse_json_field(prisma_session.credentials, default={})
successful_agent_runs = _parse_json_field(
prisma_session.successfulAgentRuns, default={}
)
successful_agent_schedules = _parse_json_field(
prisma_session.successfulAgentSchedules, default={}
)
# Calculate usage from token counts
usage = []
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
usage.append(
Usage(
prompt_tokens=prisma_session.totalPromptTokens or 0,
completion_tokens=prisma_session.totalCompletionTokens or 0,
total_tokens=(prisma_session.totalPromptTokens or 0)
+ (prisma_session.totalCompletionTokens or 0),
)
)
return ChatSession(
session_id=prisma_session.id,
user_id=prisma_session.userId,
title=prisma_session.title,
messages=messages,
usage=usage,
credentials=credentials,
started_at=prisma_session.createdAt,
updated_at=prisma_session.updatedAt,
successful_agent_runs=successful_agent_runs,
successful_agent_schedules=successful_agent_schedules,
)
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
messages = []
for message in self.messages:
@@ -261,337 +155,50 @@ class ChatSession(BaseModel):
return messages
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
"""Get a chat session from Redis cache."""
redis_key = _get_session_cache_key(session_id)
async def get_chat_session(
session_id: str,
user_id: str | None,
) -> ChatSession | None:
"""Get a chat session by ID."""
redis_key = f"chat:session:{session_id}"
async_redis = await get_redis_async()
raw_session: bytes | None = await async_redis.get(redis_key)
if raw_session is None:
logger.warning(f"Session {session_id} not found in Redis")
return None
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
async def _cache_session(session: ChatSession) -> None:
"""Cache a chat session in Redis."""
redis_key = _get_session_cache_key(session.session_id)
async_redis = await get_redis_async()
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id)
if not prisma_session:
return None
messages = prisma_session.Messages
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
)
return ChatSession.from_db(prisma_session, messages)
async def _save_session_to_db(
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database."""
# Check if session exists in DB
existing = await chat_db.get_chat_session(session.session_id)
if not existing:
# Create new session
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await chat_db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def get_chat_session(
session_id: str,
user_id: str | None = None,
) -> ChatSession | None:
"""Get a chat session by ID.
Checks Redis cache first, falls back to database if not found.
Caches database results back to Redis.
Args:
session_id: The session ID to fetch.
user_id: If provided, validates that the session belongs to this user.
If None, ownership is not validated (admin/system access).
"""
# Try cache first
try:
session = await _get_session_from_cache(session_id)
if session:
# Verify user ownership if user_id was provided for validation
if user_id is not None and session.user_id != user_id:
logger.warning(
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
)
return None
return session
except RedisError:
logger.warning(f"Cache error for session {session_id}, trying database")
except Exception as e:
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
# Fall back to database
logger.info(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id)
if session is None:
logger.warning(f"Session {session_id} not found in cache or database")
return None
# Verify user ownership if user_id was provided for validation
if user_id is not None and session.user_id != user_id:
if session.user_id is not None and session.user_id != user_id:
logger.warning(
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
)
return None
# Cache the session from DB
try:
await _cache_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
return session
async def upsert_chat_session(
session: ChatSession,
) -> ChatSession:
"""Update a chat session in both cache and database.
"""Update a chat session with the given messages."""
Uses session-level locking to prevent race conditions when concurrent
operations (e.g., background title update and main stream handler)
attempt to upsert the same session simultaneously.
redis_key = f"chat:session:{session.session_id}"
Raises:
DatabaseError: If the database write fails. The cache is still updated
as a best-effort optimization, but the error is propagated to ensure
callers are aware of the persistence failure.
RedisError: If the cache write fails (after successful DB write).
"""
# Acquire session-specific lock to prevent concurrent upserts
lock = await _get_session_lock(session.session_id)
async_redis = await get_redis_async()
resp = await async_redis.setex(
redis_key, config.session_ttl, session.model_dump_json()
)
async with lock:
# Get existing message count from DB for incremental saves
existing_message_count = await chat_db.get_chat_session_message_count(
session.session_id
if not resp:
raise RedisError(
f"Failed to persist chat session {session.session_id} to Redis: {resp}"
)
db_error: Exception | None = None
# Save to database (primary storage)
try:
await _save_session_to_db(session, existing_message_count)
except Exception as e:
logger.error(
f"Failed to save session {session.session_id} to database: {e}"
)
db_error = e
# Save to cache (best-effort, even if DB failed)
try:
await _cache_session(session)
except Exception as e:
# If DB succeeded but cache failed, raise cache error
if db_error is None:
raise RedisError(
f"Failed to persist chat session {session.session_id} to Redis: {e}"
) from e
# If both failed, log cache error but raise DB error (more critical)
logger.warning(
f"Cache write also failed for session {session.session_id}: {e}"
)
# Propagate DB error after attempting cache (prevents data loss)
if db_error is not None:
raise DatabaseError(
f"Failed to persist chat session {session.session_id} to database"
) from db_error
return session
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
Raises:
DatabaseError: If the database write fails. We fail fast to ensure
callers never receive a non-persisted session that only exists
in cache (which would be lost when the cache expires).
"""
session = ChatSession.new(user_id)
# Create in database first - fail fast if this fails
try:
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=user_id,
)
except Exception as e:
logger.error(f"Failed to create session {session.session_id} in database: {e}")
raise DatabaseError(
f"Failed to create chat session {session.session_id} in database"
) from e
# Cache the session (best-effort optimization, DB is source of truth)
try:
await _cache_session(session)
except Exception as e:
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
return session
async def get_user_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> tuple[list[ChatSession], int]:
"""Get chat sessions for a user from the database with total count.
Returns:
A tuple of (sessions, total_count) where total_count is the overall
number of sessions for the user (not just the current page).
"""
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
total_count = await chat_db.get_user_session_count(user_id)
sessions = []
for prisma_session in prisma_sessions:
# Convert without messages for listing (lighter weight)
sessions.append(ChatSession.from_db(prisma_session, None))
return sessions, total_count
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
"""Delete a chat session from both cache and database.
Args:
session_id: The session ID to delete.
user_id: If provided, validates that the session belongs to this user
before deletion. This prevents unauthorized deletion.
Returns:
True if deleted successfully, False otherwise.
"""
# Delete from database first (with optional user_id validation)
# This confirms ownership before invalidating cache
deleted = await chat_db.delete_chat_session(session_id, user_id)
if not deleted:
return False
# Only invalidate cache and clean up lock after DB confirms deletion
try:
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
return True
async def update_session_title(session_id: str, title: str) -> bool:
"""Update only the title of a chat session.
This is a lightweight operation that doesn't touch messages, avoiding
race conditions with concurrent message updates. Use this for background
title generation instead of upsert_chat_session.
Args:
session_id: The session ID to update.
title: The new title to set.
Returns:
True if updated successfully, False otherwise.
"""
try:
result = await chat_db.update_chat_session(session_id=session_id, title=title)
if result is None:
logger.warning(f"Session {session_id} not found for title update")
return False
# Invalidate cache so next fetch gets updated title
try:
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
return True
except Exception as e:
logger.error(f"Failed to update title for session {session_id}: {e}")
return False

View File

@@ -43,9 +43,9 @@ async def test_chatsession_serialization_deserialization():
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
async def test_chatsession_redis_storage():
s = ChatSession.new(user_id=test_user_id)
s = ChatSession.new(user_id=None)
s.messages = messages
s = await upsert_chat_session(s)
@@ -59,61 +59,12 @@ async def test_chatsession_redis_storage(setup_test_user, test_user_id):
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage_user_id_mismatch(
setup_test_user, test_user_id
):
async def test_chatsession_redis_storage_user_id_mismatch():
s = ChatSession.new(user_id=test_user_id)
s = ChatSession.new(user_id="abc123")
s.messages = messages
s = await upsert_chat_session(s)
s2 = await get_chat_session(s.session_id, "different_user_id")
s2 = await get_chat_session(s.session_id, None)
assert s2 is None
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_db_storage(setup_test_user, test_user_id):
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
from backend.data.redis_client import get_redis_async
# Create session with messages including assistant message
s = ChatSession.new(user_id=test_user_id)
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
s = await upsert_chat_session(s)
# Clear the Redis cache to force DB load
redis_key = f"chat:session:{s.session_id}"
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
# Load from DB (cache was cleared)
s2 = await get_chat_session(
session_id=s.session_id,
user_id=s.user_id,
)
assert s2 is not None, "Session not found after loading from DB"
assert len(s2.messages) == len(
s.messages
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
# Verify all roles are present
roles = [m.role for m in s2.messages]
assert "user" in roles, f"User message missing. Roles found: {roles}"
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
# Verify message content
for orig, loaded in zip(s.messages, s2.messages):
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
assert (
orig.content == loaded.content
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
if orig.tool_calls:
assert (
loaded.tool_calls is not None
), f"Tool calls missing for {orig.role} message"
assert len(orig.tool_calls) == len(loaded.tool_calls)

View File

@@ -0,0 +1,104 @@
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find and set up AutoGPT agents to solve their business problems.
Here are the functions available to you:
<functions>
1. **find_agent** - Search for agents that solve the user's problem
2. **run_agent** - Run or schedule an agent (automatically handles setup)
</functions>
## HOW run_agent WORKS
The `run_agent` tool automatically handles the entire setup flow:
1. **First call** (no inputs) → Returns available inputs so user can decide what values to use
2. **Credentials check** → If missing, UI automatically prompts user to add them (you don't need to mention this)
3. **Execution** → Runs when you provide `inputs` OR set `use_defaults=true`
Parameters:
- `username_agent_slug` (required): Agent identifier like "creator/agent-name"
- `inputs`: Object with input values for the agent
- `use_defaults`: Set to `true` to run with default values (only after user confirms)
- `schedule_name` + `cron`: For scheduled execution
## WORKFLOW
1. **find_agent** - Search for agents that solve the user's problem
2. **run_agent** (first call, no inputs) - Get available inputs for the agent
3. **Ask user** what values they want to use OR if they want to use defaults
4. **run_agent** (second call) - Either with `inputs={...}` or `use_defaults=true`
## YOUR APPROACH
**Step 1: Understand the Problem**
- Ask maximum 1-2 targeted questions
- Focus on: What business problem are they solving?
- Move quickly to searching for solutions
**Step 2: Find Agents**
- Use `find_agent` immediately with relevant keywords
- Suggest the best option from search results
- Explain briefly how it solves their problem
**Step 3: Get Agent Inputs**
- Call `run_agent(username_agent_slug="creator/agent-name")` without inputs
- This returns the available inputs (required and optional)
- Present these to the user and ask what values they want
**Step 4: Run with User's Choice**
- If user provides values: `run_agent(username_agent_slug="...", inputs={...})`
- If user says "use defaults": `run_agent(username_agent_slug="...", use_defaults=true)`
- On success, share the agent link with the user
**For Scheduled Execution:**
- Add `schedule_name` and `cron` parameters
- Example: `run_agent(username_agent_slug="...", inputs={...}, schedule_name="Daily Report", cron="0 9 * * *")`
## FUNCTION CALL FORMAT
To call a function, use this exact format:
`<function_call>function_name(parameter="value")</function_call>`
Examples:
- `<function_call>find_agent(query="social media automation")</function_call>`
- `<function_call>run_agent(username_agent_slug="creator/agent-name")</function_call>` (get inputs)
- `<function_call>run_agent(username_agent_slug="creator/agent-name", inputs={"topic": "AI news"})</function_call>`
- `<function_call>run_agent(username_agent_slug="creator/agent-name", use_defaults=true)</function_call>`
## KEY RULES
**What You DON'T Do:**
- Don't help with login (frontend handles this)
- Don't mention or explain credentials to the user (frontend handles this automatically)
- Don't run agents without first showing available inputs to the user
- Don't use `use_defaults=true` without user explicitly confirming
- Don't write responses longer than 3 sentences
**What You DO:**
- Always call run_agent first without inputs to see what's available
- Ask user what values they want OR if they want to use defaults
- Keep all responses to maximum 3 sentences
- Include the agent link in your response after successful execution
**Error Handling:**
- Authentication needed → "Please sign in via the interface"
- Credentials missing → The UI handles this automatically. Focus on asking the user about input values instead.
## RESPONSE STRUCTURE
Before responding, wrap your analysis in <thinking> tags to systematically plan your approach:
- Extract the key business problem or request from the user's message
- Determine what function call (if any) you need to make next
- Plan your response to stay under the 3-sentence maximum
Example interaction:
```
User: "Run the AI news agent for me"
Otto: <function_call>run_agent(username_agent_slug="autogpt/ai-news")</function_call>
[Tool returns: Agent accepts inputs - Required: topic. Optional: num_articles (default: 5)]
Otto: The AI News agent needs a topic. What topic would you like news about, or should I use the defaults?
User: "Use defaults"
Otto: <function_call>run_agent(username_agent_slug="autogpt/ai-news", use_defaults=true)</function_call>
```
KEEP ANSWERS TO 3 SENTENCES

View File

@@ -1,10 +1,3 @@
"""
Response models for Vercel AI SDK UI Stream Protocol.
This module implements the AI SDK UI Stream Protocol (v1) for streaming chat responses.
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
"""
from enum import Enum
from typing import Any
@@ -12,133 +5,97 @@ from pydantic import BaseModel, Field
class ResponseType(str, Enum):
"""Types of streaming responses following AI SDK protocol."""
"""Types of streaming responses."""
# Message lifecycle
START = "start"
FINISH = "finish"
# Text streaming
TEXT_START = "text-start"
TEXT_DELTA = "text-delta"
TEXT_END = "text-end"
# Tool interaction
TOOL_INPUT_START = "tool-input-start"
TOOL_INPUT_AVAILABLE = "tool-input-available"
TOOL_OUTPUT_AVAILABLE = "tool-output-available"
# Other
TEXT_CHUNK = "text_chunk"
TEXT_ENDED = "text_ended"
TOOL_CALL = "tool_call"
TOOL_CALL_START = "tool_call_start"
TOOL_RESPONSE = "tool_response"
ERROR = "error"
USAGE = "usage"
STREAM_END = "stream_end"
class StreamBaseResponse(BaseModel):
"""Base response model for all streaming responses."""
type: ResponseType
timestamp: str | None = None
def to_sse(self) -> str:
"""Convert to SSE format."""
return f"data: {self.model_dump_json()}\n\n"
# ========== Message Lifecycle ==========
class StreamTextChunk(StreamBaseResponse):
"""Streaming text content from the assistant."""
type: ResponseType = ResponseType.TEXT_CHUNK
content: str = Field(..., description="Text content chunk")
class StreamStart(StreamBaseResponse):
"""Start of a new message."""
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
class StreamFinish(StreamBaseResponse):
"""End of message/stream."""
type: ResponseType = ResponseType.FINISH
# ========== Text Streaming ==========
class StreamTextStart(StreamBaseResponse):
"""Start of a text block."""
type: ResponseType = ResponseType.TEXT_START
id: str = Field(..., description="Text block ID")
class StreamTextDelta(StreamBaseResponse):
"""Streaming text content delta."""
type: ResponseType = ResponseType.TEXT_DELTA
id: str = Field(..., description="Text block ID")
delta: str = Field(..., description="Text content delta")
class StreamTextEnd(StreamBaseResponse):
"""End of a text block."""
type: ResponseType = ResponseType.TEXT_END
id: str = Field(..., description="Text block ID")
# ========== Tool Interaction ==========
class StreamToolInputStart(StreamBaseResponse):
class StreamToolCallStart(StreamBaseResponse):
"""Tool call started notification."""
type: ResponseType = ResponseType.TOOL_INPUT_START
toolCallId: str = Field(..., description="Unique tool call ID")
toolName: str = Field(..., description="Name of the tool being called")
type: ResponseType = ResponseType.TOOL_CALL_START
tool_name: str = Field(..., description="Name of the tool that was executed")
tool_id: str = Field(..., description="Unique tool call ID")
class StreamToolInputAvailable(StreamBaseResponse):
"""Tool input is ready for execution."""
class StreamToolCall(StreamBaseResponse):
"""Tool invocation notification."""
type: ResponseType = ResponseType.TOOL_INPUT_AVAILABLE
toolCallId: str = Field(..., description="Unique tool call ID")
toolName: str = Field(..., description="Name of the tool being called")
input: dict[str, Any] = Field(
default_factory=dict, description="Tool input arguments"
type: ResponseType = ResponseType.TOOL_CALL
tool_id: str = Field(..., description="Unique tool call ID")
tool_name: str = Field(..., description="Name of the tool being called")
arguments: dict[str, Any] = Field(
default_factory=dict, description="Tool arguments"
)
class StreamToolOutputAvailable(StreamBaseResponse):
class StreamToolExecutionResult(StreamBaseResponse):
"""Tool execution result."""
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
toolCallId: str = Field(..., description="Tool call ID this responds to")
output: str | dict[str, Any] = Field(..., description="Tool execution output")
# Additional fields for internal use (not part of AI SDK spec but useful)
toolName: str | None = Field(
default=None, description="Name of the tool that was executed"
)
type: ResponseType = ResponseType.TOOL_RESPONSE
tool_id: str = Field(..., description="Tool call ID this responds to")
tool_name: str = Field(..., description="Name of the tool that was executed")
result: str | dict[str, Any] = Field(..., description="Tool execution result")
success: bool = Field(
default=True, description="Whether the tool execution succeeded"
)
# ========== Other ==========
class StreamUsage(StreamBaseResponse):
"""Token usage statistics."""
type: ResponseType = ResponseType.USAGE
promptTokens: int = Field(..., description="Number of prompt tokens")
completionTokens: int = Field(..., description="Number of completion tokens")
totalTokens: int = Field(..., description="Total number of tokens")
prompt_tokens: int
completion_tokens: int
total_tokens: int
class StreamError(StreamBaseResponse):
"""Error response."""
type: ResponseType = ResponseType.ERROR
errorText: str = Field(..., description="Error message text")
message: str = Field(..., description="Error message")
code: str | None = Field(default=None, description="Error code")
details: dict[str, Any] | None = Field(
default=None, description="Additional error details"
)
class StreamTextEnded(StreamBaseResponse):
"""Text streaming completed marker."""
type: ResponseType = ResponseType.TEXT_ENDED
class StreamEnd(StreamBaseResponse):
"""End of stream marker."""
type: ResponseType = ResponseType.STREAM_END
summary: dict[str, Any] | None = Field(
default=None, description="Stream summary statistics"
)

View File

@@ -13,25 +13,12 @@ from backend.util.exceptions import NotFoundError
from . import service as chat_service
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
config = ChatConfig()
logger = logging.getLogger(__name__)
async def _validate_and_get_session(
session_id: str,
user_id: str | None,
) -> ChatSession:
"""Validate session exists and belongs to user."""
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found.")
return session
router = APIRouter(
tags=["chat"],
)
@@ -39,14 +26,6 @@ router = APIRouter(
# ========== Request/Response Models ==========
class StreamChatRequest(BaseModel):
"""Request model for streaming chat with optional context."""
message: str
is_user_message: bool = True
context: dict[str, str] | None = None # {url: str, content: str}
class CreateSessionResponse(BaseModel):
"""Response model containing information on a newly created chat session."""
@@ -65,77 +44,22 @@ class SessionDetailResponse(BaseModel):
messages: list[dict]
class SessionSummaryResponse(BaseModel):
"""Response model for a session summary (without messages)."""
id: str
created_at: str
updated_at: str
title: str | None = None
class ListSessionsResponse(BaseModel):
"""Response model for listing chat sessions."""
sessions: list[SessionSummaryResponse]
total: int
# ========== Routes ==========
@router.get(
"/sessions",
dependencies=[Security(auth.requires_user)],
)
async def list_sessions(
user_id: Annotated[str, Security(auth.get_user_id)],
limit: int = Query(default=50, ge=1, le=100),
offset: int = Query(default=0, ge=0),
) -> ListSessionsResponse:
"""
List chat sessions for the authenticated user.
Returns a paginated list of chat sessions belonging to the current user,
ordered by most recently updated.
Args:
user_id: The authenticated user's ID.
limit: Maximum number of sessions to return (1-100).
offset: Number of sessions to skip for pagination.
Returns:
ListSessionsResponse: List of session summaries and total count.
"""
sessions, total_count = await get_user_sessions(user_id, limit, offset)
return ListSessionsResponse(
sessions=[
SessionSummaryResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
title=session.title,
)
for session in sessions
],
total=total_count,
)
@router.post(
"/sessions",
)
async def create_session(
user_id: Annotated[str, Depends(auth.get_user_id)],
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CreateSessionResponse:
"""
Create a new chat session.
Initiates a new chat session for the authenticated user.
Initiates a new chat session for either an authenticated or anonymous user.
Args:
user_id: The authenticated user ID parsed from the JWT (required).
user_id: The optional authenticated user ID parsed from the JWT. If missing, creates an anonymous session.
Returns:
CreateSessionResponse: Details of the created session.
@@ -143,15 +67,15 @@ async def create_session(
"""
logger.info(
f"Creating session with user_id: "
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
)
session = await create_chat_session(user_id)
session = await chat_service.create_chat_session(user_id)
return CreateSessionResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
user_id=session.user_id,
user_id=session.user_id or None,
)
@@ -175,88 +99,29 @@ async def get_session(
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
"""
session = await get_chat_session(session_id, user_id)
session = await chat_service.get_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found")
messages = [message.model_dump() for message in session.messages]
logger.info(
f"Returning session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
)
@router.post(
"/sessions/{session_id}/stream",
)
async def stream_chat_post(
session_id: str,
request: StreamChatRequest,
user_id: str | None = Depends(auth.get_user_id),
):
"""
Stream chat responses for a session (POST with context support).
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
- Tool call UI elements (if invoked)
- Tool execution results
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
"""
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
yield chunk.to_sse()
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
messages=[message.model_dump() for message in session.messages],
)
@router.get(
"/sessions/{session_id}/stream",
)
async def stream_chat_get(
async def stream_chat(
session_id: str,
message: Annotated[str, Query(min_length=1, max_length=10000)],
user_id: str | None = Depends(auth.get_user_id),
is_user_message: bool = Query(default=True),
):
"""
Stream chat responses for a session (GET - legacy endpoint).
Stream chat responses for a session.
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
@@ -272,7 +137,14 @@ async def stream_chat_get(
StreamingResponse: SSE-formatted response chunks.
"""
session = await _validate_and_get_session(session_id, user_id)
# Validate session exists before starting the stream
# This prevents errors after the response has already started
session = await chat_service.get_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found. ")
if session.user_id is None and user_id is not None:
session = await chat_service.assign_user_to_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
async for chunk in chat_service.stream_chat_completion(
@@ -283,8 +155,6 @@ async def stream_chat_get(
session=session, # Pass pre-fetched session to avoid double-fetch
):
yield chunk.to_sse()
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -293,7 +163,6 @@ async def stream_chat_get(
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
)
@@ -332,28 +201,16 @@ async def health_check() -> dict:
"""
Health check endpoint for the chat service.
Performs a full cycle test of session creation and retrieval. Should always return healthy
Performs a full cycle test of session creation, assignment, and retrieval. Should always return healthy
if the service and data layer are operational.
Returns:
dict: A status dictionary indicating health, service name, and API version.
"""
from backend.data.user import get_or_create_user
# Ensure health check user exists (required for FK constraint)
health_check_user_id = "health-check-user"
await get_or_create_user(
{
"sub": health_check_user_id,
"email": "health-check@system.local",
"user_metadata": {"name": "Health Check User"},
}
)
# Create and retrieve session to verify full data layer
session = await create_chat_session(health_check_user_id)
await get_chat_session(session.session_id, health_check_user_id)
session = await chat_service.create_chat_session(None)
await chat_service.assign_user_to_session(session.session_id, "test_user")
await chat_service.get_session(session.session_id, "test_user")
return {
"status": "healthy",

File diff suppressed because it is too large Load Diff

View File

@@ -4,19 +4,18 @@ from os import getenv
import pytest
from . import service as chat_service
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import (
StreamEnd,
StreamError,
StreamFinish,
StreamTextDelta,
StreamToolOutputAvailable,
StreamTextChunk,
StreamToolExecutionResult,
)
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion(setup_test_user, test_user_id):
async def test_stream_chat_completion():
"""
Test the stream_chat_completion function.
"""
@@ -24,7 +23,7 @@ async def test_stream_chat_completion(setup_test_user, test_user_id):
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await chat_service.create_chat_session()
has_errors = False
has_ended = False
@@ -35,9 +34,9 @@ async def test_stream_chat_completion(setup_test_user, test_user_id):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamTextDelta):
assistant_message += chunk.delta
if isinstance(chunk, StreamFinish):
if isinstance(chunk, StreamTextChunk):
assistant_message += chunk.content
if isinstance(chunk, StreamEnd):
has_ended = True
assert has_ended, "Chat completion did not end"
@@ -46,7 +45,7 @@ async def test_stream_chat_completion(setup_test_user, test_user_id):
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
async def test_stream_chat_completion_with_tool_calls():
"""
Test the stream_chat_completion function.
"""
@@ -54,8 +53,8 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
session = await chat_service.create_chat_session()
session = await chat_service.upsert_chat_session(session)
has_errors = False
has_ended = False
@@ -69,14 +68,14 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamFinish):
if isinstance(chunk, StreamEnd):
has_ended = True
if isinstance(chunk, StreamToolOutputAvailable):
if isinstance(chunk, StreamToolExecutionResult):
had_tool_calls = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert had_tool_calls, "Tool calls did not occur"
session = await get_chat_session(session.session_id)
session = await chat_service.get_session(session.session_id)
assert session, "Session not found"
assert session.usage, "Usage is empty"

View File

@@ -4,32 +4,21 @@ from openai.types.chat import ChatCompletionToolParam
from backend.api.features.chat.model import ChatSession
from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .find_agent import FindAgentTool
from .find_library_agent import FindLibraryAgentTool
from .run_agent import RunAgentTool
if TYPE_CHECKING:
from backend.api.features.chat.response_model import StreamToolOutputAvailable
from backend.api.features.chat.response_model import StreamToolExecutionResult
# Single source of truth for all tools
TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(),
"find_agent": FindAgentTool(),
"find_library_agent": FindLibraryAgentTool(),
"run_agent": RunAgentTool(),
"agent_output": AgentOutputTool(),
}
# Initialize tool instances
find_agent_tool = FindAgentTool()
run_agent_tool = RunAgentTool()
# Export individual tool instances for backwards compatibility
find_agent_tool = TOOL_REGISTRY["find_agent"]
run_agent_tool = TOOL_REGISTRY["run_agent"]
# Generated from registry for OpenAI API
# Export tools as OpenAI format
tools: list[ChatCompletionToolParam] = [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
find_agent_tool.as_openai_tool(),
run_agent_tool.as_openai_tool(),
]
@@ -39,9 +28,14 @@ async def execute_tool(
user_id: str | None,
session: ChatSession,
tool_call_id: str,
) -> "StreamToolOutputAvailable":
"""Execute a tool by name."""
tool = TOOL_REGISTRY.get(tool_name)
if not tool:
) -> "StreamToolExecutionResult":
tool_map: dict[str, BaseTool] = {
"find_agent": find_agent_tool,
"run_agent": run_agent_tool,
}
if tool_name not in tool_map:
raise ValueError(f"Tool {tool_name} not found")
return await tool.execute(user_id, session, tool_call_id, **parameters)
return await tool_map[tool_name].execute(
user_id, session, tool_call_id, **parameters
)

View File

@@ -3,7 +3,6 @@ from datetime import UTC, datetime
from os import getenv
import pytest
from prisma.types import ProfileCreateInput
from pydantic import SecretStr
from backend.api.features.chat.model import ChatSession
@@ -18,7 +17,7 @@ from backend.data.user import get_or_create_user
from backend.integrations.credentials_store import IntegrationCredentialsStore
def make_session(user_id: str):
def make_session(user_id: str | None = None):
return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
@@ -50,13 +49,13 @@ async def setup_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile",
links=[], # Required field - empty array for test profiles
)
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile",
"links": [], # Required field - empty array for test profiles
}
)
# 2. Create a test graph with agent input -> agent output
@@ -173,13 +172,13 @@ async def setup_llm_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile for LLM tests",
links=[], # Required field - empty array for test profiles
)
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for LLM tests",
"links": [], # Required field - empty array for test profiles
}
)
# 2. Create test OpenAI credentials for the user
@@ -333,13 +332,13 @@ async def setup_firecrawl_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile for Firecrawl tests",
links=[], # Required field - empty array for test profiles
)
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for Firecrawl tests",
"links": [], # Required field - empty array for test profiles
}
)
# NOTE: We deliberately do NOT create Firecrawl credentials for this user

View File

@@ -1,119 +0,0 @@
"""Tool for capturing user business understanding incrementally."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.data.understanding import (
BusinessUnderstandingInput,
upsert_business_understanding,
)
from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
logger = logging.getLogger(__name__)
class AddUnderstandingTool(BaseTool):
"""Tool for capturing user's business understanding incrementally."""
@property
def name(self) -> str:
return "add_understanding"
@property
def description(self) -> str:
return """Capture and store information about the user's business context,
workflows, pain points, and automation goals. Call this tool whenever the user
shares information about their business. Each call incrementally adds to the
existing understanding - you don't need to provide all fields at once.
Use this to build a comprehensive profile that helps recommend better agents
and automations for the user's specific needs."""
@property
def parameters(self) -> dict[str, Any]:
# Auto-generate from Pydantic model schema
schema = BusinessUnderstandingInput.model_json_schema()
properties = {}
for field_name, field_schema in schema.get("properties", {}).items():
prop: dict[str, Any] = {"description": field_schema.get("description", "")}
# Handle anyOf for Optional types
if "anyOf" in field_schema:
for option in field_schema["anyOf"]:
if option.get("type") != "null":
prop["type"] = option.get("type", "string")
if "items" in option:
prop["items"] = option["items"]
break
else:
prop["type"] = field_schema.get("type", "string")
if "items" in field_schema:
prop["items"] = field_schema["items"]
properties[field_name] = prop
return {"type": "object", "properties": properties, "required": []}
@property
def requires_auth(self) -> bool:
"""Requires authentication to store user-specific data."""
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""
Capture and store business understanding incrementally.
Each call merges new data with existing understanding:
- String fields are overwritten if provided
- List fields are appended (with deduplication)
"""
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required to save business understanding.",
session_id=session_id,
)
# Check if any data was provided
if not any(v is not None for v in kwargs.values()):
return ErrorResponse(
message="Please provide at least one field to update.",
session_id=session_id,
)
# Build input model from kwargs (only include fields defined in the model)
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
input_data = BusinessUnderstandingInput(
**{k: v for k, v in kwargs.items() if k in valid_fields}
)
# Track which fields were updated
updated_fields = [
k for k, v in kwargs.items() if k in valid_fields and v is not None
]
# Upsert with merge
understanding = await upsert_business_understanding(user_id, input_data)
# Build current understanding summary (filter out empty values)
current_understanding = {
k: v
for k, v in understanding.model_dump(
exclude={"id", "user_id", "created_at", "updated_at"}
).items()
if v is not None and v != [] and v != ""
}
return UnderstandingUpdatedResponse(
message=f"Updated understanding with: {', '.join(updated_fields)}. "
"I now have a better picture of your business context.",
session_id=session_id,
updated_fields=updated_fields,
current_understanding=current_understanding,
)

View File

@@ -1,446 +0,0 @@
"""Tool for retrieving agent execution outputs from user's library."""
import logging
import re
from datetime import datetime, timedelta, timezone
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.api.features.library.model import LibraryAgent
from backend.data import execution as execution_db
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from .base import BaseTool
from .models import (
AgentOutputResponse,
ErrorResponse,
ExecutionOutputInfo,
NoResultsResponse,
ToolResponseBase,
)
from .utils import fetch_graph_from_store_slug
logger = logging.getLogger(__name__)
class AgentOutputInput(BaseModel):
"""Input parameters for the agent_output tool."""
agent_name: str = ""
library_agent_id: str = ""
store_slug: str = ""
execution_id: str = ""
run_time: str = "latest"
@field_validator(
"agent_name",
"library_agent_id",
"store_slug",
"execution_id",
"run_time",
mode="before",
)
@classmethod
def strip_strings(cls, v: Any) -> Any:
"""Strip whitespace from string fields."""
return v.strip() if isinstance(v, str) else v
def parse_time_expression(
time_expr: str | None,
) -> tuple[datetime | None, datetime | None]:
"""
Parse time expression into datetime range (start, end).
Supports: "latest", "yesterday", "today", "last week", "last 7 days",
"last month", "last 30 days", ISO date "YYYY-MM-DD", ISO datetime.
"""
if not time_expr or time_expr.lower() == "latest":
return None, None
now = datetime.now(timezone.utc)
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
expr = time_expr.lower().strip()
# Relative time expressions lookup
relative_times: dict[str, tuple[datetime, datetime]] = {
"yesterday": (today_start - timedelta(days=1), today_start),
"today": (today_start, now),
"last week": (now - timedelta(days=7), now),
"last 7 days": (now - timedelta(days=7), now),
"last month": (now - timedelta(days=30), now),
"last 30 days": (now - timedelta(days=30), now),
}
if expr in relative_times:
return relative_times[expr]
# Try ISO date format (YYYY-MM-DD)
date_match = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", expr)
if date_match:
try:
year, month, day = map(int, date_match.groups())
start = datetime(year, month, day, 0, 0, 0, tzinfo=timezone.utc)
return start, start + timedelta(days=1)
except ValueError:
# Invalid date components (e.g., month=13, day=32)
pass
# Try ISO datetime
try:
parsed = datetime.fromisoformat(expr.replace("Z", "+00:00"))
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed - timedelta(hours=1), parsed + timedelta(hours=1)
except ValueError:
return None, None
class AgentOutputTool(BaseTool):
"""Tool for retrieving execution outputs from user's library agents."""
@property
def name(self) -> str:
return "agent_output"
@property
def description(self) -> str:
return """Retrieve execution outputs from agents in the user's library.
Identify the agent using one of:
- agent_name: Fuzzy search in user's library
- library_agent_id: Exact library agent ID
- store_slug: Marketplace format 'username/agent-name'
Select which run to retrieve using:
- execution_id: Specific execution ID
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
"""
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_name": {
"type": "string",
"description": "Agent name to search for in user's library (fuzzy match)",
},
"library_agent_id": {
"type": "string",
"description": "Exact library agent ID",
},
"store_slug": {
"type": "string",
"description": "Marketplace identifier: 'username/agent-slug'",
},
"execution_id": {
"type": "string",
"description": "Specific execution ID to retrieve",
},
"run_time": {
"type": "string",
"description": (
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return True
async def _resolve_agent(
self,
user_id: str,
agent_name: str | None,
library_agent_id: str | None,
store_slug: str | None,
) -> tuple[LibraryAgent | None, str | None]:
"""
Resolve agent from provided identifiers.
Returns (library_agent, error_message).
"""
# Priority 1: Exact library agent ID
if library_agent_id:
try:
agent = await library_db.get_library_agent(library_agent_id, user_id)
return agent, None
except Exception as e:
logger.warning(f"Failed to get library agent by ID: {e}")
return None, f"Library agent '{library_agent_id}' not found"
# Priority 2: Store slug (username/agent-name)
if store_slug and "/" in store_slug:
username, agent_slug = store_slug.split("/", 1)
graph, _ = await fetch_graph_from_store_slug(username, agent_slug)
if not graph:
return None, f"Agent '{store_slug}' not found in marketplace"
# Find in user's library by graph_id
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
if not agent:
return (
None,
f"Agent '{store_slug}' is not in your library. "
"Add it first to see outputs.",
)
return agent, None
# Priority 3: Fuzzy name search in library
if agent_name:
try:
response = await library_db.list_library_agents(
user_id=user_id,
search_term=agent_name,
page_size=5,
)
if not response.agents:
return (
None,
f"No agents matching '{agent_name}' found in your library",
)
# Return best match (first result from search)
return response.agents[0], None
except Exception as e:
logger.error(f"Error searching library agents: {e}")
return None, f"Error searching for agent: {e}"
return (
None,
"Please specify an agent name, library_agent_id, or store_slug",
)
async def _get_execution(
self,
user_id: str,
graph_id: str,
execution_id: str | None,
time_start: datetime | None,
time_end: datetime | None,
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
"""
Fetch execution(s) based on filters.
Returns (single_execution, available_executions_meta, error_message).
"""
# If specific execution_id provided, fetch it directly
if execution_id:
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not execution:
return None, [], f"Execution '{execution_id}' not found"
return execution, [], None
# Get completed executions with time filters
executions = await execution_db.get_graph_executions(
graph_id=graph_id,
user_id=user_id,
statuses=[ExecutionStatus.COMPLETED],
created_time_gte=time_start,
created_time_lte=time_end,
limit=10,
)
if not executions:
return None, [], None # No error, just no executions
# If only one execution, fetch full details
if len(executions) == 1:
full_execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
)
return full_execution, [], None
# Multiple executions - return latest with full details, plus list of available
full_execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
)
return full_execution, executions, None
def _build_response(
self,
agent: LibraryAgent,
execution: GraphExecution | None,
available_executions: list[GraphExecutionMeta],
session_id: str | None,
) -> AgentOutputResponse:
"""Build the response based on execution data."""
library_agent_link = f"/library/agents/{agent.id}"
if not execution:
return AgentOutputResponse(
message=f"No completed executions found for agent '{agent.name}'",
session_id=session_id,
agent_name=agent.name,
agent_id=agent.graph_id,
library_agent_id=agent.id,
library_agent_link=library_agent_link,
total_executions=0,
)
execution_info = ExecutionOutputInfo(
execution_id=execution.id,
status=execution.status.value,
started_at=execution.started_at,
ended_at=execution.ended_at,
outputs=dict(execution.outputs),
inputs_summary=execution.inputs if execution.inputs else None,
)
available_list = None
if len(available_executions) > 1:
available_list = [
{
"id": e.id,
"status": e.status.value,
"started_at": e.started_at.isoformat() if e.started_at else None,
}
for e in available_executions[:5]
]
message = f"Found execution outputs for agent '{agent.name}'"
if len(available_executions) > 1:
message += (
f". Showing latest of {len(available_executions)} matching executions."
)
return AgentOutputResponse(
message=message,
session_id=session_id,
agent_name=agent.name,
agent_id=agent.graph_id,
library_agent_id=agent.id,
library_agent_link=library_agent_link,
execution=execution_info,
available_executions=available_list,
total_executions=len(available_executions) if available_executions else 1,
)
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the agent_output tool."""
session_id = session.session_id
# Parse and validate input
try:
input_data = AgentOutputInput(**kwargs)
except Exception as e:
logger.error(f"Invalid input: {e}")
return ErrorResponse(
message="Invalid input parameters",
error=str(e),
session_id=session_id,
)
# Ensure user_id is present (should be guaranteed by requires_auth)
if not user_id:
return ErrorResponse(
message="User authentication required",
session_id=session_id,
)
# Check if at least one identifier is provided
if not any(
[
input_data.agent_name,
input_data.library_agent_id,
input_data.store_slug,
input_data.execution_id,
]
):
return ErrorResponse(
message=(
"Please specify at least one of: agent_name, "
"library_agent_id, store_slug, or execution_id"
),
session_id=session_id,
)
# If only execution_id provided, we need to find the agent differently
if (
input_data.execution_id
and not input_data.agent_name
and not input_data.library_agent_id
and not input_data.store_slug
):
# Fetch execution directly to get graph_id
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=input_data.execution_id,
include_node_executions=False,
)
if not execution:
return ErrorResponse(
message=f"Execution '{input_data.execution_id}' not found",
session_id=session_id,
)
# Find library agent by graph_id
agent = await library_db.get_library_agent_by_graph_id(
user_id, execution.graph_id
)
if not agent:
return NoResultsResponse(
message=(
f"Execution found but agent not in your library. "
f"Graph ID: {execution.graph_id}"
),
session_id=session_id,
suggestions=["Add the agent to your library to see more details"],
)
return self._build_response(agent, execution, [], session_id)
# Resolve agent from identifiers
agent, error = await self._resolve_agent(
user_id=user_id,
agent_name=input_data.agent_name or None,
library_agent_id=input_data.library_agent_id or None,
store_slug=input_data.store_slug or None,
)
if error or not agent:
return NoResultsResponse(
message=error or "Agent not found",
session_id=session_id,
suggestions=[
"Check the agent name or ID",
"Make sure the agent is in your library",
],
)
# Parse time expression
time_start, time_end = parse_time_expression(input_data.run_time)
# Fetch execution(s)
execution, available_executions, exec_error = await self._get_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=input_data.execution_id or None,
time_start=time_start,
time_end=time_end,
)
if exec_error:
return ErrorResponse(
message=exec_error,
session_id=session_id,
)
return self._build_response(agent, execution, available_executions, session_id)

View File

@@ -1,151 +0,0 @@
"""Shared agent search functionality for find_agent and find_library_agent tools."""
import logging
from typing import Literal
from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db
from backend.util.exceptions import DatabaseError, NotFoundError
from .models import (
AgentInfo,
AgentsFoundResponse,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
SearchSource = Literal["marketplace", "library"]
async def search_agents(
query: str,
source: SearchSource,
session_id: str | None,
user_id: str | None = None,
) -> ToolResponseBase:
"""
Search for agents in marketplace or user library.
Args:
query: Search query string
source: "marketplace" or "library"
session_id: Chat session ID
user_id: User ID (required for library search)
Returns:
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
"""
if not query:
return ErrorResponse(
message="Please provide a search query", session_id=session_id
)
if source == "library" and not user_id:
return ErrorResponse(
message="User authentication required to search library",
session_id=session_id,
)
agents: list[AgentInfo] = []
try:
if source == "marketplace":
logger.info(f"Searching marketplace for: {query}")
results = await store_db.get_store_agents(search_query=query, page_size=5)
for agent in results.agents:
agents.append(
AgentInfo(
id=f"{agent.creator}/{agent.slug}",
name=agent.agent_name,
description=agent.description or "",
source="marketplace",
in_library=False,
creator=agent.creator,
category="general",
rating=agent.rating,
runs=agent.runs,
is_featured=False,
)
)
else: # library
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
)
logger.info(f"Found {len(agents)} agents in {source}")
except NotFoundError:
pass
except DatabaseError as e:
logger.error(f"Error searching {source}: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to search {source}. Please try again.",
error=str(e),
session_id=session_id,
)
if not agents:
suggestions = (
[
"Try more general terms",
"Browse categories in the marketplace",
"Check spelling",
]
if source == "marketplace"
else [
"Try different keywords",
"Use find_agent to search the marketplace",
"Check your library at /library",
]
)
no_results_msg = (
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
if source == "marketplace"
else f"No agents matching '{query}' found in your library."
)
return NoResultsResponse(
message=no_results_msg, session_id=session_id, suggestions=suggestions
)
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
title += (
f"for '{query}'"
if source == "marketplace"
else f"in your library for '{query}'"
)
message = (
"Now you have found some options for the user to choose from. "
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
"Please ask the user if they would like to use any of these agents."
if source == "marketplace"
else "Found agents in the user's library. You can provide a link to view an agent at: "
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
)
return AgentsFoundResponse(
message=message,
title=title,
agents=agents,
count=len(agents),
session_id=session_id,
)

View File

@@ -6,7 +6,7 @@ from typing import Any
from openai.types.chat import ChatCompletionToolParam
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.response_model import StreamToolOutputAvailable
from backend.api.features.chat.response_model import StreamToolExecutionResult
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
@@ -53,7 +53,7 @@ class BaseTool:
session: ChatSession,
tool_call_id: str,
**kwargs,
) -> StreamToolOutputAvailable:
) -> StreamToolExecutionResult:
"""Execute the tool with authentication check.
Args:
@@ -69,10 +69,10 @@ class BaseTool:
logger.error(
f"Attempted tool call for {self.name} but user not authenticated"
)
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=self.name,
output=NeedLoginResponse(
return StreamToolExecutionResult(
tool_id=tool_call_id,
tool_name=self.name,
result=NeedLoginResponse(
message=f"Please sign in to use {self.name}",
session_id=session.session_id,
).model_dump_json(),
@@ -81,17 +81,17 @@ class BaseTool:
try:
result = await self._execute(user_id, session, **kwargs)
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=self.name,
output=result.model_dump_json(),
return StreamToolExecutionResult(
tool_id=tool_call_id,
tool_name=self.name,
result=result.model_dump_json(),
)
except Exception as e:
logger.error(f"Error in {self.name}: {e}", exc_info=True)
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=self.name,
output=ErrorResponse(
return StreamToolExecutionResult(
tool_id=tool_call_id,
tool_name=self.name,
result=ErrorResponse(
message=f"An error occurred while executing {self.name}",
error=str(e),
session_id=session.session_id,

View File

@@ -1,16 +1,26 @@
"""Tool for discovering agents from marketplace."""
"""Tool for discovering agents from marketplace and user library."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.util.exceptions import DatabaseError, NotFoundError
from .agent_search import search_agents
from .base import BaseTool
from .models import ToolResponseBase
from .models import (
AgentCarouselResponse,
AgentInfo,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class FindAgentTool(BaseTool):
"""Tool for discovering agents from the marketplace."""
"""Tool for discovering agents based on user needs."""
@property
def name(self) -> str:
@@ -36,11 +46,84 @@ class FindAgentTool(BaseTool):
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
return await search_agents(
query=kwargs.get("query", "").strip(),
source="marketplace",
session_id=session.session_id,
user_id=user_id,
"""Search for agents in the marketplace.
Args:
user_id: User ID (may be anonymous)
session_id: Chat session ID
query: Search query
Returns:
AgentCarouselResponse: List of agents found in the marketplace
NoResultsResponse: No agents found in the marketplace
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
session_id = session.session_id
if not query:
return ErrorResponse(
message="Please provide a search query",
session_id=session_id,
)
agents = []
try:
logger.info(f"Searching marketplace for: {query}")
store_results = await store_db.get_store_agents(
search_query=query,
page_size=5,
)
logger.info(f"Find agents tool found {len(store_results.agents)} agents")
for agent in store_results.agents:
agent_id = f"{agent.creator}/{agent.slug}"
logger.info(f"Building agent ID = {agent_id}")
agents.append(
AgentInfo(
id=agent_id,
name=agent.agent_name,
description=agent.description or "",
source="marketplace",
in_library=False,
creator=agent.creator,
category="general",
rating=agent.rating,
runs=agent.runs,
is_featured=False,
),
)
except NotFoundError:
pass
except DatabaseError as e:
logger.error(f"Error searching agents: {e}", exc_info=True)
return ErrorResponse(
message="Failed to search for agents. Please try again.",
error=str(e),
session_id=session_id,
)
if not agents:
return NoResultsResponse(
message=f"No agents found matching '{query}'. Try different keywords or browse the marketplace. If you have 3 consecutive find_agent tool calls results and found no agents. Please stop trying and ask the user if there is anything else you can help with.",
session_id=session_id,
suggestions=[
"Try more general terms",
"Browse categories in the marketplace",
"Check spelling",
],
)
# Return formatted carousel
title = (
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
)
return AgentCarouselResponse(
message="Now you have found some options for the user to choose from. You can add a link to a recommended agent at: /marketplace/agent/agent_id Please ask the user if they would like to use any of these agents. If they do, please call the get_agent_details tool for this agent.",
title=title,
agents=agents,
count=len(agents),
session_id=session_id,
)

View File

@@ -1,52 +0,0 @@
"""Tool for searching agents in the user's library."""
from typing import Any
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
from .base import BaseTool
from .models import ToolResponseBase
class FindLibraryAgentTool(BaseTool):
"""Tool for searching agents in the user's library."""
@property
def name(self) -> str:
return "find_library_agent"
@property
def description(self) -> str:
return (
"Search for agents in the user's library. Use this to find agents "
"the user has already added to their library, including agents they "
"created or added from the marketplace."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query to find agents by name or description.",
},
},
"required": ["query"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
return await search_agents(
query=kwargs.get("query", "").strip(),
source="library",
session_id=session.session_id,
user_id=user_id,
)

View File

@@ -1,6 +1,5 @@
"""Pydantic models for tool responses."""
from datetime import datetime
from enum import Enum
from typing import Any
@@ -12,15 +11,14 @@ from backend.data.model import CredentialsMetaInput
class ResponseType(str, Enum):
"""Types of tool responses."""
AGENTS_FOUND = "agents_found"
AGENT_CAROUSEL = "agent_carousel"
AGENT_DETAILS = "agent_details"
SETUP_REQUIREMENTS = "setup_requirements"
EXECUTION_STARTED = "execution_started"
NEED_LOGIN = "need_login"
ERROR = "error"
NO_RESULTS = "no_results"
AGENT_OUTPUT = "agent_output"
UNDERSTANDING_UPDATED = "understanding_updated"
SUCCESS = "success"
# Base response model
@@ -53,14 +51,14 @@ class AgentInfo(BaseModel):
graph_id: str | None = None
class AgentsFoundResponse(ToolResponseBase):
class AgentCarouselResponse(ToolResponseBase):
"""Response for find_agent tool."""
type: ResponseType = ResponseType.AGENTS_FOUND
type: ResponseType = ResponseType.AGENT_CAROUSEL
title: str = "Available Agents"
agents: list[AgentInfo]
count: int
name: str = "agents_found"
name: str = "agent_carousel"
class NoResultsResponse(ToolResponseBase):
@@ -175,37 +173,3 @@ class ErrorResponse(ToolResponseBase):
type: ResponseType = ResponseType.ERROR
error: str | None = None
details: dict[str, Any] | None = None
# Agent output models
class ExecutionOutputInfo(BaseModel):
"""Summary of a single execution's outputs."""
execution_id: str
status: str
started_at: datetime | None = None
ended_at: datetime | None = None
outputs: dict[str, list[Any]]
inputs_summary: dict[str, Any] | None = None
class AgentOutputResponse(ToolResponseBase):
"""Response for agent_output tool."""
type: ResponseType = ResponseType.AGENT_OUTPUT
agent_name: str
agent_id: str
library_agent_id: str | None = None
library_agent_link: str | None = None
execution: ExecutionOutputInfo | None = None
available_executions: list[dict[str, Any]] | None = None
total_executions: int = 0
# Business understanding models
class UnderstandingUpdatedResponse(ToolResponseBase):
"""Response for add_understanding tool."""
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
updated_fields: list[str] = Field(default_factory=list)
current_understanding: dict[str, Any] = Field(default_factory=dict)

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field, field_validator
from backend.api.features.chat.config import ChatConfig
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.data.user import get_user_by_id
@@ -58,7 +57,6 @@ class RunAgentInput(BaseModel):
"""Input parameters for the run_agent tool."""
username_agent_slug: str = ""
library_agent_id: str = ""
inputs: dict[str, Any] = Field(default_factory=dict)
use_defaults: bool = False
schedule_name: str = ""
@@ -66,12 +64,7 @@ class RunAgentInput(BaseModel):
timezone: str = "UTC"
@field_validator(
"username_agent_slug",
"library_agent_id",
"schedule_name",
"cron",
"timezone",
mode="before",
"username_agent_slug", "schedule_name", "cron", "timezone", mode="before"
)
@classmethod
def strip_strings(cls, v: Any) -> Any:
@@ -97,7 +90,7 @@ class RunAgentTool(BaseTool):
@property
def description(self) -> str:
return """Run or schedule an agent from the marketplace or user's library.
return """Run or schedule an agent from the marketplace.
The tool automatically handles the setup flow:
- Returns missing inputs if required fields are not provided
@@ -105,10 +98,6 @@ class RunAgentTool(BaseTool):
- Executes immediately if all requirements are met
- Schedules execution if cron expression is provided
Identify the agent using either:
- username_agent_slug: Marketplace format 'username/agent-name'
- library_agent_id: ID of an agent in the user's library
For scheduled execution, provide: schedule_name, cron, and optionally timezone."""
@property
@@ -120,10 +109,6 @@ class RunAgentTool(BaseTool):
"type": "string",
"description": "Agent identifier in format 'username/agent-name'",
},
"library_agent_id": {
"type": "string",
"description": "Library agent ID from user's library",
},
"inputs": {
"type": "object",
"description": "Input values for the agent",
@@ -146,7 +131,7 @@ class RunAgentTool(BaseTool):
"description": "IANA timezone for schedule (default: UTC)",
},
},
"required": [],
"required": ["username_agent_slug"],
}
@property
@@ -164,16 +149,10 @@ class RunAgentTool(BaseTool):
params = RunAgentInput(**kwargs)
session_id = session.session_id
# Validate at least one identifier is provided
has_slug = params.username_agent_slug and "/" in params.username_agent_slug
has_library_id = bool(params.library_agent_id)
if not has_slug and not has_library_id:
# Validate agent slug format
if not params.username_agent_slug or "/" not in params.username_agent_slug:
return ErrorResponse(
message=(
"Please provide either a username_agent_slug "
"(format 'username/agent-name') or a library_agent_id"
),
message="Please provide an agent slug in format 'username/agent-name'",
session_id=session_id,
)
@@ -188,41 +167,13 @@ class RunAgentTool(BaseTool):
is_schedule = bool(params.schedule_name or params.cron)
try:
# Step 1: Fetch agent details
graph: GraphModel | None = None
library_agent = None
# Priority: library_agent_id if provided
if has_library_id:
library_agent = await library_db.get_library_agent(
params.library_agent_id, user_id
)
if not library_agent:
return ErrorResponse(
message=f"Library agent '{params.library_agent_id}' not found",
session_id=session_id,
)
# Get the graph from the library agent
from backend.data.graph import get_graph
graph = await get_graph(
library_agent.graph_id,
library_agent.graph_version,
user_id=user_id,
)
else:
# Fetch from marketplace slug
username, agent_name = params.username_agent_slug.split("/", 1)
graph, _ = await fetch_graph_from_store_slug(username, agent_name)
# Step 1: Fetch agent details (always happens first)
username, agent_name = params.username_agent_slug.split("/", 1)
graph, store_agent = await fetch_graph_from_store_slug(username, agent_name)
if not graph:
identifier = (
params.library_agent_id
if has_library_id
else params.username_agent_slug
)
return ErrorResponse(
message=f"Agent '{identifier}' not found",
message=f"Agent '{params.username_agent_slug}' not found in marketplace",
session_id=session_id,
)

View File

@@ -1,5 +1,4 @@
import uuid
from unittest.mock import AsyncMock, patch
import orjson
import pytest
@@ -18,17 +17,6 @@ setup_test_data = setup_test_data
setup_firecrawl_test_data = setup_firecrawl_test_data
@pytest.fixture(scope="session", autouse=True)
def mock_embedding_functions():
"""Mock embedding functions for all tests to avoid database/API dependencies."""
with patch(
"backend.api.features.store.db.ensure_embedding",
new_callable=AsyncMock,
return_value=True,
):
yield
@pytest.mark.asyncio(scope="session")
async def test_run_agent(setup_test_data):
"""Test that the run_agent tool successfully executes an approved agent"""
@@ -58,11 +46,11 @@ async def test_run_agent(setup_test_data):
# Verify the response
assert response is not None
assert hasattr(response, "output")
assert hasattr(response, "result")
# Parse the result JSON to verify the execution started
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert "execution_id" in result_data
assert "graph_id" in result_data
assert result_data["graph_id"] == graph.id
@@ -98,11 +86,11 @@ async def test_run_agent_missing_inputs(setup_test_data):
# Verify that we get an error response
assert response is not None
assert hasattr(response, "output")
assert hasattr(response, "result")
# The tool should return an ErrorResponse when setup info indicates not ready
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert "message" in result_data
@@ -130,10 +118,10 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
# Verify that we get an error response
assert response is not None
assert hasattr(response, "output")
assert hasattr(response, "result")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert "message" in result_data
# Should get an error about failed setup or not found
assert any(
@@ -170,12 +158,12 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
# Verify the response
assert response is not None
assert hasattr(response, "output")
assert hasattr(response, "result")
# Parse the result JSON to verify the execution started
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Should successfully start execution since credentials are available
assert "execution_id" in result_data
@@ -207,9 +195,9 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Should return agent_details type showing available inputs
assert result_data.get("type") == "agent_details"
@@ -242,9 +230,9 @@ async def test_run_agent_with_use_defaults(setup_test_data):
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Should execute successfully
assert "execution_id" in result_data
@@ -272,9 +260,9 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Should return setup_requirements type with missing credentials
assert result_data.get("type") == "setup_requirements"
@@ -304,9 +292,9 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Should return error
assert result_data.get("type") == "error"
@@ -317,10 +305,9 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
async def test_run_agent_unauthenticated():
"""Test that run_agent returns need_login for unauthenticated users."""
tool = RunAgentTool()
# Session has a user_id (session owner), but we test tool execution without user_id
session = make_session(user_id="test-session-owner")
session = make_session(user_id=None)
# Execute without user_id to test unauthenticated behavior
# Execute without user_id
response = await tool.execute(
user_id=None,
session_id=str(uuid.uuid4()),
@@ -331,9 +318,9 @@ async def test_run_agent_unauthenticated():
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Base tool returns need_login type for unauthenticated users
assert result_data.get("type") == "need_login"
@@ -363,9 +350,9 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Should return error about missing cron
assert result_data.get("type") == "error"
@@ -395,9 +382,9 @@ async def test_run_agent_schedule_without_name(setup_test_data):
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Should return error about missing schedule_name
assert result_data.get("type") == "error"

View File

@@ -35,7 +35,11 @@ from backend.data.model import (
OAuth2Credentials,
UserIntegrations,
)
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.onboarding import (
OnboardingStep,
complete_onboarding_step,
increment_runs,
)
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
@@ -171,7 +175,6 @@ async def callback(
f"Successfully processed OAuth callback for user {user_id} "
f"and provider {provider.value}"
)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
@@ -190,7 +193,6 @@ async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
credentials = await creds_manager.store.get_all_creds(user_id)
return [
CredentialsMetaResponse(
id=cred.id,
@@ -213,7 +215,6 @@ async def list_credentials_by_provider(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
return [
CredentialsMetaResponse(
id=cred.id,
@@ -377,6 +378,7 @@ async def webhook_ingress_generic(
return
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
await increment_runs(user_id)
# Execute all triggers concurrently for better performance
tasks = []
@@ -829,18 +831,6 @@ async def list_providers() -> List[str]:
return all_providers
@router.get("/providers/system", response_model=List[str])
async def list_system_providers() -> List[str]:
"""
Get a list of providers that have platform credits (system credentials) available.
These providers can be used without the user providing their own API keys.
"""
from backend.integrations.credentials_store import SYSTEM_PROVIDERS
return list(SYSTEM_PROVIDERS)
@router.get("/providers/names", response_model=ProviderNamesResponse)
async def get_provider_names() -> ProviderNamesResponse:
"""

View File

@@ -8,6 +8,7 @@ from backend.data.execution import GraphExecutionMeta
from backend.data.graph import get_graph
from backend.data.integrations import get_webhook
from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_runs
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks import get_webhook_manager
@@ -402,6 +403,8 @@ async def execute_preset(
merged_node_input = preset.inputs | inputs
merged_credential_inputs = preset.credentials | credential_inputs
await increment_runs(user_id)
return await add_graph_execution(
user_id=user_id,
graph_id=preset.graph_id,

View File

@@ -1,431 +0,0 @@
"""
Content Type Handlers for Unified Embeddings
Pluggable system for different content sources (store agents, blocks, docs).
Each handler knows how to fetch and process its content type for embedding.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from prisma.enums import ContentType
from backend.data.db import query_raw_with_schema
logger = logging.getLogger(__name__)
@dataclass
class ContentItem:
"""Represents a piece of content to be embedded."""
content_id: str # Unique identifier (DB ID or file path)
content_type: ContentType
searchable_text: str # Combined text for embedding
metadata: dict[str, Any] # Content-specific metadata
user_id: str | None = None # For user-scoped content
class ContentHandler(ABC):
"""Base handler for fetching and processing content for embeddings."""
@property
@abstractmethod
def content_type(self) -> ContentType:
"""The ContentType this handler manages."""
pass
@abstractmethod
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""
Fetch items that don't have embeddings yet.
Args:
batch_size: Maximum number of items to return
Returns:
List of ContentItem objects ready for embedding
"""
pass
@abstractmethod
async def get_stats(self) -> dict[str, int]:
"""
Get statistics about embedding coverage.
Returns:
Dict with keys: total, with_embeddings, without_embeddings
"""
pass
class StoreAgentHandler(ContentHandler):
"""Handler for marketplace store agent listings."""
@property
def content_type(self) -> ContentType:
return ContentType.STORE_AGENT
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch approved store listings without embeddings."""
from backend.api.features.store.embeddings import build_searchable_text
missing = await query_raw_with_schema(
"""
SELECT
slv.id,
slv.name,
slv.description,
slv."subHeading",
slv.categories
FROM {schema_prefix}"StoreListingVersion" slv
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
WHERE slv."submissionStatus" = 'APPROVED'
AND slv."isDeleted" = false
AND uce."contentId" IS NULL
LIMIT $1
""",
batch_size,
)
return [
ContentItem(
content_id=row["id"],
content_type=ContentType.STORE_AGENT,
searchable_text=build_searchable_text(
name=row["name"],
description=row["description"],
sub_heading=row["subHeading"],
categories=row["categories"] or [],
),
metadata={
"name": row["name"],
"categories": row["categories"] or [],
},
user_id=None, # Store agents are public
)
for row in missing
]
async def get_stats(self) -> dict[str, int]:
"""Get statistics about store agent embedding coverage."""
# Count approved versions
approved_result = await query_raw_with_schema(
"""
SELECT COUNT(*) as count
FROM {schema_prefix}"StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
AND "isDeleted" = false
"""
)
total_approved = approved_result[0]["count"] if approved_result else 0
# Count versions with embeddings
embedded_result = await query_raw_with_schema(
"""
SELECT COUNT(*) as count
FROM {schema_prefix}"StoreListingVersion" slv
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
WHERE slv."submissionStatus" = 'APPROVED'
AND slv."isDeleted" = false
"""
)
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
return {
"total": total_approved,
"with_embeddings": with_embeddings,
"without_embeddings": total_approved - with_embeddings,
}
class BlockHandler(ContentHandler):
"""Handler for block definitions (Python classes)."""
@property
def content_type(self) -> ContentType:
return ContentType.BLOCK
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch blocks without embeddings."""
from backend.data.block import get_blocks
# Get all available blocks
all_blocks = get_blocks()
# Check which ones have embeddings
if not all_blocks:
return []
block_ids = list(all_blocks.keys())
# Query for existing embeddings
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
existing_result = await query_raw_with_schema(
f"""
SELECT "contentId"
FROM {{schema_prefix}}"UnifiedContentEmbedding"
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
AND "contentId" = ANY(ARRAY[{placeholders}])
""",
*block_ids,
)
existing_ids = {row["contentId"] for row in existing_result}
missing_blocks = [
(block_id, block_cls)
for block_id, block_cls in all_blocks.items()
if block_id not in existing_ids
]
# Convert to ContentItem
items = []
for block_id, block_cls in missing_blocks[:batch_size]:
try:
block_instance = block_cls()
# Build searchable text from block metadata
parts = []
if hasattr(block_instance, "name") and block_instance.name:
parts.append(block_instance.name)
if (
hasattr(block_instance, "description")
and block_instance.description
):
parts.append(block_instance.description)
if hasattr(block_instance, "categories") and block_instance.categories:
# Convert BlockCategory enum to strings
parts.append(
" ".join(str(cat.value) for cat in block_instance.categories)
)
# Add input/output schema info
if hasattr(block_instance, "input_schema"):
schema = block_instance.input_schema
if hasattr(schema, "model_json_schema"):
schema_dict = schema.model_json_schema()
if "properties" in schema_dict:
for prop_name, prop_info in schema_dict[
"properties"
].items():
if "description" in prop_info:
parts.append(
f"{prop_name}: {prop_info['description']}"
)
searchable_text = " ".join(parts)
# Convert categories set of enums to list of strings for JSON serialization
categories = getattr(block_instance, "categories", set())
categories_list = (
[cat.value for cat in categories] if categories else []
)
items.append(
ContentItem(
content_id=block_id,
content_type=ContentType.BLOCK,
searchable_text=searchable_text,
metadata={
"name": getattr(block_instance, "name", ""),
"categories": categories_list,
},
user_id=None, # Blocks are public
)
)
except Exception as e:
logger.warning(f"Failed to process block {block_id}: {e}")
continue
return items
async def get_stats(self) -> dict[str, int]:
"""Get statistics about block embedding coverage."""
from backend.data.block import get_blocks
all_blocks = get_blocks()
total_blocks = len(all_blocks)
if total_blocks == 0:
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
block_ids = list(all_blocks.keys())
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
embedded_result = await query_raw_with_schema(
f"""
SELECT COUNT(*) as count
FROM {{schema_prefix}}"UnifiedContentEmbedding"
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
AND "contentId" = ANY(ARRAY[{placeholders}])
""",
*block_ids,
)
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
return {
"total": total_blocks,
"with_embeddings": with_embeddings,
"without_embeddings": total_blocks - with_embeddings,
}
class DocumentationHandler(ContentHandler):
"""Handler for documentation files (.md/.mdx)."""
@property
def content_type(self) -> ContentType:
return ContentType.DOCUMENTATION
def _get_docs_root(self) -> Path:
"""Get the documentation root directory."""
# content_handlers.py is at: backend/backend/api/features/store/content_handlers.py
# Need to go up to project root then into docs/
# In container: /app/autogpt_platform/backend/backend/api/features/store -> /app/docs
# In development: /repo/autogpt_platform/backend/backend/api/features/store -> /repo/docs
this_file = Path(
__file__
) # .../backend/backend/api/features/store/content_handlers.py
project_root = (
this_file.parent.parent.parent.parent.parent.parent.parent
) # -> /app or /repo
docs_root = project_root / "docs"
return docs_root
def _extract_title_and_content(self, file_path: Path) -> tuple[str, str]:
"""Extract title and content from markdown file."""
try:
content = file_path.read_text(encoding="utf-8")
# Try to extract title from first # heading
lines = content.split("\n")
title = ""
body_lines = []
for line in lines:
if line.startswith("# ") and not title:
title = line[2:].strip()
else:
body_lines.append(line)
# If no title found, use filename
if not title:
title = file_path.stem.replace("-", " ").replace("_", " ").title()
body = "\n".join(body_lines)
return title, body
except Exception as e:
logger.warning(f"Failed to read {file_path}: {e}")
return file_path.stem, ""
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch documentation files without embeddings."""
docs_root = self._get_docs_root()
if not docs_root.exists():
logger.warning(f"Documentation root not found: {docs_root}")
return []
# Find all .md and .mdx files
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
# Get relative paths for content IDs
doc_paths = [str(doc.relative_to(docs_root)) for doc in all_docs]
if not doc_paths:
return []
# Check which ones have embeddings
placeholders = ",".join([f"${i+1}" for i in range(len(doc_paths))])
existing_result = await query_raw_with_schema(
f"""
SELECT "contentId"
FROM {{schema_prefix}}"UnifiedContentEmbedding"
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
AND "contentId" = ANY(ARRAY[{placeholders}])
""",
*doc_paths,
)
existing_ids = {row["contentId"] for row in existing_result}
missing_docs = [
(doc_path, doc_file)
for doc_path, doc_file in zip(doc_paths, all_docs)
if doc_path not in existing_ids
]
# Convert to ContentItem
items = []
for doc_path, doc_file in missing_docs[:batch_size]:
try:
title, content = self._extract_title_and_content(doc_file)
# Build searchable text
searchable_text = f"{title} {content}"
items.append(
ContentItem(
content_id=doc_path,
content_type=ContentType.DOCUMENTATION,
searchable_text=searchable_text,
metadata={
"title": title,
"path": doc_path,
},
user_id=None, # Documentation is public
)
)
except Exception as e:
logger.warning(f"Failed to process doc {doc_path}: {e}")
continue
return items
async def get_stats(self) -> dict[str, int]:
"""Get statistics about documentation embedding coverage."""
docs_root = self._get_docs_root()
if not docs_root.exists():
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
# Count all .md and .mdx files
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
total_docs = len(all_docs)
if total_docs == 0:
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
doc_paths = [str(doc.relative_to(docs_root)) for doc in all_docs]
placeholders = ",".join([f"${i+1}" for i in range(len(doc_paths))])
embedded_result = await query_raw_with_schema(
f"""
SELECT COUNT(*) as count
FROM {{schema_prefix}}"UnifiedContentEmbedding"
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
AND "contentId" = ANY(ARRAY[{placeholders}])
""",
*doc_paths,
)
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
return {
"total": total_docs,
"with_embeddings": with_embeddings,
"without_embeddings": total_docs - with_embeddings,
}
# Content handler registry
CONTENT_HANDLERS: dict[ContentType, ContentHandler] = {
ContentType.STORE_AGENT: StoreAgentHandler(),
ContentType.BLOCK: BlockHandler(),
ContentType.DOCUMENTATION: DocumentationHandler(),
}

View File

@@ -1,215 +0,0 @@
"""
Integration tests for content handlers using real DB.
Run with: poetry run pytest backend/api/features/store/content_handlers_integration_test.py -xvs
These tests use the real database but mock OpenAI calls.
"""
from unittest.mock import patch
import pytest
from backend.api.features.store.content_handlers import (
CONTENT_HANDLERS,
BlockHandler,
DocumentationHandler,
StoreAgentHandler,
)
from backend.api.features.store.embeddings import (
EMBEDDING_DIM,
backfill_all_content_types,
ensure_content_embedding,
get_embedding_stats,
)
@pytest.mark.asyncio(loop_scope="session")
async def test_store_agent_handler_real_db():
"""Test StoreAgentHandler with real database queries."""
handler = StoreAgentHandler()
# Get stats from real DB
stats = await handler.get_stats()
# Stats should have correct structure
assert "total" in stats
assert "with_embeddings" in stats
assert "without_embeddings" in stats
assert stats["total"] >= 0
assert stats["with_embeddings"] >= 0
assert stats["without_embeddings"] >= 0
# Get missing items (max 1 to keep test fast)
items = await handler.get_missing_items(batch_size=1)
# Items should be list (may be empty if all have embeddings)
assert isinstance(items, list)
if items:
item = items[0]
assert item.content_id is not None
assert item.content_type.value == "STORE_AGENT"
assert item.searchable_text != ""
assert item.user_id is None
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_real_db():
"""Test BlockHandler with real database queries."""
handler = BlockHandler()
# Get stats from real DB
stats = await handler.get_stats()
# Stats should have correct structure
assert "total" in stats
assert "with_embeddings" in stats
assert "without_embeddings" in stats
assert stats["total"] >= 0 # Should have at least some blocks
assert stats["with_embeddings"] >= 0
assert stats["without_embeddings"] >= 0
# Get missing items (max 1 to keep test fast)
items = await handler.get_missing_items(batch_size=1)
# Items should be list
assert isinstance(items, list)
if items:
item = items[0]
assert item.content_id is not None # Should be block UUID
assert item.content_type.value == "BLOCK"
assert item.searchable_text != ""
assert item.user_id is None
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_real_fs():
"""Test DocumentationHandler with real filesystem."""
handler = DocumentationHandler()
# Get stats from real filesystem
stats = await handler.get_stats()
# Stats should have correct structure
assert "total" in stats
assert "with_embeddings" in stats
assert "without_embeddings" in stats
assert stats["total"] >= 0
assert stats["with_embeddings"] >= 0
assert stats["without_embeddings"] >= 0
# Get missing items (max 1 to keep test fast)
items = await handler.get_missing_items(batch_size=1)
# Items should be list
assert isinstance(items, list)
if items:
item = items[0]
assert item.content_id is not None # Should be relative path
assert item.content_type.value == "DOCUMENTATION"
assert item.searchable_text != ""
assert item.user_id is None
@pytest.mark.asyncio(loop_scope="session")
async def test_get_embedding_stats_all_types():
"""Test get_embedding_stats aggregates all content types."""
stats = await get_embedding_stats()
# Should have structure with by_type and totals
assert "by_type" in stats
assert "totals" in stats
# Check each content type is present
by_type = stats["by_type"]
assert "STORE_AGENT" in by_type
assert "BLOCK" in by_type
assert "DOCUMENTATION" in by_type
# Check totals are aggregated
totals = stats["totals"]
assert totals["total"] >= 0
assert totals["with_embeddings"] >= 0
assert totals["without_embeddings"] >= 0
assert "coverage_percent" in totals
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
async def test_ensure_content_embedding_blocks(mock_generate):
"""Test creating embeddings for blocks (mocked OpenAI)."""
# Mock OpenAI to return fake embedding
mock_generate.return_value = [0.1] * EMBEDDING_DIM
# Get one block without embedding
handler = BlockHandler()
items = await handler.get_missing_items(batch_size=1)
if not items:
pytest.skip("No blocks without embeddings")
item = items[0]
# Try to create embedding (OpenAI mocked)
result = await ensure_content_embedding(
content_type=item.content_type,
content_id=item.content_id,
searchable_text=item.searchable_text,
metadata=item.metadata,
user_id=item.user_id,
)
# Should succeed with mocked OpenAI
assert result is True
mock_generate.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
async def test_backfill_all_content_types_dry_run(mock_generate):
"""Test backfill_all_content_types processes all handlers in order."""
# Mock OpenAI to return fake embedding
mock_generate.return_value = [0.1] * EMBEDDING_DIM
# Run backfill with batch_size=1 to process max 1 per type
result = await backfill_all_content_types(batch_size=1)
# Should have results for all content types
assert "by_type" in result
assert "totals" in result
by_type = result["by_type"]
assert "BLOCK" in by_type
assert "STORE_AGENT" in by_type
assert "DOCUMENTATION" in by_type
# Each type should have correct structure
for content_type, type_result in by_type.items():
assert "processed" in type_result
assert "success" in type_result
assert "failed" in type_result
# Totals should aggregate
totals = result["totals"]
assert totals["processed"] >= 0
assert totals["success"] >= 0
assert totals["failed"] >= 0
@pytest.mark.asyncio(loop_scope="session")
async def test_content_handler_registry():
"""Test all handlers are registered in correct order."""
from prisma.enums import ContentType
# All three types should be registered
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
assert ContentType.BLOCK in CONTENT_HANDLERS
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
# Check handler types
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)

View File

@@ -1,324 +0,0 @@
"""
E2E tests for content handlers (blocks, store agents, documentation).
Tests the full flow: discovering content → generating embeddings → storing.
"""
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from prisma.enums import ContentType
from backend.api.features.store.content_handlers import (
CONTENT_HANDLERS,
BlockHandler,
DocumentationHandler,
StoreAgentHandler,
)
@pytest.mark.asyncio(loop_scope="session")
async def test_store_agent_handler_get_missing_items(mocker):
"""Test StoreAgentHandler fetches approved agents without embeddings."""
handler = StoreAgentHandler()
# Mock database query
mock_missing = [
{
"id": "agent-1",
"name": "Test Agent",
"description": "A test agent",
"subHeading": "Test heading",
"categories": ["AI", "Testing"],
}
]
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=mock_missing,
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].content_id == "agent-1"
assert items[0].content_type == ContentType.STORE_AGENT
assert "Test Agent" in items[0].searchable_text
assert "A test agent" in items[0].searchable_text
assert items[0].metadata["name"] == "Test Agent"
assert items[0].user_id is None
@pytest.mark.asyncio(loop_scope="session")
async def test_store_agent_handler_get_stats(mocker):
"""Test StoreAgentHandler returns correct stats."""
handler = StoreAgentHandler()
# Mock approved count query
mock_approved = [{"count": 50}]
# Mock embedded count query
mock_embedded = [{"count": 30}]
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
side_effect=[mock_approved, mock_embedded],
):
stats = await handler.get_stats()
assert stats["total"] == 50
assert stats["with_embeddings"] == 30
assert stats["without_embeddings"] == 20
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_missing_items(mocker):
"""Test BlockHandler discovers blocks without embeddings."""
handler = BlockHandler()
# Mock get_blocks to return test blocks
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Calculator Block"
mock_block_instance.description = "Performs calculations"
mock_block_instance.categories = [MagicMock(value="MATH")]
mock_block_instance.input_schema.model_json_schema.return_value = {
"properties": {"expression": {"description": "Math expression to evaluate"}}
}
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-uuid-1": mock_block_class}
# Mock existing embeddings query (no embeddings exist)
mock_existing = []
with patch(
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=mock_existing,
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].content_id == "block-uuid-1"
assert items[0].content_type == ContentType.BLOCK
assert "Calculator Block" in items[0].searchable_text
assert "Performs calculations" in items[0].searchable_text
assert "MATH" in items[0].searchable_text
assert "expression: Math expression" in items[0].searchable_text
assert items[0].user_id is None
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_stats(mocker):
"""Test BlockHandler returns correct stats."""
handler = BlockHandler()
# Mock get_blocks
mock_blocks = {
"block-1": MagicMock(),
"block-2": MagicMock(),
"block-3": MagicMock(),
}
# Mock embedded count query (2 blocks have embeddings)
mock_embedded = [{"count": 2}]
with patch(
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=mock_embedded,
):
stats = await handler.get_stats()
assert stats["total"] == 3
assert stats["with_embeddings"] == 2
assert stats["without_embeddings"] == 1
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
"""Test DocumentationHandler discovers docs without embeddings."""
handler = DocumentationHandler()
# Create temporary docs directory with test files
docs_root = tmp_path / "docs"
docs_root.mkdir()
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
# Mock _get_docs_root to return temp dir
with patch.object(handler, "_get_docs_root", return_value=docs_root):
# Mock existing embeddings query (no embeddings exist)
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 2
# Check guide.md
guide_item = next(
(item for item in items if item.content_id == "guide.md"), None
)
assert guide_item is not None
assert guide_item.content_type == ContentType.DOCUMENTATION
assert "Getting Started" in guide_item.searchable_text
assert "This is a guide" in guide_item.searchable_text
assert guide_item.metadata["title"] == "Getting Started"
assert guide_item.user_id is None
# Check api.mdx
api_item = next(
(item for item in items if item.content_id == "api.mdx"), None
)
assert api_item is not None
assert "API Reference" in api_item.searchable_text
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_get_stats(tmp_path, mocker):
"""Test DocumentationHandler returns correct stats."""
handler = DocumentationHandler()
# Create temporary docs directory
docs_root = tmp_path / "docs"
docs_root.mkdir()
(docs_root / "doc1.md").write_text("# Doc 1")
(docs_root / "doc2.md").write_text("# Doc 2")
(docs_root / "doc3.mdx").write_text("# Doc 3")
# Mock embedded count query (1 doc has embedding)
mock_embedded = [{"count": 1}]
with patch.object(handler, "_get_docs_root", return_value=docs_root):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=mock_embedded,
):
stats = await handler.get_stats()
assert stats["total"] == 3
assert stats["with_embeddings"] == 1
assert stats["without_embeddings"] == 2
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_title_extraction(tmp_path):
"""Test DocumentationHandler extracts title from markdown heading."""
handler = DocumentationHandler()
# Test with heading
doc_with_heading = tmp_path / "with_heading.md"
doc_with_heading.write_text("# My Title\n\nContent here")
title, content = handler._extract_title_and_content(doc_with_heading)
assert title == "My Title"
assert "# My Title" not in content
assert "Content here" in content
# Test without heading
doc_without_heading = tmp_path / "no-heading.md"
doc_without_heading.write_text("Just content, no heading")
title, content = handler._extract_title_and_content(doc_without_heading)
assert title == "No Heading" # Uses filename
assert "Just content" in content
@pytest.mark.asyncio(loop_scope="session")
async def test_content_handlers_registry():
"""Test all content types are registered."""
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
assert ContentType.BLOCK in CONTENT_HANDLERS
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_handles_missing_attributes():
"""Test BlockHandler gracefully handles blocks with missing attributes."""
handler = BlockHandler()
# Mock block with minimal attributes
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Minimal Block"
# No description, categories, or schema
del mock_block_instance.description
del mock_block_instance.categories
del mock_block_instance.input_schema
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-minimal": mock_block_class}
with patch(
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].searchable_text == "Minimal Block"
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_skips_failed_blocks():
"""Test BlockHandler skips blocks that fail to instantiate."""
handler = BlockHandler()
# Mock one good block and one bad block
good_block = MagicMock()
good_instance = MagicMock()
good_instance.name = "Good Block"
good_instance.description = "Works fine"
good_instance.categories = []
good_block.return_value = good_instance
bad_block = MagicMock()
bad_block.side_effect = Exception("Instantiation failed")
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
with patch(
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
# Should only get the good block
assert len(items) == 1
assert items[0].content_id == "good-block"
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_missing_docs_directory():
"""Test DocumentationHandler handles missing docs directory gracefully."""
handler = DocumentationHandler()
# Mock _get_docs_root to return non-existent path
fake_path = Path("/nonexistent/docs")
with patch.object(handler, "_get_docs_root", return_value=fake_path):
items = await handler.get_missing_items(batch_size=10)
assert items == []
stats = await handler.get_stats()
assert stats["total"] == 0
assert stats["with_embeddings"] == 0
assert stats["without_embeddings"] == 0

View File

@@ -1,7 +1,8 @@
import asyncio
import logging
import typing
from datetime import datetime, timezone
from typing import Any, Literal
from typing import Literal
import fastapi
import prisma.enums
@@ -9,7 +10,7 @@ import prisma.errors
import prisma.models
import prisma.types
from backend.data.db import transaction
from backend.data.db import query_raw_with_schema, transaction
from backend.data.graph import (
GraphMeta,
GraphModel,
@@ -29,8 +30,6 @@ from backend.util.settings import Settings
from . import exceptions as store_exceptions
from . import model as store_model
from .embeddings import ensure_embedding
from .hybrid_search import hybrid_search
logger = logging.getLogger(__name__)
settings = Settings()
@@ -51,77 +50,128 @@ async def get_store_agents(
page_size: int = 20,
) -> store_model.StoreAgentsResponse:
"""
Get PUBLIC store agents from the StoreAgent view.
Search behavior:
- With search_query: Uses hybrid search (semantic + lexical)
- Fallback: If embeddings unavailable, gracefully degrades to lexical-only
- Rationale: User-facing endpoint prioritizes availability over accuracy
Note: Admin operations (approval) use fail-fast to prevent inconsistent state.
Get PUBLIC store agents from the StoreAgent view
"""
logger.debug(
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
)
search_used_hybrid = False
store_agents: list[store_model.StoreAgent] = []
agents: list[dict[str, Any]] = []
total = 0
total_pages = 0
try:
# If search_query is provided, use hybrid search (embeddings + tsvector)
# If search_query is provided, use full-text search
if search_query:
# Try hybrid search combining semantic and lexical signals
# Falls back to lexical-only if OpenAI unavailable (user-facing, high SLA)
try:
agents, total = await hybrid_search(
query=search_query,
featured=featured,
creators=creators,
category=category,
sorted_by="relevance", # Use hybrid scoring for relevance
page=page,
page_size=page_size,
)
search_used_hybrid = True
except Exception as e:
# Log error but fall back to lexical search for better UX
logger.error(
f"Hybrid search failed (likely OpenAI unavailable), "
f"falling back to lexical search: {e}"
)
# search_used_hybrid remains False, will use fallback path below
offset = (page - 1) * page_size
# Convert hybrid search results (dict format) if hybrid succeeded
if search_used_hybrid:
total_pages = (total + page_size - 1) // page_size
store_agents: list[store_model.StoreAgent] = []
for agent in agents:
try:
store_agent = store_model.StoreAgent(
slug=agent["slug"],
agent_name=agent["agent_name"],
agent_image=(
agent["agent_image"][0] if agent["agent_image"] else ""
),
creator=agent["creator_username"] or "Needs Profile",
creator_avatar=agent["creator_avatar"] or "",
sub_heading=agent["sub_heading"],
description=agent["description"],
runs=agent["runs"],
rating=agent["rating"],
)
store_agents.append(store_agent)
except Exception as e:
logger.error(
f"Error parsing Store agent from hybrid search results: {e}"
)
continue
# Whitelist allowed order_by columns
ALLOWED_ORDER_BY = {
"rating": "rating DESC, rank DESC",
"runs": "runs DESC, rank DESC",
"name": "agent_name ASC, rank ASC",
"updated_at": "updated_at DESC, rank DESC",
}
if not search_used_hybrid:
# Fallback path - use basic search or no search
# Validate and get order clause
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
else:
order_by_clause = "updated_at DESC, rank DESC"
# Build WHERE conditions and parameters list
where_parts: list[str] = []
params: list[typing.Any] = [search_query] # $1 - search term
param_index = 2 # Start at $2 for next parameter
# Always filter for available agents
where_parts.append("is_available = true")
if featured:
where_parts.append("featured = true")
if creators and creators:
# Use ANY with array parameter
where_parts.append(f"creator_username = ANY(${param_index})")
params.append(creators)
param_index += 1
if category and category:
where_parts.append(f"${param_index} = ANY(categories)")
params.append(category)
param_index += 1
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
# Add pagination params
params.extend([page_size, offset])
limit_param = f"${param_index}"
offset_param = f"${param_index + 1}"
# Execute full-text search query with parameterized values
sql_query = f"""
SELECT
slug,
agent_name,
agent_image,
creator_username,
creator_avatar,
sub_heading,
description,
runs,
rating,
categories,
featured,
is_available,
updated_at,
ts_rank_cd(search, query) AS rank
FROM {{schema_prefix}}"StoreAgent",
plainto_tsquery('english', $1) AS query
WHERE {sql_where_clause}
AND search @@ query
ORDER BY {order_by_clause}
LIMIT {limit_param} OFFSET {offset_param}
"""
# Count query for pagination - only uses search term parameter
count_query = f"""
SELECT COUNT(*) as count
FROM {{schema_prefix}}"StoreAgent",
plainto_tsquery('english', $1) AS query
WHERE {sql_where_clause}
AND search @@ query
"""
# Execute both queries with parameters
agents = await query_raw_with_schema(sql_query, *params)
# For count, use params without pagination (last 2 params)
count_params = params[:-2]
count_result = await query_raw_with_schema(count_query, *count_params)
total = count_result[0]["count"] if count_result else 0
total_pages = (total + page_size - 1) // page_size
# Convert raw results to StoreAgent models
store_agents: list[store_model.StoreAgent] = []
for agent in agents:
try:
store_agent = store_model.StoreAgent(
slug=agent["slug"],
agent_name=agent["agent_name"],
agent_image=(
agent["agent_image"][0] if agent["agent_image"] else ""
),
creator=agent["creator_username"] or "Needs Profile",
creator_avatar=agent["creator_avatar"] or "",
sub_heading=agent["sub_heading"],
description=agent["description"],
runs=agent["runs"],
rating=agent["rating"],
)
store_agents.append(store_agent)
except Exception as e:
logger.error(f"Error parsing Store agent from search results: {e}")
continue
else:
# Non-search query path (original logic)
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
if featured:
where_clause["featured"] = featured
@@ -130,14 +180,6 @@ async def get_store_agents(
if category:
where_clause["categories"] = {"has": category}
# Add basic text search if search_query provided but hybrid failed
if search_query:
where_clause["OR"] = [
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
{"description": {"contains": search_query, "mode": "insensitive"}},
]
order_by = []
if sorted_by == "rating":
order_by.append({"rating": "desc"})
@@ -146,7 +188,7 @@ async def get_store_agents(
elif sorted_by == "name":
order_by.append({"agent_name": "asc"})
db_agents = await prisma.models.StoreAgent.prisma().find_many(
agents = await prisma.models.StoreAgent.prisma().find_many(
where=where_clause,
order=order_by,
skip=(page - 1) * page_size,
@@ -157,7 +199,7 @@ async def get_store_agents(
total_pages = (total + page_size - 1) // page_size
store_agents: list[store_model.StoreAgent] = []
for agent in db_agents:
for agent in agents:
try:
# Create the StoreAgent object safely
store_agent = store_model.StoreAgent(
@@ -1535,7 +1577,7 @@ async def review_store_submission(
)
# Update the AgentGraph with store listing data
await prisma.models.AgentGraph.prisma(tx).update(
await prisma.models.AgentGraph.prisma().update(
where={
"graphVersionId": {
"id": store_listing_version.agentGraphId,
@@ -1550,23 +1592,6 @@ async def review_store_submission(
},
)
# Generate embedding for approved listing (blocking - admin operation)
# Inside transaction: if embedding fails, entire transaction rolls back
embedding_success = await ensure_embedding(
version_id=store_listing_version_id,
name=store_listing_version.name,
description=store_listing_version.description,
sub_heading=store_listing_version.subHeading,
categories=store_listing_version.categories or [],
tx=tx,
)
if not embedding_success:
raise ValueError(
f"Failed to generate embedding for listing {store_listing_version_id}. "
"This is likely due to OpenAI API being unavailable. "
"Please try again later or contact support if the issue persists."
)
await prisma.models.StoreListing.prisma(tx).update(
where={"id": store_listing_version.StoreListing.id},
data={

View File

@@ -1,962 +0,0 @@
"""
Unified Content Embeddings Service
Handles generation and storage of OpenAI embeddings for all content types
(store listings, blocks, documentation, library agents) to enable semantic/hybrid search.
"""
import asyncio
import logging
import time
from typing import Any
import prisma
from prisma.enums import ContentType
from tiktoken import encoding_for_model
from backend.api.features.store.content_handlers import CONTENT_HANDLERS
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
from backend.util.clients import get_openai_client
from backend.util.json import dumps
logger = logging.getLogger(__name__)
# OpenAI embedding model configuration
EMBEDDING_MODEL = "text-embedding-3-small"
# Embedding dimension for the model above
# text-embedding-3-small: 1536, text-embedding-3-large: 3072
EMBEDDING_DIM = 1536
# OpenAI embedding token limit (8,191 with 1 token buffer for safety)
EMBEDDING_MAX_TOKENS = 8191
def build_searchable_text(
name: str,
description: str,
sub_heading: str,
categories: list[str],
) -> str:
"""
Build searchable text from listing version fields.
Combines relevant fields into a single string for embedding.
"""
parts = []
# Name is important - include it
if name:
parts.append(name)
# Sub-heading provides context
if sub_heading:
parts.append(sub_heading)
# Description is the main content
if description:
parts.append(description)
# Categories help with semantic matching
if categories:
parts.append(" ".join(categories))
return " ".join(parts)
async def generate_embedding(text: str) -> list[float] | None:
"""
Generate embedding for text using OpenAI API.
Returns None if embedding generation fails.
Fail-fast: no retries to maintain consistency with approval flow.
"""
try:
client = get_openai_client()
if not client:
logger.error("openai_internal_api_key not set, cannot generate embedding")
return None
# Truncate text to token limit using tiktoken
# Character-based truncation is insufficient because token ratios vary by content type
enc = encoding_for_model(EMBEDDING_MODEL)
tokens = enc.encode(text)
if len(tokens) > EMBEDDING_MAX_TOKENS:
tokens = tokens[:EMBEDDING_MAX_TOKENS]
truncated_text = enc.decode(tokens)
logger.info(
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
)
else:
truncated_text = text
start_time = time.time()
response = await client.embeddings.create(
model=EMBEDDING_MODEL,
input=truncated_text,
)
latency_ms = (time.time() - start_time) * 1000
embedding = response.data[0].embedding
logger.info(
f"Generated embedding: {len(embedding)} dims, "
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
)
return embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
return None
async def store_embedding(
version_id: str,
embedding: list[float],
tx: prisma.Prisma | None = None,
) -> bool:
"""
Store embedding in the database.
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
DEPRECATED: Use ensure_embedding() instead (includes searchable_text).
"""
return await store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=version_id,
embedding=embedding,
searchable_text="", # Empty for backward compat; ensure_embedding() populates this
metadata=None,
user_id=None, # Store agents are public
tx=tx,
)
async def store_content_embedding(
content_type: ContentType,
content_id: str,
embedding: list[float],
searchable_text: str,
metadata: dict | None = None,
user_id: str | None = None,
tx: prisma.Prisma | None = None,
) -> bool:
"""
Store embedding in the unified content embeddings table.
New function for unified content embedding storage.
Uses raw SQL since Prisma doesn't natively support pgvector.
"""
try:
client = tx if tx else prisma.get_client()
# Convert embedding to PostgreSQL vector format
embedding_str = embedding_to_vector_string(embedding)
metadata_json = dumps(metadata or {})
# Upsert the embedding
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
await execute_raw_with_schema(
"""
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
)
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
ON CONFLICT ("contentType", "contentId", "userId")
DO UPDATE SET
"embedding" = $4::vector,
"searchableText" = $5,
"metadata" = $6::jsonb,
"updatedAt" = NOW()
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
""",
content_type,
content_id,
user_id,
embedding_str,
searchable_text,
metadata_json,
client=client,
set_public_search_path=True,
)
logger.info(f"Stored embedding for {content_type}:{content_id}")
return True
except Exception as e:
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
return False
async def get_embedding(version_id: str) -> dict[str, Any] | None:
"""
Retrieve embedding record for a listing version.
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
Returns dict with storeListingVersionId, embedding, timestamps or None if not found.
"""
result = await get_content_embedding(
ContentType.STORE_AGENT, version_id, user_id=None
)
if result:
# Transform to old format for backward compatibility
return {
"storeListingVersionId": result["contentId"],
"embedding": result["embedding"],
"createdAt": result["createdAt"],
"updatedAt": result["updatedAt"],
}
return None
async def get_content_embedding(
content_type: ContentType, content_id: str, user_id: str | None = None
) -> dict[str, Any] | None:
"""
Retrieve embedding record for any content type.
New function for unified content embedding retrieval.
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
"""
try:
result = await query_raw_with_schema(
"""
SELECT
"contentType",
"contentId",
"userId",
"embedding"::text as "embedding",
"searchableText",
"metadata",
"createdAt",
"updatedAt"
FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
set_public_search_path=True,
)
if result and len(result) > 0:
return result[0]
return None
except Exception as e:
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
return None
async def ensure_embedding(
version_id: str,
name: str,
description: str,
sub_heading: str,
categories: list[str],
force: bool = False,
tx: prisma.Prisma | None = None,
) -> bool:
"""
Ensure an embedding exists for the listing version.
Creates embedding if missing. Use force=True to regenerate.
Backward-compatible wrapper for store listings.
Args:
version_id: The StoreListingVersion ID
name: Agent name
description: Agent description
sub_heading: Agent sub-heading
categories: Agent categories
force: Force regeneration even if embedding exists
tx: Optional transaction client
Returns:
True if embedding exists/was created, False on failure
"""
try:
# Check if embedding already exists
if not force:
existing = await get_embedding(version_id)
if existing and existing.get("embedding"):
logger.debug(f"Embedding for version {version_id} already exists")
return True
# Build searchable text for embedding
searchable_text = build_searchable_text(
name, description, sub_heading, categories
)
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(f"Could not generate embedding for version {version_id}")
return False
# Store the embedding with metadata using new function
metadata = {
"name": name,
"subHeading": sub_heading,
"categories": categories,
}
return await store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=version_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata,
user_id=None, # Store agents are public
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
return False
async def delete_embedding(version_id: str) -> bool:
"""
Delete embedding for a listing version.
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
Note: This is usually handled automatically by CASCADE delete,
but provided for manual cleanup if needed.
"""
return await delete_content_embedding(ContentType.STORE_AGENT, version_id)
async def delete_content_embedding(
content_type: ContentType, content_id: str, user_id: str | None = None
) -> bool:
"""
Delete embedding for any content type.
New function for unified content embedding deletion.
Note: This is usually handled automatically by CASCADE delete,
but provided for manual cleanup if needed.
Args:
content_type: The type of content (STORE_AGENT, LIBRARY_AGENT, etc.)
content_id: The unique identifier for the content
user_id: Optional user ID. For public content (STORE_AGENT, BLOCK), pass None.
For user-scoped content (LIBRARY_AGENT), pass the user's ID to avoid
deleting embeddings belonging to other users.
Returns:
True if deletion succeeded, False otherwise
"""
try:
client = prisma.get_client()
await execute_raw_with_schema(
"""
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType"
AND "contentId" = $2
AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
client=client,
)
user_str = f" (user: {user_id})" if user_id else ""
logger.info(f"Deleted embedding for {content_type}:{content_id}{user_str}")
return True
except Exception as e:
logger.error(f"Failed to delete embedding for {content_type}:{content_id}: {e}")
return False
async def get_embedding_stats() -> dict[str, Any]:
"""
Get statistics about embedding coverage for all content types.
Returns stats per content type and overall totals.
"""
try:
stats_by_type = {}
total_items = 0
total_with_embeddings = 0
total_without_embeddings = 0
# Aggregate stats from all handlers
for content_type, handler in CONTENT_HANDLERS.items():
try:
stats = await handler.get_stats()
stats_by_type[content_type.value] = {
"total": stats["total"],
"with_embeddings": stats["with_embeddings"],
"without_embeddings": stats["without_embeddings"],
"coverage_percent": (
round(stats["with_embeddings"] / stats["total"] * 100, 1)
if stats["total"] > 0
else 0
),
}
total_items += stats["total"]
total_with_embeddings += stats["with_embeddings"]
total_without_embeddings += stats["without_embeddings"]
except Exception as e:
logger.error(f"Failed to get stats for {content_type.value}: {e}")
stats_by_type[content_type.value] = {
"total": 0,
"with_embeddings": 0,
"without_embeddings": 0,
"coverage_percent": 0,
"error": str(e),
}
return {
"by_type": stats_by_type,
"totals": {
"total": total_items,
"with_embeddings": total_with_embeddings,
"without_embeddings": total_without_embeddings,
"coverage_percent": (
round(total_with_embeddings / total_items * 100, 1)
if total_items > 0
else 0
),
},
}
except Exception as e:
logger.error(f"Failed to get embedding stats: {e}")
return {
"by_type": {},
"totals": {
"total": 0,
"with_embeddings": 0,
"without_embeddings": 0,
"coverage_percent": 0,
},
"error": str(e),
}
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
"""
Generate embeddings for approved listings that don't have them.
BACKWARD COMPATIBILITY: Maintained for existing usage.
This now delegates to backfill_all_content_types() to process all content types.
Args:
batch_size: Number of embeddings to generate per content type
Returns:
Dict with success/failure counts aggregated across all content types
"""
# Delegate to the new generic backfill system
result = await backfill_all_content_types(batch_size)
# Return in the old format for backward compatibility
return result["totals"]
async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
"""
Generate embeddings for all content types using registered handlers.
Processes content types in order: BLOCK → STORE_AGENT → DOCUMENTATION.
This ensures foundational content (blocks) are searchable first.
Args:
batch_size: Number of embeddings to generate per content type
Returns:
Dict with stats per content type and overall totals
"""
results_by_type = {}
total_processed = 0
total_success = 0
total_failed = 0
# Process content types in explicit order
processing_order = [
ContentType.BLOCK,
ContentType.STORE_AGENT,
ContentType.DOCUMENTATION,
]
for content_type in processing_order:
handler = CONTENT_HANDLERS.get(content_type)
if not handler:
logger.warning(f"No handler registered for {content_type.value}")
continue
try:
logger.info(f"Processing {content_type.value} content type...")
# Get missing items from handler
missing_items = await handler.get_missing_items(batch_size)
if not missing_items:
results_by_type[content_type.value] = {
"processed": 0,
"success": 0,
"failed": 0,
"message": "No missing embeddings",
}
continue
# Process embeddings concurrently for better performance
embedding_tasks = [
ensure_content_embedding(
content_type=item.content_type,
content_id=item.content_id,
searchable_text=item.searchable_text,
metadata=item.metadata,
user_id=item.user_id,
)
for item in missing_items
]
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
success = sum(1 for result in results if result is True)
failed = len(results) - success
results_by_type[content_type.value] = {
"processed": len(missing_items),
"success": success,
"failed": failed,
"message": f"Backfilled {success} embeddings, {failed} failed",
}
total_processed += len(missing_items)
total_success += success
total_failed += failed
logger.info(
f"{content_type.value}: processed {len(missing_items)}, "
f"success {success}, failed {failed}"
)
except Exception as e:
logger.error(f"Failed to process {content_type.value}: {e}")
results_by_type[content_type.value] = {
"processed": 0,
"success": 0,
"failed": 0,
"error": str(e),
}
return {
"by_type": results_by_type,
"totals": {
"processed": total_processed,
"success": total_success,
"failed": total_failed,
"message": f"Overall: {total_success} succeeded, {total_failed} failed",
},
}
async def embed_query(query: str) -> list[float] | None:
"""
Generate embedding for a search query.
Same as generate_embedding but with clearer intent.
"""
return await generate_embedding(query)
def embedding_to_vector_string(embedding: list[float]) -> str:
"""Convert embedding list to PostgreSQL vector string format."""
return "[" + ",".join(str(x) for x in embedding) + "]"
async def ensure_content_embedding(
content_type: ContentType,
content_id: str,
searchable_text: str,
metadata: dict | None = None,
user_id: str | None = None,
force: bool = False,
tx: prisma.Prisma | None = None,
) -> bool:
"""
Ensure an embedding exists for any content type.
Generic function for creating embeddings for store agents, blocks, docs, etc.
Args:
content_type: ContentType enum value (STORE_AGENT, BLOCK, etc.)
content_id: Unique identifier for the content
searchable_text: Combined text for embedding generation
metadata: Optional metadata to store with embedding
force: Force regeneration even if embedding exists
tx: Optional transaction client
Returns:
True if embedding exists/was created, False on failure
"""
try:
# Check if embedding already exists
if not force:
existing = await get_content_embedding(content_type, content_id, user_id)
if existing and existing.get("embedding"):
logger.debug(
f"Embedding for {content_type}:{content_id} already exists"
)
return True
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(
f"Could not generate embedding for {content_type}:{content_id}"
)
return False
# Store the embedding
return await store_content_embedding(
content_type=content_type,
content_id=content_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata or {},
user_id=user_id,
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
return False
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
"""
Clean up embeddings for content that no longer exists or is no longer valid.
Compares current content with embeddings in database and removes orphaned records:
- STORE_AGENT: Removes embeddings for rejected/deleted store listings
- BLOCK: Removes embeddings for blocks no longer registered
- DOCUMENTATION: Removes embeddings for deleted doc files
Returns:
Dict with cleanup statistics per content type
"""
results_by_type = {}
total_deleted = 0
# Cleanup orphaned embeddings for all content types
cleanup_types = [
ContentType.STORE_AGENT,
ContentType.BLOCK,
ContentType.DOCUMENTATION,
]
for content_type in cleanup_types:
try:
handler = CONTENT_HANDLERS.get(content_type)
if not handler:
logger.warning(f"No handler registered for {content_type}")
results_by_type[content_type.value] = {
"deleted": 0,
"error": "No handler registered",
}
continue
# Get all current content IDs from handler
if content_type == ContentType.STORE_AGENT:
# Get IDs of approved store listing versions from non-deleted listings
valid_agents = await query_raw_with_schema(
"""
SELECT slv.id
FROM {schema_prefix}"StoreListingVersion" slv
JOIN {schema_prefix}"StoreListing" sl ON slv."storeListingId" = sl.id
WHERE slv."submissionStatus" = 'APPROVED'
AND slv."isDeleted" = false
AND sl."isDeleted" = false
""",
)
current_ids = {row["id"] for row in valid_agents}
elif content_type == ContentType.BLOCK:
from backend.data.block import get_blocks
current_ids = set(get_blocks().keys())
elif content_type == ContentType.DOCUMENTATION:
from pathlib import Path
# embeddings.py is at: backend/backend/api/features/store/embeddings.py
# Need to go up to project root then into docs/
this_file = Path(__file__)
project_root = (
this_file.parent.parent.parent.parent.parent.parent.parent
)
docs_root = project_root / "docs"
if docs_root.exists():
all_docs = list(docs_root.rglob("*.md")) + list(
docs_root.rglob("*.mdx")
)
current_ids = {str(doc.relative_to(docs_root)) for doc in all_docs}
else:
current_ids = set()
else:
# Skip unknown content types to avoid accidental deletion
logger.warning(
f"Skipping cleanup for unknown content type: {content_type}"
)
results_by_type[content_type.value] = {
"deleted": 0,
"error": "Unknown content type - skipped for safety",
}
continue
# Get all embedding IDs from database
db_embeddings = await query_raw_with_schema(
"""
SELECT "contentId"
FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType"
""",
content_type,
)
db_ids = {row["contentId"] for row in db_embeddings}
# Find orphaned embeddings (in DB but not in current content)
orphaned_ids = db_ids - current_ids
if not orphaned_ids:
logger.info(f"{content_type.value}: No orphaned embeddings found")
results_by_type[content_type.value] = {
"deleted": 0,
"message": "No orphaned embeddings",
}
continue
# Delete orphaned embeddings in batch for better performance
orphaned_list = list(orphaned_ids)
try:
await execute_raw_with_schema(
"""
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType"
AND "contentId" = ANY($2::text[])
""",
content_type,
orphaned_list,
)
deleted = len(orphaned_list)
except Exception as e:
logger.error(f"Failed to batch delete orphaned embeddings: {e}")
deleted = 0
logger.info(
f"{content_type.value}: Deleted {deleted}/{len(orphaned_ids)} orphaned embeddings"
)
results_by_type[content_type.value] = {
"deleted": deleted,
"orphaned": len(orphaned_ids),
"message": f"Deleted {deleted} orphaned embeddings",
}
total_deleted += deleted
except Exception as e:
logger.error(f"Failed to cleanup {content_type.value}: {e}")
results_by_type[content_type.value] = {
"deleted": 0,
"error": str(e),
}
return {
"by_type": results_by_type,
"totals": {
"deleted": total_deleted,
"message": f"Deleted {total_deleted} orphaned embeddings",
},
}
async def semantic_search(
query: str,
content_types: list[ContentType] | None = None,
user_id: str | None = None,
limit: int = 20,
min_similarity: float = 0.5,
) -> list[dict[str, Any]]:
"""
Semantic search across content types using embeddings.
Performs vector similarity search on UnifiedContentEmbedding table.
Used directly for blocks/docs/library agents, or as the semantic component
within hybrid_search for store agents.
If embedding generation fails, falls back to lexical search on searchableText.
Args:
query: Search query string
content_types: List of ContentType to search. Defaults to [BLOCK, STORE_AGENT, DOCUMENTATION]
user_id: Optional user ID for searching private content (library agents)
limit: Maximum number of results to return (default: 20)
min_similarity: Minimum cosine similarity threshold (0-1, default: 0.5)
Returns:
List of search results with the following structure:
[
{
"content_id": str,
"content_type": str, # "BLOCK", "STORE_AGENT", "DOCUMENTATION", or "LIBRARY_AGENT"
"searchable_text": str,
"metadata": dict,
"similarity": float, # Cosine similarity score (0-1)
},
...
]
Examples:
# Search blocks only
results = await semantic_search("calculate", content_types=[ContentType.BLOCK])
# Search blocks and documentation
results = await semantic_search(
"how to use API",
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION]
)
# Search all public content (default)
results = await semantic_search("AI agent")
# Search user's library agents
results = await semantic_search(
"my custom agent",
content_types=[ContentType.LIBRARY_AGENT],
user_id="user123"
)
"""
# Default to searching all public content types
if content_types is None:
content_types = [
ContentType.BLOCK,
ContentType.STORE_AGENT,
ContentType.DOCUMENTATION,
]
# Validate inputs
if not content_types:
return [] # Empty content_types would cause invalid SQL (IN ())
query = query.strip()
if not query:
return []
if limit < 1:
limit = 1
if limit > 100:
limit = 100
# Generate query embedding
query_embedding = await embed_query(query)
if query_embedding is not None:
# Semantic search with embeddings
embedding_str = embedding_to_vector_string(query_embedding)
# Build params in order: limit, then user_id (if provided), then content types
params: list[Any] = [limit]
user_filter = ""
if user_id is not None:
user_filter = 'AND "userId" = ${}'.format(len(params) + 1)
params.append(user_id)
# Add content type parameters and build placeholders dynamically
content_type_start_idx = len(params) + 1
content_type_placeholders = ", ".join(
f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"'
for i in range(len(content_types))
)
params.extend([ct.value for ct in content_types])
sql = f"""
SELECT
"contentId" as content_id,
"contentType" as content_type,
"searchableText" as searchable_text,
metadata,
1 - (embedding <=> '{embedding_str}'::vector) as similarity
FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding"
WHERE "contentType" IN ({content_type_placeholders})
{user_filter}
AND 1 - (embedding <=> '{embedding_str}'::vector) >= ${len(params) + 1}
ORDER BY similarity DESC
LIMIT $1
"""
params.append(min_similarity)
try:
results = await query_raw_with_schema(
sql, *params, set_public_search_path=True
)
return [
{
"content_id": row["content_id"],
"content_type": row["content_type"],
"searchable_text": row["searchable_text"],
"metadata": row["metadata"],
"similarity": float(row["similarity"]),
}
for row in results
]
except Exception as e:
logger.error(f"Semantic search failed: {e}")
# Fall through to lexical search below
# Fallback to lexical search if embeddings unavailable
logger.warning("Falling back to lexical search (embeddings unavailable)")
params_lexical: list[Any] = [limit]
user_filter = ""
if user_id is not None:
user_filter = 'AND "userId" = ${}'.format(len(params_lexical) + 1)
params_lexical.append(user_id)
# Add content type parameters and build placeholders dynamically
content_type_start_idx = len(params_lexical) + 1
content_type_placeholders_lexical = ", ".join(
f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"'
for i in range(len(content_types))
)
params_lexical.extend([ct.value for ct in content_types])
sql_lexical = f"""
SELECT
"contentId" as content_id,
"contentType" as content_type,
"searchableText" as searchable_text,
metadata,
0.0 as similarity
FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding"
WHERE "contentType" IN ({content_type_placeholders_lexical})
{user_filter}
AND "searchableText" ILIKE ${len(params_lexical) + 1}
ORDER BY "updatedAt" DESC
LIMIT $1
"""
params_lexical.append(f"%{query}%")
try:
results = await query_raw_with_schema(
sql_lexical, *params_lexical, set_public_search_path=True
)
return [
{
"content_id": row["content_id"],
"content_type": row["content_type"],
"searchable_text": row["searchable_text"],
"metadata": row["metadata"],
"similarity": 0.0, # Lexical search doesn't provide similarity
}
for row in results
]
except Exception as e:
logger.error(f"Lexical search failed: {e}")
return []

View File

@@ -1,666 +0,0 @@
"""
End-to-end database tests for embeddings and hybrid search.
These tests hit the actual database to verify SQL queries work correctly.
Tests cover:
1. Embedding storage (store_content_embedding)
2. Embedding retrieval (get_content_embedding)
3. Embedding deletion (delete_content_embedding)
4. Unified hybrid search across content types
5. Store agent hybrid search
"""
import uuid
from typing import AsyncGenerator
import pytest
from prisma.enums import ContentType
from backend.api.features.store import embeddings
from backend.api.features.store.embeddings import EMBEDDING_DIM
from backend.api.features.store.hybrid_search import (
hybrid_search,
unified_hybrid_search,
)
# ============================================================================
# Test Fixtures
# ============================================================================
@pytest.fixture
def test_content_id() -> str:
"""Generate unique content ID for test isolation."""
return f"test-content-{uuid.uuid4()}"
@pytest.fixture
def test_user_id() -> str:
"""Generate unique user ID for test isolation."""
return f"test-user-{uuid.uuid4()}"
@pytest.fixture
def mock_embedding() -> list[float]:
"""Generate a mock embedding vector."""
# Create a normalized embedding vector
import math
raw = [float(i % 10) / 10.0 for i in range(EMBEDDING_DIM)]
# Normalize to unit length (required for cosine similarity)
magnitude = math.sqrt(sum(x * x for x in raw))
return [x / magnitude for x in raw]
@pytest.fixture
def similar_embedding() -> list[float]:
"""Generate an embedding similar to mock_embedding."""
import math
# Similar but slightly different values
raw = [float(i % 10) / 10.0 + 0.01 for i in range(EMBEDDING_DIM)]
magnitude = math.sqrt(sum(x * x for x in raw))
return [x / magnitude for x in raw]
@pytest.fixture
def different_embedding() -> list[float]:
"""Generate an embedding very different from mock_embedding."""
import math
# Reversed pattern to be maximally different
raw = [float((EMBEDDING_DIM - i) % 10) / 10.0 for i in range(EMBEDDING_DIM)]
magnitude = math.sqrt(sum(x * x for x in raw))
return [x / magnitude for x in raw]
@pytest.fixture
async def cleanup_embeddings(
server,
) -> AsyncGenerator[list[tuple[ContentType, str, str | None]], None]:
"""
Fixture that tracks created embeddings and cleans them up after tests.
Yields a list to which tests can append (content_type, content_id, user_id) tuples.
"""
created_embeddings: list[tuple[ContentType, str, str | None]] = []
yield created_embeddings
# Cleanup all created embeddings
for content_type, content_id, user_id in created_embeddings:
try:
await embeddings.delete_content_embedding(content_type, content_id, user_id)
except Exception:
pass # Ignore cleanup errors
# ============================================================================
# store_content_embedding Tests
# ============================================================================
@pytest.mark.asyncio(loop_scope="session")
async def test_store_content_embedding_store_agent(
server,
test_content_id: str,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test storing embedding for STORE_AGENT content type."""
# Track for cleanup
cleanup_embeddings.append((ContentType.STORE_AGENT, test_content_id, None))
result = await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="AI assistant for productivity tasks",
metadata={"name": "Test Agent", "categories": ["productivity"]},
user_id=None, # Store agents are public
)
assert result is True
# Verify it was stored
stored = await embeddings.get_content_embedding(
ContentType.STORE_AGENT, test_content_id, user_id=None
)
assert stored is not None
assert stored["contentId"] == test_content_id
assert stored["contentType"] == "STORE_AGENT"
assert stored["searchableText"] == "AI assistant for productivity tasks"
@pytest.mark.asyncio(loop_scope="session")
async def test_store_content_embedding_block(
server,
test_content_id: str,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test storing embedding for BLOCK content type."""
cleanup_embeddings.append((ContentType.BLOCK, test_content_id, None))
result = await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="HTTP request block for API calls",
metadata={"name": "HTTP Request Block"},
user_id=None, # Blocks are public
)
assert result is True
stored = await embeddings.get_content_embedding(
ContentType.BLOCK, test_content_id, user_id=None
)
assert stored is not None
assert stored["contentType"] == "BLOCK"
@pytest.mark.asyncio(loop_scope="session")
async def test_store_content_embedding_documentation(
server,
test_content_id: str,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test storing embedding for DOCUMENTATION content type."""
cleanup_embeddings.append((ContentType.DOCUMENTATION, test_content_id, None))
result = await embeddings.store_content_embedding(
content_type=ContentType.DOCUMENTATION,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="Getting started guide for AutoGPT platform",
metadata={"title": "Getting Started", "url": "/docs/getting-started"},
user_id=None, # Docs are public
)
assert result is True
stored = await embeddings.get_content_embedding(
ContentType.DOCUMENTATION, test_content_id, user_id=None
)
assert stored is not None
assert stored["contentType"] == "DOCUMENTATION"
@pytest.mark.asyncio(loop_scope="session")
async def test_store_content_embedding_upsert(
server,
test_content_id: str,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test that storing embedding twice updates instead of duplicates."""
cleanup_embeddings.append((ContentType.BLOCK, test_content_id, None))
# Store first time
result1 = await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="Original text",
metadata={"version": 1},
user_id=None,
)
assert result1 is True
# Store again with different text (upsert)
result2 = await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="Updated text",
metadata={"version": 2},
user_id=None,
)
assert result2 is True
# Verify only one record with updated text
stored = await embeddings.get_content_embedding(
ContentType.BLOCK, test_content_id, user_id=None
)
assert stored is not None
assert stored["searchableText"] == "Updated text"
# ============================================================================
# get_content_embedding Tests
# ============================================================================
@pytest.mark.asyncio(loop_scope="session")
async def test_get_content_embedding_not_found(server):
"""Test retrieving non-existent embedding returns None."""
result = await embeddings.get_content_embedding(
ContentType.STORE_AGENT, "non-existent-id", user_id=None
)
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_get_content_embedding_with_metadata(
server,
test_content_id: str,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test that metadata is correctly stored and retrieved."""
cleanup_embeddings.append((ContentType.STORE_AGENT, test_content_id, None))
metadata = {
"name": "Test Agent",
"subHeading": "A test agent",
"categories": ["ai", "productivity"],
"customField": 123,
}
await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="test",
metadata=metadata,
user_id=None,
)
stored = await embeddings.get_content_embedding(
ContentType.STORE_AGENT, test_content_id, user_id=None
)
assert stored is not None
assert stored["metadata"]["name"] == "Test Agent"
assert stored["metadata"]["categories"] == ["ai", "productivity"]
assert stored["metadata"]["customField"] == 123
# ============================================================================
# delete_content_embedding Tests
# ============================================================================
@pytest.mark.asyncio(loop_scope="session")
async def test_delete_content_embedding(
server,
test_content_id: str,
mock_embedding: list[float],
):
"""Test deleting embedding removes it from database."""
# Store embedding
await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="To be deleted",
metadata=None,
user_id=None,
)
# Verify it exists
stored = await embeddings.get_content_embedding(
ContentType.BLOCK, test_content_id, user_id=None
)
assert stored is not None
# Delete it
result = await embeddings.delete_content_embedding(
ContentType.BLOCK, test_content_id, user_id=None
)
assert result is True
# Verify it's gone
stored = await embeddings.get_content_embedding(
ContentType.BLOCK, test_content_id, user_id=None
)
assert stored is None
@pytest.mark.asyncio(loop_scope="session")
async def test_delete_content_embedding_not_found(server):
"""Test deleting non-existent embedding doesn't error."""
result = await embeddings.delete_content_embedding(
ContentType.BLOCK, "non-existent-id", user_id=None
)
# Should succeed even if nothing to delete
assert result is True
# ============================================================================
# unified_hybrid_search Tests
# ============================================================================
@pytest.mark.asyncio(loop_scope="session")
async def test_unified_hybrid_search_finds_matching_content(
server,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test unified search finds content matching the query."""
# Create unique content IDs
agent_id = f"test-agent-{uuid.uuid4()}"
block_id = f"test-block-{uuid.uuid4()}"
doc_id = f"test-doc-{uuid.uuid4()}"
cleanup_embeddings.append((ContentType.STORE_AGENT, agent_id, None))
cleanup_embeddings.append((ContentType.BLOCK, block_id, None))
cleanup_embeddings.append((ContentType.DOCUMENTATION, doc_id, None))
# Store embeddings for different content types
await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=agent_id,
embedding=mock_embedding,
searchable_text="AI writing assistant for blog posts",
metadata={"name": "Writing Assistant"},
user_id=None,
)
await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=block_id,
embedding=mock_embedding,
searchable_text="Text generation block for creative writing",
metadata={"name": "Text Generator"},
user_id=None,
)
await embeddings.store_content_embedding(
content_type=ContentType.DOCUMENTATION,
content_id=doc_id,
embedding=mock_embedding,
searchable_text="How to use writing blocks in AutoGPT",
metadata={"title": "Writing Guide"},
user_id=None,
)
# Search for "writing" - should find all three
results, total = await unified_hybrid_search(
query="writing",
page=1,
page_size=20,
)
# Should find at least our test content (may find others too)
content_ids = [r["content_id"] for r in results]
assert agent_id in content_ids or total >= 1 # Lexical search should find it
@pytest.mark.asyncio(loop_scope="session")
async def test_unified_hybrid_search_filter_by_content_type(
server,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test unified search can filter by content type."""
agent_id = f"test-agent-{uuid.uuid4()}"
block_id = f"test-block-{uuid.uuid4()}"
cleanup_embeddings.append((ContentType.STORE_AGENT, agent_id, None))
cleanup_embeddings.append((ContentType.BLOCK, block_id, None))
# Store both types with same searchable text
await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=agent_id,
embedding=mock_embedding,
searchable_text="unique_search_term_xyz123",
metadata={},
user_id=None,
)
await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=block_id,
embedding=mock_embedding,
searchable_text="unique_search_term_xyz123",
metadata={},
user_id=None,
)
# Search only for BLOCK type
results, total = await unified_hybrid_search(
query="unique_search_term_xyz123",
content_types=[ContentType.BLOCK],
page=1,
page_size=20,
)
# All results should be BLOCK type
for r in results:
assert r["content_type"] == "BLOCK"
@pytest.mark.asyncio(loop_scope="session")
async def test_unified_hybrid_search_empty_query(server):
"""Test unified search with empty query returns empty results."""
results, total = await unified_hybrid_search(
query="",
page=1,
page_size=20,
)
assert results == []
assert total == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_unified_hybrid_search_pagination(
server,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test unified search pagination works correctly."""
# Create multiple items
content_ids = []
for i in range(5):
content_id = f"test-pagination-{uuid.uuid4()}"
content_ids.append(content_id)
cleanup_embeddings.append((ContentType.BLOCK, content_id, None))
await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=content_id,
embedding=mock_embedding,
searchable_text=f"pagination test item number {i}",
metadata={"index": i},
user_id=None,
)
# Get first page
page1_results, total1 = await unified_hybrid_search(
query="pagination test",
content_types=[ContentType.BLOCK],
page=1,
page_size=2,
)
# Get second page
page2_results, total2 = await unified_hybrid_search(
query="pagination test",
content_types=[ContentType.BLOCK],
page=2,
page_size=2,
)
# Total should be consistent
assert total1 == total2
# Pages should have different content (if we have enough results)
if len(page1_results) > 0 and len(page2_results) > 0:
page1_ids = {r["content_id"] for r in page1_results}
page2_ids = {r["content_id"] for r in page2_results}
# No overlap between pages
assert page1_ids.isdisjoint(page2_ids)
@pytest.mark.asyncio(loop_scope="session")
async def test_unified_hybrid_search_min_score_filtering(
server,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test unified search respects min_score threshold."""
content_id = f"test-minscore-{uuid.uuid4()}"
cleanup_embeddings.append((ContentType.BLOCK, content_id, None))
await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=content_id,
embedding=mock_embedding,
searchable_text="completely unrelated content about bananas",
metadata={},
user_id=None,
)
# Search with very high min_score - should filter out low relevance
results_high, _ = await unified_hybrid_search(
query="quantum computing algorithms",
content_types=[ContentType.BLOCK],
min_score=0.9, # Very high threshold
page=1,
page_size=20,
)
# Search with low min_score
results_low, _ = await unified_hybrid_search(
query="quantum computing algorithms",
content_types=[ContentType.BLOCK],
min_score=0.01, # Very low threshold
page=1,
page_size=20,
)
# High threshold should have fewer or equal results
assert len(results_high) <= len(results_low)
# ============================================================================
# hybrid_search (Store Agents) Tests
# ============================================================================
@pytest.mark.asyncio(loop_scope="session")
async def test_hybrid_search_store_agents_sql_valid(server):
"""Test that hybrid_search SQL executes without errors."""
# This test verifies the SQL is syntactically correct
# even if no results are found
results, total = await hybrid_search(
query="test agent",
page=1,
page_size=20,
)
# Should not raise - verifies SQL is valid
assert isinstance(results, list)
assert isinstance(total, int)
assert total >= 0
@pytest.mark.asyncio(loop_scope="session")
async def test_hybrid_search_with_filters(server):
"""Test hybrid_search with various filter options."""
# Test with all filter types
results, total = await hybrid_search(
query="productivity",
featured=True,
creators=["test-creator"],
category="productivity",
page=1,
page_size=10,
)
# Should not raise - verifies filter SQL is valid
assert isinstance(results, list)
assert isinstance(total, int)
@pytest.mark.asyncio(loop_scope="session")
async def test_hybrid_search_pagination(server):
"""Test hybrid_search pagination."""
# Page 1
results1, total1 = await hybrid_search(
query="agent",
page=1,
page_size=5,
)
# Page 2
results2, total2 = await hybrid_search(
query="agent",
page=2,
page_size=5,
)
# Verify SQL executes without error
assert isinstance(results1, list)
assert isinstance(results2, list)
assert isinstance(total1, int)
assert isinstance(total2, int)
# If page 1 has results, total should be > 0
# Note: total from page 2 may be 0 if no results on that page (COUNT(*) OVER limitation)
if results1:
assert total1 > 0
# ============================================================================
# SQL Validity Tests (verify queries don't break)
# ============================================================================
@pytest.mark.asyncio(loop_scope="session")
async def test_all_content_types_searchable(server):
"""Test that all content types can be searched without SQL errors."""
for content_type in [
ContentType.STORE_AGENT,
ContentType.BLOCK,
ContentType.DOCUMENTATION,
]:
results, total = await unified_hybrid_search(
query="test",
content_types=[content_type],
page=1,
page_size=10,
)
# Should not raise
assert isinstance(results, list)
assert isinstance(total, int)
@pytest.mark.asyncio(loop_scope="session")
async def test_multiple_content_types_searchable(server):
"""Test searching multiple content types at once."""
results, total = await unified_hybrid_search(
query="test",
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION],
page=1,
page_size=20,
)
# Should not raise
assert isinstance(results, list)
assert isinstance(total, int)
@pytest.mark.asyncio(loop_scope="session")
async def test_search_all_content_types_default(server):
"""Test searching all content types (default behavior)."""
results, total = await unified_hybrid_search(
query="test",
content_types=None, # Should search all
page=1,
page_size=20,
)
# Should not raise
assert isinstance(results, list)
assert isinstance(total, int)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -1,315 +0,0 @@
"""
Integration tests for embeddings with schema handling.
These tests verify that embeddings operations work correctly across different database schemas.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prisma.enums import ContentType
from backend.api.features.store import embeddings
from backend.api.features.store.embeddings import EMBEDDING_DIM
# Schema prefix tests removed - functionality moved to db.raw_with_schema() helper
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_store_content_embedding_with_schema():
"""Test storing embeddings with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
mock_get_client.return_value = mock_client
result = await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1] * EMBEDDING_DIM,
searchable_text="test text",
metadata={"test": "data"},
user_id=None,
)
# Verify the query was called
assert mock_client.execute_raw.called
# Get the SQL query that was executed
call_args = mock_client.execute_raw.call_args
sql_query = call_args[0][0]
# Verify schema prefix is in the query
assert '"platform"."UnifiedContentEmbedding"' in sql_query
# Verify result
assert result is True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_get_content_embedding_with_schema():
"""Test retrieving embeddings with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.query_raw.return_value = [
{
"contentType": "STORE_AGENT",
"contentId": "test-id",
"userId": None,
"embedding": "[0.1, 0.2]",
"searchableText": "test",
"metadata": {},
"createdAt": "2024-01-01",
"updatedAt": "2024-01-01",
}
]
mock_get_client.return_value = mock_client
result = await embeddings.get_content_embedding(
ContentType.STORE_AGENT,
"test-id",
user_id=None,
)
# Verify the query was called
assert mock_client.query_raw.called
# Get the SQL query that was executed
call_args = mock_client.query_raw.call_args
sql_query = call_args[0][0]
# Verify schema prefix is in the query
assert '"platform"."UnifiedContentEmbedding"' in sql_query
# Verify result
assert result is not None
assert result["contentId"] == "test-id"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_delete_content_embedding_with_schema():
"""Test deleting embeddings with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
mock_get_client.return_value = mock_client
result = await embeddings.delete_content_embedding(
ContentType.STORE_AGENT,
"test-id",
)
# Verify the query was called
assert mock_client.execute_raw.called
# Get the SQL query that was executed
call_args = mock_client.execute_raw.call_args
sql_query = call_args[0][0]
# Verify schema prefix is in the query
assert '"platform"."UnifiedContentEmbedding"' in sql_query
# Verify result
assert result is True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_get_embedding_stats_with_schema():
"""Test embedding statistics with proper schema handling via content handlers."""
# Mock handler to return stats
mock_handler = MagicMock()
mock_handler.get_stats = AsyncMock(
return_value={
"total": 100,
"with_embeddings": 80,
"without_embeddings": 20,
}
)
with patch(
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
{ContentType.STORE_AGENT: mock_handler},
):
result = await embeddings.get_embedding_stats()
# Verify handler was called
mock_handler.get_stats.assert_called_once()
# Verify new result structure
assert "by_type" in result
assert "totals" in result
assert result["totals"]["total"] == 100
assert result["totals"]["with_embeddings"] == 80
assert result["totals"]["without_embeddings"] == 20
assert result["totals"]["coverage_percent"] == 80.0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_backfill_missing_embeddings_with_schema():
"""Test backfilling embeddings via content handlers."""
from backend.api.features.store.content_handlers import ContentItem
# Create mock content item
mock_item = ContentItem(
content_id="version-1",
content_type=ContentType.STORE_AGENT,
searchable_text="Test Agent Test description",
metadata={"name": "Test Agent"},
)
# Mock handler
mock_handler = MagicMock()
mock_handler.get_missing_items = AsyncMock(return_value=[mock_item])
with patch(
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
{ContentType.STORE_AGENT: mock_handler},
):
with patch(
"backend.api.features.store.embeddings.generate_embedding",
return_value=[0.1] * EMBEDDING_DIM,
):
with patch(
"backend.api.features.store.embeddings.store_content_embedding",
return_value=True,
):
result = await embeddings.backfill_missing_embeddings(batch_size=10)
# Verify handler was called
mock_handler.get_missing_items.assert_called_once_with(10)
# Verify results
assert result["processed"] == 1
assert result["success"] == 1
assert result["failed"] == 0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_ensure_content_embedding_with_schema():
"""Test ensuring embeddings exist with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch(
"backend.api.features.store.embeddings.get_content_embedding"
) as mock_get:
# Simulate no existing embedding
mock_get.return_value = None
with patch(
"backend.api.features.store.embeddings.generate_embedding"
) as mock_generate:
mock_generate.return_value = [0.1] * EMBEDDING_DIM
with patch(
"backend.api.features.store.embeddings.store_content_embedding"
) as mock_store:
mock_store.return_value = True
result = await embeddings.ensure_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
searchable_text="test text",
metadata={"test": "data"},
user_id=None,
force=False,
)
# Verify the flow
assert mock_get.called
assert mock_generate.called
assert mock_store.called
assert result is True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_backward_compatibility_store_embedding():
"""Test backward compatibility wrapper for store_embedding."""
with patch(
"backend.api.features.store.embeddings.store_content_embedding"
) as mock_store:
mock_store.return_value = True
result = await embeddings.store_embedding(
version_id="test-version-id",
embedding=[0.1] * EMBEDDING_DIM,
tx=None,
)
# Verify it calls the new function with correct parameters
assert mock_store.called
call_args = mock_store.call_args
assert call_args[1]["content_type"] == ContentType.STORE_AGENT
assert call_args[1]["content_id"] == "test-version-id"
assert call_args[1]["user_id"] is None
assert result is True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_backward_compatibility_get_embedding():
"""Test backward compatibility wrapper for get_embedding."""
with patch(
"backend.api.features.store.embeddings.get_content_embedding"
) as mock_get:
mock_get.return_value = {
"contentType": "STORE_AGENT",
"contentId": "test-version-id",
"embedding": "[0.1, 0.2]",
"createdAt": "2024-01-01",
"updatedAt": "2024-01-01",
}
result = await embeddings.get_embedding("test-version-id")
# Verify it calls the new function
assert mock_get.called
# Verify it transforms to old format
assert result is not None
assert result["storeListingVersionId"] == "test-version-id"
assert "embedding" in result
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_schema_handling_error_cases():
"""Test error handling in schema-aware operations."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.execute_raw.side_effect = Exception("Database error")
mock_get_client.return_value = mock_client
result = await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1] * EMBEDDING_DIM,
searchable_text="test",
metadata=None,
user_id=None,
)
# Should return False on error, not raise
assert result is False
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -1,407 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import prisma
import pytest
from prisma import Prisma
from prisma.enums import ContentType
from backend.api.features.store import embeddings
@pytest.fixture(autouse=True)
async def setup_prisma():
"""Setup Prisma client for tests."""
try:
Prisma()
except prisma.errors.ClientAlreadyRegisteredError:
pass
yield
@pytest.mark.asyncio(loop_scope="session")
async def test_build_searchable_text():
"""Test searchable text building from listing fields."""
result = embeddings.build_searchable_text(
name="AI Assistant",
description="A helpful AI assistant for productivity",
sub_heading="Boost your productivity",
categories=["AI", "Productivity"],
)
expected = "AI Assistant Boost your productivity A helpful AI assistant for productivity AI Productivity"
assert result == expected
@pytest.mark.asyncio(loop_scope="session")
async def test_build_searchable_text_empty_fields():
"""Test searchable text building with empty fields."""
result = embeddings.build_searchable_text(
name="", description="Test description", sub_heading="", categories=[]
)
assert result == "Test description"
@pytest.mark.asyncio(loop_scope="session")
async def test_generate_embedding_success():
"""Test successful embedding generation."""
# Mock OpenAI response
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.data = [MagicMock()]
mock_response.data[0].embedding = [0.1, 0.2, 0.3] * 512 # 1536 dimensions
# Use AsyncMock for async embeddings.create method
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
# Patch at the point of use in embeddings.py
with patch(
"backend.api.features.store.embeddings.get_openai_client"
) as mock_get_client:
mock_get_client.return_value = mock_client
result = await embeddings.generate_embedding("test text")
assert result is not None
assert len(result) == embeddings.EMBEDDING_DIM
assert result[0] == 0.1
mock_client.embeddings.create.assert_called_once_with(
model="text-embedding-3-small", input="test text"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_generate_embedding_no_api_key():
"""Test embedding generation without API key."""
# Patch at the point of use in embeddings.py
with patch(
"backend.api.features.store.embeddings.get_openai_client"
) as mock_get_client:
mock_get_client.return_value = None
result = await embeddings.generate_embedding("test text")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_generate_embedding_api_error():
"""Test embedding generation with API error."""
mock_client = MagicMock()
mock_client.embeddings.create = AsyncMock(side_effect=Exception("API Error"))
# Patch at the point of use in embeddings.py
with patch(
"backend.api.features.store.embeddings.get_openai_client"
) as mock_get_client:
mock_get_client.return_value = mock_client
result = await embeddings.generate_embedding("test text")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_generate_embedding_text_truncation():
"""Test that long text is properly truncated using tiktoken."""
from tiktoken import encoding_for_model
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.data = [MagicMock()]
mock_response.data[0].embedding = [0.1] * embeddings.EMBEDDING_DIM
# Use AsyncMock for async embeddings.create method
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
# Patch at the point of use in embeddings.py
with patch(
"backend.api.features.store.embeddings.get_openai_client"
) as mock_get_client:
mock_get_client.return_value = mock_client
# Create text that will exceed 8191 tokens
# Use varied characters to ensure token-heavy text: each word is ~1 token
words = [f"word{i}" for i in range(10000)]
long_text = " ".join(words) # ~10000 tokens
await embeddings.generate_embedding(long_text)
# Verify text was truncated to 8191 tokens
call_args = mock_client.embeddings.create.call_args
truncated_text = call_args.kwargs["input"]
# Count actual tokens in truncated text
enc = encoding_for_model("text-embedding-3-small")
actual_tokens = len(enc.encode(truncated_text))
# Should be at or just under 8191 tokens
assert actual_tokens <= 8191
# Should be close to the limit (not over-truncated)
assert actual_tokens >= 8100
@pytest.mark.asyncio(loop_scope="session")
async def test_store_embedding_success(mocker):
"""Test successful embedding storage."""
mock_client = mocker.AsyncMock()
mock_client.execute_raw = mocker.AsyncMock()
embedding = [0.1, 0.2, 0.3]
result = await embeddings.store_embedding(
version_id="test-version-id", embedding=embedding, tx=mock_client
)
assert result is True
# execute_raw is called twice: once for SET search_path, once for INSERT
assert mock_client.execute_raw.call_count == 2
# First call: SET search_path
first_call_args = mock_client.execute_raw.call_args_list[0][0]
assert "SET search_path" in first_call_args[0]
# Second call: INSERT query with the actual data
second_call_args = mock_client.execute_raw.call_args_list[1][0]
assert "test-version-id" in second_call_args
assert "[0.1,0.2,0.3]" in second_call_args
assert None in second_call_args # userId should be None for store agents
@pytest.mark.asyncio(loop_scope="session")
async def test_store_embedding_database_error(mocker):
"""Test embedding storage with database error."""
mock_client = mocker.AsyncMock()
mock_client.execute_raw.side_effect = Exception("Database error")
embedding = [0.1, 0.2, 0.3]
result = await embeddings.store_embedding(
version_id="test-version-id", embedding=embedding, tx=mock_client
)
assert result is False
@pytest.mark.asyncio(loop_scope="session")
async def test_get_embedding_success():
"""Test successful embedding retrieval."""
mock_result = [
{
"contentType": "STORE_AGENT",
"contentId": "test-version-id",
"userId": None,
"embedding": "[0.1,0.2,0.3]",
"searchableText": "Test text",
"metadata": {},
"createdAt": "2024-01-01T00:00:00Z",
"updatedAt": "2024-01-01T00:00:00Z",
}
]
with patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_result,
):
result = await embeddings.get_embedding("test-version-id")
assert result is not None
assert result["storeListingVersionId"] == "test-version-id"
assert result["embedding"] == "[0.1,0.2,0.3]"
@pytest.mark.asyncio(loop_scope="session")
async def test_get_embedding_not_found():
"""Test embedding retrieval when not found."""
with patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=[],
):
result = await embeddings.get_embedding("test-version-id")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
@patch("backend.api.features.store.embeddings.store_embedding")
@patch("backend.api.features.store.embeddings.get_embedding")
async def test_ensure_embedding_already_exists(mock_get, mock_store, mock_generate):
"""Test ensure_embedding when embedding already exists."""
mock_get.return_value = {"embedding": "[0.1,0.2,0.3]"}
result = await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
assert result is True
mock_generate.assert_not_called()
mock_store.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
@patch("backend.api.features.store.embeddings.store_content_embedding")
@patch("backend.api.features.store.embeddings.get_embedding")
async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
"""Test ensure_embedding creating new embedding."""
mock_get.return_value = None
mock_generate.return_value = [0.1, 0.2, 0.3]
mock_store.return_value = True
result = await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
assert result is True
mock_generate.assert_called_once_with("Test Test heading Test description test")
mock_store.assert_called_once_with(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1, 0.2, 0.3],
searchable_text="Test Test heading Test description test",
metadata={"name": "Test", "subHeading": "Test heading", "categories": ["test"]},
user_id=None,
tx=None,
)
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
@patch("backend.api.features.store.embeddings.get_embedding")
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
"""Test ensure_embedding when generation fails."""
mock_get.return_value = None
mock_generate.return_value = None
result = await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
assert result is False
@pytest.mark.asyncio(loop_scope="session")
async def test_get_embedding_stats():
"""Test embedding statistics retrieval."""
# Mock handler stats for each content type
mock_handler = MagicMock()
mock_handler.get_stats = AsyncMock(
return_value={
"total": 100,
"with_embeddings": 75,
"without_embeddings": 25,
}
)
# Patch the CONTENT_HANDLERS where it's used (in embeddings module)
with patch(
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
{ContentType.STORE_AGENT: mock_handler},
):
result = await embeddings.get_embedding_stats()
assert "by_type" in result
assert "totals" in result
assert result["totals"]["total"] == 100
assert result["totals"]["with_embeddings"] == 75
assert result["totals"]["without_embeddings"] == 25
assert result["totals"]["coverage_percent"] == 75.0
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.store_content_embedding")
async def test_backfill_missing_embeddings_success(mock_store):
"""Test backfill with successful embedding generation."""
# Mock ContentItem from handlers
from backend.api.features.store.content_handlers import ContentItem
mock_items = [
ContentItem(
content_id="version-1",
content_type=ContentType.STORE_AGENT,
searchable_text="Agent 1 Description 1",
metadata={"name": "Agent 1"},
),
ContentItem(
content_id="version-2",
content_type=ContentType.STORE_AGENT,
searchable_text="Agent 2 Description 2",
metadata={"name": "Agent 2"},
),
]
# Mock handler to return missing items
mock_handler = MagicMock()
mock_handler.get_missing_items = AsyncMock(return_value=mock_items)
# Mock store_content_embedding to succeed for first, fail for second
mock_store.side_effect = [True, False]
with patch(
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
{ContentType.STORE_AGENT: mock_handler},
):
with patch(
"backend.api.features.store.embeddings.generate_embedding",
return_value=[0.1] * embeddings.EMBEDDING_DIM,
):
result = await embeddings.backfill_missing_embeddings(batch_size=5)
assert result["processed"] == 2
assert result["success"] == 1
assert result["failed"] == 1
assert mock_store.call_count == 2
@pytest.mark.asyncio(loop_scope="session")
async def test_backfill_missing_embeddings_no_missing():
"""Test backfill when no embeddings are missing."""
# Mock handler to return no missing items
mock_handler = MagicMock()
mock_handler.get_missing_items = AsyncMock(return_value=[])
with patch(
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
{ContentType.STORE_AGENT: mock_handler},
):
result = await embeddings.backfill_missing_embeddings(batch_size=5)
assert result["processed"] == 0
assert result["success"] == 0
assert result["failed"] == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_embedding_to_vector_string():
"""Test embedding to PostgreSQL vector string conversion."""
embedding = [0.1, 0.2, 0.3, -0.4]
result = embeddings.embedding_to_vector_string(embedding)
assert result == "[0.1,0.2,0.3,-0.4]"
@pytest.mark.asyncio(loop_scope="session")
async def test_embed_query():
"""Test embed_query function (alias for generate_embedding)."""
with patch(
"backend.api.features.store.embeddings.generate_embedding"
) as mock_generate:
mock_generate.return_value = [0.1, 0.2, 0.3]
result = await embeddings.embed_query("test query")
assert result == [0.1, 0.2, 0.3]
mock_generate.assert_called_once_with("test query")

View File

@@ -1,625 +0,0 @@
"""
Unified Hybrid Search
Combines semantic (embedding) search with lexical (tsvector) search
for improved relevance across all content types (agents, blocks, docs).
"""
import logging
from dataclasses import dataclass
from typing import Any, Literal
from prisma.enums import ContentType
from backend.api.features.store.embeddings import (
EMBEDDING_DIM,
embed_query,
embedding_to_vector_string,
)
from backend.data.db import query_raw_with_schema
logger = logging.getLogger(__name__)
@dataclass
class UnifiedSearchWeights:
"""Weights for unified search (no popularity signal)."""
semantic: float = 0.40 # Embedding cosine similarity
lexical: float = 0.40 # tsvector ts_rank_cd score
category: float = 0.10 # Category match boost (for types that have categories)
recency: float = 0.10 # Newer content ranked higher
def __post_init__(self):
"""Validate weights are non-negative and sum to approximately 1.0."""
total = self.semantic + self.lexical + self.category + self.recency
if any(
w < 0 for w in [self.semantic, self.lexical, self.category, self.recency]
):
raise ValueError("All weights must be non-negative")
if not (0.99 <= total <= 1.01):
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
# Default weights for unified search
DEFAULT_UNIFIED_WEIGHTS = UnifiedSearchWeights()
# Minimum relevance score thresholds
DEFAULT_MIN_SCORE = 0.15 # For unified search (more permissive)
DEFAULT_STORE_AGENT_MIN_SCORE = 0.20 # For store agent search (original threshold)
async def unified_hybrid_search(
query: str,
content_types: list[ContentType] | None = None,
category: str | None = None,
page: int = 1,
page_size: int = 20,
weights: UnifiedSearchWeights | None = None,
min_score: float | None = None,
user_id: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
Unified hybrid search across all content types.
Searches UnifiedContentEmbedding using both semantic (vector) and lexical (tsvector) signals.
Args:
query: Search query string
content_types: List of content types to search. Defaults to all public types.
category: Filter by category (for content types that support it)
page: Page number (1-indexed)
page_size: Results per page
weights: Custom weights for search signals
min_score: Minimum relevance score threshold (0-1)
user_id: User ID for searching private content (library agents)
Returns:
Tuple of (results list, total count)
"""
# Validate inputs
query = query.strip()
if not query:
return [], 0
if page < 1:
page = 1
if page_size < 1:
page_size = 1
if page_size > 100:
page_size = 100
if content_types is None:
content_types = [
ContentType.STORE_AGENT,
ContentType.BLOCK,
ContentType.DOCUMENTATION,
]
if weights is None:
weights = DEFAULT_UNIFIED_WEIGHTS
if min_score is None:
min_score = DEFAULT_MIN_SCORE
offset = (page - 1) * page_size
# Generate query embedding
query_embedding = await embed_query(query)
# Graceful degradation if embedding unavailable
if query_embedding is None or not query_embedding:
logger.warning(
"Failed to generate query embedding - falling back to lexical-only search. "
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
)
query_embedding = [0.0] * EMBEDDING_DIM
# Redistribute semantic weight to lexical
total_non_semantic = weights.lexical + weights.category + weights.recency
if total_non_semantic > 0:
factor = 1.0 / total_non_semantic
weights = UnifiedSearchWeights(
semantic=0.0,
lexical=weights.lexical * factor,
category=weights.category * factor,
recency=weights.recency * factor,
)
else:
weights = UnifiedSearchWeights(
semantic=0.0, lexical=1.0, category=0.0, recency=0.0
)
# Build parameters
params: list[Any] = []
param_idx = 1
# Query for lexical search
params.append(query)
query_param = f"${param_idx}"
param_idx += 1
# Query lowercase for category matching
params.append(query.lower())
query_lower_param = f"${param_idx}"
param_idx += 1
# Embedding
embedding_str = embedding_to_vector_string(query_embedding)
params.append(embedding_str)
embedding_param = f"${param_idx}"
param_idx += 1
# Content types
content_type_values = [ct.value for ct in content_types]
params.append(content_type_values)
content_types_param = f"${param_idx}"
param_idx += 1
# User ID filter (for private content)
user_filter = ""
if user_id is not None:
params.append(user_id)
user_filter = f'AND (uce."userId" = ${param_idx} OR uce."userId" IS NULL)'
param_idx += 1
else:
user_filter = 'AND uce."userId" IS NULL'
# Weights
params.append(weights.semantic)
w_semantic = f"${param_idx}"
param_idx += 1
params.append(weights.lexical)
w_lexical = f"${param_idx}"
param_idx += 1
params.append(weights.category)
w_category = f"${param_idx}"
param_idx += 1
params.append(weights.recency)
w_recency = f"${param_idx}"
param_idx += 1
# Min score
params.append(min_score)
min_score_param = f"${param_idx}"
param_idx += 1
# Pagination
params.append(page_size)
limit_param = f"${param_idx}"
param_idx += 1
params.append(offset)
offset_param = f"${param_idx}"
param_idx += 1
# Unified search query on UnifiedContentEmbedding
sql_query = f"""
WITH candidates AS (
-- Lexical matches (uses GIN index on search column)
SELECT uce.id, uce."contentType", uce."contentId"
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
{user_filter}
AND uce.search @@ plainto_tsquery('english', {query_param})
UNION
-- Semantic matches (uses HNSW index on embedding)
(
SELECT uce.id, uce."contentType", uce."contentId"
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
{user_filter}
ORDER BY uce.embedding <=> {embedding_param}::vector
LIMIT 200
)
),
search_scores AS (
SELECT
uce."contentType" as content_type,
uce."contentId" as content_id,
uce."searchableText" as searchable_text,
uce.metadata,
uce."updatedAt" as updated_at,
-- Semantic score: cosine similarity (1 - distance)
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
-- Lexical score: ts_rank_cd
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
-- Category match from metadata
CASE
WHEN uce.metadata ? 'categories' AND EXISTS (
SELECT 1 FROM jsonb_array_elements_text(uce.metadata->'categories') cat
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
)
THEN 1.0
ELSE 0.0
END as category_score,
-- Recency score: linear decay over 90 days
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - uce."updatedAt")) / (90 * 24 * 3600)) as recency_score
FROM candidates c
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce ON c.id = uce.id
),
max_lexical AS (
SELECT GREATEST(MAX(lexical_raw), 0.001) as max_val FROM search_scores
),
normalized AS (
SELECT
ss.*,
ss.lexical_raw / ml.max_val as lexical_score
FROM search_scores ss
CROSS JOIN max_lexical ml
),
scored AS (
SELECT
content_type,
content_id,
searchable_text,
metadata,
updated_at,
semantic_score,
lexical_score,
category_score,
recency_score,
(
{w_semantic} * semantic_score +
{w_lexical} * lexical_score +
{w_category} * category_score +
{w_recency} * recency_score
) as combined_score
FROM normalized
),
filtered AS (
SELECT
*,
COUNT(*) OVER () as total_count
FROM scored
WHERE combined_score >= {min_score_param}
)
SELECT * FROM filtered
ORDER BY combined_score DESC
LIMIT {limit_param} OFFSET {offset_param}
"""
results = await query_raw_with_schema(
sql_query, *params, set_public_search_path=True
)
total = results[0]["total_count"] if results else 0
# Clean up results
for result in results:
result.pop("total_count", None)
logger.info(f"Unified hybrid search: {len(results)} results, {total} total")
return results, total
# ============================================================================
# Store Agent specific search (with full metadata)
# ============================================================================
@dataclass
class StoreAgentSearchWeights:
"""Weights for store agent search including popularity."""
semantic: float = 0.30
lexical: float = 0.30
category: float = 0.20
recency: float = 0.10
popularity: float = 0.10
def __post_init__(self):
total = (
self.semantic
+ self.lexical
+ self.category
+ self.recency
+ self.popularity
)
if any(
w < 0
for w in [
self.semantic,
self.lexical,
self.category,
self.recency,
self.popularity,
]
):
raise ValueError("All weights must be non-negative")
if not (0.99 <= total <= 1.01):
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
DEFAULT_STORE_AGENT_WEIGHTS = StoreAgentSearchWeights()
async def hybrid_search(
query: str,
featured: bool = False,
creators: list[str] | None = None,
category: str | None = None,
sorted_by: (
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
) = None,
page: int = 1,
page_size: int = 20,
weights: StoreAgentSearchWeights | None = None,
min_score: float | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
Hybrid search for store agents with full metadata.
Uses UnifiedContentEmbedding for search, joins to StoreAgent for metadata.
"""
query = query.strip()
if not query:
return [], 0
if page < 1:
page = 1
if page_size < 1:
page_size = 1
if page_size > 100:
page_size = 100
if weights is None:
weights = DEFAULT_STORE_AGENT_WEIGHTS
if min_score is None:
min_score = (
DEFAULT_STORE_AGENT_MIN_SCORE # Use original threshold for store agents
)
offset = (page - 1) * page_size
# Generate query embedding
query_embedding = await embed_query(query)
# Graceful degradation
if query_embedding is None or not query_embedding:
logger.warning(
"Failed to generate query embedding - falling back to lexical-only search."
)
query_embedding = [0.0] * EMBEDDING_DIM
total_non_semantic = (
weights.lexical + weights.category + weights.recency + weights.popularity
)
if total_non_semantic > 0:
factor = 1.0 / total_non_semantic
weights = StoreAgentSearchWeights(
semantic=0.0,
lexical=weights.lexical * factor,
category=weights.category * factor,
recency=weights.recency * factor,
popularity=weights.popularity * factor,
)
else:
weights = StoreAgentSearchWeights(
semantic=0.0, lexical=1.0, category=0.0, recency=0.0, popularity=0.0
)
# Build parameters
params: list[Any] = []
param_idx = 1
params.append(query)
query_param = f"${param_idx}"
param_idx += 1
params.append(query.lower())
query_lower_param = f"${param_idx}"
param_idx += 1
embedding_str = embedding_to_vector_string(query_embedding)
params.append(embedding_str)
embedding_param = f"${param_idx}"
param_idx += 1
# Build WHERE clause for StoreAgent filters
where_parts = ["sa.is_available = true"]
if featured:
where_parts.append("sa.featured = true")
if creators:
params.append(creators)
where_parts.append(f"sa.creator_username = ANY(${param_idx})")
param_idx += 1
if category:
params.append(category)
where_parts.append(f"${param_idx} = ANY(sa.categories)")
param_idx += 1
where_clause = " AND ".join(where_parts)
# Weights
params.append(weights.semantic)
w_semantic = f"${param_idx}"
param_idx += 1
params.append(weights.lexical)
w_lexical = f"${param_idx}"
param_idx += 1
params.append(weights.category)
w_category = f"${param_idx}"
param_idx += 1
params.append(weights.recency)
w_recency = f"${param_idx}"
param_idx += 1
params.append(weights.popularity)
w_popularity = f"${param_idx}"
param_idx += 1
params.append(min_score)
min_score_param = f"${param_idx}"
param_idx += 1
params.append(page_size)
limit_param = f"${param_idx}"
param_idx += 1
params.append(offset)
offset_param = f"${param_idx}"
param_idx += 1
# Query using UnifiedContentEmbedding for search, StoreAgent for metadata
sql_query = f"""
WITH candidates AS (
-- Lexical matches via UnifiedContentEmbedding.search
SELECT uce."contentId" as "storeListingVersionId"
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa."storeListingVersionId"
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND uce.search @@ plainto_tsquery('english', {query_param})
AND {where_clause}
UNION
-- Semantic matches via UnifiedContentEmbedding.embedding
SELECT uce."contentId" as "storeListingVersionId"
FROM (
SELECT uce."contentId", uce.embedding
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa."storeListingVersionId"
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND {where_clause}
ORDER BY uce.embedding <=> {embedding_param}::vector
LIMIT 200
) uce
),
search_scores AS (
SELECT
sa.slug,
sa.agent_name,
sa.agent_image,
sa.creator_username,
sa.creator_avatar,
sa.sub_heading,
sa.description,
sa.runs,
sa.rating,
sa.categories,
sa.featured,
sa.is_available,
sa.updated_at,
-- Semantic score
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
-- Lexical score (raw, will normalize)
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
-- Category match
CASE
WHEN EXISTS (
SELECT 1 FROM unnest(sa.categories) cat
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
)
THEN 1.0
ELSE 0.0
END as category_score,
-- Recency
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score,
-- Popularity (raw)
sa.runs as popularity_raw
FROM candidates c
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON c."storeListingVersionId" = sa."storeListingVersionId"
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
ON sa."storeListingVersionId" = uce."contentId"
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
),
max_vals AS (
SELECT
GREATEST(MAX(lexical_raw), 0.001) as max_lexical,
GREATEST(MAX(popularity_raw), 1) as max_popularity
FROM search_scores
),
normalized AS (
SELECT
ss.*,
ss.lexical_raw / mv.max_lexical as lexical_score,
CASE
WHEN ss.popularity_raw > 0
THEN LN(1 + ss.popularity_raw) / LN(1 + mv.max_popularity)
ELSE 0
END as popularity_score
FROM search_scores ss
CROSS JOIN max_vals mv
),
scored AS (
SELECT
slug,
agent_name,
agent_image,
creator_username,
creator_avatar,
sub_heading,
description,
runs,
rating,
categories,
featured,
is_available,
updated_at,
semantic_score,
lexical_score,
category_score,
recency_score,
popularity_score,
(
{w_semantic} * semantic_score +
{w_lexical} * lexical_score +
{w_category} * category_score +
{w_recency} * recency_score +
{w_popularity} * popularity_score
) as combined_score
FROM normalized
),
filtered AS (
SELECT *, COUNT(*) OVER () as total_count
FROM scored
WHERE combined_score >= {min_score_param}
)
SELECT * FROM filtered
ORDER BY combined_score DESC
LIMIT {limit_param} OFFSET {offset_param}
"""
results = await query_raw_with_schema(
sql_query, *params, set_public_search_path=True
)
total = results[0]["total_count"] if results else 0
for result in results:
result.pop("total_count", None)
logger.info(f"Hybrid search (store agents): {len(results)} results, {total} total")
return results, total
async def hybrid_search_simple(
query: str,
page: int = 1,
page_size: int = 20,
) -> tuple[list[dict[str, Any]], int]:
"""Simplified hybrid search for store agents."""
return await hybrid_search(query=query, page=page, page_size=page_size)
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
# for existing code that expects the popularity parameter
HybridSearchWeights = StoreAgentSearchWeights

View File

@@ -1,667 +0,0 @@
"""
Integration tests for hybrid search with schema handling.
These tests verify that hybrid search works correctly across different database schemas.
"""
from unittest.mock import patch
import pytest
from prisma.enums import ContentType
from backend.api.features.store import embeddings
from backend.api.features.store.hybrid_search import (
HybridSearchWeights,
UnifiedSearchWeights,
hybrid_search,
unified_hybrid_search,
)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_with_schema_handling():
"""Test that hybrid search correctly handles database schema prefixes."""
# Test with a mock query to ensure schema handling works
query = "test agent"
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Mock the query result
mock_query.return_value = [
{
"slug": "test/agent",
"agent_name": "Test Agent",
"agent_image": "test.png",
"creator_username": "test",
"creator_avatar": "avatar.png",
"sub_heading": "Test sub-heading",
"description": "Test description",
"runs": 10,
"rating": 4.5,
"categories": ["test"],
"featured": False,
"is_available": True,
"updated_at": "2024-01-01T00:00:00Z",
"combined_score": 0.8,
"semantic_score": 0.7,
"lexical_score": 0.6,
"category_score": 0.5,
"recency_score": 0.4,
"total_count": 1,
}
]
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM # Mock embedding
results, total = await hybrid_search(
query=query,
page=1,
page_size=20,
)
# Verify the query was called
assert mock_query.called
# Verify the SQL template uses schema_prefix placeholder
call_args = mock_query.call_args
sql_template = call_args[0][0]
assert "{schema_prefix}" in sql_template
# Verify results
assert len(results) == 1
assert total == 1
assert results[0]["slug"] == "test/agent"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_with_public_schema():
"""Test hybrid search when using public schema (no prefix needed)."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "public"
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await hybrid_search(
query="test",
page=1,
page_size=20,
)
# Verify the mock was set up correctly
assert mock_schema.return_value == "public"
# Results should work even with empty results
assert results == []
assert total == 0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_with_custom_schema():
"""Test hybrid search when using custom schema (e.g., 'platform')."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await hybrid_search(
query="test",
page=1,
page_size=20,
)
# Verify the mock was set up correctly
assert mock_schema.return_value == "platform"
assert results == []
assert total == 0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_without_embeddings():
"""Test hybrid search gracefully degrades when embeddings are unavailable."""
# Mock database to return some results
mock_results = [
{
"slug": "test-agent",
"agent_name": "Test Agent",
"agent_image": "test.png",
"creator_username": "creator",
"creator_avatar": "avatar.png",
"sub_heading": "Test heading",
"description": "Test description",
"runs": 100,
"rating": 4.5,
"categories": ["AI"],
"featured": False,
"is_available": True,
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.0, # Zero because no embedding
"lexical_score": 0.5,
"category_score": 0.0,
"recency_score": 0.1,
"popularity_score": 0.2,
"combined_score": 0.3,
"total_count": 1,
}
]
with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed:
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Simulate embedding failure
mock_embed.return_value = None
mock_query.return_value = mock_results
# Should NOT raise - graceful degradation
results, total = await hybrid_search(
query="test",
page=1,
page_size=20,
)
# Verify it returns results even without embeddings
assert len(results) == 1
assert results[0]["slug"] == "test-agent"
assert total == 1
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_with_filters():
"""Test hybrid search with various filters."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
# Test with featured filter
results, total = await hybrid_search(
query="test",
featured=True,
creators=["user1", "user2"],
category="productivity",
page=1,
page_size=10,
)
# Verify filters were applied in the query
call_args = mock_query.call_args
params = call_args[0][1:] # Skip SQL template
# Should have query, query_lower, creators array, category
assert len(params) >= 4
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_weights():
"""Test hybrid search with custom weights."""
custom_weights = HybridSearchWeights(
semantic=0.5,
lexical=0.3,
category=0.1,
recency=0.1,
popularity=0.0,
)
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await hybrid_search(
query="test",
weights=custom_weights,
page=1,
page_size=20,
)
# Verify custom weights were used in the query
call_args = mock_query.call_args
sql_template = call_args[0][0]
params = call_args[0][1:] # Get all parameters passed
# Check that SQL uses parameterized weights (not f-string interpolation)
assert "$" in sql_template # Verify parameterization is used
# Check that custom weights are in the params
assert 0.5 in params # semantic weight
assert 0.3 in params # lexical weight
assert 0.1 in params # category and recency weights
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_min_score_filtering():
"""Test hybrid search minimum score threshold."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Return results with varying scores
mock_query.return_value = [
{
"slug": "high-score/agent",
"agent_name": "High Score Agent",
"combined_score": 0.8,
"total_count": 1,
# ... other fields
}
]
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
# Test with custom min_score
results, total = await hybrid_search(
query="test",
min_score=0.5, # High threshold
page=1,
page_size=20,
)
# Verify min_score was applied in query
call_args = mock_query.call_args
sql_template = call_args[0][0]
params = call_args[0][1:] # Get all parameters
# Check that SQL uses parameterized min_score
assert "combined_score >=" in sql_template
assert "$" in sql_template # Verify parameterization
# Check that custom min_score is in the params
assert 0.5 in params
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_pagination():
"""Test hybrid search pagination."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
# Test page 2 with page_size 10
results, total = await hybrid_search(
query="test",
page=2,
page_size=10,
)
# Verify pagination parameters
call_args = mock_query.call_args
params = call_args[0]
# Last two params should be LIMIT and OFFSET
limit = params[-2]
offset = params[-1]
assert limit == 10 # page_size
assert offset == 10 # (page - 1) * page_size = (2 - 1) * 10
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_error_handling():
"""Test hybrid search error handling."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Simulate database error
mock_query.side_effect = Exception("Database connection error")
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
# Should raise exception
with pytest.raises(Exception) as exc_info:
await hybrid_search(
query="test",
page=1,
page_size=20,
)
assert "Database connection error" in str(exc_info.value)
# =============================================================================
# Unified Hybrid Search Tests
# =============================================================================
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_basic():
"""Test basic unified hybrid search across all content types."""
mock_results = [
{
"content_type": "STORE_AGENT",
"content_id": "agent-1",
"searchable_text": "Test Agent Description",
"metadata": {"name": "Test Agent"},
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.7,
"lexical_score": 0.8,
"category_score": 0.5,
"recency_score": 0.3,
"combined_score": 0.6,
"total_count": 2,
},
{
"content_type": "BLOCK",
"content_id": "block-1",
"searchable_text": "Test Block Description",
"metadata": {"name": "Test Block"},
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.6,
"lexical_score": 0.7,
"category_score": 0.4,
"recency_score": 0.2,
"combined_score": 0.5,
"total_count": 2,
},
]
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = mock_results
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await unified_hybrid_search(
query="test",
page=1,
page_size=20,
)
assert len(results) == 2
assert total == 2
assert results[0]["content_type"] == "STORE_AGENT"
assert results[1]["content_type"] == "BLOCK"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_filter_by_content_type():
"""Test unified search filtering by specific content types."""
mock_results = [
{
"content_type": "BLOCK",
"content_id": "block-1",
"searchable_text": "Test Block",
"metadata": {},
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.7,
"lexical_score": 0.8,
"category_score": 0.0,
"recency_score": 0.3,
"combined_score": 0.5,
"total_count": 1,
},
]
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = mock_results
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await unified_hybrid_search(
query="test",
content_types=[ContentType.BLOCK],
page=1,
page_size=20,
)
# Verify content_types parameter was passed correctly
call_args = mock_query.call_args
params = call_args[0][1:]
# The content types should be in the params as a list
assert ["BLOCK"] in params
assert len(results) == 1
assert total == 1
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_with_user_id():
"""Test unified search with user_id for private content."""
mock_results = [
{
"content_type": "STORE_AGENT",
"content_id": "agent-1",
"searchable_text": "My Private Agent",
"metadata": {},
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.7,
"lexical_score": 0.8,
"category_score": 0.0,
"recency_score": 0.3,
"combined_score": 0.6,
"total_count": 1,
},
]
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = mock_results
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await unified_hybrid_search(
query="test",
user_id="user-123",
page=1,
page_size=20,
)
# Verify SQL contains user_id filter
call_args = mock_query.call_args
sql_template = call_args[0][0]
params = call_args[0][1:]
assert 'uce."userId"' in sql_template
assert "user-123" in params
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_custom_weights():
"""Test unified search with custom weights."""
custom_weights = UnifiedSearchWeights(
semantic=0.6,
lexical=0.2,
category=0.1,
recency=0.1,
)
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = []
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await unified_hybrid_search(
query="test",
weights=custom_weights,
page=1,
page_size=20,
)
# Verify custom weights are in parameters
call_args = mock_query.call_args
params = call_args[0][1:]
assert 0.6 in params # semantic weight
assert 0.2 in params # lexical weight
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_graceful_degradation():
"""Test unified search gracefully degrades when embeddings unavailable."""
mock_results = [
{
"content_type": "DOCUMENTATION",
"content_id": "doc-1",
"searchable_text": "API Documentation",
"metadata": {},
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.0, # Zero because no embedding
"lexical_score": 0.8,
"category_score": 0.0,
"recency_score": 0.2,
"combined_score": 0.5,
"total_count": 1,
},
]
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = mock_results
mock_embed.return_value = None # Embedding failure
# Should NOT raise - graceful degradation
results, total = await unified_hybrid_search(
query="test",
page=1,
page_size=20,
)
assert len(results) == 1
assert total == 1
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_empty_query():
"""Test unified search with empty query returns empty results."""
results, total = await unified_hybrid_search(
query="",
page=1,
page_size=20,
)
assert results == []
assert total == 0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_pagination():
"""Test unified search pagination."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = []
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await unified_hybrid_search(
query="test",
page=3,
page_size=15,
)
# Verify pagination parameters (last two params are LIMIT and OFFSET)
call_args = mock_query.call_args
params = call_args[0]
limit = params[-2]
offset = params[-1]
assert limit == 15 # page_size
assert offset == 30 # (page - 1) * page_size = (3 - 1) * 15
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_schema_prefix():
"""Test unified search uses schema_prefix placeholder."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = []
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
await unified_hybrid_search(
query="test",
page=1,
page_size=20,
)
call_args = mock_query.call_args
sql_template = call_args[0][0]
# Verify schema_prefix placeholder is used for table references
assert "{schema_prefix}" in sql_template
assert '"UnifiedContentEmbedding"' in sql_template
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -221,23 +221,3 @@ class ReviewSubmissionRequest(pydantic.BaseModel):
is_approved: bool
comments: str # External comments visible to creator
internal_comments: str | None = None # Private admin notes
class UnifiedSearchResult(pydantic.BaseModel):
"""A single result from unified hybrid search across all content types."""
content_type: str # STORE_AGENT, BLOCK, DOCUMENTATION
content_id: str
searchable_text: str
metadata: dict | None = None
updated_at: datetime.datetime | None = None
combined_score: float | None = None
semantic_score: float | None = None
lexical_score: float | None = None
class UnifiedSearchResponse(pydantic.BaseModel):
"""Response model for unified search across all content types."""
results: list[UnifiedSearchResult]
pagination: Pagination

View File

@@ -7,15 +7,12 @@ from typing import Literal
import autogpt_libs.auth
import fastapi
import fastapi.responses
import prisma.enums
import backend.data.graph
import backend.util.json
from backend.util.models import Pagination
from . import cache as store_cache
from . import db as store_db
from . import hybrid_search as store_hybrid_search
from . import image_gen as store_image_gen
from . import media as store_media
from . import model as store_model
@@ -149,102 +146,6 @@ async def get_agents(
return agents
##############################################
############### Search Endpoints #############
##############################################
@router.get(
"/search",
summary="Unified search across all content types",
tags=["store", "public"],
response_model=store_model.UnifiedSearchResponse,
)
async def unified_search(
query: str,
content_types: list[str] | None = fastapi.Query(
default=None,
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
),
page: int = 1,
page_size: int = 20,
user_id: str | None = fastapi.Security(
autogpt_libs.auth.get_optional_user_id, use_cache=False
),
):
"""
Search across all content types (store agents, blocks, documentation) using hybrid search.
Combines semantic (embedding-based) and lexical (text-based) search for best results.
Args:
query: The search query string
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
page: Page number for pagination (default 1)
page_size: Number of results per page (default 20)
user_id: Optional authenticated user ID (for user-scoped content in future)
Returns:
UnifiedSearchResponse: Paginated list of search results with relevance scores
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
# Convert string content types to enum
content_type_enums: list[prisma.enums.ContentType] | None = None
if content_types:
try:
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
except ValueError as e:
raise fastapi.HTTPException(
status_code=422,
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
)
# Perform unified hybrid search
results, total = await store_hybrid_search.unified_hybrid_search(
query=query,
content_types=content_type_enums,
user_id=user_id,
page=page,
page_size=page_size,
)
# Convert results to response model
search_results = [
store_model.UnifiedSearchResult(
content_type=r["content_type"],
content_id=r["content_id"],
searchable_text=r.get("searchable_text", ""),
metadata=r.get("metadata"),
updated_at=r.get("updated_at"),
combined_score=r.get("combined_score"),
semantic_score=r.get("semantic_score"),
lexical_score=r.get("lexical_score"),
)
for r in results
]
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
return store_model.UnifiedSearchResponse(
results=search_results,
pagination=Pagination(
total_items=total,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)
@router.get(
"/agents/{username}/{agent_name}",
summary="Get specific agent",

View File

@@ -1,272 +0,0 @@
"""Tests for the semantic_search function."""
import pytest
from prisma.enums import ContentType
from backend.api.features.store.embeddings import EMBEDDING_DIM, semantic_search
@pytest.mark.asyncio
async def test_search_blocks_only(mocker):
"""Test searching only BLOCK content type."""
# Mock embed_query to return a test embedding
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
# Mock query_raw_with_schema to return test results
mock_results = [
{
"content_id": "block-123",
"content_type": "BLOCK",
"searchable_text": "Calculator Block - Performs arithmetic operations",
"metadata": {"name": "Calculator", "categories": ["Math"]},
"similarity": 0.85,
}
]
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_results,
)
results = await semantic_search(
query="calculate numbers",
content_types=[ContentType.BLOCK],
)
assert len(results) == 1
assert results[0]["content_type"] == "BLOCK"
assert results[0]["content_id"] == "block-123"
assert results[0]["similarity"] == 0.85
@pytest.mark.asyncio
async def test_search_multiple_content_types(mocker):
"""Test searching multiple content types simultaneously."""
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
mock_results = [
{
"content_id": "block-123",
"content_type": "BLOCK",
"searchable_text": "Calculator Block",
"metadata": {},
"similarity": 0.85,
},
{
"content_id": "doc-456",
"content_type": "DOCUMENTATION",
"searchable_text": "How to use Calculator",
"metadata": {},
"similarity": 0.75,
},
]
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_results,
)
results = await semantic_search(
query="calculator",
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION],
)
assert len(results) == 2
assert results[0]["content_type"] == "BLOCK"
assert results[1]["content_type"] == "DOCUMENTATION"
@pytest.mark.asyncio
async def test_search_with_min_similarity_threshold(mocker):
"""Test that results below min_similarity are filtered out."""
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
# Only return results above 0.7 similarity
mock_results = [
{
"content_id": "block-123",
"content_type": "BLOCK",
"searchable_text": "Calculator Block",
"metadata": {},
"similarity": 0.85,
}
]
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_results,
)
results = await semantic_search(
query="calculate",
content_types=[ContentType.BLOCK],
min_similarity=0.7,
)
assert len(results) == 1
assert results[0]["similarity"] >= 0.7
@pytest.mark.asyncio
async def test_search_fallback_to_lexical(mocker):
"""Test fallback to lexical search when embeddings fail."""
# Mock embed_query to return None (embeddings unavailable)
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=None,
)
mock_lexical_results = [
{
"content_id": "block-123",
"content_type": "BLOCK",
"searchable_text": "Calculator Block performs calculations",
"metadata": {},
"similarity": 0.0,
}
]
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_lexical_results,
)
results = await semantic_search(
query="calculator",
content_types=[ContentType.BLOCK],
)
assert len(results) == 1
assert results[0]["similarity"] == 0.0 # Lexical search returns 0 similarity
@pytest.mark.asyncio
async def test_search_empty_query():
"""Test that empty query returns no results."""
results = await semantic_search(query="")
assert results == []
results = await semantic_search(query=" ")
assert results == []
@pytest.mark.asyncio
async def test_search_with_user_id_filter(mocker):
"""Test searching with user_id filter for private content."""
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
mock_results = [
{
"content_id": "agent-789",
"content_type": "LIBRARY_AGENT",
"searchable_text": "My Custom Agent",
"metadata": {},
"similarity": 0.9,
}
]
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_results,
)
results = await semantic_search(
query="custom agent",
content_types=[ContentType.LIBRARY_AGENT],
user_id="user-123",
)
assert len(results) == 1
assert results[0]["content_type"] == "LIBRARY_AGENT"
@pytest.mark.asyncio
async def test_search_limit_parameter(mocker):
"""Test that limit parameter correctly limits results."""
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
# Return 5 results
mock_results = [
{
"content_id": f"block-{i}",
"content_type": "BLOCK",
"searchable_text": f"Block {i}",
"metadata": {},
"similarity": 0.8,
}
for i in range(5)
]
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_results,
)
results = await semantic_search(
query="block",
content_types=[ContentType.BLOCK],
limit=5,
)
assert len(results) == 5
@pytest.mark.asyncio
async def test_search_default_content_types(mocker):
"""Test that default content_types includes BLOCK, STORE_AGENT, and DOCUMENTATION."""
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
mock_query_raw = mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=[],
)
await semantic_search(query="test")
# Check that the SQL query includes all three default content types
call_args = mock_query_raw.call_args
assert "BLOCK" in str(call_args)
assert "STORE_AGENT" in str(call_args)
assert "DOCUMENTATION" in str(call_args)
@pytest.mark.asyncio
async def test_search_handles_database_error(mocker):
"""Test that database errors are handled gracefully."""
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
# Simulate database error
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
side_effect=Exception("Database connection failed"),
)
results = await semantic_search(
query="test",
content_types=[ContentType.BLOCK],
)
# Should return empty list on error
assert results == []

View File

@@ -64,6 +64,7 @@ from backend.data.onboarding import (
complete_re_run_agent,
get_recommended_agents,
get_user_onboarding,
increment_runs,
onboarding_enabled,
reset_user_onboarding,
update_user_onboarding,
@@ -693,13 +694,13 @@ class DeleteGraphResponse(TypedDict):
async def list_graphs(
user_id: Annotated[str, Security(get_user_id)],
) -> Sequence[graph_db.GraphMeta]:
graphs, _ = await graph_db.list_graphs_paginated(
paginated_result = await graph_db.list_graphs_paginated(
user_id=user_id,
page=1,
page_size=250,
filter_by="active",
)
return graphs
return paginated_result.graphs
@v1_router.get(
@@ -974,6 +975,7 @@ async def execute_graph(
# Record successful graph execution
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
record_graph_operation(operation="execute", status="success")
await increment_runs(user_id)
await complete_re_run_agent(user_id, graph_id)
if source == "library":
await complete_onboarding_step(

View File

@@ -0,0 +1,631 @@
import json
import shlex
import uuid
from typing import Literal, Optional
from e2b import AsyncSandbox as BaseAsyncSandbox
from pydantic import BaseModel, SecretStr
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
# Test credentials for E2B
TEST_E2B_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="e2b",
api_key=SecretStr("mock-e2b-api-key"),
title="Mock E2B API key",
expires_at=None,
)
TEST_E2B_CREDENTIALS_INPUT = {
"provider": TEST_E2B_CREDENTIALS.provider,
"id": TEST_E2B_CREDENTIALS.id,
"type": TEST_E2B_CREDENTIALS.type,
"title": TEST_E2B_CREDENTIALS.title,
}
# Test credentials for Anthropic
TEST_ANTHROPIC_CREDENTIALS = APIKeyCredentials(
id="2e568a2b-b2ea-475a-8564-9a676bf31c56",
provider="anthropic",
api_key=SecretStr("mock-anthropic-api-key"),
title="Mock Anthropic API key",
expires_at=None,
)
TEST_ANTHROPIC_CREDENTIALS_INPUT = {
"provider": TEST_ANTHROPIC_CREDENTIALS.provider,
"id": TEST_ANTHROPIC_CREDENTIALS.id,
"type": TEST_ANTHROPIC_CREDENTIALS.type,
"title": TEST_ANTHROPIC_CREDENTIALS.title,
}
class ClaudeCodeBlock(Block):
"""
Execute tasks using Claude Code (Anthropic's AI coding assistant) in an E2B sandbox.
Claude Code can create files, install tools, run commands, and perform complex
coding tasks autonomously within a secure sandbox environment.
"""
# Use base template - we'll install Claude Code ourselves for latest version
DEFAULT_TEMPLATE = "base"
class Input(BlockSchemaInput):
e2b_credentials: CredentialsMetaInput[
Literal[ProviderName.E2B], Literal["api_key"]
] = CredentialsField(
description=(
"API key for the E2B platform to create the sandbox. "
"Get one on the [e2b website](https://e2b.dev/docs)"
),
)
anthropic_credentials: CredentialsMetaInput[
Literal[ProviderName.ANTHROPIC], Literal["api_key"]
] = CredentialsField(
description=(
"API key for Anthropic to power Claude Code. "
"Get one at [Anthropic's website](https://console.anthropic.com)"
),
)
prompt: str = SchemaField(
description=(
"The task or instruction for Claude Code to execute. "
"Claude Code can create files, install packages, run commands, "
"and perform complex coding tasks."
),
placeholder="Create a hello world index.html file",
default="",
advanced=False,
)
timeout: int = SchemaField(
description=(
"Sandbox timeout in seconds. Claude Code tasks can take "
"a while, so set this appropriately for your task complexity. "
"Note: This only applies when creating a new sandbox. "
"When reconnecting to an existing sandbox via sandbox_id, "
"the original timeout is retained."
),
default=300, # 5 minutes default
advanced=True,
)
setup_commands: list[str] = SchemaField(
description=(
"Optional shell commands to run before executing Claude Code. "
"Useful for installing dependencies or setting up the environment."
),
default_factory=list,
advanced=True,
)
working_directory: str = SchemaField(
description="Working directory for Claude Code to operate in.",
default="/home/user",
advanced=True,
)
# Session/continuation support
session_id: str = SchemaField(
description=(
"Session ID to resume a previous conversation. "
"Leave empty for a new conversation. "
"Use the session_id from a previous run to continue that conversation."
),
default="",
advanced=True,
)
sandbox_id: str = SchemaField(
description=(
"Sandbox ID to reconnect to an existing sandbox. "
"Required when resuming a session (along with session_id). "
"Use the sandbox_id from a previous run where dispose_sandbox was False."
),
default="",
advanced=True,
)
conversation_history: str = SchemaField(
description=(
"Previous conversation history to continue from. "
"Use this to restore context on a fresh sandbox if the previous one timed out. "
"Pass the conversation_history output from a previous run."
),
default="",
advanced=True,
)
dispose_sandbox: bool = SchemaField(
description=(
"Whether to dispose of the sandbox immediately after execution. "
"Set to False if you want to continue the conversation later "
"(you'll need both sandbox_id and session_id from the output)."
),
default=True,
advanced=True,
)
class FileOutput(BaseModel):
"""A file extracted from the sandbox."""
path: str
relative_path: str # Path relative to working directory (for GitHub, etc.)
name: str
content: str
class Output(BlockSchemaOutput):
response: str = SchemaField(
description="The output/response from Claude Code execution"
)
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
description=(
"List of text files created/modified by Claude Code during this execution. "
"Each file has 'path', 'relative_path', 'name', and 'content' fields."
)
)
conversation_history: str = SchemaField(
description=(
"Full conversation history including this turn. "
"Pass this to conversation_history input to continue on a fresh sandbox "
"if the previous sandbox timed out."
)
)
session_id: str = SchemaField(
description=(
"Session ID for this conversation. "
"Pass this back along with sandbox_id to continue the conversation."
)
)
sandbox_id: Optional[str] = SchemaField(
description=(
"ID of the sandbox instance. "
"Pass this back along with session_id to continue the conversation. "
"This is None if dispose_sandbox was True (sandbox was disposed)."
),
default=None,
)
error: str = SchemaField(description="Error message if execution failed")
def __init__(self):
super().__init__(
id="4e34f4a5-9b89-4326-ba77-2dd6750b7194",
description=(
"Execute tasks using Claude Code in an E2B sandbox. "
"Claude Code can create files, install tools, run commands, "
"and perform complex coding tasks autonomously."
),
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.AI},
input_schema=ClaudeCodeBlock.Input,
output_schema=ClaudeCodeBlock.Output,
test_credentials={
"e2b_credentials": TEST_E2B_CREDENTIALS,
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS,
},
test_input={
"e2b_credentials": TEST_E2B_CREDENTIALS_INPUT,
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS_INPUT,
"prompt": "Create a hello world HTML file",
"timeout": 300,
"setup_commands": [],
"working_directory": "/home/user",
"session_id": "",
"sandbox_id": "",
"conversation_history": "",
"dispose_sandbox": True,
},
test_output=[
("response", "Created index.html with hello world content"),
(
"files",
[
{
"path": "/home/user/index.html",
"relative_path": "index.html",
"name": "index.html",
"content": "<html>Hello World</html>",
}
],
),
(
"conversation_history",
"User: Create a hello world HTML file\n"
"Claude: Created index.html with hello world content",
),
("session_id", str),
("sandbox_id", None), # None because dispose_sandbox=True in test_input
],
test_mock={
"execute_claude_code": lambda *args, **kwargs: (
"Created index.html with hello world content", # response
[
ClaudeCodeBlock.FileOutput(
path="/home/user/index.html",
relative_path="index.html",
name="index.html",
content="<html>Hello World</html>",
)
], # files
"User: Create a hello world HTML file\n"
"Claude: Created index.html with hello world content", # conversation_history
"test-session-id", # session_id
"sandbox_id", # sandbox_id
),
},
)
async def execute_claude_code(
self,
e2b_api_key: str,
anthropic_api_key: str,
prompt: str,
timeout: int,
setup_commands: list[str],
working_directory: str,
session_id: str,
existing_sandbox_id: str,
conversation_history: str,
dispose_sandbox: bool,
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
"""
Execute Claude Code in an E2B sandbox.
Returns:
Tuple of (response, files, conversation_history, session_id, sandbox_id)
"""
# Validate that sandbox_id is provided when resuming a session
if session_id and not existing_sandbox_id:
raise ValueError(
"sandbox_id is required when resuming a session with session_id. "
"The session state is stored in the original sandbox. "
"If the sandbox has timed out, use conversation_history instead "
"to restore context on a fresh sandbox."
)
sandbox = None
try:
# Either reconnect to existing sandbox or create a new one
if existing_sandbox_id:
# Reconnect to existing sandbox for conversation continuation
sandbox = await BaseAsyncSandbox.connect(
sandbox_id=existing_sandbox_id,
api_key=e2b_api_key,
)
else:
# Create new sandbox
sandbox = await BaseAsyncSandbox.create(
template=self.DEFAULT_TEMPLATE,
api_key=e2b_api_key,
timeout=timeout,
envs={"ANTHROPIC_API_KEY": anthropic_api_key},
)
# Install Claude Code from npm (ensures we get the latest version)
install_result = await sandbox.commands.run(
"npm install -g @anthropic-ai/claude-code@latest",
timeout=120, # 2 min timeout for install
)
if install_result.exit_code != 0:
raise Exception(
f"Failed to install Claude Code: {install_result.stderr}"
)
# Run any user-provided setup commands
for cmd in setup_commands:
setup_result = await sandbox.commands.run(cmd)
if setup_result.exit_code != 0:
raise Exception(
f"Setup command failed: {cmd}\n"
f"Exit code: {setup_result.exit_code}\n"
f"Stdout: {setup_result.stdout}\n"
f"Stderr: {setup_result.stderr}"
)
# Generate or use provided session ID
current_session_id = session_id if session_id else str(uuid.uuid4())
# Build base Claude flags
base_flags = "-p --dangerously-skip-permissions --output-format json"
# Add conversation history context if provided (for fresh sandbox continuation)
history_flag = ""
if conversation_history and not session_id:
# Inject previous conversation as context via system prompt
# Use consistent escaping via _escape_prompt helper
escaped_history = self._escape_prompt(
f"Previous conversation context: {conversation_history}"
)
history_flag = f" --append-system-prompt {escaped_history}"
# Build Claude command based on whether we're resuming or starting new
# Use shlex.quote for working_directory and session IDs to prevent injection
safe_working_dir = shlex.quote(working_directory)
if session_id:
# Resuming existing session (sandbox still alive)
safe_session_id = shlex.quote(session_id)
claude_command = (
f"cd {safe_working_dir} && "
f"echo {self._escape_prompt(prompt)} | "
f"claude --resume {safe_session_id} {base_flags}"
)
else:
# New session with specific ID
safe_current_session_id = shlex.quote(current_session_id)
claude_command = (
f"cd {safe_working_dir} && "
f"echo {self._escape_prompt(prompt)} | "
f"claude --session-id {safe_current_session_id} {base_flags}{history_flag}"
)
# Capture timestamp before running Claude Code to filter files later
# Capture timestamp 1 second in the past to avoid race condition with file creation
timestamp_result = await sandbox.commands.run(
"date -u -d '1 second ago' +%Y-%m-%dT%H:%M:%S"
)
if timestamp_result.exit_code != 0:
raise RuntimeError(
f"Failed to capture timestamp: {timestamp_result.stderr}"
)
start_timestamp = (
timestamp_result.stdout.strip() if timestamp_result.stdout else None
)
result = await sandbox.commands.run(
claude_command,
timeout=0, # No command timeout - let sandbox timeout handle it
)
# Check for command failure
if result.exit_code != 0:
error_msg = result.stderr or result.stdout or "Unknown error"
raise Exception(
f"Claude Code command failed with exit code {result.exit_code}:\n"
f"{error_msg}"
)
raw_output = result.stdout or ""
sandbox_id = sandbox.sandbox_id
# Parse JSON output to extract response and build conversation history
response = ""
new_conversation_history = conversation_history or ""
try:
# The JSON output contains the result
output_data = json.loads(raw_output)
response = output_data.get("result", raw_output)
# Build conversation history entry
turn_entry = f"User: {prompt}\nClaude: {response}"
if new_conversation_history:
new_conversation_history = (
f"{new_conversation_history}\n\n{turn_entry}"
)
else:
new_conversation_history = turn_entry
except json.JSONDecodeError:
# If not valid JSON, use raw output
response = raw_output
turn_entry = f"User: {prompt}\nClaude: {response}"
if new_conversation_history:
new_conversation_history = (
f"{new_conversation_history}\n\n{turn_entry}"
)
else:
new_conversation_history = turn_entry
# Extract files created/modified during this run
files = await self._extract_files(
sandbox, working_directory, start_timestamp
)
return (
response,
files,
new_conversation_history,
current_session_id,
sandbox_id,
)
finally:
if dispose_sandbox and sandbox:
await sandbox.kill()
async def _extract_files(
self,
sandbox: BaseAsyncSandbox,
working_directory: str,
since_timestamp: str | None = None,
) -> list["ClaudeCodeBlock.FileOutput"]:
"""
Extract text files created/modified during this Claude Code execution.
Args:
sandbox: The E2B sandbox instance
working_directory: Directory to search for files
since_timestamp: ISO timestamp - only return files modified after this time
Returns:
List of FileOutput objects with path, relative_path, name, and content
"""
files: list[ClaudeCodeBlock.FileOutput] = []
# Text file extensions we can safely read as text
text_extensions = {
".txt",
".md",
".html",
".htm",
".css",
".js",
".ts",
".jsx",
".tsx",
".json",
".xml",
".yaml",
".yml",
".toml",
".ini",
".cfg",
".conf",
".py",
".rb",
".php",
".java",
".c",
".cpp",
".h",
".hpp",
".cs",
".go",
".rs",
".swift",
".kt",
".scala",
".sh",
".bash",
".zsh",
".sql",
".graphql",
".env",
".gitignore",
".dockerfile",
"Dockerfile",
".vue",
".svelte",
".astro",
".mdx",
".rst",
".tex",
".csv",
".log",
}
try:
# List files recursively using find command
# Exclude node_modules and .git directories, but allow hidden files
# like .env and .gitignore (they're filtered by text_extensions later)
# Filter by timestamp to only get files created/modified during this run
safe_working_dir = shlex.quote(working_directory)
timestamp_filter = ""
if since_timestamp:
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
find_result = await sandbox.commands.run(
f"find {safe_working_dir} -type f "
f"{timestamp_filter}"
f"-not -path '*/node_modules/*' "
f"-not -path '*/.git/*' "
f"2>/dev/null"
)
if find_result.stdout:
for file_path in find_result.stdout.strip().split("\n"):
if not file_path:
continue
# Check if it's a text file we can read
is_text = any(
file_path.endswith(ext) for ext in text_extensions
) or file_path.endswith("Dockerfile")
if is_text:
try:
content = await sandbox.files.read(file_path)
# Handle bytes or string
if isinstance(content, bytes):
content = content.decode("utf-8", errors="replace")
# Extract filename from path
file_name = file_path.split("/")[-1]
# Calculate relative path by stripping working directory
relative_path = file_path
if file_path.startswith(working_directory):
relative_path = file_path[len(working_directory) :]
# Remove leading slash if present
if relative_path.startswith("/"):
relative_path = relative_path[1:]
files.append(
ClaudeCodeBlock.FileOutput(
path=file_path,
relative_path=relative_path,
name=file_name,
content=content,
)
)
except Exception:
# Skip files that can't be read
pass
except Exception:
# If file extraction fails, return empty results
pass
return files
def _escape_prompt(self, prompt: str) -> str:
"""Escape the prompt for safe shell execution."""
# Use single quotes and escape any single quotes in the prompt
escaped = prompt.replace("'", "'\"'\"'")
return f"'{escaped}'"
async def run(
self,
input_data: Input,
*,
e2b_credentials: APIKeyCredentials,
anthropic_credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
try:
(
response,
files,
conversation_history,
session_id,
sandbox_id,
) = await self.execute_claude_code(
e2b_api_key=e2b_credentials.api_key.get_secret_value(),
anthropic_api_key=anthropic_credentials.api_key.get_secret_value(),
prompt=input_data.prompt,
timeout=input_data.timeout,
setup_commands=input_data.setup_commands,
working_directory=input_data.working_directory,
session_id=input_data.session_id,
existing_sandbox_id=input_data.sandbox_id,
conversation_history=input_data.conversation_history,
dispose_sandbox=input_data.dispose_sandbox,
)
yield "response", response
# Always yield files (empty list if none) to match Output schema
yield "files", [f.model_dump() for f in files]
# Always yield conversation_history so user can restore context on fresh sandbox
yield "conversation_history", conversation_history
# Always yield session_id so user can continue conversation
yield "session_id", session_id
# Always yield sandbox_id (None if disposed) to match Output schema
yield "sandbox_id", sandbox_id if not input_data.dispose_sandbox else None
except Exception as e:
yield "error", str(e)

View File

@@ -38,20 +38,6 @@ POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
if POOL_TIMEOUT:
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
# Add public schema to search_path for pgvector type access
# The vector extension is in public schema, but search_path is determined by schema parameter
# Extract the schema from DATABASE_URL or default to 'public' (matching get_database_schema())
parsed_url = urlparse(DATABASE_URL)
url_params = dict(parse_qsl(parsed_url.query))
db_schema = url_params.get("schema", "public")
# Build search_path, avoiding duplicates if db_schema is already 'public'
search_path_schemas = list(
dict.fromkeys([db_schema, "public"])
) # Preserves order, removes duplicates
search_path = ",".join(search_path_schemas)
# This allows using ::vector without schema qualification
DATABASE_URL = add_param(DATABASE_URL, "options", f"-c search_path={search_path}")
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
prisma = Prisma(
@@ -122,102 +108,21 @@ def get_database_schema() -> str:
return query_params.get("schema", "public")
async def _raw_with_schema(
query_template: str,
*args,
execute: bool = False,
client: Prisma | None = None,
set_public_search_path: bool = False,
) -> list[dict] | int:
"""Internal: Execute raw SQL with proper schema handling.
Use query_raw_with_schema() or execute_raw_with_schema() instead.
Args:
query_template: SQL query with {schema_prefix} placeholder
*args: Query parameters
execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE.
client: Optional Prisma client for transactions (only used when execute=True).
set_public_search_path: If True, sets search_path to include public schema.
Needed for pgvector types and other public schema objects.
Returns:
- list[dict] if execute=False (query results)
- int if execute=True (number of affected rows)
"""
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
"""Execute raw SQL query with proper schema handling."""
schema = get_database_schema()
schema_prefix = f'"{schema}".' if schema != "public" else ""
formatted_query = query_template.format(schema_prefix=schema_prefix)
import prisma as prisma_module
db_client = client if client else prisma_module.get_client()
# Set search_path to include public schema if requested
# Prisma doesn't support the 'options' connection parameter, so we set it per-session
# This is idempotent and safe to call multiple times
if set_public_search_path:
await db_client.execute_raw(f"SET search_path = {schema}, public") # type: ignore
if execute:
result = await db_client.execute_raw(formatted_query, *args) # type: ignore
else:
result = await db_client.query_raw(formatted_query, *args) # type: ignore
result = await prisma_module.get_client().query_raw(
formatted_query, *args # type: ignore
)
return result
async def query_raw_with_schema(
query_template: str, *args, set_public_search_path: bool = False
) -> list[dict]:
"""Execute raw SQL SELECT query with proper schema handling.
Args:
query_template: SQL query with {schema_prefix} placeholder
*args: Query parameters
set_public_search_path: If True, sets search_path to include public schema.
Needed for pgvector types and other public schema objects.
Returns:
List of result rows as dictionaries
Example:
results = await query_raw_with_schema(
'SELECT * FROM {schema_prefix}"User" WHERE id = $1',
user_id
)
"""
return await _raw_with_schema(query_template, *args, execute=False, set_public_search_path=set_public_search_path) # type: ignore
async def execute_raw_with_schema(
query_template: str,
*args,
client: Prisma | None = None,
set_public_search_path: bool = False,
) -> int:
"""Execute raw SQL command (INSERT/UPDATE/DELETE) with proper schema handling.
Args:
query_template: SQL query with {schema_prefix} placeholder
*args: Query parameters
client: Optional Prisma client for transactions
set_public_search_path: If True, sets search_path to include public schema.
Needed for pgvector types and other public schema objects.
Returns:
Number of affected rows
Example:
await execute_raw_with_schema(
'INSERT INTO {schema_prefix}"User" (id, name) VALUES ($1, $2)',
user_id, name,
client=tx # Optional transaction client
)
"""
return await _raw_with_schema(query_template, *args, execute=True, client=client, set_public_search_path=set_public_search_path) # type: ignore
class BaseDbModel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))

View File

@@ -804,7 +804,9 @@ class GraphModel(Graph):
)
class GraphMeta(GraphModel):
class GraphMeta(Graph):
user_id: str
# Easy work-around to prevent exposing nodes and links in the API response
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
links: list[Link] = Field(default=[], exclude=True)
@@ -814,6 +816,13 @@ class GraphMeta(GraphModel):
return GraphMeta(**graph.model_dump())
class GraphsPaginated(BaseModel):
"""Response schema for paginated graphs."""
graphs: list[GraphMeta]
pagination: Pagination
# --------------------- CRUD functions --------------------- #
@@ -847,7 +856,7 @@ async def list_graphs_paginated(
page: int = 1,
page_size: int = 25,
filter_by: Literal["active"] | None = "active",
) -> tuple[list[GraphMeta], Pagination]:
) -> GraphsPaginated:
"""
Retrieves paginated graph metadata objects.
@@ -858,8 +867,7 @@ async def list_graphs_paginated(
filter_by: An optional filter to either select graphs.
Returns:
list[GraphMeta]: List of graph info objects.
Pagination: Pagination information.
GraphsPaginated: Paginated list of graph metadata.
"""
where_clause: AgentGraphWhereInput = {"userId": user_id}
@@ -892,11 +900,14 @@ async def list_graphs_paginated(
logger.error(f"Error processing graph {graph.id}: {e}")
continue
return graph_models, Pagination(
total_items=total_count,
total_pages=total_pages,
current_page=page,
page_size=page_size,
return GraphsPaginated(
graphs=graph_models,
pagination=Pagination(
total_items=total_count,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)

View File

@@ -1,6 +1,5 @@
import json
from typing import Any
from unittest.mock import AsyncMock, patch
from uuid import UUID
import fastapi.exceptions
@@ -19,17 +18,6 @@ from backend.usecases.sample import create_test_user
from backend.util.test import SpinTestServer
@pytest.fixture(scope="session", autouse=True)
def mock_embedding_functions():
"""Mock embedding functions for all tests to avoid database/API dependencies."""
with patch(
"backend.api.features.store.db.ensure_embedding",
new_callable=AsyncMock,
return_value=True,
):
yield
@pytest.mark.asyncio(loop_scope="session")
async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot):
"""

View File

@@ -334,7 +334,7 @@ async def _get_user_timezone(user_id: str) -> str:
return get_user_timezone_or_utc(user.timezone if user else None)
async def increment_onboarding_runs(user_id: str):
async def increment_runs(user_id: str):
"""
Increment a user's run counters and trigger any onboarding milestones.
"""

View File

@@ -1,404 +0,0 @@
"""Data models and access layer for user business understanding."""
import logging
from datetime import datetime
from typing import Any, Optional, cast
import pydantic
from prisma.models import CoPilotUnderstanding
from backend.data.redis_client import get_redis_async
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
# Cache configuration
CACHE_KEY_PREFIX = "understanding"
CACHE_TTL_SECONDS = 48 * 60 * 60 # 48 hours
def _cache_key(user_id: str) -> str:
"""Generate cache key for user business understanding."""
return f"{CACHE_KEY_PREFIX}:{user_id}"
def _json_to_list(value: Any) -> list[str]:
"""Convert Json field to list[str], handling None."""
if value is None:
return []
if isinstance(value, list):
return cast(list[str], value)
return []
class BusinessUnderstandingInput(pydantic.BaseModel):
"""Input model for updating business understanding - all fields optional for incremental updates."""
# User info
user_name: Optional[str] = pydantic.Field(None, description="The user's name")
job_title: Optional[str] = pydantic.Field(None, description="The user's job title")
# Business basics
business_name: Optional[str] = pydantic.Field(
None, description="Name of the user's business"
)
industry: Optional[str] = pydantic.Field(None, description="Industry or sector")
business_size: Optional[str] = pydantic.Field(
None, description="Company size (e.g., '1-10', '11-50')"
)
user_role: Optional[str] = pydantic.Field(
None,
description="User's role in the organization (e.g., 'decision maker', 'implementer')",
)
# Processes & activities
key_workflows: Optional[list[str]] = pydantic.Field(
None, description="Key business workflows"
)
daily_activities: Optional[list[str]] = pydantic.Field(
None, description="Daily activities performed"
)
# Pain points & goals
pain_points: Optional[list[str]] = pydantic.Field(
None, description="Current pain points"
)
bottlenecks: Optional[list[str]] = pydantic.Field(
None, description="Process bottlenecks"
)
manual_tasks: Optional[list[str]] = pydantic.Field(
None, description="Manual/repetitive tasks"
)
automation_goals: Optional[list[str]] = pydantic.Field(
None, description="Desired automation goals"
)
# Current tools
current_software: Optional[list[str]] = pydantic.Field(
None, description="Software/tools currently used"
)
existing_automation: Optional[list[str]] = pydantic.Field(
None, description="Existing automations"
)
# Additional context
additional_notes: Optional[str] = pydantic.Field(
None, description="Any additional context"
)
class BusinessUnderstanding(pydantic.BaseModel):
"""Full business understanding model returned from database."""
id: str
user_id: str
created_at: datetime
updated_at: datetime
# User info
user_name: Optional[str] = None
job_title: Optional[str] = None
# Business basics
business_name: Optional[str] = None
industry: Optional[str] = None
business_size: Optional[str] = None
user_role: Optional[str] = None
# Processes & activities
key_workflows: list[str] = pydantic.Field(default_factory=list)
daily_activities: list[str] = pydantic.Field(default_factory=list)
# Pain points & goals
pain_points: list[str] = pydantic.Field(default_factory=list)
bottlenecks: list[str] = pydantic.Field(default_factory=list)
manual_tasks: list[str] = pydantic.Field(default_factory=list)
automation_goals: list[str] = pydantic.Field(default_factory=list)
# Current tools
current_software: list[str] = pydantic.Field(default_factory=list)
existing_automation: list[str] = pydantic.Field(default_factory=list)
# Additional context
additional_notes: Optional[str] = None
@classmethod
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
"""Convert database record to Pydantic model."""
data = db_record.data if isinstance(db_record.data, dict) else {}
business = (
data.get("business", {}) if isinstance(data.get("business"), dict) else {}
)
return cls(
id=db_record.id,
user_id=db_record.userId,
created_at=db_record.createdAt,
updated_at=db_record.updatedAt,
user_name=data.get("name"),
job_title=business.get("job_title"),
business_name=business.get("business_name"),
industry=business.get("industry"),
business_size=business.get("business_size"),
user_role=business.get("user_role"),
key_workflows=_json_to_list(business.get("key_workflows")),
daily_activities=_json_to_list(business.get("daily_activities")),
pain_points=_json_to_list(business.get("pain_points")),
bottlenecks=_json_to_list(business.get("bottlenecks")),
manual_tasks=_json_to_list(business.get("manual_tasks")),
automation_goals=_json_to_list(business.get("automation_goals")),
current_software=_json_to_list(business.get("current_software")),
existing_automation=_json_to_list(business.get("existing_automation")),
additional_notes=business.get("additional_notes"),
)
def _merge_lists(existing: list | None, new: list | None) -> list | None:
"""Merge two lists, removing duplicates while preserving order."""
if new is None:
return existing
if existing is None:
return new
# Preserve order, add new items that don't exist
merged = list(existing)
for item in new:
if item not in merged:
merged.append(item)
return merged
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
"""Get business understanding from Redis cache."""
try:
redis = await get_redis_async()
cached_data = await redis.get(_cache_key(user_id))
if cached_data:
return BusinessUnderstanding.model_validate_json(cached_data)
except Exception as e:
logger.warning(f"Failed to get understanding from cache: {e}")
return None
async def _set_cache(user_id: str, understanding: BusinessUnderstanding) -> None:
"""Set business understanding in Redis cache with TTL."""
try:
redis = await get_redis_async()
await redis.setex(
_cache_key(user_id),
CACHE_TTL_SECONDS,
understanding.model_dump_json(),
)
except Exception as e:
logger.warning(f"Failed to set understanding in cache: {e}")
async def _delete_cache(user_id: str) -> None:
"""Delete business understanding from Redis cache."""
try:
redis = await get_redis_async()
await redis.delete(_cache_key(user_id))
except Exception as e:
logger.warning(f"Failed to delete understanding from cache: {e}")
async def get_business_understanding(
user_id: str,
) -> Optional[BusinessUnderstanding]:
"""Get the business understanding for a user.
Checks cache first, falls back to database if not cached.
Results are cached for 48 hours.
"""
# Try cache first
cached = await _get_from_cache(user_id)
if cached:
logger.debug(f"Business understanding cache hit for user {user_id}")
return cached
# Cache miss - load from database
logger.debug(f"Business understanding cache miss for user {user_id}")
record = await CoPilotUnderstanding.prisma().find_unique(where={"userId": user_id})
if record is None:
return None
understanding = BusinessUnderstanding.from_db(record)
# Store in cache for next time
await _set_cache(user_id, understanding)
return understanding
async def upsert_business_understanding(
user_id: str,
input_data: BusinessUnderstandingInput,
) -> BusinessUnderstanding:
"""
Create or update business understanding with incremental merge strategy.
- String fields: new value overwrites if provided (not None)
- List fields: new items are appended to existing (deduplicated)
Data is stored as: {name: ..., business: {version: 1, ...}}
"""
# Get existing record for merge
existing = await CoPilotUnderstanding.prisma().find_unique(
where={"userId": user_id}
)
# Get existing data structure or start fresh
existing_data: dict[str, Any] = {}
if existing and isinstance(existing.data, dict):
existing_data = dict(existing.data)
existing_business: dict[str, Any] = {}
if isinstance(existing_data.get("business"), dict):
existing_business = dict(existing_data["business"])
# Business fields (stored inside business object)
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
# Handle top-level name field
if input_data.user_name is not None:
existing_data["name"] = input_data.user_name
# Business string fields - overwrite if provided
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
existing_business[field] = value
# Business list fields - merge with existing
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(existing_business.get(field))
merged = _merge_lists(existing_list, value)
existing_business[field] = merged
# Set version and nest business data
existing_business["version"] = 1
existing_data["business"] = existing_business
# Upsert with the merged data
record = await CoPilotUnderstanding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(existing_data)},
"update": {"data": SafeJson(existing_data)},
},
)
understanding = BusinessUnderstanding.from_db(record)
# Update cache with new understanding
await _set_cache(user_id, understanding)
return understanding
async def clear_business_understanding(user_id: str) -> bool:
"""Clear/delete business understanding for a user from both DB and cache."""
# Delete from cache first
await _delete_cache(user_id)
try:
await CoPilotUnderstanding.prisma().delete(where={"userId": user_id})
return True
except Exception:
# Record might not exist
return False
def format_understanding_for_prompt(understanding: BusinessUnderstanding) -> str:
"""Format business understanding as text for system prompt injection."""
sections = []
# User info section
user_info = []
if understanding.user_name:
user_info.append(f"Name: {understanding.user_name}")
if understanding.job_title:
user_info.append(f"Job Title: {understanding.job_title}")
if user_info:
sections.append("## User\n" + "\n".join(user_info))
# Business section
business_info = []
if understanding.business_name:
business_info.append(f"Company: {understanding.business_name}")
if understanding.industry:
business_info.append(f"Industry: {understanding.industry}")
if understanding.business_size:
business_info.append(f"Size: {understanding.business_size}")
if understanding.user_role:
business_info.append(f"Role Context: {understanding.user_role}")
if business_info:
sections.append("## Business\n" + "\n".join(business_info))
# Processes section
processes = []
if understanding.key_workflows:
processes.append(f"Key Workflows: {', '.join(understanding.key_workflows)}")
if understanding.daily_activities:
processes.append(
f"Daily Activities: {', '.join(understanding.daily_activities)}"
)
if processes:
sections.append("## Processes\n" + "\n".join(processes))
# Pain points section
pain_points = []
if understanding.pain_points:
pain_points.append(f"Pain Points: {', '.join(understanding.pain_points)}")
if understanding.bottlenecks:
pain_points.append(f"Bottlenecks: {', '.join(understanding.bottlenecks)}")
if understanding.manual_tasks:
pain_points.append(f"Manual Tasks: {', '.join(understanding.manual_tasks)}")
if pain_points:
sections.append("## Pain Points\n" + "\n".join(pain_points))
# Goals section
if understanding.automation_goals:
sections.append(
"## Automation Goals\n"
+ "\n".join(f"- {goal}" for goal in understanding.automation_goals)
)
# Current tools section
tools_info = []
if understanding.current_software:
tools_info.append(
f"Current Software: {', '.join(understanding.current_software)}"
)
if understanding.existing_automation:
tools_info.append(
f"Existing Automation: {', '.join(understanding.existing_automation)}"
)
if tools_info:
sections.append("## Current Tools\n" + "\n".join(tools_info))
# Additional notes
if understanding.additional_notes:
sections.append(f"## Additional Context\n{understanding.additional_notes}")
if not sections:
return ""
return "# User Business Context\n\n" + "\n\n".join(sections)

View File

@@ -7,11 +7,6 @@ from backend.api.features.library.db import (
list_library_agents,
)
from backend.api.features.store.db import get_store_agent_details, get_store_agents
from backend.api.features.store.embeddings import (
backfill_missing_embeddings,
cleanup_orphaned_embeddings,
get_embedding_stats,
)
from backend.data import db
from backend.data.analytics import (
get_accuracy_trends_and_alerts,
@@ -25,7 +20,6 @@ from backend.data.execution import (
get_execution_kv_data,
get_execution_outputs_by_node_exec_id,
get_frequently_executed_graphs,
get_graph_execution,
get_graph_execution_meta,
get_graph_executions,
get_graph_executions_count,
@@ -63,7 +57,6 @@ from backend.data.notifications import (
get_user_notification_oldest_message_in_batch,
remove_notifications_from_batch,
)
from backend.data.onboarding import increment_onboarding_runs
from backend.data.user import (
get_active_user_ids_in_timerange,
get_user_by_id,
@@ -147,7 +140,6 @@ class DatabaseManager(AppService):
get_child_graph_executions = _(get_child_graph_executions)
get_graph_executions = _(get_graph_executions)
get_graph_executions_count = _(get_graph_executions_count)
get_graph_execution = _(get_graph_execution)
get_graph_execution_meta = _(get_graph_execution_meta)
create_graph_execution = _(create_graph_execution)
get_node_execution = _(get_node_execution)
@@ -212,18 +204,10 @@ class DatabaseManager(AppService):
add_store_agent_to_library = _(add_store_agent_to_library)
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
# Onboarding
increment_onboarding_runs = _(increment_onboarding_runs)
# Store
get_store_agents = _(get_store_agents)
get_store_agent_details = _(get_store_agent_details)
# Store Embeddings
get_embedding_stats = _(get_embedding_stats)
backfill_missing_embeddings = _(backfill_missing_embeddings)
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
# Summary data - async
get_user_execution_summary_data = _(get_user_execution_summary_data)
@@ -275,11 +259,6 @@ class DatabaseManagerClient(AppServiceClient):
get_store_agents = _(d.get_store_agents)
get_store_agent_details = _(d.get_store_agent_details)
# Store Embeddings
get_embedding_stats = _(d.get_embedding_stats)
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
cleanup_orphaned_embeddings = _(d.cleanup_orphaned_embeddings)
class DatabaseManagerAsyncClient(AppServiceClient):
d = DatabaseManager
@@ -295,7 +274,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_graph = d.get_graph
get_graph_metadata = d.get_graph_metadata
get_graph_settings = d.get_graph_settings
get_graph_execution = d.get_graph_execution
get_graph_execution_meta = d.get_graph_execution_meta
get_node = d.get_node
get_node_execution = d.get_node_execution
@@ -340,9 +318,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
add_store_agent_to_library = d.add_store_agent_to_library
validate_graph_execution_permissions = d.validate_graph_execution_permissions
# Onboarding
increment_onboarding_runs = d.increment_onboarding_runs
# Store
get_store_agents = d.get_store_agents
get_store_agent_details = d.get_store_agent_details

View File

@@ -1,5 +1,4 @@
import logging
from unittest.mock import AsyncMock, patch
import fastapi.responses
import pytest
@@ -20,17 +19,6 @@ from backend.util.test import SpinTestServer, wait_execution
logger = logging.getLogger(__name__)
@pytest.fixture(scope="session", autouse=True)
def mock_embedding_functions():
"""Mock embedding functions for all tests to avoid database/API dependencies."""
with patch(
"backend.api.features.store.db.ensure_embedding",
new_callable=AsyncMock,
return_value=True,
):
yield
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
logger.info(f"Creating graph for user {u.id}")
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)

View File

@@ -2,7 +2,6 @@ import asyncio
import logging
import os
import threading
import time
import uuid
from enum import Enum
from typing import Optional
@@ -28,7 +27,7 @@ from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.block import BlockInput
from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_onboarding_runs
from backend.data.onboarding import increment_runs
from backend.executor import utils as execution_utils
from backend.monitoring import (
NotificationJobArgs,
@@ -38,7 +37,7 @@ from backend.monitoring import (
report_execution_accuracy_alerts,
report_late_executions,
)
from backend.util.clients import get_database_manager_client, get_scheduler_client
from backend.util.clients import get_scheduler_client
from backend.util.cloud_storage import cleanup_expired_files_async
from backend.util.exceptions import (
GraphNotFoundError,
@@ -157,7 +156,7 @@ async def _execute_graph(**kwargs):
inputs=args.input_data,
graph_credentials_inputs=args.input_credentials,
)
await increment_onboarding_runs(args.user_id)
await increment_runs(args.user_id)
elapsed = asyncio.get_event_loop().time() - start_time
logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
@@ -255,114 +254,6 @@ def execution_accuracy_alerts():
return report_execution_accuracy_alerts()
def ensure_embeddings_coverage():
"""
Ensure all content types (store agents, blocks, docs) have embeddings for search.
Processes ALL missing embeddings in batches of 10 per content type until 100% coverage.
Missing embeddings = content invisible in hybrid search.
Schedule: Runs every 6 hours (balanced between coverage and API costs).
- Catches new content added between scheduled runs
- Batch size 10 per content type: gradual processing to avoid rate limits
- Manual trigger available via execute_ensure_embeddings_coverage endpoint
"""
db_client = get_database_manager_client()
stats = db_client.get_embedding_stats()
# Check for error from get_embedding_stats() first
if "error" in stats:
logger.error(
f"Failed to get embedding stats: {stats['error']} - skipping backfill"
)
return {
"backfill": {"processed": 0, "success": 0, "failed": 0},
"cleanup": {"deleted": 0},
"error": stats["error"],
}
# Extract totals from new stats structure
totals = stats.get("totals", {})
without_embeddings = totals.get("without_embeddings", 0)
coverage_percent = totals.get("coverage_percent", 0)
total_processed = 0
total_success = 0
total_failed = 0
if without_embeddings == 0:
logger.info("All content has embeddings, skipping backfill")
else:
# Log per-content-type stats for visibility
by_type = stats.get("by_type", {})
for content_type, type_stats in by_type.items():
if type_stats.get("without_embeddings", 0) > 0:
logger.info(
f"{content_type}: {type_stats['without_embeddings']} items without embeddings "
f"({type_stats['coverage_percent']}% coverage)"
)
logger.info(
f"Total: {without_embeddings} items without embeddings "
f"({coverage_percent}% coverage) - processing all"
)
# Process in batches until no more missing embeddings
while True:
result = db_client.backfill_missing_embeddings(batch_size=10)
total_processed += result["processed"]
total_success += result["success"]
total_failed += result["failed"]
if result["processed"] == 0:
# No more missing embeddings
break
if result["success"] == 0 and result["processed"] > 0:
# All attempts in this batch failed - stop to avoid infinite loop
logger.error(
f"All {result['processed']} embedding attempts failed - stopping backfill"
)
break
# Small delay between batches to avoid rate limits
time.sleep(1)
logger.info(
f"Embedding backfill completed: {total_success}/{total_processed} succeeded, "
f"{total_failed} failed"
)
# Clean up orphaned embeddings for blocks and docs
logger.info("Running cleanup for orphaned embeddings (blocks/docs)...")
cleanup_result = db_client.cleanup_orphaned_embeddings()
cleanup_totals = cleanup_result.get("totals", {})
cleanup_deleted = cleanup_totals.get("deleted", 0)
if cleanup_deleted > 0:
logger.info(f"Cleanup completed: deleted {cleanup_deleted} orphaned embeddings")
by_type = cleanup_result.get("by_type", {})
for content_type, type_result in by_type.items():
if type_result.get("deleted", 0) > 0:
logger.info(
f"{content_type}: deleted {type_result['deleted']} orphaned embeddings"
)
else:
logger.info("Cleanup completed: no orphaned embeddings found")
return {
"backfill": {
"processed": total_processed,
"success": total_success,
"failed": total_failed,
},
"cleanup": {
"deleted": cleanup_deleted,
},
}
# Monitoring functions are now imported from monitoring module
@@ -584,19 +475,6 @@ class Scheduler(AppService):
jobstore=Jobstores.EXECUTION.value,
)
# Embedding Coverage - Every 6 hours
# Ensures all approved agents have embeddings for hybrid search
# Critical: missing embeddings = agents invisible in search
self.scheduler.add_job(
ensure_embeddings_coverage,
id="ensure_embeddings_coverage",
trigger="interval",
hours=6,
replace_existing=True,
max_instances=1, # Prevent overlapping runs
jobstore=Jobstores.EXECUTION.value,
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
@@ -754,11 +632,6 @@ class Scheduler(AppService):
"""Manually trigger execution accuracy alert checking."""
return execution_accuracy_alerts()
@expose
def execute_ensure_embeddings_coverage(self):
"""Manually trigger embedding backfill for approved store agents."""
return ensure_embeddings_coverage()
class SchedulerClient(AppServiceClient):
@classmethod

View File

@@ -10,7 +10,6 @@ from pydantic import BaseModel, JsonValue, ValidationError
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import onboarding as onboarding_db
from backend.data import user as user_db
from backend.data.block import (
Block,
@@ -32,6 +31,7 @@ from backend.data.execution import (
GraphExecutionStats,
GraphExecutionWithNodes,
NodesInputMasks,
get_graph_execution,
)
from backend.data.graph import GraphModel, Node
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput
@@ -809,14 +809,13 @@ async def add_graph_execution(
edb = execution_db
udb = user_db
gdb = graph_db
odb = onboarding_db
else:
edb = udb = gdb = odb = get_database_manager_async_client()
edb = udb = gdb = get_database_manager_async_client()
# Get or create the graph execution
if graph_exec_id:
# Resume existing execution
graph_exec = await edb.get_graph_execution(
graph_exec = await get_graph_execution(
user_id=user_id,
execution_id=graph_exec_id,
include_node_executions=True,
@@ -892,7 +891,6 @@ async def add_graph_execution(
)
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
# Publish to execution queue for executor to pick up
exec_queue = await get_async_execution_queue()
await exec_queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
@@ -901,12 +899,14 @@ async def add_graph_execution(
)
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
# Update execution status to QUEUED
graph_exec.status = ExecutionStatus.QUEUED
await edb.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=graph_exec.status,
)
await get_async_execution_event_bus().publish(graph_exec)
return graph_exec
except BaseException as e:
err = str(e) or type(e).__name__
if not graph_exec:
@@ -927,24 +927,6 @@ async def add_graph_execution(
)
raise
try:
await get_async_execution_event_bus().publish(graph_exec)
logger.info(f"Published update for execution #{graph_exec.id} to event bus")
except Exception as e:
logger.error(
f"Failed to publish execution event for graph exec #{graph_exec.id}: {e}"
)
try:
await odb.increment_onboarding_runs(user_id)
logger.info(
f"Incremented user #{user_id} onboarding runs for exec #{graph_exec.id}"
)
except Exception as e:
logger.error(f"Failed to increment onboarding runs for user #{user_id}: {e}")
return graph_exec
# ============ Execution Output Helpers ============ #

View File

@@ -245,21 +245,6 @@ DEFAULT_CREDENTIALS = [
webshare_proxy_credentials,
]
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
# Set of providers that have system credentials available
SYSTEM_PROVIDERS = {cred.provider for cred in DEFAULT_CREDENTIALS}
def is_system_credential(credential_id: str) -> bool:
"""Check if a credential ID belongs to a system-managed credential."""
return credential_id in SYSTEM_CREDENTIAL_IDS
def is_system_provider(provider: str) -> bool:
"""Check if a provider has system-managed credentials available."""
return provider in SYSTEM_PROVIDERS
class IntegrationCredentialsStore:
def __init__(self):

View File

@@ -10,7 +10,6 @@ from backend.util.settings import Settings
settings = Settings()
if TYPE_CHECKING:
from openai import AsyncOpenAI
from supabase import AClient, Client
from backend.data.execution import (
@@ -140,24 +139,6 @@ async def get_async_supabase() -> "AClient":
)
# ============ OpenAI Client ============ #
@cached(ttl_seconds=3600)
def get_openai_client() -> "AsyncOpenAI | None":
"""
Get a process-cached async OpenAI client for embeddings.
Returns None if API key is not configured.
"""
from openai import AsyncOpenAI
api_key = settings.secrets.openai_internal_api_key
if not api_key:
return None
return AsyncOpenAI(api_key=api_key)
# ============ Notification Queue Helpers ============ #

View File

@@ -658,14 +658,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
# Langfuse prompt management
langfuse_public_key: str = Field(default="", description="Langfuse public key")
langfuse_secret_key: str = Field(default="", description="Langfuse secret key")
langfuse_host: str = Field(
default="https://cloud.langfuse.com", description="Langfuse host URL"
)
# Add more secret fields as needed
model_config = SettingsConfigDict(
env_file=".env",

View File

@@ -1,48 +0,0 @@
-- CreateExtension
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
-- Create in public schema so vector type is available across all schemas
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "vector" WITH SCHEMA "public";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'vector extension not available or already exists, skipping';
END $$;
-- CreateEnum
CREATE TYPE "ContentType" AS ENUM ('STORE_AGENT', 'BLOCK', 'INTEGRATION', 'DOCUMENTATION', 'LIBRARY_AGENT');
-- CreateTable
CREATE TABLE "UnifiedContentEmbedding" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"contentType" "ContentType" NOT NULL,
"contentId" TEXT NOT NULL,
"userId" TEXT,
"embedding" public.vector(1536) NOT NULL,
"searchableText" TEXT NOT NULL,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "UnifiedContentEmbedding_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "UnifiedContentEmbedding_contentType_idx" ON "UnifiedContentEmbedding"("contentType");
-- CreateIndex
CREATE INDEX "UnifiedContentEmbedding_userId_idx" ON "UnifiedContentEmbedding"("userId");
-- CreateIndex
CREATE INDEX "UnifiedContentEmbedding_contentType_userId_idx" ON "UnifiedContentEmbedding"("contentType", "userId");
-- CreateIndex
-- NULLS NOT DISTINCT ensures only one public (NULL userId) embedding per contentType+contentId
-- Requires PostgreSQL 15+. Supabase uses PostgreSQL 15+.
CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" ON "UnifiedContentEmbedding"("contentType", "contentId", "userId") NULLS NOT DISTINCT;
-- CreateIndex
-- HNSW index for fast vector similarity search on embeddings
-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py
-- Note: Drop first in case Prisma created a btree index (Prisma doesn't support HNSW)
DROP INDEX IF EXISTS "UnifiedContentEmbedding_embedding_idx";
CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops);

View File

@@ -1,71 +0,0 @@
-- Acknowledge Supabase-managed extensions to prevent drift warnings
-- These extensions are pre-installed by Supabase in specific schemas
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
-- Create schemas (safe in both CI and Supabase)
CREATE SCHEMA IF NOT EXISTS "extensions";
-- Extensions that exist in both CI and Supabase
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pgcrypto extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'uuid-ossp extension not available, skipping';
END $$;
-- Supabase-specific extensions (skip gracefully in CI)
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pg_net extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pgjwt extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE SCHEMA IF NOT EXISTS "graphql";
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pg_graphql extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE SCHEMA IF NOT EXISTS "pgsodium";
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pgsodium extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE SCHEMA IF NOT EXISTS "vault";
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'supabase_vault extension not available, skipping';
END $$;
-- Return to platform
CREATE SCHEMA IF NOT EXISTS "platform";

View File

@@ -1,64 +0,0 @@
-- CreateTable
CREATE TABLE "CoPilotUnderstanding" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT NOT NULL,
"data" JSONB,
CONSTRAINT "CoPilotUnderstanding_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "ChatSession" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT NOT NULL,
"title" TEXT,
"credentials" JSONB NOT NULL DEFAULT '{}',
"successfulAgentRuns" JSONB NOT NULL DEFAULT '{}',
"successfulAgentSchedules" JSONB NOT NULL DEFAULT '{}',
"totalPromptTokens" INTEGER NOT NULL DEFAULT 0,
"totalCompletionTokens" INTEGER NOT NULL DEFAULT 0,
CONSTRAINT "ChatSession_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "ChatMessage" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"sessionId" TEXT NOT NULL,
"role" TEXT NOT NULL,
"content" TEXT,
"name" TEXT,
"toolCallId" TEXT,
"refusal" TEXT,
"toolCalls" JSONB,
"functionCall" JSONB,
"sequence" INTEGER NOT NULL,
CONSTRAINT "ChatMessage_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "CoPilotUnderstanding_userId_key" ON "CoPilotUnderstanding"("userId");
-- CreateIndex
CREATE INDEX "CoPilotUnderstanding_userId_idx" ON "CoPilotUnderstanding"("userId");
-- CreateIndex
CREATE INDEX "ChatSession_userId_updatedAt_idx" ON "ChatSession"("userId", "updatedAt");
-- CreateIndex
CREATE UNIQUE INDEX "ChatMessage_sessionId_sequence_key" ON "ChatMessage"("sessionId", "sequence");
-- AddForeignKey
ALTER TABLE "CoPilotUnderstanding" ADD CONSTRAINT "CoPilotUnderstanding_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ChatSession" ADD CONSTRAINT "ChatSession_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ChatMessage" ADD CONSTRAINT "ChatMessage_sessionId_fkey" FOREIGN KEY ("sessionId") REFERENCES "ChatSession"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -1,35 +0,0 @@
-- Add tsvector search column to UnifiedContentEmbedding for unified full-text search
-- This enables hybrid search (semantic + lexical) across all content types
-- Add search column (IF NOT EXISTS for idempotency)
ALTER TABLE "UnifiedContentEmbedding" ADD COLUMN IF NOT EXISTS "search" tsvector DEFAULT ''::tsvector;
-- Create GIN index for fast full-text search
-- No @@index in schema.prisma - Prisma may generate DROP INDEX on migrate dev
-- If that happens, just let it drop and this migration will recreate it, or manually re-run:
-- CREATE INDEX IF NOT EXISTS "UnifiedContentEmbedding_search_idx" ON "UnifiedContentEmbedding" USING GIN ("search");
DROP INDEX IF EXISTS "UnifiedContentEmbedding_search_idx";
CREATE INDEX "UnifiedContentEmbedding_search_idx" ON "UnifiedContentEmbedding" USING GIN ("search");
-- Drop existing trigger/function if exists
DROP TRIGGER IF EXISTS "update_unified_tsvector" ON "UnifiedContentEmbedding";
DROP FUNCTION IF EXISTS update_unified_tsvector_column();
-- Create function to auto-update tsvector from searchableText
CREATE OR REPLACE FUNCTION update_unified_tsvector_column() RETURNS TRIGGER AS $$
BEGIN
NEW.search := to_tsvector('english', COALESCE(NEW."searchableText", ''));
RETURN NEW;
END;
$$ LANGUAGE plpgsql SECURITY DEFINER SET search_path = platform, pg_temp;
-- Create trigger to auto-update search column on insert/update
CREATE TRIGGER "update_unified_tsvector"
BEFORE INSERT OR UPDATE ON "UnifiedContentEmbedding"
FOR EACH ROW
EXECUTE FUNCTION update_unified_tsvector_column();
-- Backfill existing rows
UPDATE "UnifiedContentEmbedding"
SET search = to_tsvector('english', COALESCE("searchableText", ''))
WHERE search IS NULL OR search = ''::tsvector;

View File

@@ -1,90 +0,0 @@
-- Remove the old search column from StoreListingVersion
-- This column has been replaced by UnifiedContentEmbedding.search
-- which provides unified hybrid search across all content types
-- First drop the dependent view
DROP VIEW IF EXISTS "StoreAgent";
-- Drop the trigger and function for old search column
-- The original trigger was created in 20251016093049_add_full_text_search
DROP TRIGGER IF EXISTS "update_tsvector" ON "StoreListingVersion";
DROP FUNCTION IF EXISTS update_tsvector_column();
-- Drop the index
DROP INDEX IF EXISTS "StoreListingVersion_search_idx";
-- NOTE: Keeping search column for now to allow easy revert if needed
-- Uncomment to fully remove once migration is verified in production:
-- ALTER TABLE "StoreListingVersion" DROP COLUMN IF EXISTS "search";
-- Recreate the StoreAgent view WITHOUT the search column
-- (Search now handled by UnifiedContentEmbedding)
CREATE OR REPLACE VIEW "StoreAgent" AS
WITH latest_versions AS (
SELECT
"storeListingId",
MAX(version) AS max_version
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
),
agent_versions AS (
SELECT
"storeListingId",
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
),
agent_graph_versions AS (
SELECT
"storeListingId",
array_agg(DISTINCT "agentGraphVersion"::text ORDER BY "agentGraphVersion"::text) AS graph_versions
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
)
SELECT
sl.id AS listing_id,
slv.id AS "storeListingVersionId",
slv."createdAt" AS updated_at,
sl.slug,
COALESCE(slv.name, '') AS agent_name,
slv."videoUrl" AS agent_video,
slv."agentOutputDemoUrl" AS agent_output_demo,
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
slv."isFeatured" AS featured,
p.username AS creator_username,
p."avatarUrl" AS creator_avatar,
slv."subHeading" AS sub_heading,
slv.description,
slv.categories,
COALESCE(ar.run_count, 0::bigint) AS runs,
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions,
COALESCE(agv.graph_versions, ARRAY[slv."agentGraphVersion"::text]) AS "agentGraphVersions",
slv."agentGraphId",
slv."isAvailable" AS is_available,
COALESCE(sl."useForOnboarding", false) AS "useForOnboarding"
FROM "StoreListing" sl
JOIN latest_versions lv
ON sl.id = lv."storeListingId"
JOIN "StoreListingVersion" slv
ON slv."storeListingId" = lv."storeListingId"
AND slv.version = lv.max_version
AND slv."submissionStatus" = 'APPROVED'
JOIN "AgentGraph" a
ON slv."agentGraphId" = a.id
AND slv."agentGraphVersion" = a.version
LEFT JOIN "Profile" p
ON sl."owningUserId" = p."userId"
LEFT JOIN "mv_review_stats" rs
ON sl.id = rs."storeListingId"
LEFT JOIN "mv_agent_run_counts" ar
ON a.id = ar."agentGraphId"
LEFT JOIN agent_versions av
ON sl.id = av."storeListingId"
LEFT JOIN agent_graph_versions agv
ON sl.id = agv."storeListingId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true;

View File

@@ -2777,30 +2777,6 @@ enabler = ["pytest-enabler (>=2.2)"]
test = ["pyfakefs", "pytest (>=6,!=8.1.*)"]
type = ["pygobject-stubs", "pytest-mypy", "shtab", "types-pywin32"]
[[package]]
name = "langfuse"
version = "3.11.2"
description = "A client library for accessing langfuse"
optional = false
python-versions = "<4.0,>=3.10"
groups = ["main"]
files = [
{file = "langfuse-3.11.2-py3-none-any.whl", hash = "sha256:84faea9f909694023cc7f0eb45696be190248c8790424f22af57ca4cd7a29f2d"},
{file = "langfuse-3.11.2.tar.gz", hash = "sha256:ab5f296a8056815b7288c7f25bc308a5e79f82a8634467b25daffdde99276e09"},
]
[package.dependencies]
backoff = ">=1.10.0"
httpx = ">=0.15.4,<1.0"
openai = ">=0.27.8"
opentelemetry-api = ">=1.33.1,<2.0.0"
opentelemetry-exporter-otlp-proto-http = ">=1.33.1,<2.0.0"
opentelemetry-sdk = ">=1.33.1,<2.0.0"
packaging = ">=23.2,<26.0"
pydantic = ">=1.10.7,<3.0"
requests = ">=2,<3"
wrapt = ">=1.14,<2.0"
[[package]]
name = "launchdarkly-eventsource"
version = "1.3.0"
@@ -3492,90 +3468,6 @@ files = [
importlib-metadata = ">=6.0,<8.8.0"
typing-extensions = ">=4.5.0"
[[package]]
name = "opentelemetry-exporter-otlp-proto-common"
version = "1.35.0"
description = "OpenTelemetry Protobuf encoding"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "opentelemetry_exporter_otlp_proto_common-1.35.0-py3-none-any.whl", hash = "sha256:863465de697ae81279ede660f3918680b4480ef5f69dcdac04f30722ed7b74cc"},
{file = "opentelemetry_exporter_otlp_proto_common-1.35.0.tar.gz", hash = "sha256:6f6d8c39f629b9fa5c79ce19a2829dbd93034f8ac51243cdf40ed2196f00d7eb"},
]
[package.dependencies]
opentelemetry-proto = "1.35.0"
[[package]]
name = "opentelemetry-exporter-otlp-proto-http"
version = "1.35.0"
description = "OpenTelemetry Collector Protobuf over HTTP Exporter"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "opentelemetry_exporter_otlp_proto_http-1.35.0-py3-none-any.whl", hash = "sha256:9a001e3df3c7f160fb31056a28ed7faa2de7df68877ae909516102ae36a54e1d"},
{file = "opentelemetry_exporter_otlp_proto_http-1.35.0.tar.gz", hash = "sha256:cf940147f91b450ef5f66e9980d40eb187582eed399fa851f4a7a45bb880de79"},
]
[package.dependencies]
googleapis-common-protos = ">=1.52,<2.0"
opentelemetry-api = ">=1.15,<2.0"
opentelemetry-exporter-otlp-proto-common = "1.35.0"
opentelemetry-proto = "1.35.0"
opentelemetry-sdk = ">=1.35.0,<1.36.0"
requests = ">=2.7,<3.0"
typing-extensions = ">=4.5.0"
[[package]]
name = "opentelemetry-proto"
version = "1.35.0"
description = "OpenTelemetry Python Proto"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "opentelemetry_proto-1.35.0-py3-none-any.whl", hash = "sha256:98fffa803164499f562718384e703be8d7dfbe680192279a0429cb150a2f8809"},
{file = "opentelemetry_proto-1.35.0.tar.gz", hash = "sha256:532497341bd3e1c074def7c5b00172601b28bb83b48afc41a4b779f26eb4ee05"},
]
[package.dependencies]
protobuf = ">=5.0,<7.0"
[[package]]
name = "opentelemetry-sdk"
version = "1.35.0"
description = "OpenTelemetry Python SDK"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "opentelemetry_sdk-1.35.0-py3-none-any.whl", hash = "sha256:223d9e5f5678518f4842311bb73966e0b6db5d1e0b74e35074c052cd2487f800"},
{file = "opentelemetry_sdk-1.35.0.tar.gz", hash = "sha256:2a400b415ab68aaa6f04e8a6a9f6552908fb3090ae2ff78d6ae0c597ac581954"},
]
[package.dependencies]
opentelemetry-api = "1.35.0"
opentelemetry-semantic-conventions = "0.56b0"
typing-extensions = ">=4.5.0"
[[package]]
name = "opentelemetry-semantic-conventions"
version = "0.56b0"
description = "OpenTelemetry Semantic Conventions"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "opentelemetry_semantic_conventions-0.56b0-py3-none-any.whl", hash = "sha256:df44492868fd6b482511cc43a942e7194be64e94945f572db24df2e279a001a2"},
{file = "opentelemetry_semantic_conventions-0.56b0.tar.gz", hash = "sha256:c114c2eacc8ff6d3908cb328c811eaf64e6d68623840be9224dc829c4fd6c2ea"},
]
[package.dependencies]
opentelemetry-api = "1.35.0"
typing-extensions = ">=4.5.0"
[[package]]
name = "orjson"
version = "3.11.3"
@@ -7030,97 +6922,6 @@ files = [
{file = "websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee"},
]
[[package]]
name = "wrapt"
version = "1.17.3"
description = "Module for decorators, wrappers and monkey patching."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "wrapt-1.17.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:88bbae4d40d5a46142e70d58bf664a89b6b4befaea7b2ecc14e03cedb8e06c04"},
{file = "wrapt-1.17.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6b13af258d6a9ad602d57d889f83b9d5543acd471eee12eb51f5b01f8eb1bc2"},
{file = "wrapt-1.17.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd341868a4b6714a5962c1af0bd44f7c404ef78720c7de4892901e540417111c"},
{file = "wrapt-1.17.3-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f9b2601381be482f70e5d1051a5965c25fb3625455a2bf520b5a077b22afb775"},
{file = "wrapt-1.17.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343e44b2a8e60e06a7e0d29c1671a0d9951f59174f3709962b5143f60a2a98bd"},
{file = "wrapt-1.17.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:33486899acd2d7d3066156b03465b949da3fd41a5da6e394ec49d271baefcf05"},
{file = "wrapt-1.17.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e6f40a8aa5a92f150bdb3e1c44b7e98fb7113955b2e5394122fa5532fec4b418"},
{file = "wrapt-1.17.3-cp310-cp310-win32.whl", hash = "sha256:a36692b8491d30a8c75f1dfee65bef119d6f39ea84ee04d9f9311f83c5ad9390"},
{file = "wrapt-1.17.3-cp310-cp310-win_amd64.whl", hash = "sha256:afd964fd43b10c12213574db492cb8f73b2f0826c8df07a68288f8f19af2ebe6"},
{file = "wrapt-1.17.3-cp310-cp310-win_arm64.whl", hash = "sha256:af338aa93554be859173c39c85243970dc6a289fa907402289eeae7543e1ae18"},
{file = "wrapt-1.17.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:273a736c4645e63ac582c60a56b0acb529ef07f78e08dc6bfadf6a46b19c0da7"},
{file = "wrapt-1.17.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5531d911795e3f935a9c23eb1c8c03c211661a5060aab167065896bbf62a5f85"},
{file = "wrapt-1.17.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0610b46293c59a3adbae3dee552b648b984176f8562ee0dba099a56cfbe4df1f"},
{file = "wrapt-1.17.3-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b32888aad8b6e68f83a8fdccbf3165f5469702a7544472bdf41f582970ed3311"},
{file = "wrapt-1.17.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cccf4f81371f257440c88faed6b74f1053eef90807b77e31ca057b2db74edb1"},
{file = "wrapt-1.17.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8a210b158a34164de8bb68b0e7780041a903d7b00c87e906fb69928bf7890d5"},
{file = "wrapt-1.17.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:79573c24a46ce11aab457b472efd8d125e5a51da2d1d24387666cd85f54c05b2"},
{file = "wrapt-1.17.3-cp311-cp311-win32.whl", hash = "sha256:c31eebe420a9a5d2887b13000b043ff6ca27c452a9a22fa71f35f118e8d4bf89"},
{file = "wrapt-1.17.3-cp311-cp311-win_amd64.whl", hash = "sha256:0b1831115c97f0663cb77aa27d381237e73ad4f721391a9bfb2fe8bc25fa6e77"},
{file = "wrapt-1.17.3-cp311-cp311-win_arm64.whl", hash = "sha256:5a7b3c1ee8265eb4c8f1b7d29943f195c00673f5ab60c192eba2d4a7eae5f46a"},
{file = "wrapt-1.17.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ab232e7fdb44cdfbf55fc3afa31bcdb0d8980b9b95c38b6405df2acb672af0e0"},
{file = "wrapt-1.17.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9baa544e6acc91130e926e8c802a17f3b16fbea0fd441b5a60f5cf2cc5c3deba"},
{file = "wrapt-1.17.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b538e31eca1a7ea4605e44f81a48aa24c4632a277431a6ed3f328835901f4fd"},
{file = "wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:042ec3bb8f319c147b1301f2393bc19dba6e176b7da446853406d041c36c7828"},
{file = "wrapt-1.17.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3af60380ba0b7b5aeb329bc4e402acd25bd877e98b3727b0135cb5c2efdaefe9"},
{file = "wrapt-1.17.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b02e424deef65c9f7326d8c19220a2c9040c51dc165cddb732f16198c168396"},
{file = "wrapt-1.17.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:74afa28374a3c3a11b3b5e5fca0ae03bef8450d6aa3ab3a1e2c30e3a75d023dc"},
{file = "wrapt-1.17.3-cp312-cp312-win32.whl", hash = "sha256:4da9f45279fff3543c371d5ababc57a0384f70be244de7759c85a7f989cb4ebe"},
{file = "wrapt-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:e71d5c6ebac14875668a1e90baf2ea0ef5b7ac7918355850c0908ae82bcb297c"},
{file = "wrapt-1.17.3-cp312-cp312-win_arm64.whl", hash = "sha256:604d076c55e2fdd4c1c03d06dc1a31b95130010517b5019db15365ec4a405fc6"},
{file = "wrapt-1.17.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a47681378a0439215912ef542c45a783484d4dd82bac412b71e59cf9c0e1cea0"},
{file = "wrapt-1.17.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:54a30837587c6ee3cd1a4d1c2ec5d24e77984d44e2f34547e2323ddb4e22eb77"},
{file = "wrapt-1.17.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ecf15d6af39246fe33e507105d67e4b81d8f8d2c6598ff7e3ca1b8a37213f7"},
{file = "wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6fd1ad24dc235e4ab88cda009e19bf347aabb975e44fd5c2fb22a3f6e4141277"},
{file = "wrapt-1.17.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ed61b7c2d49cee3c027372df5809a59d60cf1b6c2f81ee980a091f3afed6a2d"},
{file = "wrapt-1.17.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:423ed5420ad5f5529db9ce89eac09c8a2f97da18eb1c870237e84c5a5c2d60aa"},
{file = "wrapt-1.17.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e01375f275f010fcbf7f643b4279896d04e571889b8a5b3f848423d91bf07050"},
{file = "wrapt-1.17.3-cp313-cp313-win32.whl", hash = "sha256:53e5e39ff71b3fc484df8a522c933ea2b7cdd0d5d15ae82e5b23fde87d44cbd8"},
{file = "wrapt-1.17.3-cp313-cp313-win_amd64.whl", hash = "sha256:1f0b2f40cf341ee8cc1a97d51ff50dddb9fcc73241b9143ec74b30fc4f44f6cb"},
{file = "wrapt-1.17.3-cp313-cp313-win_arm64.whl", hash = "sha256:7425ac3c54430f5fc5e7b6f41d41e704db073309acfc09305816bc6a0b26bb16"},
{file = "wrapt-1.17.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:cf30f6e3c077c8e6a9a7809c94551203c8843e74ba0c960f4a98cd80d4665d39"},
{file = "wrapt-1.17.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e228514a06843cae89621384cfe3a80418f3c04aadf8a3b14e46a7be704e4235"},
{file = "wrapt-1.17.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:5ea5eb3c0c071862997d6f3e02af1d055f381b1d25b286b9d6644b79db77657c"},
{file = "wrapt-1.17.3-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:281262213373b6d5e4bb4353bc36d1ba4084e6d6b5d242863721ef2bf2c2930b"},
{file = "wrapt-1.17.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dc4a8d2b25efb6681ecacad42fca8859f88092d8732b170de6a5dddd80a1c8fa"},
{file = "wrapt-1.17.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:373342dd05b1d07d752cecbec0c41817231f29f3a89aa8b8843f7b95992ed0c7"},
{file = "wrapt-1.17.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d40770d7c0fd5cbed9d84b2c3f2e156431a12c9a37dc6284060fb4bec0b7ffd4"},
{file = "wrapt-1.17.3-cp314-cp314-win32.whl", hash = "sha256:fbd3c8319de8e1dc79d346929cd71d523622da527cca14e0c1d257e31c2b8b10"},
{file = "wrapt-1.17.3-cp314-cp314-win_amd64.whl", hash = "sha256:e1a4120ae5705f673727d3253de3ed0e016f7cd78dc463db1b31e2463e1f3cf6"},
{file = "wrapt-1.17.3-cp314-cp314-win_arm64.whl", hash = "sha256:507553480670cab08a800b9463bdb881b2edeed77dc677b0a5915e6106e91a58"},
{file = "wrapt-1.17.3-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:ed7c635ae45cfbc1a7371f708727bf74690daedc49b4dba310590ca0bd28aa8a"},
{file = "wrapt-1.17.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:249f88ed15503f6492a71f01442abddd73856a0032ae860de6d75ca62eed8067"},
{file = "wrapt-1.17.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5a03a38adec8066d5a37bea22f2ba6bbf39fcdefbe2d91419ab864c3fb515454"},
{file = "wrapt-1.17.3-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5d4478d72eb61c36e5b446e375bbc49ed002430d17cdec3cecb36993398e1a9e"},
{file = "wrapt-1.17.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223db574bb38637e8230eb14b185565023ab624474df94d2af18f1cdb625216f"},
{file = "wrapt-1.17.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e405adefb53a435f01efa7ccdec012c016b5a1d3f35459990afc39b6be4d5056"},
{file = "wrapt-1.17.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:88547535b787a6c9ce4086917b6e1d291aa8ed914fdd3a838b3539dc95c12804"},
{file = "wrapt-1.17.3-cp314-cp314t-win32.whl", hash = "sha256:41b1d2bc74c2cac6f9074df52b2efbef2b30bdfe5f40cb78f8ca22963bc62977"},
{file = "wrapt-1.17.3-cp314-cp314t-win_amd64.whl", hash = "sha256:73d496de46cd2cdbdbcce4ae4bcdb4afb6a11234a1df9c085249d55166b95116"},
{file = "wrapt-1.17.3-cp314-cp314t-win_arm64.whl", hash = "sha256:f38e60678850c42461d4202739f9bf1e3a737c7ad283638251e79cc49effb6b6"},
{file = "wrapt-1.17.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:70d86fa5197b8947a2fa70260b48e400bf2ccacdcab97bb7de47e3d1e6312225"},
{file = "wrapt-1.17.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:df7d30371a2accfe4013e90445f6388c570f103d61019b6b7c57e0265250072a"},
{file = "wrapt-1.17.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:caea3e9c79d5f0d2c6d9ab96111601797ea5da8e6d0723f77eabb0d4068d2b2f"},
{file = "wrapt-1.17.3-cp38-cp38-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:758895b01d546812d1f42204bd443b8c433c44d090248bf22689df673ccafe00"},
{file = "wrapt-1.17.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:02b551d101f31694fc785e58e0720ef7d9a10c4e62c1c9358ce6f63f23e30a56"},
{file = "wrapt-1.17.3-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:656873859b3b50eeebe6db8b1455e99d90c26ab058db8e427046dbc35c3140a5"},
{file = "wrapt-1.17.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:a9a2203361a6e6404f80b99234fe7fb37d1fc73487b5a78dc1aa5b97201e0f22"},
{file = "wrapt-1.17.3-cp38-cp38-win32.whl", hash = "sha256:55cbbc356c2842f39bcc553cf695932e8b30e30e797f961860afb308e6b1bb7c"},
{file = "wrapt-1.17.3-cp38-cp38-win_amd64.whl", hash = "sha256:ad85e269fe54d506b240d2d7b9f5f2057c2aa9a2ea5b32c66f8902f768117ed2"},
{file = "wrapt-1.17.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:30ce38e66630599e1193798285706903110d4f057aab3168a34b7fdc85569afc"},
{file = "wrapt-1.17.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:65d1d00fbfb3ea5f20add88bbc0f815150dbbde3b026e6c24759466c8b5a9ef9"},
{file = "wrapt-1.17.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a7c06742645f914f26c7f1fa47b8bc4c91d222f76ee20116c43d5ef0912bba2d"},
{file = "wrapt-1.17.3-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7e18f01b0c3e4a07fe6dfdb00e29049ba17eadbc5e7609a2a3a4af83ab7d710a"},
{file = "wrapt-1.17.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0f5f51a6466667a5a356e6381d362d259125b57f059103dd9fdc8c0cf1d14139"},
{file = "wrapt-1.17.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:59923aa12d0157f6b82d686c3fd8e1166fa8cdfb3e17b42ce3b6147ff81528df"},
{file = "wrapt-1.17.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:46acc57b331e0b3bcb3e1ca3b421d65637915cfcd65eb783cb2f78a511193f9b"},
{file = "wrapt-1.17.3-cp39-cp39-win32.whl", hash = "sha256:3e62d15d3cfa26e3d0788094de7b64efa75f3a53875cdbccdf78547aed547a81"},
{file = "wrapt-1.17.3-cp39-cp39-win_amd64.whl", hash = "sha256:1f23fa283f51c890eda8e34e4937079114c74b4c81d2b2f1f1d94948f5cc3d7f"},
{file = "wrapt-1.17.3-cp39-cp39-win_arm64.whl", hash = "sha256:24c2ed34dc222ed754247a2702b1e1e89fdbaa4016f324b4b8f1a802d4ffe87f"},
{file = "wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22"},
{file = "wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0"},
]
[[package]]
name = "xattr"
version = "1.2.0"
@@ -7494,4 +7295,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "86838b5ae40d606d6e01a14dad8a56c389d890d7a6a0c274a6602cca80f0df84"
content-hash = "a93ba0cea3b465cb6ec3e3f258b383b09f84ea352ccfdbfa112902cde5653fc6"

View File

@@ -33,7 +33,6 @@ html2text = "^2024.2.26"
jinja2 = "^3.1.6"
jsonref = "^1.1.0"
jsonschema = "^4.25.0"
langfuse = "^3.11.0"
launchdarkly-server-sdk = "^9.12.0"
mem0ai = "^0.1.115"
moviepy = "^2.1.2"

View File

@@ -1,15 +1,14 @@
datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
directUrl = env("DIRECT_URL")
extensions = [pgvector(map: "vector")]
provider = "postgresql"
url = env("DATABASE_URL")
directUrl = env("DIRECT_URL")
}
generator client {
provider = "prisma-client-py"
recursive_type_depth = -1
interface = "asyncio"
previewFeatures = ["views", "fullTextSearch", "postgresqlExtensions"]
previewFeatures = ["views", "fullTextSearch"]
partial_type_generator = "backend/data/partial_types.py"
}
@@ -48,13 +47,12 @@ model User {
AnalyticsMetrics AnalyticsMetrics[]
CreditTransactions CreditTransaction[]
UserBalance UserBalance?
ChatSessions ChatSession[]
AgentPresets AgentPreset[]
LibraryAgents LibraryAgent[]
AgentPresets AgentPreset[]
LibraryAgents LibraryAgent[]
Profile Profile[]
UserOnboarding UserOnboarding?
CoPilotUnderstanding CoPilotUnderstanding?
BuilderSearchHistory BuilderSearchHistory[]
StoreListings StoreListing[]
StoreListingReviews StoreListingReview[]
@@ -123,84 +121,19 @@ model UserOnboarding {
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
}
model CoPilotUnderstanding {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
userId String @unique
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
data Json?
@@index([userId])
}
model BuilderSearchHistory {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
searchQuery String
filter String[] @default([])
byCreator String[] @default([])
filter String[] @default([])
byCreator String[] @default([])
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
//////////////// CHAT SESSION TABLES ///////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
model ChatSession {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
// Session metadata
title String?
credentials Json @default("{}") // Map of provider -> credential metadata
// Rate limiting counters (stored as JSON maps)
successfulAgentRuns Json @default("{}") // Map of graph_id -> count
successfulAgentSchedules Json @default("{}") // Map of graph_id -> count
// Usage tracking
totalPromptTokens Int @default(0)
totalCompletionTokens Int @default(0)
Messages ChatMessage[]
@@index([userId, updatedAt])
}
model ChatMessage {
id String @id @default(uuid())
createdAt DateTime @default(now())
sessionId String
Session ChatSession @relation(fields: [sessionId], references: [id], onDelete: Cascade)
// Message content
role String // "user", "assistant", "system", "tool", "function"
content String?
name String?
toolCallId String?
refusal String?
toolCalls Json? // List of tool calls for assistant messages
functionCall Json? // Deprecated but kept for compatibility
// Ordering within session
sequence Int
@@unique([sessionId, sequence])
}
// This model describes the Agent Graph/Flow (Multi Agent System).
model AgentGraph {
id String @default(uuid())
@@ -788,25 +721,26 @@ view StoreAgent {
storeListingVersionId String
updated_at DateTime
slug String
agent_name String
agent_video String?
agent_output_demo String?
agent_image String[]
slug String
agent_name String
agent_video String?
agent_output_demo String?
agent_image String[]
featured Boolean @default(false)
creator_username String?
creator_avatar String?
sub_heading String
description String
categories String[]
runs Int
rating Float
versions String[]
agentGraphVersions String[]
agentGraphId String
is_available Boolean @default(true)
useForOnboarding Boolean @default(false)
featured Boolean @default(false)
creator_username String?
creator_avatar String?
sub_heading String
description String
categories String[]
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
runs Int
rating Float
versions String[]
agentGraphVersions String[]
agentGraphId String
is_available Boolean @default(true)
useForOnboarding Boolean @default(false)
// Materialized views used (refreshed every 15 minutes via pg_cron):
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
@@ -922,14 +856,14 @@ model StoreListingVersion {
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
// Content fields
name String
subHeading String
videoUrl String?
agentOutputDemoUrl String?
imageUrls String[]
description String
instructions String?
categories String[]
name String
subHeading String
videoUrl String?
agentOutputDemoUrl String?
imageUrls String[]
description String
instructions String?
categories String[]
isFeatured Boolean @default(false)
@@ -937,7 +871,7 @@ model StoreListingVersion {
// Old versions can be made unavailable by the author if desired
isAvailable Boolean @default(true)
// Note: search column removed - now using UnifiedContentEmbedding.search
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
// Version workflow state
submissionStatus SubmissionStatus @default(DRAFT)
@@ -965,9 +899,6 @@ model StoreListingVersion {
// Reviews for this specific version
Reviews StoreListingReview[]
// Note: Embeddings now stored in UnifiedContentEmbedding table
// Use contentType=STORE_AGENT and contentId=storeListingVersionId
@@unique([storeListingId, version])
@@index([storeListingId, submissionStatus, isAvailable])
@@index([submissionStatus])
@@ -975,45 +906,6 @@ model StoreListingVersion {
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
}
// Content type enum for unified search across store agents, blocks, docs
// Note: BLOCK/INTEGRATION are file-based (Python classes), not DB records
// DOCUMENTATION are file-based (.md files), not DB records
// Only STORE_AGENT and LIBRARY_AGENT are stored in database
enum ContentType {
STORE_AGENT // Database: StoreListingVersion
BLOCK // File-based: Python classes in /backend/blocks/
INTEGRATION // File-based: Python classes (blocks with credentials)
DOCUMENTATION // File-based: .md/.mdx files
LIBRARY_AGENT // Database: User's personal agents
}
// Unified embeddings table for all searchable content types
// Supports both public content (userId=null) and user-specific content (userId=userID)
model UnifiedContentEmbedding {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
// Content identification
contentType ContentType
contentId String // DB ID (storeListingVersionId) or file identifier (block.id, file_path)
userId String? // NULL for public content (store, blocks, docs), userId for private content (library agents)
// Search data
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
searchableText String // Combined text for search and fallback
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) // Full-text search (auto-populated by trigger)
metadata Json @default("{}") // Content-specific metadata
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
@@index([contentType])
@@index([userId])
@@index([contentType, userId])
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
// NO @@index for search - GIN index "UnifiedContentEmbedding_search_idx" created via SQL migration
// Prisma may generate DROP INDEX on migrate dev - that's okay, migration recreates it
}
model StoreListingReview {
id String @id @default(uuid())
createdAt DateTime @default(now())
@@ -1040,31 +932,16 @@ enum SubmissionStatus {
}
enum APIKeyPermission {
// Legacy v1 permissions (kept for backward compatibility)
IDENTITY // Info about the authenticated user
EXECUTE_GRAPH // Can execute agent graphs (v1 only)
EXECUTE_GRAPH // Can execute agent graphs
READ_GRAPH // Can get graph versions and details
EXECUTE_BLOCK // Can execute individual blocks (v1 only)
EXECUTE_BLOCK // Can execute individual blocks
READ_BLOCK // Can get block information
READ_STORE // Can read store/marketplace agents and creators
USE_TOOLS // Can use chat tools via external API (v1 only)
MANAGE_INTEGRATIONS // Can initiate OAuth flows and complete them (v1 only)
READ_STORE // Can read store agents and creators
USE_TOOLS // Can use chat tools via external API
MANAGE_INTEGRATIONS // Can initiate OAuth flows and complete them
READ_INTEGRATIONS // Can list credentials and providers
DELETE_INTEGRATIONS // Can delete credentials (v1 only)
// V2 permissions
WRITE_GRAPH // Can create, update, delete graphs
READ_SCHEDULE // Can list schedules
WRITE_SCHEDULE // Can create and delete schedules
WRITE_STORE // Can create, update, delete marketplace submissions
READ_LIBRARY // Can list library agents and runs
RUN_AGENT // Can run agents from library
READ_RUN // Can list and get run details
WRITE_RUN // Can stop and delete runs
READ_RUN_REVIEW // Can list pending human-in-the-loop reviews
WRITE_RUN_REVIEW // Can submit human-in-the-loop review responses
READ_CREDITS // Can get credit balance and transactions
UPLOAD_FILES // Can upload files for agent input
DELETE_INTEGRATIONS // Can delete credentials
}
model APIKey {
@@ -1121,16 +998,16 @@ model OAuthApplication {
updatedAt DateTime @updatedAt
// Application metadata
name String
description String?
logoUrl String? // URL to app logo stored in GCS
clientId String @unique
clientSecret String // Hashed with Scrypt (same as API keys)
clientSecretSalt String // Salt for Scrypt hashing
name String
description String?
logoUrl String? // URL to app logo stored in GCS
clientId String @unique
clientSecret String // Hashed with Scrypt (same as API keys)
clientSecretSalt String // Salt for Scrypt hashing
// OAuth configuration
redirectUris String[] // Allowed callback URLs
grantTypes String[] @default(["authorization_code", "refresh_token"])
grantTypes String[] @default(["authorization_code", "refresh_token"])
scopes APIKeyPermission[] // Which permissions the app can request
// Application management

View File

@@ -708,7 +708,10 @@ export function CreateButton() {
## 🧪 Testing & Storybook
- See `TESTING.md` for Playwright setup, E2E data seeding, and Storybook usage.
- End-to-end: [Playwright](https://playwright.dev/docs/intro) (`pnpm test`, `pnpm test-ui`)
- [Storybook](https://storybook.js.org/docs) for isolated UI development (`pnpm storybook` / `pnpm build-storybook`)
- For Storybook tests in CI, see [`@storybook/test-runner`](https://storybook.js.org/docs/writing-tests/test-runner) (`test-storybook:ci`)
- When changing components in `src/components`, update or add stories and visually verify in Storybook/Chromatic
---

View File

@@ -5,7 +5,6 @@ This is the frontend for AutoGPT's next generation
This project uses [**pnpm**](https://pnpm.io/) as the package manager via **corepack**. [Corepack](https://github.com/nodejs/corepack) is a Node.js tool that automatically manages package managers without requiring global installations.
For architecture, conventions, data fetching, feature flags, design system usage, state management, and PR process, see [CONTRIBUTING.md](./CONTRIBUTING.md).
For Playwright and Storybook testing setup, see [TESTING.md](./TESTING.md).
### Prerequisites

View File

@@ -1,57 +0,0 @@
# Frontend Testing 🧪
## Quick Start (local) 🚀
1. Start the backend + Supabase stack:
- From `autogpt_platform`: `docker compose --profile local up deps_backend -d`
- Or run the full stack: `docker compose up -d`
2. Seed rich E2E data (creates `test123@gmail.com` with library agents):
- From `autogpt_platform/backend`: `poetry run python test/e2e_test_data.py`
3. Run Playwright:
- From `autogpt_platform/frontend`: `pnpm test` or `pnpm test-ui`
## How Playwright setup works 🎭
- Playwright runs from `frontend/playwright.config.ts` with a global setup step.
- The global setup creates a user pool via the real signup UI and stores it in `frontend/.auth/user-pool.json`.
- Most tests call `getTestUser()` (from `src/tests/utils/auth.ts`) which pulls a random user from that pool.
- these users do not contain library agents, it's user that just "signed up" on the platform, hence some tests to make use of users created via script (see below) with more data
## Test users 👤
- **User pool (basic users)**
Created automatically by the Playwright global setup through `/signup`.
Used by `getTestUser()` in `src/tests/utils/auth.ts`.
- **Rich user with library agents**
Created by `backend/test/e2e_test_data.py`.
Accessed via `getTestUserWithLibraryAgents()` in `src/tests/credentials/index.ts`.
Use the rich user when a test needs existing library agents (e.g. `library.spec.ts`).
## Resetting or wiping the DB 🔁
If you reset the Docker DB and logins start failing:
1. Delete `frontend/.auth/user-pool.json` so the pool is regenerated.
2. Re-run the E2E data script to recreate the rich user + library agents:
- `poetry run python test/e2e_test_data.py`
## Storybook 📚
## Flow diagram 🗺️
```mermaid
flowchart TD
A[Start Docker stack] --> B[Run e2e_test_data.py]
B --> C[Run Playwright tests]
C --> D[Global setup creates user pool]
D --> E{Test needs rich data?}
E -->|No| F[getTestUser from user pool]
E -->|Yes| G[getTestUserWithLibraryAgents]
```
- `pnpm storybook` Run Storybook locally
- `pnpm build-storybook` Build a static Storybook
- CI runner: `pnpm test-storybook`
- When changing components in `src/components`, update or add stories and verify in Storybook/Chromatic.

View File

@@ -3,13 +3,6 @@ import { withSentryConfig } from "@sentry/nextjs";
/** @type {import('next').NextConfig} */
const nextConfig = {
productionBrowserSourceMaps: true,
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
serverExternalPackages: [
"@opentelemetry/instrumentation",
"@opentelemetry/sdk-node",
"import-in-the-middle",
"require-in-the-middle",
],
experimental: {
serverActions: {
bodySizeLimit: "256mb",

View File

@@ -32,7 +32,6 @@
"@hookform/resolvers": "5.2.2",
"@next/third-parties": "15.4.6",
"@phosphor-icons/react": "2.1.10",
"@radix-ui/react-accordion": "1.2.12",
"@radix-ui/react-alert-dialog": "1.1.15",
"@radix-ui/react-avatar": "1.1.10",
"@radix-ui/react-checkbox": "1.3.3",
@@ -118,7 +117,6 @@
},
"devDependencies": {
"@chromatic-com/storybook": "4.1.2",
"@opentelemetry/instrumentation": "0.209.0",
"@playwright/test": "1.56.1",
"@storybook/addon-a11y": "9.1.5",
"@storybook/addon-docs": "9.1.5",
@@ -142,7 +140,6 @@
"eslint": "8.57.1",
"eslint-config-next": "15.5.7",
"eslint-plugin-storybook": "9.1.5",
"import-in-the-middle": "2.0.2",
"msw": "2.11.6",
"msw-storybook-addon": "2.0.6",
"orval": "7.13.0",
@@ -150,7 +147,7 @@
"postcss": "8.5.6",
"prettier": "3.6.2",
"prettier-plugin-tailwindcss": "0.7.1",
"require-in-the-middle": "8.0.1",
"require-in-the-middle": "7.5.2",
"storybook": "9.1.5",
"tailwindcss": "3.4.17",
"typescript": "5.9.3"
@@ -160,10 +157,5 @@
"public"
]
},
"pnpm": {
"overrides": {
"@opentelemetry/instrumentation": "0.209.0"
}
},
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
}

View File

@@ -4,9 +4,6 @@ settings:
autoInstallPeers: true
excludeLinksFromLockfile: false
overrides:
'@opentelemetry/instrumentation': 0.209.0
importers:
.:
@@ -23,9 +20,6 @@ importers:
'@phosphor-icons/react':
specifier: 2.1.10
version: 2.1.10(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@radix-ui/react-accordion':
specifier: 1.2.12
version: 1.2.12(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@radix-ui/react-alert-dialog':
specifier: 1.1.15
version: 1.1.15(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
@@ -276,9 +270,6 @@ importers:
'@chromatic-com/storybook':
specifier: 4.1.2
version: 4.1.2(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))
'@opentelemetry/instrumentation':
specifier: 0.209.0
version: 0.209.0(@opentelemetry/api@1.9.0)
'@playwright/test':
specifier: 1.56.1
version: 1.56.1
@@ -348,9 +339,6 @@ importers:
eslint-plugin-storybook:
specifier: 9.1.5
version: 9.1.5(eslint@8.57.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))(typescript@5.9.3)
import-in-the-middle:
specifier: 2.0.2
version: 2.0.2
msw:
specifier: 2.11.6
version: 2.11.6(@types/node@24.10.0)(typescript@5.9.3)
@@ -373,8 +361,8 @@ importers:
specifier: 0.7.1
version: 0.7.1(prettier@3.6.2)
require-in-the-middle:
specifier: 8.0.1
version: 8.0.1
specifier: 7.5.2
version: 7.5.2
storybook:
specifier: 9.1.5
version: 9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)
@@ -1555,8 +1543,8 @@ packages:
'@open-draft/until@2.1.0':
resolution: {integrity: sha512-U69T3ItWHvLwGg5eJ0n3I62nWuE6ilHlmz7zM0npLBRvPRd7e6NYmg54vvRtP5mZG7kZqZCFVdsTWo7BPtBujg==}
'@opentelemetry/api-logs@0.209.0':
resolution: {integrity: sha512-xomnUNi7TiAGtOgs0tb54LyrjRZLu9shJGGwkcN7NgtiPYOpNnKLkRJtzZvTjD/w6knSZH9sFZcUSUovYOPg6A==}
'@opentelemetry/api-logs@0.208.0':
resolution: {integrity: sha512-CjruKY9V6NMssL/T1kAFgzosF1v9o6oeN+aX5JB/C/xPNtmgIJqcXHG7fA82Ou1zCpWGl4lROQUKwUNE1pMCyg==}
engines: {node: '>=8.0.0'}
'@opentelemetry/api@1.9.0':
@@ -1707,8 +1695,8 @@ packages:
peerDependencies:
'@opentelemetry/api': ^1.7.0
'@opentelemetry/instrumentation@0.209.0':
resolution: {integrity: sha512-Cwe863ojTCnFlxVuuhG7s6ODkAOzKsAEthKAcI4MDRYz1OmGWYnmSl4X2pbyS+hBxVTdvfZePfoEA01IjqcEyw==}
'@opentelemetry/instrumentation@0.208.0':
resolution: {integrity: sha512-Eju0L4qWcQS+oXxi6pgh7zvE2byogAkcsVv0OjHF/97iOz1N/aKE6etSGowYkie+YA1uo6DNwdSxaaNnLvcRlA==}
engines: {node: ^18.19.0 || >=20.6.0}
peerDependencies:
'@opentelemetry/api': ^1.3.0
@@ -1822,19 +1810,6 @@ packages:
'@radix-ui/primitive@1.1.3':
resolution: {integrity: sha512-JTF99U/6XIjCBo0wqkU5sK10glYe27MRRsfwoiq5zzOEZLHU3A3KCMa5X/azekYRCJ0HlwI0crAXS/5dEHTzDg==}
'@radix-ui/react-accordion@1.2.12':
resolution: {integrity: sha512-T4nygeh9YE9dLRPhAHSeOZi7HBXo+0kYIPJXayZfvWOWA0+n3dESrZbjfDPUABkUNym6Hd+f2IR113To8D2GPA==}
peerDependencies:
'@types/react': '*'
'@types/react-dom': '*'
react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc
react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc
peerDependenciesMeta:
'@types/react':
optional: true
'@types/react-dom':
optional: true
'@radix-ui/react-alert-dialog@1.1.15':
resolution: {integrity: sha512-oTVLkEw5GpdRe29BqJ0LSDFWI3qu0vR1M0mUkOQWDIUnY/QIkLpgDMWuKxP94c2NAC2LGcgVhG1ImF3jkZ5wXw==}
peerDependencies:
@@ -2656,7 +2631,7 @@ packages:
'@opentelemetry/api': ^1.9.0
'@opentelemetry/context-async-hooks': ^1.30.1 || ^2.1.0 || ^2.2.0
'@opentelemetry/core': ^1.30.1 || ^2.1.0 || ^2.2.0
'@opentelemetry/instrumentation': 0.209.0
'@opentelemetry/instrumentation': '>=0.57.1 <1'
'@opentelemetry/resources': ^1.30.1 || ^2.1.0 || ^2.2.0
'@opentelemetry/sdk-trace-base': ^1.30.1 || ^2.1.0 || ^2.2.0
'@opentelemetry/semantic-conventions': ^1.37.0
@@ -4982,8 +4957,8 @@ packages:
resolution: {integrity: sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==}
engines: {node: '>=6'}
import-in-the-middle@2.0.2:
resolution: {integrity: sha512-qet/hkGt3EbNGVtbDfPu0BM+tCqBS8wT1SYrstPaDKoWtshsC6licOemz7DVtpBEyvDNzo8UTBf9/GwWuSDZ9w==}
import-in-the-middle@2.0.1:
resolution: {integrity: sha512-bruMpJ7xz+9jwGzrwEhWgvRrlKRYCRDBrfU+ur3FcasYXLJDxTruJ//8g2Noj+QFyRBeqbpj8Bhn4Fbw6HjvhA==}
imurmurhash@0.1.4:
resolution: {integrity: sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==}
@@ -6527,6 +6502,10 @@ packages:
resolution: {integrity: sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==}
engines: {node: '>=0.10.0'}
require-in-the-middle@7.5.2:
resolution: {integrity: sha512-gAZ+kLqBdHarXB64XpAe2VCjB7rIRv+mU8tfRWziHRJ5umKsIHN2tLLv6EtMw7WCdP19S0ERVMldNvxYCHnhSQ==}
engines: {node: '>=8.6.0'}
require-in-the-middle@8.0.1:
resolution: {integrity: sha512-QT7FVMXfWOYFbeRBF6nu+I6tr2Tf3u0q8RIEjNob/heKY/nh7drD/k7eeMFmSQgnTtCzLDcCu/XEnpW2wk4xCQ==}
engines: {node: '>=9.3.0 || >=8.10.0 <9.0.0'}
@@ -8737,7 +8716,7 @@ snapshots:
'@open-draft/until@2.1.0': {}
'@opentelemetry/api-logs@0.209.0':
'@opentelemetry/api-logs@0.208.0':
dependencies:
'@opentelemetry/api': 1.9.0
@@ -8756,7 +8735,7 @@ snapshots:
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
transitivePeerDependencies:
- supports-color
@@ -8764,7 +8743,7 @@ snapshots:
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/semantic-conventions': 1.38.0
'@types/connect': 3.4.38
transitivePeerDependencies:
@@ -8773,7 +8752,7 @@ snapshots:
'@opentelemetry/instrumentation-dataloader@0.26.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
transitivePeerDependencies:
- supports-color
@@ -8781,7 +8760,7 @@ snapshots:
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/semantic-conventions': 1.38.0
transitivePeerDependencies:
- supports-color
@@ -8790,21 +8769,21 @@ snapshots:
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
transitivePeerDependencies:
- supports-color
'@opentelemetry/instrumentation-generic-pool@0.52.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
transitivePeerDependencies:
- supports-color
'@opentelemetry/instrumentation-graphql@0.56.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
transitivePeerDependencies:
- supports-color
@@ -8812,7 +8791,7 @@ snapshots:
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/semantic-conventions': 1.38.0
transitivePeerDependencies:
- supports-color
@@ -8821,7 +8800,7 @@ snapshots:
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/semantic-conventions': 1.38.0
forwarded-parse: 2.1.2
transitivePeerDependencies:
@@ -8830,7 +8809,7 @@ snapshots:
'@opentelemetry/instrumentation-ioredis@0.56.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/redis-common': 0.38.2
transitivePeerDependencies:
- supports-color
@@ -8838,7 +8817,7 @@ snapshots:
'@opentelemetry/instrumentation-kafkajs@0.18.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/semantic-conventions': 1.38.0
transitivePeerDependencies:
- supports-color
@@ -8846,7 +8825,7 @@ snapshots:
'@opentelemetry/instrumentation-knex@0.53.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/semantic-conventions': 1.38.0
transitivePeerDependencies:
- supports-color
@@ -8855,7 +8834,7 @@ snapshots:
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/semantic-conventions': 1.38.0
transitivePeerDependencies:
- supports-color
@@ -8863,14 +8842,14 @@ snapshots:
'@opentelemetry/instrumentation-lru-memoizer@0.53.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
transitivePeerDependencies:
- supports-color
'@opentelemetry/instrumentation-mongodb@0.61.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
transitivePeerDependencies:
- supports-color
@@ -8878,14 +8857,14 @@ snapshots:
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
transitivePeerDependencies:
- supports-color
'@opentelemetry/instrumentation-mysql2@0.55.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/semantic-conventions': 1.38.0
'@opentelemetry/sql-common': 0.41.2(@opentelemetry/api@1.9.0)
transitivePeerDependencies:
@@ -8894,7 +8873,7 @@ snapshots:
'@opentelemetry/instrumentation-mysql@0.54.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@types/mysql': 2.15.27
transitivePeerDependencies:
- supports-color
@@ -8903,7 +8882,7 @@ snapshots:
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/semantic-conventions': 1.38.0
'@opentelemetry/sql-common': 0.41.2(@opentelemetry/api@1.9.0)
'@types/pg': 8.15.6
@@ -8914,7 +8893,7 @@ snapshots:
'@opentelemetry/instrumentation-redis@0.57.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/redis-common': 0.38.2
'@opentelemetry/semantic-conventions': 1.38.0
transitivePeerDependencies:
@@ -8923,7 +8902,7 @@ snapshots:
'@opentelemetry/instrumentation-tedious@0.27.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@types/tedious': 4.0.14
transitivePeerDependencies:
- supports-color
@@ -8932,16 +8911,16 @@ snapshots:
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/semantic-conventions': 1.38.0
transitivePeerDependencies:
- supports-color
'@opentelemetry/instrumentation@0.209.0(@opentelemetry/api@1.9.0)':
'@opentelemetry/instrumentation@0.208.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/api-logs': 0.209.0
import-in-the-middle: 2.0.2
'@opentelemetry/api-logs': 0.208.0
import-in-the-middle: 2.0.1
require-in-the-middle: 8.0.1
transitivePeerDependencies:
- supports-color
@@ -9121,7 +9100,7 @@ snapshots:
'@prisma/instrumentation@6.19.0(@opentelemetry/api@1.9.0)':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
transitivePeerDependencies:
- supports-color
@@ -9129,23 +9108,6 @@ snapshots:
'@radix-ui/primitive@1.1.3': {}
'@radix-ui/react-accordion@1.2.12(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
dependencies:
'@radix-ui/primitive': 1.1.3
'@radix-ui/react-collapsible': 1.1.12(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@radix-ui/react-collection': 1.1.7(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.17)(react@18.3.1)
'@radix-ui/react-context': 1.1.2(@types/react@18.3.17)(react@18.3.1)
'@radix-ui/react-direction': 1.1.1(@types/react@18.3.17)(react@18.3.1)
'@radix-ui/react-id': 1.1.1(@types/react@18.3.17)(react@18.3.1)
'@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@radix-ui/react-use-controllable-state': 1.2.2(@types/react@18.3.17)(react@18.3.1)
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
optionalDependencies:
'@types/react': 18.3.17
'@types/react-dom': 18.3.5(@types/react@18.3.17)
'@radix-ui/react-alert-dialog@1.1.15(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
dependencies:
'@radix-ui/primitive': 1.1.3
@@ -9970,19 +9932,19 @@ snapshots:
- supports-color
- webpack
'@sentry/node-core@10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/instrumentation@0.209.0(@opentelemetry/api@1.9.0))(@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)':
'@sentry/node-core@10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/instrumentation@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)':
dependencies:
'@apm-js-collab/tracing-hooks': 0.3.1
'@opentelemetry/api': 1.9.0
'@opentelemetry/context-async-hooks': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/resources': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/sdk-trace-base': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/semantic-conventions': 1.38.0
'@sentry/core': 10.27.0
'@sentry/opentelemetry': 10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)
import-in-the-middle: 2.0.2
import-in-the-middle: 2.0.1
transitivePeerDependencies:
- supports-color
@@ -9991,7 +9953,7 @@ snapshots:
'@opentelemetry/api': 1.9.0
'@opentelemetry/context-async-hooks': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation-amqplib': 0.55.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation-connect': 0.52.0(@opentelemetry/api@1.9.0)
'@opentelemetry/instrumentation-dataloader': 0.26.0(@opentelemetry/api@1.9.0)
@@ -10019,9 +9981,9 @@ snapshots:
'@opentelemetry/semantic-conventions': 1.38.0
'@prisma/instrumentation': 6.19.0(@opentelemetry/api@1.9.0)
'@sentry/core': 10.27.0
'@sentry/node-core': 10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/instrumentation@0.209.0(@opentelemetry/api@1.9.0))(@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)
'@sentry/node-core': 10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/instrumentation@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)
'@sentry/opentelemetry': 10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)
import-in-the-middle: 2.0.2
import-in-the-middle: 2.0.1
minimatch: 9.0.5
transitivePeerDependencies:
- supports-color
@@ -12830,7 +12792,7 @@ snapshots:
parent-module: 1.0.1
resolve-from: 4.0.0
import-in-the-middle@2.0.2:
import-in-the-middle@2.0.1:
dependencies:
acorn: 8.15.0
acorn-import-attributes: 1.9.5(acorn@8.15.0)
@@ -14669,6 +14631,14 @@ snapshots:
require-from-string@2.0.2: {}
require-in-the-middle@7.5.2:
dependencies:
debug: 4.4.3
module-details-from-path: 1.0.4
resolve: 1.22.11
transitivePeerDependencies:
- supports-color
require-in-the-middle@8.0.1:
dependencies:
debug: 4.4.3

View File

@@ -1,4 +1,4 @@
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput";
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs";
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
import { useState } from "react";

View File

@@ -1,22 +1,22 @@
"use client";
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput";
import { useGetOauthGetOauthAppInfo } from "@/app/api/__generated__/endpoints/oauth/oauth";
import { okData } from "@/app/api/helpers";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import Image from "next/image";
import Link from "next/link";
import { useSearchParams } from "next/navigation";
import { useState, useMemo, useRef } from "react";
import { AuthCard } from "@/components/auth/AuthCard";
import { Text } from "@/components/atoms/Text/Text";
import { Button } from "@/components/atoms/Button/Button";
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs";
import type {
BlockIOCredentialsSubSchema,
CredentialsMetaInput,
CredentialsType,
} from "@/lib/autogpt-server-api";
import { CheckIcon, CircleIcon } from "@phosphor-icons/react";
import Image from "next/image";
import Link from "next/link";
import { useSearchParams } from "next/navigation";
import { useMemo, useRef, useState } from "react";
import { useGetOauthGetOauthAppInfo } from "@/app/api/__generated__/endpoints/oauth/oauth";
import { okData } from "@/app/api/helpers";
// All credential types - we accept any type of credential
const ALL_CREDENTIAL_TYPES: CredentialsType[] = [

View File

@@ -10,10 +10,7 @@ export const BuilderActions = memo(() => {
flowID: parseAsString,
});
return (
<div
data-id="builder-actions"
className="absolute bottom-4 left-[50%] z-[100] flex -translate-x-1/2 items-center gap-4 rounded-full bg-white p-2 px-2 shadow-lg"
>
<div className="absolute bottom-4 left-[50%] z-[100] flex -translate-x-1/2 items-center gap-4 rounded-full bg-white p-2 px-2 shadow-lg">
<AgentOutputs flowID={flowID} />
<RunGraph flowID={flowID} />
<ScheduleGraph flowID={flowID} />

View File

@@ -79,7 +79,6 @@ export const AgentOutputs = ({ flowID }: { flowID: string | null }) => {
<Button
variant="outline"
size="icon"
data-id="agent-outputs-button"
disabled={!flowID || !hasOutputs()}
>
<BookOpenIcon className="size-4" />

View File

@@ -31,7 +31,6 @@ export const RunGraph = ({ flowID }: { flowID: string | null }) => {
<Button
size="icon"
variant={isGraphRunning ? "destructive" : "primary"}
data-id={isGraphRunning ? "stop-graph-button" : "run-graph-button"}
onClick={isGraphRunning ? handleStopGraph : handleRunGraph}
disabled={!flowID || isExecutingGraph || isTerminatingGraph}
loading={isExecutingGraph || isTerminatingGraph || isSaving}

View File

@@ -7,11 +7,10 @@ import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
import { GraphExecutionMeta } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/use-agent-runs";
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
import { useShallow } from "zustand/react/shallow";
import { useEffect, useState } from "react";
import { useState } from "react";
import { useSaveGraph } from "@/app/(platform)/build/hooks/useSaveGraph";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { ApiError } from "@/lib/autogpt-server-api/helpers"; // Check if this exists
import { useTutorialStore } from "@/app/(platform)/build/stores/tutorialStore";
export const useRunGraph = () => {
const { saveGraph, isSaving } = useSaveGraph({
@@ -34,29 +33,6 @@ export const useRunGraph = () => {
useShallow((state) => state.clearAllNodeErrors),
);
// Tutorial integration - force open dialog when tutorial requests it
const forceOpenRunInputDialog = useTutorialStore(
(state) => state.forceOpenRunInputDialog,
);
const setForceOpenRunInputDialog = useTutorialStore(
(state) => state.setForceOpenRunInputDialog,
);
// Sync tutorial state with dialog state
useEffect(() => {
if (forceOpenRunInputDialog && !openRunInputDialog) {
setOpenRunInputDialog(true);
}
}, [forceOpenRunInputDialog, openRunInputDialog]);
// Reset tutorial state when dialog closes
const handleSetOpenRunInputDialog = (isOpen: boolean) => {
setOpenRunInputDialog(isOpen);
if (!isOpen && forceOpenRunInputDialog) {
setForceOpenRunInputDialog(false);
}
};
const [{ flowID, flowVersion, flowExecutionID }, setQueryStates] =
useQueryStates({
flowID: parseAsString,
@@ -162,6 +138,6 @@ export const useRunGraph = () => {
isExecutingGraph,
isTerminatingGraph,
openRunInputDialog,
setOpenRunInputDialog: handleSetOpenRunInputDialog,
setOpenRunInputDialog,
};
};

View File

@@ -8,8 +8,6 @@ import { Text } from "@/components/atoms/Text/Text";
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
import { useRunInputDialog } from "./useRunInputDialog";
import { CronSchedulerDialog } from "../CronSchedulerDialog/CronSchedulerDialog";
import { useTutorialStore } from "@/app/(platform)/build/stores/tutorialStore";
import { useEffect } from "react";
export const RunInputDialog = ({
isOpen,
@@ -39,21 +37,6 @@ export const RunInputDialog = ({
isExecutingGraph,
} = useRunInputDialog({ setIsOpen });
// Tutorial integration - track input values for the tutorial
const setTutorialInputValues = useTutorialStore(
(state) => state.setTutorialInputValues,
);
const isTutorialRunning = useTutorialStore(
(state) => state.isTutorialRunning,
);
// Update tutorial store when input values change
useEffect(() => {
if (isTutorialRunning) {
setTutorialInputValues(inputValues);
}
}, [inputValues, isTutorialRunning, setTutorialInputValues]);
return (
<>
<Dialog
@@ -65,16 +48,16 @@ export const RunInputDialog = ({
styling={{ maxWidth: "600px", minWidth: "600px" }}
>
<Dialog.Content>
<div className="space-y-6 p-1" data-id="run-input-dialog-content">
<div className="space-y-6 p-1">
{/* Credentials Section */}
{hasCredentials() && (
<div data-id="run-input-credentials-section">
<div>
<div className="mb-4">
<Text variant="h4" className="text-gray-900">
Credentials
</Text>
</div>
<div className="px-2" data-id="run-input-credentials-form">
<div className="px-2">
<FormRenderer
jsonSchema={credentialsSchema as RJSFSchema}
handleChange={(v) => handleCredentialChange(v.formData)}
@@ -92,13 +75,13 @@ export const RunInputDialog = ({
{/* Inputs Section */}
{hasInputs() && (
<div data-id="run-input-inputs-section">
<div>
<div className="mb-4">
<Text variant="h4" className="text-gray-900">
Inputs
</Text>
</div>
<div data-id="run-input-inputs-form">
<div className="px-2">
<FormRenderer
jsonSchema={inputSchema as RJSFSchema}
handleChange={(v) => handleInputChange(v.formData)}
@@ -114,10 +97,7 @@ export const RunInputDialog = ({
)}
{/* Action Button */}
<div
className="flex justify-end pt-2"
data-id="run-input-actions-section"
>
<div className="flex justify-end pt-2">
{purpose === "run" && (
<Button
variant="primary"
@@ -125,7 +105,6 @@ export const RunInputDialog = ({
className="group h-fit min-w-0 gap-2"
onClick={handleManualRun}
loading={isExecutingGraph}
data-id="run-input-manual-run-button"
>
{!isExecutingGraph && (
<PlayIcon className="size-5 transition-transform group-hover:scale-110" />
@@ -139,7 +118,6 @@ export const RunInputDialog = ({
size="large"
className="group h-fit min-w-0 gap-2"
onClick={() => setOpenCronSchedulerDialog(true)}
data-id="run-input-schedule-button"
>
<ClockIcon className="size-5 transition-transform group-hover:scale-110" />
<span className="font-semibold">Schedule Run</span>

View File

@@ -1,8 +1,7 @@
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
import { usePostV1ExecuteGraphAgent } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { useToast } from "@/components/molecules/Toast/use-toast";
import {
ApiError,
CredentialsMetaInput,
GraphExecutionMeta,
} from "@/lib/autogpt-server-api";
@@ -10,9 +9,6 @@ import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
import { useMemo, useState } from "react";
import { uiSchema } from "../../../FlowEditor/nodes/uiSchema";
import { isCredentialFieldSchema } from "@/components/renderers/InputRenderer/custom/CredentialField/helpers";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { useReactFlow } from "@xyflow/react";
export const useRunInputDialog = ({
setIsOpen,
@@ -35,7 +31,6 @@ export const useRunInputDialog = ({
flowVersion: parseAsInteger,
});
const { toast } = useToast();
const { setViewport } = useReactFlow();
const { mutateAsync: executeGraph, isPending: isExecutingGraph } =
usePostV1ExecuteGraphAgent({
@@ -47,63 +42,13 @@ export const useRunInputDialog = ({
});
},
onError: (error) => {
if (error instanceof ApiError && error.isGraphValidationError()) {
const errorData = error.response?.detail;
Object.entries(errorData.node_errors).forEach(
([nodeId, nodeErrors]) => {
useNodeStore
.getState()
.updateNodeErrors(
nodeId,
nodeErrors as { [key: string]: string },
);
},
);
toast({
title: errorData?.message || "Graph validation failed",
description:
"Please fix the validation errors on the highlighted nodes and try again.",
variant: "destructive",
});
setIsOpen(false);
const firstBackendId = Object.keys(errorData.node_errors)[0];
if (firstBackendId) {
const firstErrorNode = useNodeStore
.getState()
.nodes.find(
(n) =>
n.data.metadata?.backend_id === firstBackendId ||
n.id === firstBackendId,
);
if (firstErrorNode) {
setTimeout(() => {
setViewport(
{
x:
-firstErrorNode.position.x * 0.8 +
window.innerWidth / 2 -
150,
y: -firstErrorNode.position.y * 0.8 + 50,
zoom: 0.8,
},
{ duration: 500 },
);
}, 50);
}
}
} else {
toast({
title: "Error running graph",
description:
(error as Error).message || "An unexpected error occurred.",
variant: "destructive",
});
setIsOpen(false);
}
// Reset running state on error
setIsGraphRunning(false);
toast({
title: (error.detail as string) ?? "An unexpected error occurred.",
description: "An unexpected error occurred.",
variant: "destructive",
});
},
},
});

View File

@@ -26,7 +26,6 @@ export const ScheduleGraph = ({ flowID }: { flowID: string | null }) => {
<Button
variant="outline"
size="icon"
data-id="schedule-graph-button"
onClick={handleScheduleGraph}
disabled={!flowID}
>

View File

@@ -6,17 +6,12 @@ import {
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import {
ChalkboardIcon,
CircleNotchIcon,
FrameCornersIcon,
MinusIcon,
PlusIcon,
} from "@phosphor-icons/react/dist/ssr";
import { LockIcon, LockOpenIcon } from "lucide-react";
import { memo, useEffect, useState } from "react";
import { useSearchParams, useRouter } from "next/navigation";
import { useTutorialStore } from "@/app/(platform)/build/stores/tutorialStore";
import { startTutorial, setTutorialLoadingCallback } from "../../tutorial";
import { memo } from "react";
export const CustomControls = memo(
({
@@ -27,65 +22,27 @@ export const CustomControls = memo(
setIsLocked: (isLocked: boolean) => void;
}) => {
const { zoomIn, zoomOut, fitView } = useReactFlow();
const { isTutorialRunning, setIsTutorialRunning } = useTutorialStore();
const [isTutorialLoading, setIsTutorialLoading] = useState(false);
const searchParams = useSearchParams();
const router = useRouter();
useEffect(() => {
setTutorialLoadingCallback(setIsTutorialLoading);
return () => setTutorialLoadingCallback(() => {});
}, []);
const handleTutorialClick = () => {
if (isTutorialLoading) return;
const flowId = searchParams.get("flowID");
if (flowId) {
router.push("/build?view=new");
return;
}
startTutorial();
setIsTutorialRunning(true);
};
const controls = [
{
id: "zoom-in-button",
icon: <PlusIcon className="size-4" />,
label: "Zoom In",
onClick: () => zoomIn(),
className: "h-10 w-10 border-none",
},
{
id: "zoom-out-button",
icon: <MinusIcon className="size-4" />,
label: "Zoom Out",
onClick: () => zoomOut(),
className: "h-10 w-10 border-none",
},
{
id: "tutorial-button",
icon: isTutorialLoading ? (
<CircleNotchIcon className="size-4 animate-spin" />
) : (
<ChalkboardIcon className="size-4" />
),
label: isTutorialLoading ? "Loading Tutorial..." : "Start Tutorial",
onClick: handleTutorialClick,
className: `h-10 w-10 border-none ${isTutorialRunning || isTutorialLoading ? "bg-zinc-100" : "bg-white"}`,
disabled: isTutorialLoading,
},
{
id: "fit-view-button",
icon: <FrameCornersIcon className="size-4" />,
label: "Fit View",
onClick: () => fitView({ padding: 0.2, duration: 800, maxZoom: 1 }),
className: "h-10 w-10 border-none",
},
{
id: "lock-button",
icon: !isLocked ? (
<LockOpenIcon className="size-4" />
) : (
@@ -98,20 +55,15 @@ export const CustomControls = memo(
];
return (
<div
data-id="custom-controls"
className="absolute bottom-4 left-4 z-10 flex flex-col items-center gap-2 rounded-full bg-white px-1 py-2 shadow-lg"
>
{controls.map((control) => (
<Tooltip key={control.id} delayDuration={0}>
<div className="absolute bottom-4 left-4 z-10 flex flex-col items-center gap-2 rounded-full bg-white px-1 py-2 shadow-lg">
{controls.map((control, index) => (
<Tooltip key={index} delayDuration={300}>
<TooltipTrigger asChild>
<Button
variant="icon"
size={"small"}
onClick={control.onClick}
className={control.className}
data-id={control.id}
disabled={"disabled" in control ? control.disabled : false}
>
{control.icon}
<span className="sr-only">{control.label}</span>

View File

@@ -3,7 +3,6 @@ import { useGetV2GetSpecificBlocks } from "@/app/api/__generated__/endpoints/def
import {
useGetV1GetExecutionDetails,
useGetV1GetSpecificGraph,
useGetV1ListUserGraphs,
} from "@/app/api/__generated__/endpoints/graphs/graphs";
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
@@ -18,7 +17,6 @@ import { useReactFlow } from "@xyflow/react";
import { useControlPanelStore } from "../../../stores/controlPanelStore";
import { useHistoryStore } from "../../../stores/historyStore";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { okData } from "@/app/api/helpers";
export const useFlow = () => {
const [isLocked, setIsLocked] = useState(false);
@@ -38,9 +36,6 @@ export const useFlow = () => {
const setGraphExecutionStatus = useGraphStore(
useShallow((state) => state.setGraphExecutionStatus),
);
const setAvailableSubGraphs = useGraphStore(
useShallow((state) => state.setAvailableSubGraphs),
);
const updateEdgeBeads = useEdgeStore(
useShallow((state) => state.updateEdgeBeads),
);
@@ -67,11 +62,6 @@ export const useFlow = () => {
},
);
// Fetch all available graphs for sub-agent update detection
const { data: availableGraphs } = useGetV1ListUserGraphs({
query: { select: okData },
});
const { data: graph, isLoading: isGraphLoading } = useGetV1GetSpecificGraph(
flowID ?? "",
flowVersion !== null ? { version: flowVersion } : {},
@@ -126,18 +116,10 @@ export const useFlow = () => {
}
}, [graph]);
// Update available sub-graphs in store for sub-agent update detection
useEffect(() => {
if (availableGraphs) {
setAvailableSubGraphs(availableGraphs);
}
}, [availableGraphs, setAvailableSubGraphs]);
// adding nodes
useEffect(() => {
if (customNodes.length > 0) {
useNodeStore.getState().setNodes([]);
useNodeStore.getState().clearResolutionState();
addNodes(customNodes);
// Sync hardcoded values with handle IDs.
@@ -221,7 +203,6 @@ export const useFlow = () => {
useEffect(() => {
return () => {
useNodeStore.getState().setNodes([]);
useNodeStore.getState().clearResolutionState();
useEdgeStore.getState().setEdges([]);
useGraphStore.getState().reset();
useEdgeStore.getState().resetEdgeBeads();

View File

@@ -8,7 +8,6 @@ import {
getBezierPath,
} from "@xyflow/react";
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { XIcon } from "@phosphor-icons/react";
import { cn } from "@/lib/utils";
import { NodeExecutionResult } from "@/lib/autogpt-server-api";
@@ -36,8 +35,6 @@ const CustomEdge = ({
selected,
}: EdgeProps<CustomEdge>) => {
const removeConnection = useEdgeStore((state) => state.removeEdge);
// Subscribe to the brokenEdgeIDs map and check if this edge is broken across any node
const isBroken = useNodeStore((state) => state.isEdgeBroken(id));
const [isHovered, setIsHovered] = useState(false);
const [edgePath, labelX, labelY] = getBezierPath({
@@ -53,12 +50,6 @@ const CustomEdge = ({
const beadUp = data?.beadUp ?? 0;
const beadDown = data?.beadDown ?? 0;
const handleRemoveEdge = () => {
removeConnection(id);
// Note: broken edge tracking is cleaned up automatically by useSubAgentUpdateState
// when it detects the edge no longer exists
};
return (
<>
<BaseEdge
@@ -66,11 +57,9 @@ const CustomEdge = ({
markerEnd={markerEnd}
className={cn(
isStatic && "!stroke-[1.5px] [stroke-dasharray:6]",
isBroken
? "!stroke-red-500 !stroke-[2px] [stroke-dasharray:4]"
: selected
? "stroke-zinc-800"
: "stroke-zinc-500/50 hover:stroke-zinc-500",
selected
? "stroke-zinc-800"
: "stroke-zinc-500/50 hover:stroke-zinc-500",
)}
/>
<JSBeads
@@ -81,16 +70,12 @@ const CustomEdge = ({
/>
<EdgeLabelRenderer>
<Button
onClick={handleRemoveEdge}
onClick={() => removeConnection(id)}
className={cn(
"absolute h-fit min-w-0 p-1 transition-opacity",
isBroken
? "bg-red-500 opacity-100 hover:bg-red-600"
: isHovered
? "opacity-100"
: "opacity-0",
isHovered ? "opacity-100" : "opacity-0",
)}
variant={isBroken ? "primary" : "secondary"}
variant="secondary"
style={{
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
pointerEvents: "all",

View File

@@ -3,7 +3,6 @@ import { Handle, Position } from "@xyflow/react";
import { useEdgeStore } from "../../../stores/edgeStore";
import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers";
import { cn } from "@/lib/utils";
import { useNodeStore } from "../../../stores/nodeStore";
const InputNodeHandle = ({
handleId,
@@ -16,9 +15,6 @@ const InputNodeHandle = ({
const isInputConnected = useEdgeStore((state) =>
state.isInputConnected(nodeId ?? "", cleanedHandleId),
);
const isInputBroken = useNodeStore((state) =>
state.isInputBroken(nodeId, cleanedHandleId),
);
return (
<Handle
@@ -26,16 +22,12 @@ const InputNodeHandle = ({
position={Position.Left}
id={cleanedHandleId}
className={"-ml-6 mr-2"}
data-tutorial-id={`input-handler-${nodeId}-${cleanedHandleId}`}
>
<div className="pointer-events-none">
<CircleIcon
size={16}
weight={isInputConnected ? "fill" : "duotone"}
className={cn(
"text-gray-400 opacity-100",
isInputBroken && "text-red-500",
)}
className={"text-gray-400 opacity-100"}
/>
</div>
</Handle>
@@ -46,34 +38,27 @@ const OutputNodeHandle = ({
field_name,
nodeId,
hexColor,
isBroken,
}: {
field_name: string;
nodeId: string;
hexColor: string;
isBroken: boolean;
}) => {
const isOutputConnected = useEdgeStore((state) =>
state.isOutputConnected(nodeId, field_name),
);
return (
<Handle
type={"source"}
position={Position.Right}
id={field_name}
className={"-mr-2 ml-2"}
data-tutorial-id={`output-handler-${nodeId}-${field_name}`}
>
<div className="pointer-events-none">
<CircleIcon
size={16}
weight={"duotone"}
color={isOutputConnected ? hexColor : "gray"}
className={cn(
"text-gray-400 opacity-100",
isBroken && "text-red-500",
)}
className={cn("text-gray-400 opacity-100")}
/>
</div>
</Handle>

Some files were not shown because too many files have changed in this diff Show More