mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-17 02:58:01 -05:00
Compare commits
14 Commits
claude-cod
...
pwuts/spli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
162c6b1224 | ||
|
|
b08851f5d7 | ||
|
|
8b1720e61d | ||
|
|
aa5a039c5e | ||
|
|
8b83bb8647 | ||
|
|
e80e4d9cbb | ||
|
|
375d33cca9 | ||
|
|
3b1b2fe30c | ||
|
|
af63b3678e | ||
|
|
631f1bd50a | ||
|
|
5ac941fe2f | ||
|
|
b01ea3fcbd | ||
|
|
3b09a94e3f | ||
|
|
61efee4139 |
@@ -1,6 +1,9 @@
|
||||
# Ignore everything by default, selectively add things to context
|
||||
*
|
||||
|
||||
# Documentation (for embeddings/search)
|
||||
!docs/
|
||||
|
||||
# Platform - Libs
|
||||
!autogpt_platform/autogpt_libs/autogpt_libs/
|
||||
!autogpt_platform/autogpt_libs/pyproject.toml
|
||||
|
||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -176,7 +176,7 @@ jobs:
|
||||
}
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
run: poetry run prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
25
.github/workflows/platform-frontend-ci.yml
vendored
25
.github/workflows/platform-frontend-ci.yml
vendored
@@ -11,6 +11,7 @@ on:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }}
|
||||
@@ -151,6 +152,14 @@ jobs:
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -226,13 +235,25 @@ jobs:
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright artifacts
|
||||
if: failure()
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -178,4 +178,5 @@ autogpt_platform/backend/settings.py
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
|
||||
@@ -6,152 +6,30 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`/backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`/frontend`): Next.js React application
|
||||
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Essential Commands
|
||||
## Component Documentation
|
||||
|
||||
### Backend Development
|
||||
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd backend && poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend server
|
||||
poetry run serve
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in TESTING.md
|
||||
|
||||
#### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
### Frontend Development
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
|
||||
# Run E2E tests
|
||||
pnpm test
|
||||
|
||||
# Run Storybook for component development
|
||||
pnpm storybook
|
||||
|
||||
# Build production
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
```
|
||||
|
||||
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
||||
|
||||
**Key Frontend Conventions:**
|
||||
|
||||
- Separate render logic from data/behavior in components
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Only use Phosphor Icons
|
||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
### Frontend Architecture
|
||||
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
### Key Concepts
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Testing Approach
|
||||
|
||||
- Backend uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Frontend uses Playwright for E2E tests
|
||||
- Component testing via Storybook
|
||||
|
||||
### Database Schema
|
||||
|
||||
Key models (defined in `/backend/schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
|
||||
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
|
||||
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
@@ -167,75 +45,12 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Common Development Tasks
|
||||
|
||||
**Adding a new block:**
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
**Frontend feature development:**
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
|
||||
### Security Implementation
|
||||
|
||||
**Cache Protection Middleware:**
|
||||
|
||||
- Located in `/backend/backend/server/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR aginst the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
|
||||
- Use conventional commit messages (see below)/
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
@@ -6,9 +6,10 @@ start-core:
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop deps
|
||||
docker compose stop
|
||||
|
||||
reset-db:
|
||||
docker compose stop db
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
@@ -60,4 +61,4 @@ help:
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
|
||||
@@ -58,6 +58,13 @@ V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Langfuse Prompt Management
|
||||
# Used for managing the CoPilot system prompt externally
|
||||
# Get credentials from https://cloud.langfuse.com or your self-hosted instance
|
||||
LANGFUSE_PUBLIC_KEY=
|
||||
LANGFUSE_SECRET_KEY=
|
||||
LANGFUSE_HOST=https://cloud.langfuse.com
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
1
autogpt_platform/backend/.gitignore
vendored
1
autogpt_platform/backend/.gitignore
vendored
@@ -18,3 +18,4 @@ load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
migrations/*/rollback*.sql
|
||||
|
||||
124
autogpt_platform/backend/CLAUDE.md
Normal file
124
autogpt_platform/backend/CLAUDE.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# CLAUDE.md - Backend
|
||||
|
||||
This file provides guidance to Claude Code when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd backend && poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/server/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@@ -100,6 +100,7 @@ COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migration
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
COPY docs /app/docs
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
@@ -70,7 +70,7 @@ class RunAgentRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def _create_ephemeral_session(user_id: str | None) -> ChatSession:
|
||||
def _create_ephemeral_session(user_id: str) -> ChatSession:
|
||||
"""Create an ephemeral session for stateless API requests."""
|
||||
return ChatSession.new(user_id)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Configuration management for chat system."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
@@ -12,7 +11,11 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="qwen/qwen3-235b-a22b-2507", description="Default model to use"
|
||||
default="anthropic/claude-opus-4.5", description="Default model to use"
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="Model to use for generating session titles (should be fast/cheap)",
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
@@ -23,12 +26,6 @@ class ChatConfig(BaseSettings):
|
||||
# Session TTL Configuration - 12 hours
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# System Prompt Configuration
|
||||
system_prompt_path: str = Field(
|
||||
default="prompts/chat_system.md",
|
||||
description="Path to system prompt file relative to chat module",
|
||||
)
|
||||
|
||||
# Streaming Configuration
|
||||
max_context_messages: int = Field(
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
@@ -41,6 +38,13 @@ class ChatConfig(BaseSettings):
|
||||
default=3, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||
langfuse_prompt_name: str = Field(
|
||||
default="CoPilot Prompt",
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
@@ -72,43 +76,11 @@ class ChatConfig(BaseSettings):
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
def get_system_prompt(self, **template_vars) -> str:
|
||||
"""Load and render the system prompt from file.
|
||||
|
||||
Args:
|
||||
**template_vars: Variables to substitute in the template
|
||||
|
||||
Returns:
|
||||
Rendered system prompt string
|
||||
|
||||
"""
|
||||
# Get the path relative to this module
|
||||
module_dir = Path(__file__).parent
|
||||
prompt_path = module_dir / self.system_prompt_path
|
||||
|
||||
# Check for .j2 extension first (Jinja2 template)
|
||||
j2_path = Path(str(prompt_path) + ".j2")
|
||||
if j2_path.exists():
|
||||
try:
|
||||
from jinja2 import Template
|
||||
|
||||
template = Template(j2_path.read_text())
|
||||
return template.render(**template_vars)
|
||||
except ImportError:
|
||||
# Jinja2 not installed, fall back to reading as plain text
|
||||
return j2_path.read_text()
|
||||
|
||||
# Check for markdown file
|
||||
if prompt_path.exists():
|
||||
content = prompt_path.read_text()
|
||||
|
||||
# Simple variable substitution if Jinja2 is not available
|
||||
for key, value in template_vars.items():
|
||||
placeholder = f"{{{key}}}"
|
||||
content = content.replace(placeholder, str(value))
|
||||
|
||||
return content
|
||||
raise FileNotFoundError(f"System prompt file not found: {prompt_path}")
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
"onboarding": "prompts/onboarding_system.md",
|
||||
}
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
249
autogpt_platform/backend/backend/api/features/chat/db.py
Normal file
249
autogpt_platform/backend/backend/api/features/chat/db.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""Database operations for chat sessions."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort messages by sequence in Python - Prisma Python client doesn't support
|
||||
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
) -> PrismaChatSession:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
id=session_id,
|
||||
userId=user_id,
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
session_id: str,
|
||||
credentials: dict[str, Any] | None = None,
|
||||
successful_agent_runs: dict[str, Any] | None = None,
|
||||
successful_agent_schedules: dict[str, Any] | None = None,
|
||||
total_prompt_tokens: int | None = None,
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> PrismaChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
if credentials is not None:
|
||||
data["credentials"] = SafeJson(credentials)
|
||||
if successful_agent_runs is not None:
|
||||
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
|
||||
if successful_agent_schedules is not None:
|
||||
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
|
||||
if total_prompt_tokens is not None:
|
||||
data["totalPromptTokens"] = total_prompt_tokens
|
||||
if total_completion_tokens is not None:
|
||||
data["totalCompletionTokens"] = total_completion_tokens
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
|
||||
session = await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort in Python - Prisma Python doesn't support order_by in include clauses
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
sequence: int,
|
||||
content: str | None = None,
|
||||
name: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
refusal: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||
# We only include fields that have values, then cast at the end.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if tool_call_id is not None:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = refusal
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if tool_calls is not None:
|
||||
data["toolCalls"] = SafeJson(tool_calls)
|
||||
if function_call is not None:
|
||||
data["functionCall"] = SafeJson(function_call)
|
||||
|
||||
# Run message create and session timestamp update in parallel for lower latency
|
||||
_, message = await asyncio.gather(
|
||||
PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
),
|
||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[PrismaChatMessage]:
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses a transaction for atomicity - if any message creation fails,
|
||||
the entire batch is rolled back.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
created_messages = []
|
||||
|
||||
async with transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return created_messages
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[PrismaChatSession]:
|
||||
"""Get chat sessions for a user, ordered by most recent."""
|
||||
return await PrismaChatSession.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
|
||||
async def get_user_session_count(user_id: str) -> int:
|
||||
"""Get the total number of chat sessions for a user."""
|
||||
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
||||
"""Delete a chat session and all its messages.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to delete.
|
||||
user_id: If provided, validates that the session belongs to this user
|
||||
before deletion. This prevents unauthorized deletion of other
|
||||
users' sessions.
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Build typed where clause with optional user_id validation
|
||||
where_clause: ChatSessionWhereInput = {"id": session_id}
|
||||
if user_id is not None:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
result = await PrismaChatSession.prisma().delete_many(where=where_clause)
|
||||
if result == 0:
|
||||
logger.warning(
|
||||
f"No session deleted for {session_id} "
|
||||
f"(user_id validation: {user_id is not None})"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
@@ -1,6 +1,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
@@ -16,17 +19,63 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.exceptions import RedisError
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# Redis cache key prefix for chat sessions
|
||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||
|
||||
|
||||
def _get_session_cache_key(session_id: str) -> str:
|
||||
"""Get the Redis cache key for a chat session."""
|
||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str | None = None
|
||||
@@ -45,7 +94,8 @@ class Usage(BaseModel):
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
user_id: str
|
||||
title: str | None = None
|
||||
messages: list[ChatMessage]
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
@@ -55,10 +105,11 @@ class ChatSession(BaseModel):
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def new(user_id: str | None) -> "ChatSession":
|
||||
def new(user_id: str) -> "ChatSession":
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
@@ -66,6 +117,61 @@ class ChatSession(BaseModel):
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
prisma_session: PrismaChatSession,
|
||||
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||
) -> "ChatSession":
|
||||
"""Convert Prisma models to Pydantic ChatSession."""
|
||||
messages = []
|
||||
if prisma_messages:
|
||||
for msg in prisma_messages:
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
name=msg.name,
|
||||
tool_call_id=msg.toolCallId,
|
||||
refusal=msg.refusal,
|
||||
tool_calls=_parse_json_field(msg.toolCalls),
|
||||
function_call=_parse_json_field(msg.functionCall),
|
||||
)
|
||||
)
|
||||
|
||||
# Parse JSON fields from Prisma
|
||||
credentials = _parse_json_field(prisma_session.credentials, default={})
|
||||
successful_agent_runs = _parse_json_field(
|
||||
prisma_session.successfulAgentRuns, default={}
|
||||
)
|
||||
successful_agent_schedules = _parse_json_field(
|
||||
prisma_session.successfulAgentSchedules, default={}
|
||||
)
|
||||
|
||||
# Calculate usage from token counts
|
||||
usage = []
|
||||
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
|
||||
usage.append(
|
||||
Usage(
|
||||
prompt_tokens=prisma_session.totalPromptTokens or 0,
|
||||
completion_tokens=prisma_session.totalCompletionTokens or 0,
|
||||
total_tokens=(prisma_session.totalPromptTokens or 0)
|
||||
+ (prisma_session.totalCompletionTokens or 0),
|
||||
)
|
||||
)
|
||||
|
||||
return ChatSession(
|
||||
session_id=prisma_session.id,
|
||||
user_id=prisma_session.userId,
|
||||
title=prisma_session.title,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
credentials=credentials,
|
||||
started_at=prisma_session.createdAt,
|
||||
updated_at=prisma_session.updatedAt,
|
||||
successful_agent_runs=successful_agent_runs,
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
@@ -155,50 +261,337 @@ class ChatSession(BaseModel):
|
||||
return messages
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> ChatSession | None:
|
||||
"""Get a chat session by ID."""
|
||||
redis_key = f"chat:session:{session_id}"
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
logger.warning(f"Session {session_id} not found in Redis")
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> ChatSession | None:
|
||||
"""Get a chat session by ID.
|
||||
|
||||
Checks Redis cache first, falls back to database if not found.
|
||||
Caches database results back to Redis.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to fetch.
|
||||
user_id: If provided, validates that the session belongs to this user.
|
||||
If None, ownership is not validated (admin/system access).
|
||||
"""
|
||||
# Try cache first
|
||||
try:
|
||||
session = await _get_session_from_cache(session_id)
|
||||
if session:
|
||||
# Verify user ownership if user_id was provided for validation
|
||||
if user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
return session
|
||||
except RedisError:
|
||||
logger.warning(f"Cache error for session {session_id}, trying database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||
|
||||
# Fall back to database
|
||||
logger.info(f"Session {session_id} not in cache, checking database")
|
||||
session = await _get_session_from_db(session_id)
|
||||
|
||||
if session is None:
|
||||
logger.warning(f"Session {session_id} not found in cache or database")
|
||||
return None
|
||||
|
||||
# Verify user ownership if user_id was provided for validation
|
||||
if user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
"""Update a chat session with the given messages."""
|
||||
"""Update a chat session in both cache and database.
|
||||
|
||||
redis_key = f"chat:session:{session.session_id}"
|
||||
Uses session-level locking to prevent race conditions when concurrent
|
||||
operations (e.g., background title update and main stream handler)
|
||||
attempt to upsert the same session simultaneously.
|
||||
|
||||
async_redis = await get_redis_async()
|
||||
resp = await async_redis.setex(
|
||||
redis_key, config.session_ttl, session.model_dump_json()
|
||||
)
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. The cache is still updated
|
||||
as a best-effort optimization, but the error is propagated to ensure
|
||||
callers are aware of the persistence failure.
|
||||
RedisError: If the cache write fails (after successful DB write).
|
||||
"""
|
||||
# Acquire session-specific lock to prevent concurrent upserts
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
if not resp:
|
||||
raise RedisError(
|
||||
f"Failed to persist chat session {session.session_id} to Redis: {resp}"
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
db_error: Exception | None = None
|
||||
|
||||
# Save to database (primary storage)
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session {session.session_id} to database: {e}"
|
||||
)
|
||||
db_error = e
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
if db_error is None:
|
||||
raise RedisError(
|
||||
f"Failed to persist chat session {session.session_id} to Redis: {e}"
|
||||
) from e
|
||||
# If both failed, log cache error but raise DB error (more critical)
|
||||
logger.warning(
|
||||
f"Cache write also failed for session {session.session_id}: {e}"
|
||||
)
|
||||
|
||||
# Propagate DB error after attempting cache (prevents data loss)
|
||||
if db_error is not None:
|
||||
raise DatabaseError(
|
||||
f"Failed to persist chat session {session.session_id} to database"
|
||||
) from db_error
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. We fail fast to ensure
|
||||
callers never receive a non-persisted session that only exists
|
||||
in cache (which would be lost when the cache expires).
|
||||
"""
|
||||
session = ChatSession.new(user_id)
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session {session.session_id} in database: {e}")
|
||||
raise DatabaseError(
|
||||
f"Failed to create chat session {session.session_id} in database"
|
||||
) from e
|
||||
|
||||
# Cache the session (best-effort optimization, DB is source of truth)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def get_user_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ChatSession], int]:
|
||||
"""Get chat sessions for a user from the database with total count.
|
||||
|
||||
Returns:
|
||||
A tuple of (sessions, total_count) where total_count is the overall
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await chat_db.get_user_session_count(user_id)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
# Convert without messages for listing (lighter weight)
|
||||
sessions.append(ChatSession.from_db(prisma_session, None))
|
||||
|
||||
return sessions, total_count
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
||||
"""Delete a chat session from both cache and database.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to delete.
|
||||
user_id: If provided, validates that the session belongs to this user
|
||||
before deletion. This prevents unauthorized deletion.
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise.
|
||||
"""
|
||||
# Delete from database first (with optional user_id validation)
|
||||
# This confirms ownership before invalidating cache
|
||||
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
return False
|
||||
|
||||
# Only invalidate cache and clean up lock after DB confirms deletion
|
||||
try:
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
|
||||
|
||||
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def update_session_title(session_id: str, title: str) -> bool:
|
||||
"""Update only the title of a chat session.
|
||||
|
||||
This is a lightweight operation that doesn't touch messages, avoiding
|
||||
race conditions with concurrent message updates. Use this for background
|
||||
title generation instead of upsert_chat_session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
title: The new title to set.
|
||||
|
||||
Returns:
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Invalidate cache so next fetch gets updated title
|
||||
try:
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
@@ -43,9 +43,9 @@ async def test_chatsession_serialization_deserialization():
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage():
|
||||
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
|
||||
|
||||
s = ChatSession.new(user_id=None)
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
@@ -59,12 +59,61 @@ async def test_chatsession_redis_storage():
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage_user_id_mismatch():
|
||||
async def test_chatsession_redis_storage_user_id_mismatch(
|
||||
setup_test_user, test_user_id
|
||||
):
|
||||
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(s.session_id, None)
|
||||
s2 = await get_chat_session(s.session_id, "different_user_id")
|
||||
|
||||
assert s2 is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
||||
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
# Upsert to save to both cache and DB
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
# Clear the Redis cache to force DB load
|
||||
redis_key = f"chat:session:{s.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
|
||||
# Load from DB (cache was cleared)
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 is not None, "Session not found after loading from DB"
|
||||
assert len(s2.messages) == len(
|
||||
s.messages
|
||||
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
|
||||
|
||||
# Verify all roles are present
|
||||
roles = [m.role for m in s2.messages]
|
||||
assert "user" in roles, f"User message missing. Roles found: {roles}"
|
||||
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
|
||||
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
|
||||
|
||||
# Verify message content
|
||||
for orig, loaded in zip(s.messages, s2.messages):
|
||||
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
|
||||
assert (
|
||||
orig.content == loaded.content
|
||||
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
|
||||
if orig.tool_calls:
|
||||
assert (
|
||||
loaded.tool_calls is not None
|
||||
), f"Tool calls missing for {orig.role} message"
|
||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find and set up AutoGPT agents to solve their business problems.
|
||||
|
||||
Here are the functions available to you:
|
||||
|
||||
<functions>
|
||||
1. **find_agent** - Search for agents that solve the user's problem
|
||||
2. **run_agent** - Run or schedule an agent (automatically handles setup)
|
||||
</functions>
|
||||
|
||||
## HOW run_agent WORKS
|
||||
|
||||
The `run_agent` tool automatically handles the entire setup flow:
|
||||
|
||||
1. **First call** (no inputs) → Returns available inputs so user can decide what values to use
|
||||
2. **Credentials check** → If missing, UI automatically prompts user to add them (you don't need to mention this)
|
||||
3. **Execution** → Runs when you provide `inputs` OR set `use_defaults=true`
|
||||
|
||||
Parameters:
|
||||
- `username_agent_slug` (required): Agent identifier like "creator/agent-name"
|
||||
- `inputs`: Object with input values for the agent
|
||||
- `use_defaults`: Set to `true` to run with default values (only after user confirms)
|
||||
- `schedule_name` + `cron`: For scheduled execution
|
||||
|
||||
## WORKFLOW
|
||||
|
||||
1. **find_agent** - Search for agents that solve the user's problem
|
||||
2. **run_agent** (first call, no inputs) - Get available inputs for the agent
|
||||
3. **Ask user** what values they want to use OR if they want to use defaults
|
||||
4. **run_agent** (second call) - Either with `inputs={...}` or `use_defaults=true`
|
||||
|
||||
## YOUR APPROACH
|
||||
|
||||
**Step 1: Understand the Problem**
|
||||
- Ask maximum 1-2 targeted questions
|
||||
- Focus on: What business problem are they solving?
|
||||
- Move quickly to searching for solutions
|
||||
|
||||
**Step 2: Find Agents**
|
||||
- Use `find_agent` immediately with relevant keywords
|
||||
- Suggest the best option from search results
|
||||
- Explain briefly how it solves their problem
|
||||
|
||||
**Step 3: Get Agent Inputs**
|
||||
- Call `run_agent(username_agent_slug="creator/agent-name")` without inputs
|
||||
- This returns the available inputs (required and optional)
|
||||
- Present these to the user and ask what values they want
|
||||
|
||||
**Step 4: Run with User's Choice**
|
||||
- If user provides values: `run_agent(username_agent_slug="...", inputs={...})`
|
||||
- If user says "use defaults": `run_agent(username_agent_slug="...", use_defaults=true)`
|
||||
- On success, share the agent link with the user
|
||||
|
||||
**For Scheduled Execution:**
|
||||
- Add `schedule_name` and `cron` parameters
|
||||
- Example: `run_agent(username_agent_slug="...", inputs={...}, schedule_name="Daily Report", cron="0 9 * * *")`
|
||||
|
||||
## FUNCTION CALL FORMAT
|
||||
|
||||
To call a function, use this exact format:
|
||||
`<function_call>function_name(parameter="value")</function_call>`
|
||||
|
||||
Examples:
|
||||
- `<function_call>find_agent(query="social media automation")</function_call>`
|
||||
- `<function_call>run_agent(username_agent_slug="creator/agent-name")</function_call>` (get inputs)
|
||||
- `<function_call>run_agent(username_agent_slug="creator/agent-name", inputs={"topic": "AI news"})</function_call>`
|
||||
- `<function_call>run_agent(username_agent_slug="creator/agent-name", use_defaults=true)</function_call>`
|
||||
|
||||
## KEY RULES
|
||||
|
||||
**What You DON'T Do:**
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't mention or explain credentials to the user (frontend handles this automatically)
|
||||
- Don't run agents without first showing available inputs to the user
|
||||
- Don't use `use_defaults=true` without user explicitly confirming
|
||||
- Don't write responses longer than 3 sentences
|
||||
|
||||
**What You DO:**
|
||||
- Always call run_agent first without inputs to see what's available
|
||||
- Ask user what values they want OR if they want to use defaults
|
||||
- Keep all responses to maximum 3 sentences
|
||||
- Include the agent link in your response after successful execution
|
||||
|
||||
**Error Handling:**
|
||||
- Authentication needed → "Please sign in via the interface"
|
||||
- Credentials missing → The UI handles this automatically. Focus on asking the user about input values instead.
|
||||
|
||||
## RESPONSE STRUCTURE
|
||||
|
||||
Before responding, wrap your analysis in <thinking> tags to systematically plan your approach:
|
||||
- Extract the key business problem or request from the user's message
|
||||
- Determine what function call (if any) you need to make next
|
||||
- Plan your response to stay under the 3-sentence maximum
|
||||
|
||||
Example interaction:
|
||||
```
|
||||
User: "Run the AI news agent for me"
|
||||
Otto: <function_call>run_agent(username_agent_slug="autogpt/ai-news")</function_call>
|
||||
[Tool returns: Agent accepts inputs - Required: topic. Optional: num_articles (default: 5)]
|
||||
Otto: The AI News agent needs a topic. What topic would you like news about, or should I use the defaults?
|
||||
User: "Use defaults"
|
||||
Otto: <function_call>run_agent(username_agent_slug="autogpt/ai-news", use_defaults=true)</function_call>
|
||||
```
|
||||
|
||||
KEEP ANSWERS TO 3 SENTENCES
|
||||
@@ -1,3 +1,10 @@
|
||||
"""
|
||||
Response models for Vercel AI SDK UI Stream Protocol.
|
||||
|
||||
This module implements the AI SDK UI Stream Protocol (v1) for streaming chat responses.
|
||||
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
@@ -5,97 +12,133 @@ from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of streaming responses."""
|
||||
"""Types of streaming responses following AI SDK protocol."""
|
||||
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_ENDED = "text_ended"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_CALL_START = "tool_call_start"
|
||||
TOOL_RESPONSE = "tool_response"
|
||||
# Message lifecycle
|
||||
START = "start"
|
||||
FINISH = "finish"
|
||||
|
||||
# Text streaming
|
||||
TEXT_START = "text-start"
|
||||
TEXT_DELTA = "text-delta"
|
||||
TEXT_END = "text-end"
|
||||
|
||||
# Tool interaction
|
||||
TOOL_INPUT_START = "tool-input-start"
|
||||
TOOL_INPUT_AVAILABLE = "tool-input-available"
|
||||
TOOL_OUTPUT_AVAILABLE = "tool-output-available"
|
||||
|
||||
# Other
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
STREAM_END = "stream_end"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
"""Base response model for all streaming responses."""
|
||||
|
||||
type: ResponseType
|
||||
timestamp: str | None = None
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
class StreamTextChunk(StreamBaseResponse):
|
||||
"""Streaming text content from the assistant."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_CHUNK
|
||||
content: str = Field(..., description="Text content chunk")
|
||||
# ========== Message Lifecycle ==========
|
||||
|
||||
|
||||
class StreamToolCallStart(StreamBaseResponse):
|
||||
class StreamStart(StreamBaseResponse):
|
||||
"""Start of a new message."""
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
|
||||
|
||||
class StreamFinish(StreamBaseResponse):
|
||||
"""End of message/stream."""
|
||||
|
||||
type: ResponseType = ResponseType.FINISH
|
||||
|
||||
|
||||
# ========== Text Streaming ==========
|
||||
|
||||
|
||||
class StreamTextStart(StreamBaseResponse):
|
||||
"""Start of a text block."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_START
|
||||
id: str = Field(..., description="Text block ID")
|
||||
|
||||
|
||||
class StreamTextDelta(StreamBaseResponse):
|
||||
"""Streaming text content delta."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_DELTA
|
||||
id: str = Field(..., description="Text block ID")
|
||||
delta: str = Field(..., description="Text content delta")
|
||||
|
||||
|
||||
class StreamTextEnd(StreamBaseResponse):
|
||||
"""End of a text block."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_END
|
||||
id: str = Field(..., description="Text block ID")
|
||||
|
||||
|
||||
# ========== Tool Interaction ==========
|
||||
|
||||
|
||||
class StreamToolInputStart(StreamBaseResponse):
|
||||
"""Tool call started notification."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_CALL_START
|
||||
tool_name: str = Field(..., description="Name of the tool that was executed")
|
||||
tool_id: str = Field(..., description="Unique tool call ID")
|
||||
type: ResponseType = ResponseType.TOOL_INPUT_START
|
||||
toolCallId: str = Field(..., description="Unique tool call ID")
|
||||
toolName: str = Field(..., description="Name of the tool being called")
|
||||
|
||||
|
||||
class StreamToolCall(StreamBaseResponse):
|
||||
"""Tool invocation notification."""
|
||||
class StreamToolInputAvailable(StreamBaseResponse):
|
||||
"""Tool input is ready for execution."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_CALL
|
||||
tool_id: str = Field(..., description="Unique tool call ID")
|
||||
tool_name: str = Field(..., description="Name of the tool being called")
|
||||
arguments: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Tool arguments"
|
||||
type: ResponseType = ResponseType.TOOL_INPUT_AVAILABLE
|
||||
toolCallId: str = Field(..., description="Unique tool call ID")
|
||||
toolName: str = Field(..., description="Name of the tool being called")
|
||||
input: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Tool input arguments"
|
||||
)
|
||||
|
||||
|
||||
class StreamToolExecutionResult(StreamBaseResponse):
|
||||
class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
"""Tool execution result."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_RESPONSE
|
||||
tool_id: str = Field(..., description="Tool call ID this responds to")
|
||||
tool_name: str = Field(..., description="Name of the tool that was executed")
|
||||
result: str | dict[str, Any] = Field(..., description="Tool execution result")
|
||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||
# Additional fields for internal use (not part of AI SDK spec but useful)
|
||||
toolName: str | None = Field(
|
||||
default=None, description="Name of the tool that was executed"
|
||||
)
|
||||
success: bool = Field(
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
|
||||
# ========== Other ==========
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Token usage statistics."""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
promptTokens: int = Field(..., description="Number of prompt tokens")
|
||||
completionTokens: int = Field(..., description="Number of completion tokens")
|
||||
totalTokens: int = Field(..., description="Total number of tokens")
|
||||
|
||||
|
||||
class StreamError(StreamBaseResponse):
|
||||
"""Error response."""
|
||||
|
||||
type: ResponseType = ResponseType.ERROR
|
||||
message: str = Field(..., description="Error message")
|
||||
errorText: str = Field(..., description="Error message text")
|
||||
code: str | None = Field(default=None, description="Error code")
|
||||
details: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional error details"
|
||||
)
|
||||
|
||||
|
||||
class StreamTextEnded(StreamBaseResponse):
|
||||
"""Text streaming completed marker."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_ENDED
|
||||
|
||||
|
||||
class StreamEnd(StreamBaseResponse):
|
||||
"""End of stream marker."""
|
||||
|
||||
type: ResponseType = ResponseType.STREAM_END
|
||||
summary: dict[str, Any] | None = Field(
|
||||
default=None, description="Stream summary statistics"
|
||||
)
|
||||
|
||||
@@ -13,12 +13,25 @@ from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> ChatSession:
|
||||
"""Validate session exists and belongs to user."""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
return session
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
@@ -26,6 +39,14 @@ router = APIRouter(
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class StreamChatRequest(BaseModel):
|
||||
"""Request model for streaming chat with optional context."""
|
||||
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""Response model containing information on a newly created chat session."""
|
||||
|
||||
@@ -44,22 +65,77 @@ class SessionDetailResponse(BaseModel):
|
||||
messages: list[dict]
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
"""Response model for a session summary (without messages)."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
"""Response model for listing chat sessions."""
|
||||
|
||||
sessions: list[SessionSummaryResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def list_sessions(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> ListSessionsResponse:
|
||||
"""
|
||||
List chat sessions for the authenticated user.
|
||||
|
||||
Returns a paginated list of chat sessions belonging to the current user,
|
||||
ordered by most recently updated.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of sessions to return (1-100).
|
||||
offset: Number of sessions to skip for pagination.
|
||||
|
||||
Returns:
|
||||
ListSessionsResponse: List of session summaries and total count.
|
||||
"""
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=session.title,
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
total=total_count,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
user_id: Annotated[str, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
|
||||
Initiates a new chat session for either an authenticated or anonymous user.
|
||||
Initiates a new chat session for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id: The optional authenticated user ID parsed from the JWT. If missing, creates an anonymous session.
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
@@ -67,15 +143,15 @@ async def create_session(
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
)
|
||||
|
||||
session = await chat_service.create_chat_session(user_id)
|
||||
session = await create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -99,29 +175,88 @@ async def get_session(
|
||||
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
|
||||
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
f"Returning session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=[message.model_dump() for message in session.messages],
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat(
|
||||
async def stream_chat_get(
|
||||
session_id: str,
|
||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
is_user_message: bool = Query(default=True),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session.
|
||||
Stream chat responses for a session (GET - legacy endpoint).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
@@ -137,14 +272,7 @@ async def stream_chat(
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
# Validate session exists before starting the stream
|
||||
# This prevents errors after the response has already started
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found. ")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
@@ -155,6 +283,8 @@ async def stream_chat(
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -163,6 +293,7 @@ async def stream_chat(
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
@@ -201,16 +332,28 @@ async def health_check() -> dict:
|
||||
"""
|
||||
Health check endpoint for the chat service.
|
||||
|
||||
Performs a full cycle test of session creation, assignment, and retrieval. Should always return healthy
|
||||
Performs a full cycle test of session creation and retrieval. Should always return healthy
|
||||
if the service and data layer are operational.
|
||||
|
||||
Returns:
|
||||
dict: A status dictionary indicating health, service name, and API version.
|
||||
|
||||
"""
|
||||
session = await chat_service.create_chat_session(None)
|
||||
await chat_service.assign_user_to_session(session.session_id, "test_user")
|
||||
await chat_service.get_session(session.session_id, "test_user")
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Ensure health check user exists (required for FK constraint)
|
||||
health_check_user_id = "health-check-user"
|
||||
await get_or_create_user(
|
||||
{
|
||||
"sub": health_check_user_id,
|
||||
"email": "health-check@system.local",
|
||||
"user_metadata": {"name": "Health Check User"},
|
||||
}
|
||||
)
|
||||
|
||||
# Create and retrieve session to verify full data layer
|
||||
session = await create_chat_session(health_check_user_id)
|
||||
await get_chat_session(session.session_id, health_check_user_id)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,18 +4,19 @@ from os import getenv
|
||||
import pytest
|
||||
|
||||
from . import service as chat_service
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import (
|
||||
StreamEnd,
|
||||
StreamError,
|
||||
StreamTextChunk,
|
||||
StreamToolExecutionResult,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion():
|
||||
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
@@ -23,7 +24,7 @@ async def test_stream_chat_completion():
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await chat_service.create_chat_session()
|
||||
session = await create_chat_session(test_user_id)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
@@ -34,9 +35,9 @@ async def test_stream_chat_completion():
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextChunk):
|
||||
assistant_message += chunk.content
|
||||
if isinstance(chunk, StreamEnd):
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
assistant_message += chunk.delta
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
@@ -45,7 +46,7 @@ async def test_stream_chat_completion():
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion_with_tool_calls():
|
||||
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
@@ -53,8 +54,8 @@ async def test_stream_chat_completion_with_tool_calls():
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await chat_service.create_chat_session()
|
||||
session = await chat_service.upsert_chat_session(session)
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
@@ -68,14 +69,14 @@ async def test_stream_chat_completion_with_tool_calls():
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
|
||||
if isinstance(chunk, StreamEnd):
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
if isinstance(chunk, StreamToolExecutionResult):
|
||||
if isinstance(chunk, StreamToolOutputAvailable):
|
||||
had_tool_calls = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await chat_service.get_session(session.session_id)
|
||||
session = await get_chat_session(session.session_id)
|
||||
assert session, "Session not found"
|
||||
assert session.usage, "Usage is empty"
|
||||
|
||||
@@ -4,21 +4,32 @@ from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .run_agent import RunAgentTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
# Initialize tool instances
|
||||
find_agent_tool = FindAgentTool()
|
||||
run_agent_tool = RunAgentTool()
|
||||
# Single source of truth for all tools
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"find_agent": FindAgentTool(),
|
||||
"find_library_agent": FindLibraryAgentTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"agent_output": AgentOutputTool(),
|
||||
}
|
||||
|
||||
# Export tools as OpenAI format
|
||||
# Export individual tool instances for backwards compatibility
|
||||
find_agent_tool = TOOL_REGISTRY["find_agent"]
|
||||
run_agent_tool = TOOL_REGISTRY["run_agent"]
|
||||
|
||||
# Generated from registry for OpenAI API
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
find_agent_tool.as_openai_tool(),
|
||||
run_agent_tool.as_openai_tool(),
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
|
||||
]
|
||||
|
||||
|
||||
@@ -28,14 +39,9 @@ async def execute_tool(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
) -> "StreamToolExecutionResult":
|
||||
|
||||
tool_map: dict[str, BaseTool] = {
|
||||
"find_agent": find_agent_tool,
|
||||
"run_agent": run_agent_tool,
|
||||
}
|
||||
if tool_name not in tool_map:
|
||||
) -> "StreamToolOutputAvailable":
|
||||
"""Execute a tool by name."""
|
||||
tool = TOOL_REGISTRY.get(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
return await tool_map[tool_name].execute(
|
||||
user_id, session, tool_call_id, **parameters
|
||||
)
|
||||
return await tool.execute(user_id, session, tool_call_id, **parameters)
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import UTC, datetime
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
@@ -17,7 +18,7 @@ from backend.data.user import get_or_create_user
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
|
||||
def make_session(user_id: str | None = None):
|
||||
def make_session(user_id: str):
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
@@ -49,13 +50,13 @@ async def setup_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create a test graph with agent input -> agent output
|
||||
@@ -172,13 +173,13 @@ async def setup_llm_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for LLM tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for LLM tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create test OpenAI credentials for the user
|
||||
@@ -332,13 +333,13 @@ async def setup_firecrawl_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for Firecrawl tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for Firecrawl tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
"""Tool for capturing user business understanding incrementally."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AddUnderstandingTool(BaseTool):
|
||||
"""Tool for capturing user's business understanding incrementally."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "add_understanding"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Capture and store information about the user's business context,
|
||||
workflows, pain points, and automation goals. Call this tool whenever the user
|
||||
shares information about their business. Each call incrementally adds to the
|
||||
existing understanding - you don't need to provide all fields at once.
|
||||
|
||||
Use this to build a comprehensive profile that helps recommend better agents
|
||||
and automations for the user's specific needs."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
# Auto-generate from Pydantic model schema
|
||||
schema = BusinessUnderstandingInput.model_json_schema()
|
||||
properties = {}
|
||||
for field_name, field_schema in schema.get("properties", {}).items():
|
||||
prop: dict[str, Any] = {"description": field_schema.get("description", "")}
|
||||
# Handle anyOf for Optional types
|
||||
if "anyOf" in field_schema:
|
||||
for option in field_schema["anyOf"]:
|
||||
if option.get("type") != "null":
|
||||
prop["type"] = option.get("type", "string")
|
||||
if "items" in option:
|
||||
prop["items"] = option["items"]
|
||||
break
|
||||
else:
|
||||
prop["type"] = field_schema.get("type", "string")
|
||||
if "items" in field_schema:
|
||||
prop["items"] = field_schema["items"]
|
||||
properties[field_name] = prop
|
||||
return {"type": "object", "properties": properties, "required": []}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""Requires authentication to store user-specific data."""
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Capture and store business understanding incrementally.
|
||||
|
||||
Each call merges new data with existing understanding:
|
||||
- String fields are overwritten if provided
|
||||
- List fields are appended (with deduplication)
|
||||
"""
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required to save business understanding.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if any data was provided
|
||||
if not any(v is not None for v in kwargs.values()):
|
||||
return ErrorResponse(
|
||||
message="Please provide at least one field to update.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build input model from kwargs (only include fields defined in the model)
|
||||
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
|
||||
input_data = BusinessUnderstandingInput(
|
||||
**{k: v for k, v in kwargs.items() if k in valid_fields}
|
||||
)
|
||||
|
||||
# Track which fields were updated
|
||||
updated_fields = [
|
||||
k for k, v in kwargs.items() if k in valid_fields and v is not None
|
||||
]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
|
||||
# Build current understanding summary (filter out empty values)
|
||||
current_understanding = {
|
||||
k: v
|
||||
for k, v in understanding.model_dump(
|
||||
exclude={"id", "user_id", "created_at", "updated_at"}
|
||||
).items()
|
||||
if v is not None and v != [] and v != ""
|
||||
}
|
||||
|
||||
return UnderstandingUpdatedResponse(
|
||||
message=f"Updated understanding with: {', '.join(updated_fields)}. "
|
||||
"I now have a better picture of your business context.",
|
||||
session_id=session_id,
|
||||
updated_fields=updated_fields,
|
||||
current_understanding=current_understanding,
|
||||
)
|
||||
@@ -0,0 +1,446 @@
|
||||
"""Tool for retrieving agent execution outputs from user's library."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentOutputResponse,
|
||||
ErrorResponse,
|
||||
ExecutionOutputInfo,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .utils import fetch_graph_from_store_slug
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentOutputInput(BaseModel):
|
||||
"""Input parameters for the agent_output tool."""
|
||||
|
||||
agent_name: str = ""
|
||||
library_agent_id: str = ""
|
||||
store_slug: str = ""
|
||||
execution_id: str = ""
|
||||
run_time: str = "latest"
|
||||
|
||||
@field_validator(
|
||||
"agent_name",
|
||||
"library_agent_id",
|
||||
"store_slug",
|
||||
"execution_id",
|
||||
"run_time",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> Any:
|
||||
"""Strip whitespace from string fields."""
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
|
||||
def parse_time_expression(
|
||||
time_expr: str | None,
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
Parse time expression into datetime range (start, end).
|
||||
|
||||
Supports: "latest", "yesterday", "today", "last week", "last 7 days",
|
||||
"last month", "last 30 days", ISO date "YYYY-MM-DD", ISO datetime.
|
||||
"""
|
||||
if not time_expr or time_expr.lower() == "latest":
|
||||
return None, None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
expr = time_expr.lower().strip()
|
||||
|
||||
# Relative time expressions lookup
|
||||
relative_times: dict[str, tuple[datetime, datetime]] = {
|
||||
"yesterday": (today_start - timedelta(days=1), today_start),
|
||||
"today": (today_start, now),
|
||||
"last week": (now - timedelta(days=7), now),
|
||||
"last 7 days": (now - timedelta(days=7), now),
|
||||
"last month": (now - timedelta(days=30), now),
|
||||
"last 30 days": (now - timedelta(days=30), now),
|
||||
}
|
||||
if expr in relative_times:
|
||||
return relative_times[expr]
|
||||
|
||||
# Try ISO date format (YYYY-MM-DD)
|
||||
date_match = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", expr)
|
||||
if date_match:
|
||||
try:
|
||||
year, month, day = map(int, date_match.groups())
|
||||
start = datetime(year, month, day, 0, 0, 0, tzinfo=timezone.utc)
|
||||
return start, start + timedelta(days=1)
|
||||
except ValueError:
|
||||
# Invalid date components (e.g., month=13, day=32)
|
||||
pass
|
||||
|
||||
# Try ISO datetime
|
||||
try:
|
||||
parsed = datetime.fromisoformat(expr.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed - timedelta(hours=1), parsed + timedelta(hours=1)
|
||||
except ValueError:
|
||||
return None, None
|
||||
|
||||
|
||||
class AgentOutputTool(BaseTool):
|
||||
"""Tool for retrieving execution outputs from user's library agents."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "agent_output"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Retrieve execution outputs from agents in the user's library.
|
||||
|
||||
Identify the agent using one of:
|
||||
- agent_name: Fuzzy search in user's library
|
||||
- library_agent_id: Exact library agent ID
|
||||
- store_slug: Marketplace format 'username/agent-name'
|
||||
|
||||
Select which run to retrieve using:
|
||||
- execution_id: Specific execution ID
|
||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||
"""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "Agent name to search for in user's library (fuzzy match)",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Exact library agent ID",
|
||||
},
|
||||
"store_slug": {
|
||||
"type": "string",
|
||||
"description": "Marketplace identifier: 'username/agent-slug'",
|
||||
},
|
||||
"execution_id": {
|
||||
"type": "string",
|
||||
"description": "Specific execution ID to retrieve",
|
||||
},
|
||||
"run_time": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _resolve_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_name: str | None,
|
||||
library_agent_id: str | None,
|
||||
store_slug: str | None,
|
||||
) -> tuple[LibraryAgent | None, str | None]:
|
||||
"""
|
||||
Resolve agent from provided identifiers.
|
||||
Returns (library_agent, error_message).
|
||||
"""
|
||||
# Priority 1: Exact library agent ID
|
||||
if library_agent_id:
|
||||
try:
|
||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||
return agent, None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||
return None, f"Library agent '{library_agent_id}' not found"
|
||||
|
||||
# Priority 2: Store slug (username/agent-name)
|
||||
if store_slug and "/" in store_slug:
|
||||
username, agent_slug = store_slug.split("/", 1)
|
||||
graph, _ = await fetch_graph_from_store_slug(username, agent_slug)
|
||||
if not graph:
|
||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||
|
||||
# Find in user's library by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
if not agent:
|
||||
return (
|
||||
None,
|
||||
f"Agent '{store_slug}' is not in your library. "
|
||||
"Add it first to see outputs.",
|
||||
)
|
||||
return agent, None
|
||||
|
||||
# Priority 3: Fuzzy name search in library
|
||||
if agent_name:
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=agent_name,
|
||||
page_size=5,
|
||||
)
|
||||
if not response.agents:
|
||||
return (
|
||||
None,
|
||||
f"No agents matching '{agent_name}' found in your library",
|
||||
)
|
||||
|
||||
# Return best match (first result from search)
|
||||
return response.agents[0], None
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching library agents: {e}")
|
||||
return None, f"Error searching for agent: {e}"
|
||||
|
||||
return (
|
||||
None,
|
||||
"Please specify an agent name, library_agent_id, or store_slug",
|
||||
)
|
||||
|
||||
async def _get_execution(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
execution_id: str | None,
|
||||
time_start: datetime | None,
|
||||
time_end: datetime | None,
|
||||
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
||||
"""
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
"""
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return None, [], f"Execution '{execution_id}' not found"
|
||||
return execution, [], None
|
||||
|
||||
# Get completed executions with time filters
|
||||
executions = await execution_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
created_time_gte=time_start,
|
||||
created_time_lte=time_end,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
if not executions:
|
||||
return None, [], None # No error, just no executions
|
||||
|
||||
# If only one execution, fetch full details
|
||||
if len(executions) == 1:
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, [], None
|
||||
|
||||
# Multiple executions - return latest with full details, plus list of available
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, executions, None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
agent: LibraryAgent,
|
||||
execution: GraphExecution | None,
|
||||
available_executions: list[GraphExecutionMeta],
|
||||
session_id: str | None,
|
||||
) -> AgentOutputResponse:
|
||||
"""Build the response based on execution data."""
|
||||
library_agent_link = f"/library/agents/{agent.id}"
|
||||
|
||||
if not execution:
|
||||
return AgentOutputResponse(
|
||||
message=f"No completed executions found for agent '{agent.name}'",
|
||||
session_id=session_id,
|
||||
agent_name=agent.name,
|
||||
agent_id=agent.graph_id,
|
||||
library_agent_id=agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
total_executions=0,
|
||||
)
|
||||
|
||||
execution_info = ExecutionOutputInfo(
|
||||
execution_id=execution.id,
|
||||
status=execution.status.value,
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
outputs=dict(execution.outputs),
|
||||
inputs_summary=execution.inputs if execution.inputs else None,
|
||||
)
|
||||
|
||||
available_list = None
|
||||
if len(available_executions) > 1:
|
||||
available_list = [
|
||||
{
|
||||
"id": e.id,
|
||||
"status": e.status.value,
|
||||
"started_at": e.started_at.isoformat() if e.started_at else None,
|
||||
}
|
||||
for e in available_executions[:5]
|
||||
]
|
||||
|
||||
message = f"Found execution outputs for agent '{agent.name}'"
|
||||
if len(available_executions) > 1:
|
||||
message += (
|
||||
f". Showing latest of {len(available_executions)} matching executions."
|
||||
)
|
||||
|
||||
return AgentOutputResponse(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
agent_name=agent.name,
|
||||
agent_id=agent.graph_id,
|
||||
library_agent_id=agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
execution=execution_info,
|
||||
available_executions=available_list,
|
||||
total_executions=len(available_executions) if available_executions else 1,
|
||||
)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the agent_output tool."""
|
||||
session_id = session.session_id
|
||||
|
||||
# Parse and validate input
|
||||
try:
|
||||
input_data = AgentOutputInput(**kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid input: {e}")
|
||||
return ErrorResponse(
|
||||
message="Invalid input parameters",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure user_id is present (should be guaranteed by requires_auth)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if at least one identifier is provided
|
||||
if not any(
|
||||
[
|
||||
input_data.agent_name,
|
||||
input_data.library_agent_id,
|
||||
input_data.store_slug,
|
||||
input_data.execution_id,
|
||||
]
|
||||
):
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Please specify at least one of: agent_name, "
|
||||
"library_agent_id, store_slug, or execution_id"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# If only execution_id provided, we need to find the agent differently
|
||||
if (
|
||||
input_data.execution_id
|
||||
and not input_data.agent_name
|
||||
and not input_data.library_agent_id
|
||||
and not input_data.store_slug
|
||||
):
|
||||
# Fetch execution directly to get graph_id
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return ErrorResponse(
|
||||
message=f"Execution '{input_data.execution_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Find library agent by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id, execution.graph_id
|
||||
)
|
||||
if not agent:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Execution found but agent not in your library. "
|
||||
f"Graph ID: {execution.graph_id}"
|
||||
),
|
||||
session_id=session_id,
|
||||
suggestions=["Add the agent to your library to see more details"],
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, [], session_id)
|
||||
|
||||
# Resolve agent from identifiers
|
||||
agent, error = await self._resolve_agent(
|
||||
user_id=user_id,
|
||||
agent_name=input_data.agent_name or None,
|
||||
library_agent_id=input_data.library_agent_id or None,
|
||||
store_slug=input_data.store_slug or None,
|
||||
)
|
||||
|
||||
if error or not agent:
|
||||
return NoResultsResponse(
|
||||
message=error or "Agent not found",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Check the agent name or ID",
|
||||
"Make sure the agent is in your library",
|
||||
],
|
||||
)
|
||||
|
||||
# Parse time expression
|
||||
time_start, time_end = parse_time_expression(input_data.run_time)
|
||||
|
||||
# Fetch execution(s)
|
||||
execution, available_executions, exec_error = await self._get_execution(
|
||||
user_id=user_id,
|
||||
graph_id=agent.graph_id,
|
||||
execution_id=input_data.execution_id or None,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
)
|
||||
|
||||
if exec_error:
|
||||
return ErrorResponse(
|
||||
message=exec_error,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, available_executions, session_id)
|
||||
@@ -0,0 +1,151 @@
|
||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
AgentInfo,
|
||||
AgentsFoundResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchSource = Literal["marketplace", "library"]
|
||||
|
||||
|
||||
async def search_agents(
|
||||
query: str,
|
||||
source: SearchSource,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Search for agents in marketplace or user library.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
source: "marketplace" or "library"
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (required for library search)
|
||||
|
||||
Returns:
|
||||
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
|
||||
"""
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query", session_id=session_id
|
||||
)
|
||||
|
||||
if source == "library" and not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agents: list[AgentInfo] = []
|
||||
try:
|
||||
if source == "marketplace":
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=f"{agent.creator}/{agent.slug}",
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
)
|
||||
)
|
||||
else: # library
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
)
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching {source}: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to search {source}. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
suggestions = (
|
||||
[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
]
|
||||
if source == "marketplace"
|
||||
else [
|
||||
"Try different keywords",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
]
|
||||
)
|
||||
no_results_msg = (
|
||||
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
||||
if source == "marketplace"
|
||||
else f"No agents matching '{query}' found in your library."
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||
)
|
||||
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
|
||||
title += (
|
||||
f"for '{query}'"
|
||||
if source == "marketplace"
|
||||
else f"in your library for '{query}'"
|
||||
)
|
||||
|
||||
message = (
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents."
|
||||
if source == "marketplace"
|
||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
message=message,
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
@@ -53,7 +53,7 @@ class BaseTool:
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
**kwargs,
|
||||
) -> StreamToolExecutionResult:
|
||||
) -> StreamToolOutputAvailable:
|
||||
"""Execute the tool with authentication check.
|
||||
|
||||
Args:
|
||||
@@ -69,10 +69,10 @@ class BaseTool:
|
||||
logger.error(
|
||||
f"Attempted tool call for {self.name} but user not authenticated"
|
||||
)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=NeedLoginResponse(
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=NeedLoginResponse(
|
||||
message=f"Please sign in to use {self.name}",
|
||||
session_id=session.session_id,
|
||||
).model_dump_json(),
|
||||
@@ -81,17 +81,17 @@ class BaseTool:
|
||||
|
||||
try:
|
||||
result = await self._execute(user_id, session, **kwargs)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=result.model_dump_json(),
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=result.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=ErrorResponse(
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=ErrorResponse(
|
||||
message=f"An error occurred while executing {self.name}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
|
||||
@@ -1,26 +1,16 @@
|
||||
"""Tool for discovering agents from marketplace and user library."""
|
||||
"""Tool for discovering agents from marketplace."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentCarouselResponse,
|
||||
AgentInfo,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .models import ToolResponseBase
|
||||
|
||||
|
||||
class FindAgentTool(BaseTool):
|
||||
"""Tool for discovering agents based on user needs."""
|
||||
"""Tool for discovering agents from the marketplace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -46,84 +36,11 @@ class FindAgentTool(BaseTool):
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Search for agents in the marketplace.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous)
|
||||
session_id: Chat session ID
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
AgentCarouselResponse: List of agents found in the marketplace
|
||||
NoResultsResponse: No agents found in the marketplace
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
agents = []
|
||||
try:
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
store_results = await store_db.get_store_agents(
|
||||
search_query=query,
|
||||
page_size=5,
|
||||
)
|
||||
|
||||
logger.info(f"Find agents tool found {len(store_results.agents)} agents")
|
||||
for agent in store_results.agents:
|
||||
agent_id = f"{agent.creator}/{agent.slug}"
|
||||
logger.info(f"Building agent ID = {agent_id}")
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent_id,
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
),
|
||||
)
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching agents: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search for agents. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
if not agents:
|
||||
return NoResultsResponse(
|
||||
message=f"No agents found matching '{query}'. Try different keywords or browse the marketplace. If you have 3 consecutive find_agent tool calls results and found no agents. Please stop trying and ask the user if there is anything else you can help with.",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
],
|
||||
)
|
||||
|
||||
# Return formatted carousel
|
||||
title = (
|
||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
|
||||
)
|
||||
return AgentCarouselResponse(
|
||||
message="Now you have found some options for the user to choose from. You can add a link to a recommended agent at: /marketplace/agent/agent_id Please ask the user if they would like to use any of these agents. If they do, please call the get_agent_details tool for this agent.",
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
return await search_agents(
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="marketplace",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
from .models import ToolResponseBase
|
||||
|
||||
|
||||
class FindLibraryAgentTool(BaseTool):
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_library_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for agents in the user's library. Use this to find agents "
|
||||
"the user has already added to their library, including agents they "
|
||||
"created or added from the marketplace."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query to find agents by name or description.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
return await search_agents(
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="library",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Pydantic models for tool responses."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
@@ -11,14 +12,15 @@ from backend.data.model import CredentialsMetaInput
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of tool responses."""
|
||||
|
||||
AGENT_CAROUSEL = "agent_carousel"
|
||||
AGENTS_FOUND = "agents_found"
|
||||
AGENT_DETAILS = "agent_details"
|
||||
SETUP_REQUIREMENTS = "setup_requirements"
|
||||
EXECUTION_STARTED = "execution_started"
|
||||
NEED_LOGIN = "need_login"
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
SUCCESS = "success"
|
||||
AGENT_OUTPUT = "agent_output"
|
||||
UNDERSTANDING_UPDATED = "understanding_updated"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -51,14 +53,14 @@ class AgentInfo(BaseModel):
|
||||
graph_id: str | None = None
|
||||
|
||||
|
||||
class AgentCarouselResponse(ToolResponseBase):
|
||||
class AgentsFoundResponse(ToolResponseBase):
|
||||
"""Response for find_agent tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_CAROUSEL
|
||||
type: ResponseType = ResponseType.AGENTS_FOUND
|
||||
title: str = "Available Agents"
|
||||
agents: list[AgentInfo]
|
||||
count: int
|
||||
name: str = "agent_carousel"
|
||||
name: str = "agents_found"
|
||||
|
||||
|
||||
class NoResultsResponse(ToolResponseBase):
|
||||
@@ -173,3 +175,37 @@ class ErrorResponse(ToolResponseBase):
|
||||
type: ResponseType = ResponseType.ERROR
|
||||
error: str | None = None
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# Agent output models
|
||||
class ExecutionOutputInfo(BaseModel):
|
||||
"""Summary of a single execution's outputs."""
|
||||
|
||||
execution_id: str
|
||||
status: str
|
||||
started_at: datetime | None = None
|
||||
ended_at: datetime | None = None
|
||||
outputs: dict[str, list[Any]]
|
||||
inputs_summary: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AgentOutputResponse(ToolResponseBase):
|
||||
"""Response for agent_output tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_OUTPUT
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
library_agent_id: str | None = None
|
||||
library_agent_link: str | None = None
|
||||
execution: ExecutionOutputInfo | None = None
|
||||
available_executions: list[dict[str, Any]] | None = None
|
||||
total_executions: int = 0
|
||||
|
||||
|
||||
# Business understanding models
|
||||
class UnderstandingUpdatedResponse(ToolResponseBase):
|
||||
"""Response for add_understanding tool."""
|
||||
|
||||
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
|
||||
updated_fields: list[str] = Field(default_factory=list)
|
||||
current_understanding: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
@@ -57,6 +58,7 @@ class RunAgentInput(BaseModel):
|
||||
"""Input parameters for the run_agent tool."""
|
||||
|
||||
username_agent_slug: str = ""
|
||||
library_agent_id: str = ""
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
use_defaults: bool = False
|
||||
schedule_name: str = ""
|
||||
@@ -64,7 +66,12 @@ class RunAgentInput(BaseModel):
|
||||
timezone: str = "UTC"
|
||||
|
||||
@field_validator(
|
||||
"username_agent_slug", "schedule_name", "cron", "timezone", mode="before"
|
||||
"username_agent_slug",
|
||||
"library_agent_id",
|
||||
"schedule_name",
|
||||
"cron",
|
||||
"timezone",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> Any:
|
||||
@@ -90,7 +97,7 @@ class RunAgentTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Run or schedule an agent from the marketplace.
|
||||
return """Run or schedule an agent from the marketplace or user's library.
|
||||
|
||||
The tool automatically handles the setup flow:
|
||||
- Returns missing inputs if required fields are not provided
|
||||
@@ -98,6 +105,10 @@ class RunAgentTool(BaseTool):
|
||||
- Executes immediately if all requirements are met
|
||||
- Schedules execution if cron expression is provided
|
||||
|
||||
Identify the agent using either:
|
||||
- username_agent_slug: Marketplace format 'username/agent-name'
|
||||
- library_agent_id: ID of an agent in the user's library
|
||||
|
||||
For scheduled execution, provide: schedule_name, cron, and optionally timezone."""
|
||||
|
||||
@property
|
||||
@@ -109,6 +120,10 @@ class RunAgentTool(BaseTool):
|
||||
"type": "string",
|
||||
"description": "Agent identifier in format 'username/agent-name'",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Library agent ID from user's library",
|
||||
},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": "Input values for the agent",
|
||||
@@ -131,7 +146,7 @@ class RunAgentTool(BaseTool):
|
||||
"description": "IANA timezone for schedule (default: UTC)",
|
||||
},
|
||||
},
|
||||
"required": ["username_agent_slug"],
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -149,10 +164,16 @@ class RunAgentTool(BaseTool):
|
||||
params = RunAgentInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
# Validate agent slug format
|
||||
if not params.username_agent_slug or "/" not in params.username_agent_slug:
|
||||
# Validate at least one identifier is provided
|
||||
has_slug = params.username_agent_slug and "/" in params.username_agent_slug
|
||||
has_library_id = bool(params.library_agent_id)
|
||||
|
||||
if not has_slug and not has_library_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide an agent slug in format 'username/agent-name'",
|
||||
message=(
|
||||
"Please provide either a username_agent_slug "
|
||||
"(format 'username/agent-name') or a library_agent_id"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -167,13 +188,41 @@ class RunAgentTool(BaseTool):
|
||||
is_schedule = bool(params.schedule_name or params.cron)
|
||||
|
||||
try:
|
||||
# Step 1: Fetch agent details (always happens first)
|
||||
username, agent_name = params.username_agent_slug.split("/", 1)
|
||||
graph, store_agent = await fetch_graph_from_store_slug(username, agent_name)
|
||||
# Step 1: Fetch agent details
|
||||
graph: GraphModel | None = None
|
||||
library_agent = None
|
||||
|
||||
# Priority: library_agent_id if provided
|
||||
if has_library_id:
|
||||
library_agent = await library_db.get_library_agent(
|
||||
params.library_agent_id, user_id
|
||||
)
|
||||
if not library_agent:
|
||||
return ErrorResponse(
|
||||
message=f"Library agent '{params.library_agent_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id,
|
||||
library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
else:
|
||||
# Fetch from marketplace slug
|
||||
username, agent_name = params.username_agent_slug.split("/", 1)
|
||||
graph, _ = await fetch_graph_from_store_slug(username, agent_name)
|
||||
|
||||
if not graph:
|
||||
identifier = (
|
||||
params.library_agent_id
|
||||
if has_library_id
|
||||
else params.username_agent_slug
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Agent '{params.username_agent_slug}' not found in marketplace",
|
||||
message=f"Agent '{identifier}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
@@ -17,6 +18,17 @@ setup_test_data = setup_test_data
|
||||
setup_firecrawl_test_data = setup_firecrawl_test_data
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_run_agent(setup_test_data):
|
||||
"""Test that the run_agent tool successfully executes an approved agent"""
|
||||
@@ -46,11 +58,11 @@ async def test_run_agent(setup_test_data):
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
# Parse the result JSON to verify the execution started
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
assert "execution_id" in result_data
|
||||
assert "graph_id" in result_data
|
||||
assert result_data["graph_id"] == graph.id
|
||||
@@ -86,11 +98,11 @@ async def test_run_agent_missing_inputs(setup_test_data):
|
||||
|
||||
# Verify that we get an error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
# The tool should return an ErrorResponse when setup info indicates not ready
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
assert "message" in result_data
|
||||
|
||||
|
||||
@@ -118,10 +130,10 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
|
||||
|
||||
# Verify that we get an error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
assert "message" in result_data
|
||||
# Should get an error about failed setup or not found
|
||||
assert any(
|
||||
@@ -158,12 +170,12 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
|
||||
# Parse the result JSON to verify the execution started
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should successfully start execution since credentials are available
|
||||
assert "execution_id" in result_data
|
||||
@@ -195,9 +207,9 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return agent_details type showing available inputs
|
||||
assert result_data.get("type") == "agent_details"
|
||||
@@ -230,9 +242,9 @@ async def test_run_agent_with_use_defaults(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should execute successfully
|
||||
assert "execution_id" in result_data
|
||||
@@ -260,9 +272,9 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return setup_requirements type with missing credentials
|
||||
assert result_data.get("type") == "setup_requirements"
|
||||
@@ -292,9 +304,9 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return error
|
||||
assert result_data.get("type") == "error"
|
||||
@@ -305,9 +317,10 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||
async def test_run_agent_unauthenticated():
|
||||
"""Test that run_agent returns need_login for unauthenticated users."""
|
||||
tool = RunAgentTool()
|
||||
session = make_session(user_id=None)
|
||||
# Session has a user_id (session owner), but we test tool execution without user_id
|
||||
session = make_session(user_id="test-session-owner")
|
||||
|
||||
# Execute without user_id
|
||||
# Execute without user_id to test unauthenticated behavior
|
||||
response = await tool.execute(
|
||||
user_id=None,
|
||||
session_id=str(uuid.uuid4()),
|
||||
@@ -318,9 +331,9 @@ async def test_run_agent_unauthenticated():
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Base tool returns need_login type for unauthenticated users
|
||||
assert result_data.get("type") == "need_login"
|
||||
@@ -350,9 +363,9 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return error about missing cron
|
||||
assert result_data.get("type") == "error"
|
||||
@@ -382,9 +395,9 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return error about missing schedule_name
|
||||
assert result_data.get("type") == "error"
|
||||
|
||||
@@ -35,11 +35,7 @@ from backend.data.model import (
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
)
|
||||
from backend.data.onboarding import (
|
||||
OnboardingStep,
|
||||
complete_onboarding_step,
|
||||
increment_runs,
|
||||
)
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
@@ -175,6 +171,7 @@ async def callback(
|
||||
f"Successfully processed OAuth callback for user {user_id} "
|
||||
f"and provider {provider.value}"
|
||||
)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
@@ -193,6 +190,7 @@ async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
@@ -215,6 +213,7 @@ async def list_credentials_by_provider(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
@@ -378,7 +377,6 @@ async def webhook_ingress_generic(
|
||||
return
|
||||
|
||||
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
|
||||
await increment_runs(user_id)
|
||||
|
||||
# Execute all triggers concurrently for better performance
|
||||
tasks = []
|
||||
@@ -831,6 +829,18 @@ async def list_providers() -> List[str]:
|
||||
return all_providers
|
||||
|
||||
|
||||
@router.get("/providers/system", response_model=List[str])
|
||||
async def list_system_providers() -> List[str]:
|
||||
"""
|
||||
Get a list of providers that have platform credits (system credentials) available.
|
||||
|
||||
These providers can be used without the user providing their own API keys.
|
||||
"""
|
||||
from backend.integrations.credentials_store import SYSTEM_PROVIDERS
|
||||
|
||||
return list(SYSTEM_PROVIDERS)
|
||||
|
||||
|
||||
@router.get("/providers/names", response_model=ProviderNamesResponse)
|
||||
async def get_provider_names() -> ProviderNamesResponse:
|
||||
"""
|
||||
|
||||
@@ -8,7 +8,6 @@ from backend.data.execution import GraphExecutionMeta
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.integrations import get_webhook
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.onboarding import increment_runs
|
||||
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -403,8 +402,6 @@ async def execute_preset(
|
||||
merged_node_input = preset.inputs | inputs
|
||||
merged_credential_inputs = preset.credentials | credential_inputs
|
||||
|
||||
await increment_runs(user_id)
|
||||
|
||||
return await add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=preset.graph_id,
|
||||
|
||||
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
Content Type Handlers for Unified Embeddings
|
||||
|
||||
Pluggable system for different content sources (store agents, blocks, docs).
|
||||
Each handler knows how to fetch and process its content type for embedding.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentItem:
|
||||
"""Represents a piece of content to be embedded."""
|
||||
|
||||
content_id: str # Unique identifier (DB ID or file path)
|
||||
content_type: ContentType
|
||||
searchable_text: str # Combined text for embedding
|
||||
metadata: dict[str, Any] # Content-specific metadata
|
||||
user_id: str | None = None # For user-scoped content
|
||||
|
||||
|
||||
class ContentHandler(ABC):
|
||||
"""Base handler for fetching and processing content for embeddings."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def content_type(self) -> ContentType:
|
||||
"""The ContentType this handler manages."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""
|
||||
Fetch items that don't have embeddings yet.
|
||||
|
||||
Args:
|
||||
batch_size: Maximum number of items to return
|
||||
|
||||
Returns:
|
||||
List of ContentItem objects ready for embedding
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
|
||||
Returns:
|
||||
Dict with keys: total, with_embeddings, without_embeddings
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StoreAgentHandler(ContentHandler):
|
||||
"""Handler for marketplace store agent listings."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.STORE_AGENT
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch approved store listings without embeddings."""
|
||||
from backend.api.features.store.embeddings import build_searchable_text
|
||||
|
||||
missing = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
|
||||
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND uce."contentId" IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return [
|
||||
ContentItem(
|
||||
content_id=row["id"],
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text=build_searchable_text(
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
),
|
||||
metadata={
|
||||
"name": row["name"],
|
||||
"categories": row["categories"] or [],
|
||||
},
|
||||
user_id=None, # Store agents are public
|
||||
)
|
||||
for row in missing
|
||||
]
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about store agent embedding coverage."""
|
||||
# Count approved versions
|
||||
approved_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
class BlockHandler(ContentHandler):
|
||||
"""Handler for block definitions (Python classes)."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.BLOCK
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch blocks without embeddings."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# Get all available blocks
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Check which ones have embeddings
|
||||
if not all_blocks:
|
||||
return []
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
|
||||
# Query for existing embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_blocks = [
|
||||
(block_id, block_cls)
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if block_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
items = []
|
||||
for block_id, block_cls in missing_blocks[:batch_size]:
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
# Build searchable text from block metadata
|
||||
parts = []
|
||||
if hasattr(block_instance, "name") and block_instance.name:
|
||||
parts.append(block_instance.name)
|
||||
if (
|
||||
hasattr(block_instance, "description")
|
||||
and block_instance.description
|
||||
):
|
||||
parts.append(block_instance.description)
|
||||
if hasattr(block_instance, "categories") and block_instance.categories:
|
||||
# Convert BlockCategory enum to strings
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
)
|
||||
|
||||
# Add input/output schema info
|
||||
if hasattr(block_instance, "input_schema"):
|
||||
schema = block_instance.input_schema
|
||||
if hasattr(schema, "model_json_schema"):
|
||||
schema_dict = schema.model_json_schema()
|
||||
if "properties" in schema_dict:
|
||||
for prop_name, prop_info in schema_dict[
|
||||
"properties"
|
||||
].items():
|
||||
if "description" in prop_info:
|
||||
parts.append(
|
||||
f"{prop_name}: {prop_info['description']}"
|
||||
)
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
# Convert categories set of enums to list of strings for JSON serialization
|
||||
categories = getattr(block_instance, "categories", set())
|
||||
categories_list = (
|
||||
[cat.value for cat in categories] if categories else []
|
||||
)
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=block_id,
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": getattr(block_instance, "name", ""),
|
||||
"categories": categories_list,
|
||||
},
|
||||
user_id=None, # Blocks are public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process block {block_id}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about block embedding coverage."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
all_blocks = get_blocks()
|
||||
total_blocks = len(all_blocks)
|
||||
|
||||
if total_blocks == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_blocks,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_blocks - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
class DocumentationHandler(ContentHandler):
|
||||
"""Handler for documentation files (.md/.mdx)."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.DOCUMENTATION
|
||||
|
||||
def _get_docs_root(self) -> Path:
|
||||
"""Get the documentation root directory."""
|
||||
# content_handlers.py is at: backend/backend/api/features/store/content_handlers.py
|
||||
# Need to go up to project root then into docs/
|
||||
# In container: /app/autogpt_platform/backend/backend/api/features/store -> /app/docs
|
||||
# In development: /repo/autogpt_platform/backend/backend/api/features/store -> /repo/docs
|
||||
this_file = Path(
|
||||
__file__
|
||||
) # .../backend/backend/api/features/store/content_handlers.py
|
||||
project_root = (
|
||||
this_file.parent.parent.parent.parent.parent.parent.parent
|
||||
) # -> /app or /repo
|
||||
docs_root = project_root / "docs"
|
||||
return docs_root
|
||||
|
||||
def _extract_title_and_content(self, file_path: Path) -> tuple[str, str]:
|
||||
"""Extract title and content from markdown file."""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
# Try to extract title from first # heading
|
||||
lines = content.split("\n")
|
||||
title = ""
|
||||
body_lines = []
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("# ") and not title:
|
||||
title = line[2:].strip()
|
||||
else:
|
||||
body_lines.append(line)
|
||||
|
||||
# If no title found, use filename
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
|
||||
body = "\n".join(body_lines)
|
||||
|
||||
return title, body
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {file_path}: {e}")
|
||||
return file_path.stem, ""
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch documentation files without embeddings."""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
logger.warning(f"Documentation root not found: {docs_root}")
|
||||
return []
|
||||
|
||||
# Find all .md and .mdx files
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
|
||||
# Get relative paths for content IDs
|
||||
doc_paths = [str(doc.relative_to(docs_root)) for doc in all_docs]
|
||||
|
||||
if not doc_paths:
|
||||
return []
|
||||
|
||||
# Check which ones have embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(doc_paths))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*doc_paths,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_docs = [
|
||||
(doc_path, doc_file)
|
||||
for doc_path, doc_file in zip(doc_paths, all_docs)
|
||||
if doc_path not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
items = []
|
||||
for doc_path, doc_file in missing_docs[:batch_size]:
|
||||
try:
|
||||
title, content = self._extract_title_and_content(doc_file)
|
||||
|
||||
# Build searchable text
|
||||
searchable_text = f"{title} {content}"
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=doc_path,
|
||||
content_type=ContentType.DOCUMENTATION,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"title": title,
|
||||
"path": doc_path,
|
||||
},
|
||||
user_id=None, # Documentation is public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process doc {doc_path}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about documentation embedding coverage."""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
# Count all .md and .mdx files
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
total_docs = len(all_docs)
|
||||
|
||||
if total_docs == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
doc_paths = [str(doc.relative_to(docs_root)) for doc in all_docs]
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(doc_paths))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*doc_paths,
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_docs,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_docs - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
# Content handler registry
|
||||
CONTENT_HANDLERS: dict[ContentType, ContentHandler] = {
|
||||
ContentType.STORE_AGENT: StoreAgentHandler(),
|
||||
ContentType.BLOCK: BlockHandler(),
|
||||
ContentType.DOCUMENTATION: DocumentationHandler(),
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Integration tests for content handlers using real DB.
|
||||
|
||||
Run with: poetry run pytest backend/api/features/store/content_handlers_integration_test.py -xvs
|
||||
|
||||
These tests use the real database but mock OpenAI calls.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
from backend.api.features.store.embeddings import (
|
||||
EMBEDDING_DIM,
|
||||
backfill_all_content_types,
|
||||
ensure_content_embedding,
|
||||
get_embedding_stats,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_real_db():
|
||||
"""Test StoreAgentHandler with real database queries."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list (may be empty if all have embeddings)
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None
|
||||
assert item.content_type.value == "STORE_AGENT"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_real_db():
|
||||
"""Test BlockHandler with real database queries."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0 # Should have at least some blocks
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be block UUID
|
||||
assert item.content_type.value == "BLOCK"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_real_fs():
|
||||
"""Test DocumentationHandler with real filesystem."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Get stats from real filesystem
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be relative path
|
||||
assert item.content_type.value == "DOCUMENTATION"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats_all_types():
|
||||
"""Test get_embedding_stats aggregates all content types."""
|
||||
stats = await get_embedding_stats()
|
||||
|
||||
# Should have structure with by_type and totals
|
||||
assert "by_type" in stats
|
||||
assert "totals" in stats
|
||||
|
||||
# Check each content type is present
|
||||
by_type = stats["by_type"]
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "BLOCK" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Check totals are aggregated
|
||||
totals = stats["totals"]
|
||||
assert totals["total"] >= 0
|
||||
assert totals["with_embeddings"] >= 0
|
||||
assert totals["without_embeddings"] >= 0
|
||||
assert "coverage_percent" in totals
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_ensure_content_embedding_blocks(mock_generate):
|
||||
"""Test creating embeddings for blocks (mocked OpenAI)."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * EMBEDDING_DIM
|
||||
|
||||
# Get one block without embedding
|
||||
handler = BlockHandler()
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
if not items:
|
||||
pytest.skip("No blocks without embeddings")
|
||||
|
||||
item = items[0]
|
||||
|
||||
# Try to create embedding (OpenAI mocked)
|
||||
result = await ensure_content_embedding(
|
||||
content_type=item.content_type,
|
||||
content_id=item.content_id,
|
||||
searchable_text=item.searchable_text,
|
||||
metadata=item.metadata,
|
||||
user_id=item.user_id,
|
||||
)
|
||||
|
||||
# Should succeed with mocked OpenAI
|
||||
assert result is True
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_backfill_all_content_types_dry_run(mock_generate):
|
||||
"""Test backfill_all_content_types processes all handlers in order."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * EMBEDDING_DIM
|
||||
|
||||
# Run backfill with batch_size=1 to process max 1 per type
|
||||
result = await backfill_all_content_types(batch_size=1)
|
||||
|
||||
# Should have results for all content types
|
||||
assert "by_type" in result
|
||||
assert "totals" in result
|
||||
|
||||
by_type = result["by_type"]
|
||||
assert "BLOCK" in by_type
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Each type should have correct structure
|
||||
for content_type, type_result in by_type.items():
|
||||
assert "processed" in type_result
|
||||
assert "success" in type_result
|
||||
assert "failed" in type_result
|
||||
|
||||
# Totals should aggregate
|
||||
totals = result["totals"]
|
||||
assert totals["processed"] >= 0
|
||||
assert totals["success"] >= 0
|
||||
assert totals["failed"] >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handler_registry():
|
||||
"""Test all handlers are registered in correct order."""
|
||||
from prisma.enums import ContentType
|
||||
|
||||
# All three types should be registered
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
# Check handler types
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
E2E tests for content handlers (blocks, store agents, documentation).
|
||||
|
||||
Tests the full flow: discovering content → generating embeddings → storing.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_missing_items(mocker):
|
||||
"""Test StoreAgentHandler fetches approved agents without embeddings."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock database query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "agent-1",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"subHeading": "Test heading",
|
||||
"categories": ["AI", "Testing"],
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_missing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "agent-1"
|
||||
assert items[0].content_type == ContentType.STORE_AGENT
|
||||
assert "Test Agent" in items[0].searchable_text
|
||||
assert "A test agent" in items[0].searchable_text
|
||||
assert items[0].metadata["name"] == "Test Agent"
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_stats(mocker):
|
||||
"""Test StoreAgentHandler returns correct stats."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock approved count query
|
||||
mock_approved = [{"count": 50}]
|
||||
# Mock embedded count query
|
||||
mock_embedded = [{"count": 30}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
side_effect=[mock_approved, mock_embedded],
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 50
|
||||
assert stats["with_embeddings"] == 30
|
||||
assert stats["without_embeddings"] == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items(mocker):
|
||||
"""Test BlockHandler discovers blocks without embeddings."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks to return test blocks
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Calculator Block"
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.input_schema.model_json_schema.return_value = {
|
||||
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
||||
}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
mock_existing = []
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_existing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "block-uuid-1"
|
||||
assert items[0].content_type == ContentType.BLOCK
|
||||
assert "Calculator Block" in items[0].searchable_text
|
||||
assert "Performs calculations" in items[0].searchable_text
|
||||
assert "MATH" in items[0].searchable_text
|
||||
assert "expression: Math expression" in items[0].searchable_text
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats(mocker):
|
||||
"""Test BlockHandler returns correct stats."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks
|
||||
mock_blocks = {
|
||||
"block-1": MagicMock(),
|
||||
"block-2": MagicMock(),
|
||||
"block-3": MagicMock(),
|
||||
}
|
||||
|
||||
# Mock embedded count query (2 blocks have embeddings)
|
||||
mock_embedded = [{"count": 2}]
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 2
|
||||
assert stats["without_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
"""Test DocumentationHandler discovers docs without embeddings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory with test files
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
|
||||
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
|
||||
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
|
||||
|
||||
# Mock _get_docs_root to return temp dir
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
# Check guide.md
|
||||
guide_item = next(
|
||||
(item for item in items if item.content_id == "guide.md"), None
|
||||
)
|
||||
assert guide_item is not None
|
||||
assert guide_item.content_type == ContentType.DOCUMENTATION
|
||||
assert "Getting Started" in guide_item.searchable_text
|
||||
assert "This is a guide" in guide_item.searchable_text
|
||||
assert guide_item.metadata["title"] == "Getting Started"
|
||||
assert guide_item.user_id is None
|
||||
|
||||
# Check api.mdx
|
||||
api_item = next(
|
||||
(item for item in items if item.content_id == "api.mdx"), None
|
||||
)
|
||||
assert api_item is not None
|
||||
assert "API Reference" in api_item.searchable_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_stats(tmp_path, mocker):
|
||||
"""Test DocumentationHandler returns correct stats."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
(docs_root / "doc1.md").write_text("# Doc 1")
|
||||
(docs_root / "doc2.md").write_text("# Doc 2")
|
||||
(docs_root / "doc3.mdx").write_text("# Doc 3")
|
||||
|
||||
# Mock embedded count query (1 doc has embedding)
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 1
|
||||
assert stats["without_embeddings"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_title_extraction(tmp_path):
|
||||
"""Test DocumentationHandler extracts title from markdown heading."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test with heading
|
||||
doc_with_heading = tmp_path / "with_heading.md"
|
||||
doc_with_heading.write_text("# My Title\n\nContent here")
|
||||
title, content = handler._extract_title_and_content(doc_with_heading)
|
||||
assert title == "My Title"
|
||||
assert "# My Title" not in content
|
||||
assert "Content here" in content
|
||||
|
||||
# Test without heading
|
||||
doc_without_heading = tmp_path / "no-heading.md"
|
||||
doc_without_heading.write_text("Just content, no heading")
|
||||
title, content = handler._extract_title_and_content(doc_without_heading)
|
||||
assert title == "No Heading" # Uses filename
|
||||
assert "Just content" in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handlers_registry():
|
||||
"""Test all content types are registered."""
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_missing_attributes():
|
||||
"""Test BlockHandler gracefully handles blocks with missing attributes."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with minimal attributes
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
# No description, categories, or schema
|
||||
del mock_block_instance.description
|
||||
del mock_block_instance.categories
|
||||
del mock_block_instance.input_schema
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock one good block and one bad block
|
||||
good_block = MagicMock()
|
||||
good_instance = MagicMock()
|
||||
good_instance.name = "Good Block"
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
bad_block.side_effect = Exception("Instantiation failed")
|
||||
|
||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
# Should only get the good block
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Mock _get_docs_root to return non-existent path
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
@@ -1,8 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import fastapi
|
||||
import prisma.enums
|
||||
@@ -10,7 +9,7 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
from backend.data.db import query_raw_with_schema, transaction
|
||||
from backend.data.db import transaction
|
||||
from backend.data.graph import (
|
||||
GraphMeta,
|
||||
GraphModel,
|
||||
@@ -30,6 +29,8 @@ from backend.util.settings import Settings
|
||||
|
||||
from . import exceptions as store_exceptions
|
||||
from . import model as store_model
|
||||
from .embeddings import ensure_embedding
|
||||
from .hybrid_search import hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -50,128 +51,77 @@ async def get_store_agents(
|
||||
page_size: int = 20,
|
||||
) -> store_model.StoreAgentsResponse:
|
||||
"""
|
||||
Get PUBLIC store agents from the StoreAgent view
|
||||
Get PUBLIC store agents from the StoreAgent view.
|
||||
|
||||
Search behavior:
|
||||
- With search_query: Uses hybrid search (semantic + lexical)
|
||||
- Fallback: If embeddings unavailable, gracefully degrades to lexical-only
|
||||
- Rationale: User-facing endpoint prioritizes availability over accuracy
|
||||
|
||||
Note: Admin operations (approval) use fail-fast to prevent inconsistent state.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
|
||||
search_used_hybrid = False
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
agents: list[dict[str, Any]] = []
|
||||
total = 0
|
||||
total_pages = 0
|
||||
|
||||
try:
|
||||
# If search_query is provided, use full-text search
|
||||
# If search_query is provided, use hybrid search (embeddings + tsvector)
|
||||
if search_query:
|
||||
offset = (page - 1) * page_size
|
||||
# Try hybrid search combining semantic and lexical signals
|
||||
# Falls back to lexical-only if OpenAI unavailable (user-facing, high SLA)
|
||||
try:
|
||||
agents, total = await hybrid_search(
|
||||
query=search_query,
|
||||
featured=featured,
|
||||
creators=creators,
|
||||
category=category,
|
||||
sorted_by="relevance", # Use hybrid scoring for relevance
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
search_used_hybrid = True
|
||||
except Exception as e:
|
||||
# Log error but fall back to lexical search for better UX
|
||||
logger.error(
|
||||
f"Hybrid search failed (likely OpenAI unavailable), "
|
||||
f"falling back to lexical search: {e}"
|
||||
)
|
||||
# search_used_hybrid remains False, will use fallback path below
|
||||
|
||||
# Whitelist allowed order_by columns
|
||||
ALLOWED_ORDER_BY = {
|
||||
"rating": "rating DESC, rank DESC",
|
||||
"runs": "runs DESC, rank DESC",
|
||||
"name": "agent_name ASC, rank ASC",
|
||||
"updated_at": "updated_at DESC, rank DESC",
|
||||
}
|
||||
# Convert hybrid search results (dict format) if hybrid succeeded
|
||||
if search_used_hybrid:
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = store_model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error parsing Store agent from hybrid search results: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate and get order clause
|
||||
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
||||
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
||||
else:
|
||||
order_by_clause = "updated_at DESC, rank DESC"
|
||||
|
||||
# Build WHERE conditions and parameters list
|
||||
where_parts: list[str] = []
|
||||
params: list[typing.Any] = [search_query] # $1 - search term
|
||||
param_index = 2 # Start at $2 for next parameter
|
||||
|
||||
# Always filter for available agents
|
||||
where_parts.append("is_available = true")
|
||||
|
||||
if featured:
|
||||
where_parts.append("featured = true")
|
||||
|
||||
if creators and creators:
|
||||
# Use ANY with array parameter
|
||||
where_parts.append(f"creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category and category:
|
||||
where_parts.append(f"${param_index} = ANY(categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
limit_param = f"${param_index}"
|
||||
offset_param = f"${param_index + 1}"
|
||||
|
||||
# Execute full-text search query with parameterized values
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
ts_rank_cd(search, query) AS rank
|
||||
FROM {{schema_prefix}}"StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
ORDER BY {order_by_clause}
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
# Count query for pagination - only uses search term parameter
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
"""
|
||||
|
||||
# Execute both queries with parameters
|
||||
agents = await query_raw_with_schema(sql_query, *params)
|
||||
|
||||
# For count, use params without pagination (last 2 params)
|
||||
count_params = params[:-2]
|
||||
count_result = await query_raw_with_schema(count_query, *count_params)
|
||||
|
||||
total = count_result[0]["count"] if count_result else 0
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert raw results to StoreAgent models
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = store_model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Store agent from search results: {e}")
|
||||
continue
|
||||
|
||||
else:
|
||||
# Non-search query path (original logic)
|
||||
if not search_used_hybrid:
|
||||
# Fallback path - use basic search or no search
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
@@ -180,6 +130,14 @@ async def get_store_agents(
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
# Add basic text search if search_query provided but hybrid failed
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
@@ -188,7 +146,7 @@ async def get_store_agents(
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
@@ -199,7 +157,7 @@ async def get_store_agents(
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
for agent in db_agents:
|
||||
try:
|
||||
# Create the StoreAgent object safely
|
||||
store_agent = store_model.StoreAgent(
|
||||
@@ -1577,7 +1535,7 @@ async def review_store_submission(
|
||||
)
|
||||
|
||||
# Update the AgentGraph with store listing data
|
||||
await prisma.models.AgentGraph.prisma().update(
|
||||
await prisma.models.AgentGraph.prisma(tx).update(
|
||||
where={
|
||||
"graphVersionId": {
|
||||
"id": store_listing_version.agentGraphId,
|
||||
@@ -1592,6 +1550,23 @@ async def review_store_submission(
|
||||
},
|
||||
)
|
||||
|
||||
# Generate embedding for approved listing (blocking - admin operation)
|
||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||
embedding_success = await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=store_listing_version.name,
|
||||
description=store_listing_version.description,
|
||||
sub_heading=store_listing_version.subHeading,
|
||||
categories=store_listing_version.categories or [],
|
||||
tx=tx,
|
||||
)
|
||||
if not embedding_success:
|
||||
raise ValueError(
|
||||
f"Failed to generate embedding for listing {store_listing_version_id}. "
|
||||
"This is likely due to OpenAI API being unavailable. "
|
||||
"Please try again later or contact support if the issue persists."
|
||||
)
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
data={
|
||||
|
||||
@@ -0,0 +1,962 @@
|
||||
"""
|
||||
Unified Content Embeddings Service
|
||||
|
||||
Handles generation and storage of OpenAI embeddings for all content types
|
||||
(store listings, blocks, documentation, library agents) to enable semantic/hybrid search.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.api.features.store.content_handlers import CONTENT_HANDLERS
|
||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.json import dumps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
# Embedding dimension for the model above
|
||||
# text-embedding-3-small: 1536, text-embedding-3-large: 3072
|
||||
EMBEDDING_DIM = 1536
|
||||
# OpenAI embedding token limit (8,191 with 1 token buffer for safety)
|
||||
EMBEDDING_MAX_TOKENS = 8191
|
||||
|
||||
|
||||
def build_searchable_text(
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Build searchable text from listing version fields.
|
||||
|
||||
Combines relevant fields into a single string for embedding.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Name is important - include it
|
||||
if name:
|
||||
parts.append(name)
|
||||
|
||||
# Sub-heading provides context
|
||||
if sub_heading:
|
||||
parts.append(sub_heading)
|
||||
|
||||
# Description is the main content
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
# Categories help with semantic matching
|
||||
if categories:
|
||||
parts.append(" ".join(categories))
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
async def generate_embedding(text: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for text using OpenAI API.
|
||||
|
||||
Returns None if embedding generation fails.
|
||||
Fail-fast: no retries to maintain consistency with approval flow.
|
||||
"""
|
||||
try:
|
||||
client = get_openai_client()
|
||||
if not client:
|
||||
logger.error("openai_internal_api_key not set, cannot generate embedding")
|
||||
return None
|
||||
|
||||
# Truncate text to token limit using tiktoken
|
||||
# Character-based truncation is insufficient because token ratios vary by content type
|
||||
enc = encoding_for_model(EMBEDDING_MODEL)
|
||||
tokens = enc.encode(text)
|
||||
if len(tokens) > EMBEDDING_MAX_TOKENS:
|
||||
tokens = tokens[:EMBEDDING_MAX_TOKENS]
|
||||
truncated_text = enc.decode(tokens)
|
||||
logger.info(
|
||||
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
|
||||
)
|
||||
else:
|
||||
truncated_text = text
|
||||
|
||||
start_time = time.time()
|
||||
response = await client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=truncated_text,
|
||||
)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
embedding = response.data[0].embedding
|
||||
logger.info(
|
||||
f"Generated embedding: {len(embedding)} dims, "
|
||||
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
|
||||
)
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embedding: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def store_embedding(
|
||||
version_id: str,
|
||||
embedding: list[float],
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the database.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
DEPRECATED: Use ensure_embedding() instead (includes searchable_text).
|
||||
"""
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text="", # Empty for backward compat; ensure_embedding() populates this
|
||||
metadata=None,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
|
||||
async def store_content_embedding(
|
||||
content_type: ContentType,
|
||||
content_id: str,
|
||||
embedding: list[float],
|
||||
searchable_text: str,
|
||||
metadata: dict | None = None,
|
||||
user_id: str | None = None,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the unified content embeddings table.
|
||||
|
||||
New function for unified content embedding storage.
|
||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||
"""
|
||||
try:
|
||||
client = tx if tx else prisma.get_client()
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = embedding_to_vector_string(embedding)
|
||||
metadata_json = dumps(metadata or {})
|
||||
|
||||
# Upsert the embedding
|
||||
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
||||
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
||||
ON CONFLICT ("contentType", "contentId", "userId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $4::vector,
|
||||
"searchableText" = $5,
|
||||
"metadata" = $6::jsonb,
|
||||
"updatedAt" = NOW()
|
||||
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
|
||||
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
embedding_str,
|
||||
searchable_text,
|
||||
metadata_json,
|
||||
client=client,
|
||||
set_public_search_path=True,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for a listing version.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Returns dict with storeListingVersionId, embedding, timestamps or None if not found.
|
||||
"""
|
||||
result = await get_content_embedding(
|
||||
ContentType.STORE_AGENT, version_id, user_id=None
|
||||
)
|
||||
if result:
|
||||
# Transform to old format for backward compatibility
|
||||
return {
|
||||
"storeListingVersionId": result["contentId"],
|
||||
"embedding": result["embedding"],
|
||||
"createdAt": result["createdAt"],
|
||||
"updatedAt": result["updatedAt"],
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def get_content_embedding(
|
||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for any content type.
|
||||
|
||||
New function for unified content embedding retrieval.
|
||||
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
||||
"""
|
||||
try:
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
"contentType",
|
||||
"contentId",
|
||||
"userId",
|
||||
"embedding"::text as "embedding",
|
||||
"searchableText",
|
||||
"metadata",
|
||||
"createdAt",
|
||||
"updatedAt"
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
set_public_search_path=True,
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def ensure_embedding(
|
||||
version_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for the listing version.
|
||||
|
||||
Creates embedding if missing. Use force=True to regenerate.
|
||||
Backward-compatible wrapper for store listings.
|
||||
|
||||
Args:
|
||||
version_id: The StoreListingVersion ID
|
||||
name: Agent name
|
||||
description: Agent description
|
||||
sub_heading: Agent sub-heading
|
||||
categories: Agent categories
|
||||
force: Force regeneration even if embedding exists
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_embedding(version_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(f"Embedding for version {version_id} already exists")
|
||||
return True
|
||||
|
||||
# Build searchable text for embedding
|
||||
searchable_text = build_searchable_text(
|
||||
name, description, sub_heading, categories
|
||||
)
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
||||
return False
|
||||
|
||||
# Store the embedding with metadata using new function
|
||||
metadata = {
|
||||
"name": name,
|
||||
"subHeading": sub_heading,
|
||||
"categories": categories,
|
||||
}
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_embedding(version_id: str) -> bool:
|
||||
"""
|
||||
Delete embedding for a listing version.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
"""
|
||||
return await delete_content_embedding(ContentType.STORE_AGENT, version_id)
|
||||
|
||||
|
||||
async def delete_content_embedding(
|
||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Delete embedding for any content type.
|
||||
|
||||
New function for unified content embedding deletion.
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
|
||||
Args:
|
||||
content_type: The type of content (STORE_AGENT, LIBRARY_AGENT, etc.)
|
||||
content_id: The unique identifier for the content
|
||||
user_id: Optional user ID. For public content (STORE_AGENT, BLOCK), pass None.
|
||||
For user-scoped content (LIBRARY_AGENT), pass the user's ID to avoid
|
||||
deleting embeddings belonging to other users.
|
||||
|
||||
Returns:
|
||||
True if deletion succeeded, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND "contentId" = $2
|
||||
AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
client=client,
|
||||
)
|
||||
|
||||
user_str = f" (user: {user_id})" if user_id else ""
|
||||
logger.info(f"Deleted embedding for {content_type}:{content_id}{user_str}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding_stats() -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about embedding coverage for all content types.
|
||||
|
||||
Returns stats per content type and overall totals.
|
||||
"""
|
||||
try:
|
||||
stats_by_type = {}
|
||||
total_items = 0
|
||||
total_with_embeddings = 0
|
||||
total_without_embeddings = 0
|
||||
|
||||
# Aggregate stats from all handlers
|
||||
for content_type, handler in CONTENT_HANDLERS.items():
|
||||
try:
|
||||
stats = await handler.get_stats()
|
||||
stats_by_type[content_type.value] = {
|
||||
"total": stats["total"],
|
||||
"with_embeddings": stats["with_embeddings"],
|
||||
"without_embeddings": stats["without_embeddings"],
|
||||
"coverage_percent": (
|
||||
round(stats["with_embeddings"] / stats["total"] * 100, 1)
|
||||
if stats["total"] > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
total_items += stats["total"]
|
||||
total_with_embeddings += stats["with_embeddings"]
|
||||
total_without_embeddings += stats["without_embeddings"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get stats for {content_type.value}: {e}")
|
||||
stats_by_type[content_type.value] = {
|
||||
"total": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
return {
|
||||
"by_type": stats_by_type,
|
||||
"totals": {
|
||||
"total": total_items,
|
||||
"with_embeddings": total_with_embeddings,
|
||||
"without_embeddings": total_without_embeddings,
|
||||
"coverage_percent": (
|
||||
round(total_with_embeddings / total_items * 100, 1)
|
||||
if total_items > 0
|
||||
else 0
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding stats: {e}")
|
||||
return {
|
||||
"by_type": {},
|
||||
"totals": {
|
||||
"total": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
},
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for approved listings that don't have them.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing usage.
|
||||
This now delegates to backfill_all_content_types() to process all content types.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate per content type
|
||||
|
||||
Returns:
|
||||
Dict with success/failure counts aggregated across all content types
|
||||
"""
|
||||
# Delegate to the new generic backfill system
|
||||
result = await backfill_all_content_types(batch_size)
|
||||
|
||||
# Return in the old format for backward compatibility
|
||||
return result["totals"]
|
||||
|
||||
|
||||
async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for all content types using registered handlers.
|
||||
|
||||
Processes content types in order: BLOCK → STORE_AGENT → DOCUMENTATION.
|
||||
This ensures foundational content (blocks) are searchable first.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate per content type
|
||||
|
||||
Returns:
|
||||
Dict with stats per content type and overall totals
|
||||
"""
|
||||
results_by_type = {}
|
||||
total_processed = 0
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
|
||||
# Process content types in explicit order
|
||||
processing_order = [
|
||||
ContentType.BLOCK,
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
for content_type in processing_order:
|
||||
handler = CONTENT_HANDLERS.get(content_type)
|
||||
if not handler:
|
||||
logger.warning(f"No handler registered for {content_type.value}")
|
||||
continue
|
||||
try:
|
||||
logger.info(f"Processing {content_type.value} content type...")
|
||||
|
||||
# Get missing items from handler
|
||||
missing_items = await handler.get_missing_items(batch_size)
|
||||
|
||||
if not missing_items:
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"message": "No missing embeddings",
|
||||
}
|
||||
continue
|
||||
|
||||
# Process embeddings concurrently for better performance
|
||||
embedding_tasks = [
|
||||
ensure_content_embedding(
|
||||
content_type=item.content_type,
|
||||
content_id=item.content_id,
|
||||
searchable_text=item.searchable_text,
|
||||
metadata=item.metadata,
|
||||
user_id=item.user_id,
|
||||
)
|
||||
for item in missing_items
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
|
||||
|
||||
success = sum(1 for result in results if result is True)
|
||||
failed = len(results) - success
|
||||
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": len(missing_items),
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"message": f"Backfilled {success} embeddings, {failed} failed",
|
||||
}
|
||||
|
||||
total_processed += len(missing_items)
|
||||
total_success += success
|
||||
total_failed += failed
|
||||
|
||||
logger.info(
|
||||
f"{content_type.value}: processed {len(missing_items)}, "
|
||||
f"success {success}, failed {failed}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process {content_type.value}: {e}")
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
return {
|
||||
"by_type": results_by_type,
|
||||
"totals": {
|
||||
"processed": total_processed,
|
||||
"success": total_success,
|
||||
"failed": total_failed,
|
||||
"message": f"Overall: {total_success} succeeded, {total_failed} failed",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def embed_query(query: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for a search query.
|
||||
|
||||
Same as generate_embedding but with clearer intent.
|
||||
"""
|
||||
return await generate_embedding(query)
|
||||
|
||||
|
||||
def embedding_to_vector_string(embedding: list[float]) -> str:
|
||||
"""Convert embedding list to PostgreSQL vector string format."""
|
||||
return "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
|
||||
|
||||
async def ensure_content_embedding(
|
||||
content_type: ContentType,
|
||||
content_id: str,
|
||||
searchable_text: str,
|
||||
metadata: dict | None = None,
|
||||
user_id: str | None = None,
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for any content type.
|
||||
|
||||
Generic function for creating embeddings for store agents, blocks, docs, etc.
|
||||
|
||||
Args:
|
||||
content_type: ContentType enum value (STORE_AGENT, BLOCK, etc.)
|
||||
content_id: Unique identifier for the content
|
||||
searchable_text: Combined text for embedding generation
|
||||
metadata: Optional metadata to store with embedding
|
||||
force: Force regeneration even if embedding exists
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_content_embedding(content_type, content_id, user_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(
|
||||
f"Embedding for {content_type}:{content_id} already exists"
|
||||
)
|
||||
return True
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for {content_type}:{content_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Store the embedding
|
||||
return await store_content_embedding(
|
||||
content_type=content_type,
|
||||
content_id=content_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata or {},
|
||||
user_id=user_id,
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
||||
"""
|
||||
Clean up embeddings for content that no longer exists or is no longer valid.
|
||||
|
||||
Compares current content with embeddings in database and removes orphaned records:
|
||||
- STORE_AGENT: Removes embeddings for rejected/deleted store listings
|
||||
- BLOCK: Removes embeddings for blocks no longer registered
|
||||
- DOCUMENTATION: Removes embeddings for deleted doc files
|
||||
|
||||
Returns:
|
||||
Dict with cleanup statistics per content type
|
||||
"""
|
||||
results_by_type = {}
|
||||
total_deleted = 0
|
||||
|
||||
# Cleanup orphaned embeddings for all content types
|
||||
cleanup_types = [
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.BLOCK,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
for content_type in cleanup_types:
|
||||
try:
|
||||
handler = CONTENT_HANDLERS.get(content_type)
|
||||
if not handler:
|
||||
logger.warning(f"No handler registered for {content_type}")
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": 0,
|
||||
"error": "No handler registered",
|
||||
}
|
||||
continue
|
||||
|
||||
# Get all current content IDs from handler
|
||||
if content_type == ContentType.STORE_AGENT:
|
||||
# Get IDs of approved store listing versions from non-deleted listings
|
||||
valid_agents = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT slv.id
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
JOIN {schema_prefix}"StoreListing" sl ON slv."storeListingId" = sl.id
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND sl."isDeleted" = false
|
||||
""",
|
||||
)
|
||||
current_ids = {row["id"] for row in valid_agents}
|
||||
elif content_type == ContentType.BLOCK:
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
current_ids = set(get_blocks().keys())
|
||||
elif content_type == ContentType.DOCUMENTATION:
|
||||
from pathlib import Path
|
||||
|
||||
# embeddings.py is at: backend/backend/api/features/store/embeddings.py
|
||||
# Need to go up to project root then into docs/
|
||||
this_file = Path(__file__)
|
||||
project_root = (
|
||||
this_file.parent.parent.parent.parent.parent.parent.parent
|
||||
)
|
||||
docs_root = project_root / "docs"
|
||||
if docs_root.exists():
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(
|
||||
docs_root.rglob("*.mdx")
|
||||
)
|
||||
current_ids = {str(doc.relative_to(docs_root)) for doc in all_docs}
|
||||
else:
|
||||
current_ids = set()
|
||||
else:
|
||||
# Skip unknown content types to avoid accidental deletion
|
||||
logger.warning(
|
||||
f"Skipping cleanup for unknown content type: {content_type}"
|
||||
)
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": 0,
|
||||
"error": "Unknown content type - skipped for safety",
|
||||
}
|
||||
continue
|
||||
|
||||
# Get all embedding IDs from database
|
||||
db_embeddings = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT "contentId"
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
||||
""",
|
||||
content_type,
|
||||
)
|
||||
|
||||
db_ids = {row["contentId"] for row in db_embeddings}
|
||||
|
||||
# Find orphaned embeddings (in DB but not in current content)
|
||||
orphaned_ids = db_ids - current_ids
|
||||
|
||||
if not orphaned_ids:
|
||||
logger.info(f"{content_type.value}: No orphaned embeddings found")
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": 0,
|
||||
"message": "No orphaned embeddings",
|
||||
}
|
||||
continue
|
||||
|
||||
# Delete orphaned embeddings in batch for better performance
|
||||
orphaned_list = list(orphaned_ids)
|
||||
try:
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND "contentId" = ANY($2::text[])
|
||||
""",
|
||||
content_type,
|
||||
orphaned_list,
|
||||
)
|
||||
deleted = len(orphaned_list)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to batch delete orphaned embeddings: {e}")
|
||||
deleted = 0
|
||||
|
||||
logger.info(
|
||||
f"{content_type.value}: Deleted {deleted}/{len(orphaned_ids)} orphaned embeddings"
|
||||
)
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": deleted,
|
||||
"orphaned": len(orphaned_ids),
|
||||
"message": f"Deleted {deleted} orphaned embeddings",
|
||||
}
|
||||
|
||||
total_deleted += deleted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup {content_type.value}: {e}")
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
return {
|
||||
"by_type": results_by_type,
|
||||
"totals": {
|
||||
"deleted": total_deleted,
|
||||
"message": f"Deleted {total_deleted} orphaned embeddings",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def semantic_search(
|
||||
query: str,
|
||||
content_types: list[ContentType] | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int = 20,
|
||||
min_similarity: float = 0.5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Semantic search across content types using embeddings.
|
||||
|
||||
Performs vector similarity search on UnifiedContentEmbedding table.
|
||||
Used directly for blocks/docs/library agents, or as the semantic component
|
||||
within hybrid_search for store agents.
|
||||
|
||||
If embedding generation fails, falls back to lexical search on searchableText.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
content_types: List of ContentType to search. Defaults to [BLOCK, STORE_AGENT, DOCUMENTATION]
|
||||
user_id: Optional user ID for searching private content (library agents)
|
||||
limit: Maximum number of results to return (default: 20)
|
||||
min_similarity: Minimum cosine similarity threshold (0-1, default: 0.5)
|
||||
|
||||
Returns:
|
||||
List of search results with the following structure:
|
||||
[
|
||||
{
|
||||
"content_id": str,
|
||||
"content_type": str, # "BLOCK", "STORE_AGENT", "DOCUMENTATION", or "LIBRARY_AGENT"
|
||||
"searchable_text": str,
|
||||
"metadata": dict,
|
||||
"similarity": float, # Cosine similarity score (0-1)
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Examples:
|
||||
# Search blocks only
|
||||
results = await semantic_search("calculate", content_types=[ContentType.BLOCK])
|
||||
|
||||
# Search blocks and documentation
|
||||
results = await semantic_search(
|
||||
"how to use API",
|
||||
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION]
|
||||
)
|
||||
|
||||
# Search all public content (default)
|
||||
results = await semantic_search("AI agent")
|
||||
|
||||
# Search user's library agents
|
||||
results = await semantic_search(
|
||||
"my custom agent",
|
||||
content_types=[ContentType.LIBRARY_AGENT],
|
||||
user_id="user123"
|
||||
)
|
||||
"""
|
||||
# Default to searching all public content types
|
||||
if content_types is None:
|
||||
content_types = [
|
||||
ContentType.BLOCK,
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
# Validate inputs
|
||||
if not content_types:
|
||||
return [] # Empty content_types would cause invalid SQL (IN ())
|
||||
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return []
|
||||
|
||||
if limit < 1:
|
||||
limit = 1
|
||||
if limit > 100:
|
||||
limit = 100
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
if query_embedding is not None:
|
||||
# Semantic search with embeddings
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
|
||||
# Build params in order: limit, then user_id (if provided), then content types
|
||||
params: list[Any] = [limit]
|
||||
user_filter = ""
|
||||
if user_id is not None:
|
||||
user_filter = 'AND "userId" = ${}'.format(len(params) + 1)
|
||||
params.append(user_id)
|
||||
|
||||
# Add content type parameters and build placeholders dynamically
|
||||
content_type_start_idx = len(params) + 1
|
||||
content_type_placeholders = ", ".join(
|
||||
f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"'
|
||||
for i in range(len(content_types))
|
||||
)
|
||||
params.extend([ct.value for ct in content_types])
|
||||
|
||||
sql = f"""
|
||||
SELECT
|
||||
"contentId" as content_id,
|
||||
"contentType" as content_type,
|
||||
"searchableText" as searchable_text,
|
||||
metadata,
|
||||
1 - (embedding <=> '{embedding_str}'::vector) as similarity
|
||||
FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ({content_type_placeholders})
|
||||
{user_filter}
|
||||
AND 1 - (embedding <=> '{embedding_str}'::vector) >= ${len(params) + 1}
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $1
|
||||
"""
|
||||
params.append(min_similarity)
|
||||
|
||||
try:
|
||||
results = await query_raw_with_schema(
|
||||
sql, *params, set_public_search_path=True
|
||||
)
|
||||
return [
|
||||
{
|
||||
"content_id": row["content_id"],
|
||||
"content_type": row["content_type"],
|
||||
"searchable_text": row["searchable_text"],
|
||||
"metadata": row["metadata"],
|
||||
"similarity": float(row["similarity"]),
|
||||
}
|
||||
for row in results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Semantic search failed: {e}")
|
||||
# Fall through to lexical search below
|
||||
|
||||
# Fallback to lexical search if embeddings unavailable
|
||||
logger.warning("Falling back to lexical search (embeddings unavailable)")
|
||||
|
||||
params_lexical: list[Any] = [limit]
|
||||
user_filter = ""
|
||||
if user_id is not None:
|
||||
user_filter = 'AND "userId" = ${}'.format(len(params_lexical) + 1)
|
||||
params_lexical.append(user_id)
|
||||
|
||||
# Add content type parameters and build placeholders dynamically
|
||||
content_type_start_idx = len(params_lexical) + 1
|
||||
content_type_placeholders_lexical = ", ".join(
|
||||
f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"'
|
||||
for i in range(len(content_types))
|
||||
)
|
||||
params_lexical.extend([ct.value for ct in content_types])
|
||||
|
||||
sql_lexical = f"""
|
||||
SELECT
|
||||
"contentId" as content_id,
|
||||
"contentType" as content_type,
|
||||
"searchableText" as searchable_text,
|
||||
metadata,
|
||||
0.0 as similarity
|
||||
FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ({content_type_placeholders_lexical})
|
||||
{user_filter}
|
||||
AND "searchableText" ILIKE ${len(params_lexical) + 1}
|
||||
ORDER BY "updatedAt" DESC
|
||||
LIMIT $1
|
||||
"""
|
||||
params_lexical.append(f"%{query}%")
|
||||
|
||||
try:
|
||||
results = await query_raw_with_schema(
|
||||
sql_lexical, *params_lexical, set_public_search_path=True
|
||||
)
|
||||
return [
|
||||
{
|
||||
"content_id": row["content_id"],
|
||||
"content_type": row["content_type"],
|
||||
"searchable_text": row["searchable_text"],
|
||||
"metadata": row["metadata"],
|
||||
"similarity": 0.0, # Lexical search doesn't provide similarity
|
||||
}
|
||||
for row in results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Lexical search failed: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,666 @@
|
||||
"""
|
||||
End-to-end database tests for embeddings and hybrid search.
|
||||
|
||||
These tests hit the actual database to verify SQL queries work correctly.
|
||||
Tests cover:
|
||||
1. Embedding storage (store_content_embedding)
|
||||
2. Embedding retrieval (get_content_embedding)
|
||||
3. Embedding deletion (delete_content_embedding)
|
||||
4. Unified hybrid search across content types
|
||||
5. Store agent hybrid search
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
from backend.api.features.store.embeddings import EMBEDDING_DIM
|
||||
from backend.api.features.store.hybrid_search import (
|
||||
hybrid_search,
|
||||
unified_hybrid_search,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Test Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_content_id() -> str:
|
||||
"""Generate unique content ID for test isolation."""
|
||||
return f"test-content-{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Generate unique user ID for test isolation."""
|
||||
return f"test-user-{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding() -> list[float]:
|
||||
"""Generate a mock embedding vector."""
|
||||
# Create a normalized embedding vector
|
||||
import math
|
||||
|
||||
raw = [float(i % 10) / 10.0 for i in range(EMBEDDING_DIM)]
|
||||
# Normalize to unit length (required for cosine similarity)
|
||||
magnitude = math.sqrt(sum(x * x for x in raw))
|
||||
return [x / magnitude for x in raw]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def similar_embedding() -> list[float]:
|
||||
"""Generate an embedding similar to mock_embedding."""
|
||||
import math
|
||||
|
||||
# Similar but slightly different values
|
||||
raw = [float(i % 10) / 10.0 + 0.01 for i in range(EMBEDDING_DIM)]
|
||||
magnitude = math.sqrt(sum(x * x for x in raw))
|
||||
return [x / magnitude for x in raw]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def different_embedding() -> list[float]:
|
||||
"""Generate an embedding very different from mock_embedding."""
|
||||
import math
|
||||
|
||||
# Reversed pattern to be maximally different
|
||||
raw = [float((EMBEDDING_DIM - i) % 10) / 10.0 for i in range(EMBEDDING_DIM)]
|
||||
magnitude = math.sqrt(sum(x * x for x in raw))
|
||||
return [x / magnitude for x in raw]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def cleanup_embeddings(
|
||||
server,
|
||||
) -> AsyncGenerator[list[tuple[ContentType, str, str | None]], None]:
|
||||
"""
|
||||
Fixture that tracks created embeddings and cleans them up after tests.
|
||||
|
||||
Yields a list to which tests can append (content_type, content_id, user_id) tuples.
|
||||
"""
|
||||
created_embeddings: list[tuple[ContentType, str, str | None]] = []
|
||||
yield created_embeddings
|
||||
|
||||
# Cleanup all created embeddings
|
||||
for content_type, content_id, user_id in created_embeddings:
|
||||
try:
|
||||
await embeddings.delete_content_embedding(content_type, content_id, user_id)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# store_content_embedding Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_content_embedding_store_agent(
|
||||
server,
|
||||
test_content_id: str,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test storing embedding for STORE_AGENT content type."""
|
||||
# Track for cleanup
|
||||
cleanup_embeddings.append((ContentType.STORE_AGENT, test_content_id, None))
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="AI assistant for productivity tasks",
|
||||
metadata={"name": "Test Agent", "categories": ["productivity"]},
|
||||
user_id=None, # Store agents are public
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify it was stored
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.STORE_AGENT, test_content_id, user_id=None
|
||||
)
|
||||
assert stored is not None
|
||||
assert stored["contentId"] == test_content_id
|
||||
assert stored["contentType"] == "STORE_AGENT"
|
||||
assert stored["searchableText"] == "AI assistant for productivity tasks"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_content_embedding_block(
|
||||
server,
|
||||
test_content_id: str,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test storing embedding for BLOCK content type."""
|
||||
cleanup_embeddings.append((ContentType.BLOCK, test_content_id, None))
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="HTTP request block for API calls",
|
||||
metadata={"name": "HTTP Request Block"},
|
||||
user_id=None, # Blocks are public
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.BLOCK, test_content_id, user_id=None
|
||||
)
|
||||
assert stored is not None
|
||||
assert stored["contentType"] == "BLOCK"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_content_embedding_documentation(
|
||||
server,
|
||||
test_content_id: str,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test storing embedding for DOCUMENTATION content type."""
|
||||
cleanup_embeddings.append((ContentType.DOCUMENTATION, test_content_id, None))
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.DOCUMENTATION,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="Getting started guide for AutoGPT platform",
|
||||
metadata={"title": "Getting Started", "url": "/docs/getting-started"},
|
||||
user_id=None, # Docs are public
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.DOCUMENTATION, test_content_id, user_id=None
|
||||
)
|
||||
assert stored is not None
|
||||
assert stored["contentType"] == "DOCUMENTATION"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_content_embedding_upsert(
|
||||
server,
|
||||
test_content_id: str,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test that storing embedding twice updates instead of duplicates."""
|
||||
cleanup_embeddings.append((ContentType.BLOCK, test_content_id, None))
|
||||
|
||||
# Store first time
|
||||
result1 = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="Original text",
|
||||
metadata={"version": 1},
|
||||
user_id=None,
|
||||
)
|
||||
assert result1 is True
|
||||
|
||||
# Store again with different text (upsert)
|
||||
result2 = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="Updated text",
|
||||
metadata={"version": 2},
|
||||
user_id=None,
|
||||
)
|
||||
assert result2 is True
|
||||
|
||||
# Verify only one record with updated text
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.BLOCK, test_content_id, user_id=None
|
||||
)
|
||||
assert stored is not None
|
||||
assert stored["searchableText"] == "Updated text"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# get_content_embedding Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_content_embedding_not_found(server):
|
||||
"""Test retrieving non-existent embedding returns None."""
|
||||
result = await embeddings.get_content_embedding(
|
||||
ContentType.STORE_AGENT, "non-existent-id", user_id=None
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_content_embedding_with_metadata(
|
||||
server,
|
||||
test_content_id: str,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test that metadata is correctly stored and retrieved."""
|
||||
cleanup_embeddings.append((ContentType.STORE_AGENT, test_content_id, None))
|
||||
|
||||
metadata = {
|
||||
"name": "Test Agent",
|
||||
"subHeading": "A test agent",
|
||||
"categories": ["ai", "productivity"],
|
||||
"customField": 123,
|
||||
}
|
||||
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="test",
|
||||
metadata=metadata,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.STORE_AGENT, test_content_id, user_id=None
|
||||
)
|
||||
|
||||
assert stored is not None
|
||||
assert stored["metadata"]["name"] == "Test Agent"
|
||||
assert stored["metadata"]["categories"] == ["ai", "productivity"]
|
||||
assert stored["metadata"]["customField"] == 123
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# delete_content_embedding Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_delete_content_embedding(
|
||||
server,
|
||||
test_content_id: str,
|
||||
mock_embedding: list[float],
|
||||
):
|
||||
"""Test deleting embedding removes it from database."""
|
||||
# Store embedding
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=test_content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="To be deleted",
|
||||
metadata=None,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Verify it exists
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.BLOCK, test_content_id, user_id=None
|
||||
)
|
||||
assert stored is not None
|
||||
|
||||
# Delete it
|
||||
result = await embeddings.delete_content_embedding(
|
||||
ContentType.BLOCK, test_content_id, user_id=None
|
||||
)
|
||||
assert result is True
|
||||
|
||||
# Verify it's gone
|
||||
stored = await embeddings.get_content_embedding(
|
||||
ContentType.BLOCK, test_content_id, user_id=None
|
||||
)
|
||||
assert stored is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_delete_content_embedding_not_found(server):
|
||||
"""Test deleting non-existent embedding doesn't error."""
|
||||
result = await embeddings.delete_content_embedding(
|
||||
ContentType.BLOCK, "non-existent-id", user_id=None
|
||||
)
|
||||
# Should succeed even if nothing to delete
|
||||
assert result is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# unified_hybrid_search Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unified_hybrid_search_finds_matching_content(
|
||||
server,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test unified search finds content matching the query."""
|
||||
# Create unique content IDs
|
||||
agent_id = f"test-agent-{uuid.uuid4()}"
|
||||
block_id = f"test-block-{uuid.uuid4()}"
|
||||
doc_id = f"test-doc-{uuid.uuid4()}"
|
||||
|
||||
cleanup_embeddings.append((ContentType.STORE_AGENT, agent_id, None))
|
||||
cleanup_embeddings.append((ContentType.BLOCK, block_id, None))
|
||||
cleanup_embeddings.append((ContentType.DOCUMENTATION, doc_id, None))
|
||||
|
||||
# Store embeddings for different content types
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=agent_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="AI writing assistant for blog posts",
|
||||
metadata={"name": "Writing Assistant"},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=block_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="Text generation block for creative writing",
|
||||
metadata={"name": "Text Generator"},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.DOCUMENTATION,
|
||||
content_id=doc_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="How to use writing blocks in AutoGPT",
|
||||
metadata={"title": "Writing Guide"},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Search for "writing" - should find all three
|
||||
results, total = await unified_hybrid_search(
|
||||
query="writing",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Should find at least our test content (may find others too)
|
||||
content_ids = [r["content_id"] for r in results]
|
||||
assert agent_id in content_ids or total >= 1 # Lexical search should find it
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unified_hybrid_search_filter_by_content_type(
|
||||
server,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test unified search can filter by content type."""
|
||||
agent_id = f"test-agent-{uuid.uuid4()}"
|
||||
block_id = f"test-block-{uuid.uuid4()}"
|
||||
|
||||
cleanup_embeddings.append((ContentType.STORE_AGENT, agent_id, None))
|
||||
cleanup_embeddings.append((ContentType.BLOCK, block_id, None))
|
||||
|
||||
# Store both types with same searchable text
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=agent_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="unique_search_term_xyz123",
|
||||
metadata={},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=block_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="unique_search_term_xyz123",
|
||||
metadata={},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Search only for BLOCK type
|
||||
results, total = await unified_hybrid_search(
|
||||
query="unique_search_term_xyz123",
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# All results should be BLOCK type
|
||||
for r in results:
|
||||
assert r["content_type"] == "BLOCK"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unified_hybrid_search_empty_query(server):
|
||||
"""Test unified search with empty query returns empty results."""
|
||||
results, total = await unified_hybrid_search(
|
||||
query="",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unified_hybrid_search_pagination(
|
||||
server,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test unified search pagination works correctly."""
|
||||
# Create multiple items
|
||||
content_ids = []
|
||||
for i in range(5):
|
||||
content_id = f"test-pagination-{uuid.uuid4()}"
|
||||
content_ids.append(content_id)
|
||||
cleanup_embeddings.append((ContentType.BLOCK, content_id, None))
|
||||
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text=f"pagination test item number {i}",
|
||||
metadata={"index": i},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Get first page
|
||||
page1_results, total1 = await unified_hybrid_search(
|
||||
query="pagination test",
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=2,
|
||||
)
|
||||
|
||||
# Get second page
|
||||
page2_results, total2 = await unified_hybrid_search(
|
||||
query="pagination test",
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=2,
|
||||
page_size=2,
|
||||
)
|
||||
|
||||
# Total should be consistent
|
||||
assert total1 == total2
|
||||
|
||||
# Pages should have different content (if we have enough results)
|
||||
if len(page1_results) > 0 and len(page2_results) > 0:
|
||||
page1_ids = {r["content_id"] for r in page1_results}
|
||||
page2_ids = {r["content_id"] for r in page2_results}
|
||||
# No overlap between pages
|
||||
assert page1_ids.isdisjoint(page2_ids)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unified_hybrid_search_min_score_filtering(
|
||||
server,
|
||||
mock_embedding: list[float],
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test unified search respects min_score threshold."""
|
||||
content_id = f"test-minscore-{uuid.uuid4()}"
|
||||
cleanup_embeddings.append((ContentType.BLOCK, content_id, None))
|
||||
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text="completely unrelated content about bananas",
|
||||
metadata={},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Search with very high min_score - should filter out low relevance
|
||||
results_high, _ = await unified_hybrid_search(
|
||||
query="quantum computing algorithms",
|
||||
content_types=[ContentType.BLOCK],
|
||||
min_score=0.9, # Very high threshold
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Search with low min_score
|
||||
results_low, _ = await unified_hybrid_search(
|
||||
query="quantum computing algorithms",
|
||||
content_types=[ContentType.BLOCK],
|
||||
min_score=0.01, # Very low threshold
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# High threshold should have fewer or equal results
|
||||
assert len(results_high) <= len(results_low)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# hybrid_search (Store Agents) Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_hybrid_search_store_agents_sql_valid(server):
|
||||
"""Test that hybrid_search SQL executes without errors."""
|
||||
# This test verifies the SQL is syntactically correct
|
||||
# even if no results are found
|
||||
results, total = await hybrid_search(
|
||||
query="test agent",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Should not raise - verifies SQL is valid
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
assert total >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_hybrid_search_with_filters(server):
|
||||
"""Test hybrid_search with various filter options."""
|
||||
# Test with all filter types
|
||||
results, total = await hybrid_search(
|
||||
query="productivity",
|
||||
featured=True,
|
||||
creators=["test-creator"],
|
||||
category="productivity",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Should not raise - verifies filter SQL is valid
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_hybrid_search_pagination(server):
|
||||
"""Test hybrid_search pagination."""
|
||||
# Page 1
|
||||
results1, total1 = await hybrid_search(
|
||||
query="agent",
|
||||
page=1,
|
||||
page_size=5,
|
||||
)
|
||||
|
||||
# Page 2
|
||||
results2, total2 = await hybrid_search(
|
||||
query="agent",
|
||||
page=2,
|
||||
page_size=5,
|
||||
)
|
||||
|
||||
# Verify SQL executes without error
|
||||
assert isinstance(results1, list)
|
||||
assert isinstance(results2, list)
|
||||
assert isinstance(total1, int)
|
||||
assert isinstance(total2, int)
|
||||
|
||||
# If page 1 has results, total should be > 0
|
||||
# Note: total from page 2 may be 0 if no results on that page (COUNT(*) OVER limitation)
|
||||
if results1:
|
||||
assert total1 > 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SQL Validity Tests (verify queries don't break)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_all_content_types_searchable(server):
|
||||
"""Test that all content types can be searched without SQL errors."""
|
||||
for content_type in [
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.BLOCK,
|
||||
ContentType.DOCUMENTATION,
|
||||
]:
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
content_types=[content_type],
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_multiple_content_types_searchable(server):
|
||||
"""Test searching multiple content types at once."""
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_search_all_content_types_default(server):
|
||||
"""Test searching all content types (default behavior)."""
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
content_types=None, # Should search all
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
Integration tests for embeddings with schema handling.
|
||||
|
||||
These tests verify that embeddings operations work correctly across different database schemas.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
from backend.api.features.store.embeddings import EMBEDDING_DIM
|
||||
|
||||
# Schema prefix tests removed - functionality moved to db.raw_with_schema() helper
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_store_content_embedding_with_schema():
|
||||
"""Test storing embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1] * EMBEDDING_DIM,
|
||||
searchable_text="test text",
|
||||
metadata={"test": "data"},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.execute_raw.called
|
||||
|
||||
# Get the SQL query that was executed
|
||||
call_args = mock_client.execute_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix is in the query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify result
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_get_content_embedding_with_schema():
|
||||
"""Test retrieving embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.query_raw.return_value = [
|
||||
{
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-id",
|
||||
"userId": None,
|
||||
"embedding": "[0.1, 0.2]",
|
||||
"searchableText": "test",
|
||||
"metadata": {},
|
||||
"createdAt": "2024-01-01",
|
||||
"updatedAt": "2024-01-01",
|
||||
}
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.get_content_embedding(
|
||||
ContentType.STORE_AGENT,
|
||||
"test-id",
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.query_raw.called
|
||||
|
||||
# Get the SQL query that was executed
|
||||
call_args = mock_client.query_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix is in the query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify result
|
||||
assert result is not None
|
||||
assert result["contentId"] == "test-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_delete_content_embedding_with_schema():
|
||||
"""Test deleting embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.delete_content_embedding(
|
||||
ContentType.STORE_AGENT,
|
||||
"test-id",
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.execute_raw.called
|
||||
|
||||
# Get the SQL query that was executed
|
||||
call_args = mock_client.execute_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix is in the query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify result
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_get_embedding_stats_with_schema():
|
||||
"""Test embedding statistics with proper schema handling via content handlers."""
|
||||
# Mock handler to return stats
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.get_stats = AsyncMock(
|
||||
return_value={
|
||||
"total": 100,
|
||||
"with_embeddings": 80,
|
||||
"without_embeddings": 20,
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
|
||||
{ContentType.STORE_AGENT: mock_handler},
|
||||
):
|
||||
result = await embeddings.get_embedding_stats()
|
||||
|
||||
# Verify handler was called
|
||||
mock_handler.get_stats.assert_called_once()
|
||||
|
||||
# Verify new result structure
|
||||
assert "by_type" in result
|
||||
assert "totals" in result
|
||||
assert result["totals"]["total"] == 100
|
||||
assert result["totals"]["with_embeddings"] == 80
|
||||
assert result["totals"]["without_embeddings"] == 20
|
||||
assert result["totals"]["coverage_percent"] == 80.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backfill_missing_embeddings_with_schema():
|
||||
"""Test backfilling embeddings via content handlers."""
|
||||
from backend.api.features.store.content_handlers import ContentItem
|
||||
|
||||
# Create mock content item
|
||||
mock_item = ContentItem(
|
||||
content_id="version-1",
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text="Test Agent Test description",
|
||||
metadata={"name": "Test Agent"},
|
||||
)
|
||||
|
||||
# Mock handler
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.get_missing_items = AsyncMock(return_value=[mock_item])
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
|
||||
{ContentType.STORE_AGENT: mock_handler},
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding",
|
||||
return_value=[0.1] * EMBEDDING_DIM,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.store_content_embedding",
|
||||
return_value=True,
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=10)
|
||||
|
||||
# Verify handler was called
|
||||
mock_handler.get_missing_items.assert_called_once_with(10)
|
||||
|
||||
# Verify results
|
||||
assert result["processed"] == 1
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_ensure_content_embedding_with_schema():
|
||||
"""Test ensuring embeddings exist with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_content_embedding"
|
||||
) as mock_get:
|
||||
# Simulate no existing embedding
|
||||
mock_get.return_value = None
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding"
|
||||
) as mock_generate:
|
||||
mock_generate.return_value = [0.1] * EMBEDDING_DIM
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.store_content_embedding"
|
||||
) as mock_store:
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.ensure_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
searchable_text="test text",
|
||||
metadata={"test": "data"},
|
||||
user_id=None,
|
||||
force=False,
|
||||
)
|
||||
|
||||
# Verify the flow
|
||||
assert mock_get.called
|
||||
assert mock_generate.called
|
||||
assert mock_store.called
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backward_compatibility_store_embedding():
|
||||
"""Test backward compatibility wrapper for store_embedding."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.store_content_embedding"
|
||||
) as mock_store:
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id",
|
||||
embedding=[0.1] * EMBEDDING_DIM,
|
||||
tx=None,
|
||||
)
|
||||
|
||||
# Verify it calls the new function with correct parameters
|
||||
assert mock_store.called
|
||||
call_args = mock_store.call_args
|
||||
|
||||
assert call_args[1]["content_type"] == ContentType.STORE_AGENT
|
||||
assert call_args[1]["content_id"] == "test-version-id"
|
||||
assert call_args[1]["user_id"] is None
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backward_compatibility_get_embedding():
|
||||
"""Test backward compatibility wrapper for get_embedding."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_content_embedding"
|
||||
) as mock_get:
|
||||
mock_get.return_value = {
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-version-id",
|
||||
"embedding": "[0.1, 0.2]",
|
||||
"createdAt": "2024-01-01",
|
||||
"updatedAt": "2024-01-01",
|
||||
}
|
||||
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
# Verify it calls the new function
|
||||
assert mock_get.called
|
||||
|
||||
# Verify it transforms to old format
|
||||
assert result is not None
|
||||
assert result["storeListingVersionId"] == "test-version-id"
|
||||
assert "embedding" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_schema_handling_error_cases():
|
||||
"""Test error handling in schema-aware operations."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1] * EMBEDDING_DIM,
|
||||
searchable_text="test",
|
||||
metadata=None,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Should return False on error, not raise
|
||||
assert result is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -0,0 +1,407 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_prisma():
|
||||
"""Setup Prisma client for tests."""
|
||||
try:
|
||||
Prisma()
|
||||
except prisma.errors.ClientAlreadyRegisteredError:
|
||||
pass
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_build_searchable_text():
|
||||
"""Test searchable text building from listing fields."""
|
||||
result = embeddings.build_searchable_text(
|
||||
name="AI Assistant",
|
||||
description="A helpful AI assistant for productivity",
|
||||
sub_heading="Boost your productivity",
|
||||
categories=["AI", "Productivity"],
|
||||
)
|
||||
|
||||
expected = "AI Assistant Boost your productivity A helpful AI assistant for productivity AI Productivity"
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_build_searchable_text_empty_fields():
|
||||
"""Test searchable text building with empty fields."""
|
||||
result = embeddings.build_searchable_text(
|
||||
name="", description="Test description", sub_heading="", categories=[]
|
||||
)
|
||||
|
||||
assert result == "Test description"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_success():
|
||||
"""Test successful embedding generation."""
|
||||
# Mock OpenAI response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock()]
|
||||
mock_response.data[0].embedding = [0.1, 0.2, 0.3] * 512 # 1536 dimensions
|
||||
|
||||
# Use AsyncMock for async embeddings.create method
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == embeddings.EMBEDDING_DIM
|
||||
assert result[0] == 0.1
|
||||
|
||||
mock_client.embeddings.create.assert_called_once_with(
|
||||
model="text-embedding-3-small", input="test text"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_no_api_key():
|
||||
"""Test embedding generation without API key."""
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = None
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_api_error():
|
||||
"""Test embedding generation with API error."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings.create = AsyncMock(side_effect=Exception("API Error"))
|
||||
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_text_truncation():
|
||||
"""Test that long text is properly truncated using tiktoken."""
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock()]
|
||||
mock_response.data[0].embedding = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
# Use AsyncMock for async embeddings.create method
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Create text that will exceed 8191 tokens
|
||||
# Use varied characters to ensure token-heavy text: each word is ~1 token
|
||||
words = [f"word{i}" for i in range(10000)]
|
||||
long_text = " ".join(words) # ~10000 tokens
|
||||
|
||||
await embeddings.generate_embedding(long_text)
|
||||
|
||||
# Verify text was truncated to 8191 tokens
|
||||
call_args = mock_client.embeddings.create.call_args
|
||||
truncated_text = call_args.kwargs["input"]
|
||||
|
||||
# Count actual tokens in truncated text
|
||||
enc = encoding_for_model("text-embedding-3-small")
|
||||
actual_tokens = len(enc.encode(truncated_text))
|
||||
|
||||
# Should be at or just under 8191 tokens
|
||||
assert actual_tokens <= 8191
|
||||
# Should be close to the limit (not over-truncated)
|
||||
assert actual_tokens >= 8100
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_embedding_success(mocker):
|
||||
"""Test successful embedding storage."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_client.execute_raw = mocker.AsyncMock()
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# execute_raw is called twice: once for SET search_path, once for INSERT
|
||||
assert mock_client.execute_raw.call_count == 2
|
||||
|
||||
# First call: SET search_path
|
||||
first_call_args = mock_client.execute_raw.call_args_list[0][0]
|
||||
assert "SET search_path" in first_call_args[0]
|
||||
|
||||
# Second call: INSERT query with the actual data
|
||||
second_call_args = mock_client.execute_raw.call_args_list[1][0]
|
||||
assert "test-version-id" in second_call_args
|
||||
assert "[0.1,0.2,0.3]" in second_call_args
|
||||
assert None in second_call_args # userId should be None for store agents
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_embedding_database_error(mocker):
|
||||
"""Test embedding storage with database error."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_success():
|
||||
"""Test successful embedding retrieval."""
|
||||
mock_result = [
|
||||
{
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-version-id",
|
||||
"userId": None,
|
||||
"embedding": "[0.1,0.2,0.3]",
|
||||
"searchableText": "Test text",
|
||||
"metadata": {},
|
||||
"createdAt": "2024-01-01T00:00:00Z",
|
||||
"updatedAt": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
assert result is not None
|
||||
assert result["storeListingVersionId"] == "test-version-id"
|
||||
assert result["embedding"] == "[0.1,0.2,0.3]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_not_found():
|
||||
"""Test embedding retrieval when not found."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.store_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_already_exists(mock_get, mock_store, mock_generate):
|
||||
"""Test ensure_embedding when embedding already exists."""
|
||||
mock_get.return_value = {"embedding": "[0.1,0.2,0.3]"}
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_generate.assert_not_called()
|
||||
mock_store.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.store_content_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
|
||||
"""Test ensure_embedding creating new embedding."""
|
||||
mock_get.return_value = None
|
||||
mock_generate.return_value = [0.1, 0.2, 0.3]
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_generate.assert_called_once_with("Test Test heading Test description test")
|
||||
mock_store.assert_called_once_with(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
searchable_text="Test Test heading Test description test",
|
||||
metadata={"name": "Test", "subHeading": "Test heading", "categories": ["test"]},
|
||||
user_id=None,
|
||||
tx=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
||||
"""Test ensure_embedding when generation fails."""
|
||||
mock_get.return_value = None
|
||||
mock_generate.return_value = None
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats():
|
||||
"""Test embedding statistics retrieval."""
|
||||
# Mock handler stats for each content type
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.get_stats = AsyncMock(
|
||||
return_value={
|
||||
"total": 100,
|
||||
"with_embeddings": 75,
|
||||
"without_embeddings": 25,
|
||||
}
|
||||
)
|
||||
|
||||
# Patch the CONTENT_HANDLERS where it's used (in embeddings module)
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
|
||||
{ContentType.STORE_AGENT: mock_handler},
|
||||
):
|
||||
result = await embeddings.get_embedding_stats()
|
||||
|
||||
assert "by_type" in result
|
||||
assert "totals" in result
|
||||
assert result["totals"]["total"] == 100
|
||||
assert result["totals"]["with_embeddings"] == 75
|
||||
assert result["totals"]["without_embeddings"] == 25
|
||||
assert result["totals"]["coverage_percent"] == 75.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.store_content_embedding")
|
||||
async def test_backfill_missing_embeddings_success(mock_store):
|
||||
"""Test backfill with successful embedding generation."""
|
||||
# Mock ContentItem from handlers
|
||||
from backend.api.features.store.content_handlers import ContentItem
|
||||
|
||||
mock_items = [
|
||||
ContentItem(
|
||||
content_id="version-1",
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text="Agent 1 Description 1",
|
||||
metadata={"name": "Agent 1"},
|
||||
),
|
||||
ContentItem(
|
||||
content_id="version-2",
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text="Agent 2 Description 2",
|
||||
metadata={"name": "Agent 2"},
|
||||
),
|
||||
]
|
||||
|
||||
# Mock handler to return missing items
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.get_missing_items = AsyncMock(return_value=mock_items)
|
||||
|
||||
# Mock store_content_embedding to succeed for first, fail for second
|
||||
mock_store.side_effect = [True, False]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
|
||||
{ContentType.STORE_AGENT: mock_handler},
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding",
|
||||
return_value=[0.1] * embeddings.EMBEDDING_DIM,
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
|
||||
assert result["processed"] == 2
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 1
|
||||
assert mock_store.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_backfill_missing_embeddings_no_missing():
|
||||
"""Test backfill when no embeddings are missing."""
|
||||
# Mock handler to return no missing items
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.get_missing_items = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
|
||||
{ContentType.STORE_AGENT: mock_handler},
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
|
||||
assert result["processed"] == 0
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_embedding_to_vector_string():
|
||||
"""Test embedding to PostgreSQL vector string conversion."""
|
||||
embedding = [0.1, 0.2, 0.3, -0.4]
|
||||
result = embeddings.embedding_to_vector_string(embedding)
|
||||
assert result == "[0.1,0.2,0.3,-0.4]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_embed_query():
|
||||
"""Test embed_query function (alias for generate_embedding)."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding"
|
||||
) as mock_generate:
|
||||
mock_generate.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.embed_query("test query")
|
||||
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
mock_generate.assert_called_once_with("test query")
|
||||
@@ -0,0 +1,625 @@
|
||||
"""
|
||||
Unified Hybrid Search
|
||||
|
||||
Combines semantic (embedding) search with lexical (tsvector) search
|
||||
for improved relevance across all content types (agents, blocks, docs).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store.embeddings import (
|
||||
EMBEDDING_DIM,
|
||||
embed_query,
|
||||
embedding_to_vector_string,
|
||||
)
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedSearchWeights:
|
||||
"""Weights for unified search (no popularity signal)."""
|
||||
|
||||
semantic: float = 0.40 # Embedding cosine similarity
|
||||
lexical: float = 0.40 # tsvector ts_rank_cd score
|
||||
category: float = 0.10 # Category match boost (for types that have categories)
|
||||
recency: float = 0.10 # Newer content ranked higher
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate weights are non-negative and sum to approximately 1.0."""
|
||||
total = self.semantic + self.lexical + self.category + self.recency
|
||||
|
||||
if any(
|
||||
w < 0 for w in [self.semantic, self.lexical, self.category, self.recency]
|
||||
):
|
||||
raise ValueError("All weights must be non-negative")
|
||||
|
||||
if not (0.99 <= total <= 1.01):
|
||||
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
|
||||
|
||||
|
||||
# Default weights for unified search
|
||||
DEFAULT_UNIFIED_WEIGHTS = UnifiedSearchWeights()
|
||||
|
||||
# Minimum relevance score thresholds
|
||||
DEFAULT_MIN_SCORE = 0.15 # For unified search (more permissive)
|
||||
DEFAULT_STORE_AGENT_MIN_SCORE = 0.20 # For store agent search (original threshold)
|
||||
|
||||
|
||||
async def unified_hybrid_search(
|
||||
query: str,
|
||||
content_types: list[ContentType] | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
weights: UnifiedSearchWeights | None = None,
|
||||
min_score: float | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Unified hybrid search across all content types.
|
||||
|
||||
Searches UnifiedContentEmbedding using both semantic (vector) and lexical (tsvector) signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
content_types: List of content types to search. Defaults to all public types.
|
||||
category: Filter by category (for content types that support it)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Results per page
|
||||
weights: Custom weights for search signals
|
||||
min_score: Minimum relevance score threshold (0-1)
|
||||
user_id: User ID for searching private content (library agents)
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, total count)
|
||||
"""
|
||||
# Validate inputs
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return [], 0
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = 1
|
||||
if page_size > 100:
|
||||
page_size = 100
|
||||
|
||||
if content_types is None:
|
||||
content_types = [
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.BLOCK,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
if weights is None:
|
||||
weights = DEFAULT_UNIFIED_WEIGHTS
|
||||
if min_score is None:
|
||||
min_score = DEFAULT_MIN_SCORE
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Graceful degradation if embedding unavailable
|
||||
if query_embedding is None or not query_embedding:
|
||||
logger.warning(
|
||||
"Failed to generate query embedding - falling back to lexical-only search. "
|
||||
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
||||
)
|
||||
query_embedding = [0.0] * EMBEDDING_DIM
|
||||
# Redistribute semantic weight to lexical
|
||||
total_non_semantic = weights.lexical + weights.category + weights.recency
|
||||
if total_non_semantic > 0:
|
||||
factor = 1.0 / total_non_semantic
|
||||
weights = UnifiedSearchWeights(
|
||||
semantic=0.0,
|
||||
lexical=weights.lexical * factor,
|
||||
category=weights.category * factor,
|
||||
recency=weights.recency * factor,
|
||||
)
|
||||
else:
|
||||
weights = UnifiedSearchWeights(
|
||||
semantic=0.0, lexical=1.0, category=0.0, recency=0.0
|
||||
)
|
||||
|
||||
# Build parameters
|
||||
params: list[Any] = []
|
||||
param_idx = 1
|
||||
|
||||
# Query for lexical search
|
||||
params.append(query)
|
||||
query_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Query lowercase for category matching
|
||||
params.append(query.lower())
|
||||
query_lower_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Embedding
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
params.append(embedding_str)
|
||||
embedding_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Content types
|
||||
content_type_values = [ct.value for ct in content_types]
|
||||
params.append(content_type_values)
|
||||
content_types_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# User ID filter (for private content)
|
||||
user_filter = ""
|
||||
if user_id is not None:
|
||||
params.append(user_id)
|
||||
user_filter = f'AND (uce."userId" = ${param_idx} OR uce."userId" IS NULL)'
|
||||
param_idx += 1
|
||||
else:
|
||||
user_filter = 'AND uce."userId" IS NULL'
|
||||
|
||||
# Weights
|
||||
params.append(weights.semantic)
|
||||
w_semantic = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.lexical)
|
||||
w_lexical = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.category)
|
||||
w_category = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.recency)
|
||||
w_recency = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Min score
|
||||
params.append(min_score)
|
||||
min_score_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Pagination
|
||||
params.append(page_size)
|
||||
limit_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(offset)
|
||||
offset_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Unified search query on UnifiedContentEmbedding
|
||||
sql_query = f"""
|
||||
WITH candidates AS (
|
||||
-- Lexical matches (uses GIN index on search column)
|
||||
SELECT uce.id, uce."contentType", uce."contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
|
||||
{user_filter}
|
||||
AND uce.search @@ plainto_tsquery('english', {query_param})
|
||||
|
||||
UNION
|
||||
|
||||
-- Semantic matches (uses HNSW index on embedding)
|
||||
(
|
||||
SELECT uce.id, uce."contentType", uce."contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
|
||||
{user_filter}
|
||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||
LIMIT 200
|
||||
)
|
||||
),
|
||||
search_scores AS (
|
||||
SELECT
|
||||
uce."contentType" as content_type,
|
||||
uce."contentId" as content_id,
|
||||
uce."searchableText" as searchable_text,
|
||||
uce.metadata,
|
||||
uce."updatedAt" as updated_at,
|
||||
-- Semantic score: cosine similarity (1 - distance)
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score: ts_rank_cd
|
||||
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match from metadata
|
||||
CASE
|
||||
WHEN uce.metadata ? 'categories' AND EXISTS (
|
||||
SELECT 1 FROM jsonb_array_elements_text(uce.metadata->'categories') cat
|
||||
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
|
||||
)
|
||||
THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
-- Recency score: linear decay over 90 days
|
||||
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - uce."updatedAt")) / (90 * 24 * 3600)) as recency_score
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce ON c.id = uce.id
|
||||
),
|
||||
max_lexical AS (
|
||||
SELECT GREATEST(MAX(lexical_raw), 0.001) as max_val FROM search_scores
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
ss.*,
|
||||
ss.lexical_raw / ml.max_val as lexical_score
|
||||
FROM search_scores ss
|
||||
CROSS JOIN max_lexical ml
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
content_type,
|
||||
content_id,
|
||||
searchable_text,
|
||||
metadata,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
(
|
||||
{w_semantic} * semantic_score +
|
||||
{w_lexical} * lexical_score +
|
||||
{w_category} * category_score +
|
||||
{w_recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
),
|
||||
filtered AS (
|
||||
SELECT
|
||||
*,
|
||||
COUNT(*) OVER () as total_count
|
||||
FROM scored
|
||||
WHERE combined_score >= {min_score_param}
|
||||
)
|
||||
SELECT * FROM filtered
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(
|
||||
sql_query, *params, set_public_search_path=True
|
||||
)
|
||||
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
# Clean up results
|
||||
for result in results:
|
||||
result.pop("total_count", None)
|
||||
|
||||
logger.info(f"Unified hybrid search: {len(results)} results, {total} total")
|
||||
|
||||
return results, total
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Store Agent specific search (with full metadata)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoreAgentSearchWeights:
|
||||
"""Weights for store agent search including popularity."""
|
||||
|
||||
semantic: float = 0.30
|
||||
lexical: float = 0.30
|
||||
category: float = 0.20
|
||||
recency: float = 0.10
|
||||
popularity: float = 0.10
|
||||
|
||||
def __post_init__(self):
|
||||
total = (
|
||||
self.semantic
|
||||
+ self.lexical
|
||||
+ self.category
|
||||
+ self.recency
|
||||
+ self.popularity
|
||||
)
|
||||
if any(
|
||||
w < 0
|
||||
for w in [
|
||||
self.semantic,
|
||||
self.lexical,
|
||||
self.category,
|
||||
self.recency,
|
||||
self.popularity,
|
||||
]
|
||||
):
|
||||
raise ValueError("All weights must be non-negative")
|
||||
if not (0.99 <= total <= 1.01):
|
||||
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
|
||||
|
||||
|
||||
DEFAULT_STORE_AGENT_WEIGHTS = StoreAgentSearchWeights()
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
query: str,
|
||||
featured: bool = False,
|
||||
creators: list[str] | None = None,
|
||||
category: str | None = None,
|
||||
sorted_by: (
|
||||
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
|
||||
) = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
weights: StoreAgentSearchWeights | None = None,
|
||||
min_score: float | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Hybrid search for store agents with full metadata.
|
||||
|
||||
Uses UnifiedContentEmbedding for search, joins to StoreAgent for metadata.
|
||||
"""
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return [], 0
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = 1
|
||||
if page_size > 100:
|
||||
page_size = 100
|
||||
|
||||
if weights is None:
|
||||
weights = DEFAULT_STORE_AGENT_WEIGHTS
|
||||
if min_score is None:
|
||||
min_score = (
|
||||
DEFAULT_STORE_AGENT_MIN_SCORE # Use original threshold for store agents
|
||||
)
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Graceful degradation
|
||||
if query_embedding is None or not query_embedding:
|
||||
logger.warning(
|
||||
"Failed to generate query embedding - falling back to lexical-only search."
|
||||
)
|
||||
query_embedding = [0.0] * EMBEDDING_DIM
|
||||
total_non_semantic = (
|
||||
weights.lexical + weights.category + weights.recency + weights.popularity
|
||||
)
|
||||
if total_non_semantic > 0:
|
||||
factor = 1.0 / total_non_semantic
|
||||
weights = StoreAgentSearchWeights(
|
||||
semantic=0.0,
|
||||
lexical=weights.lexical * factor,
|
||||
category=weights.category * factor,
|
||||
recency=weights.recency * factor,
|
||||
popularity=weights.popularity * factor,
|
||||
)
|
||||
else:
|
||||
weights = StoreAgentSearchWeights(
|
||||
semantic=0.0, lexical=1.0, category=0.0, recency=0.0, popularity=0.0
|
||||
)
|
||||
|
||||
# Build parameters
|
||||
params: list[Any] = []
|
||||
param_idx = 1
|
||||
|
||||
params.append(query)
|
||||
query_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(query.lower())
|
||||
query_lower_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
params.append(embedding_str)
|
||||
embedding_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Build WHERE clause for StoreAgent filters
|
||||
where_parts = ["sa.is_available = true"]
|
||||
|
||||
if featured:
|
||||
where_parts.append("sa.featured = true")
|
||||
|
||||
if creators:
|
||||
params.append(creators)
|
||||
where_parts.append(f"sa.creator_username = ANY(${param_idx})")
|
||||
param_idx += 1
|
||||
|
||||
if category:
|
||||
params.append(category)
|
||||
where_parts.append(f"${param_idx} = ANY(sa.categories)")
|
||||
param_idx += 1
|
||||
|
||||
where_clause = " AND ".join(where_parts)
|
||||
|
||||
# Weights
|
||||
params.append(weights.semantic)
|
||||
w_semantic = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.lexical)
|
||||
w_lexical = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.category)
|
||||
w_category = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.recency)
|
||||
w_recency = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(weights.popularity)
|
||||
w_popularity = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(min_score)
|
||||
min_score_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(page_size)
|
||||
limit_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
params.append(offset)
|
||||
offset_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
|
||||
# Query using UnifiedContentEmbedding for search, StoreAgent for metadata
|
||||
sql_query = f"""
|
||||
WITH candidates AS (
|
||||
-- Lexical matches via UnifiedContentEmbedding.search
|
||||
SELECT uce."contentId" as "storeListingVersionId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND uce.search @@ plainto_tsquery('english', {query_param})
|
||||
AND {where_clause}
|
||||
|
||||
UNION
|
||||
|
||||
-- Semantic matches via UnifiedContentEmbedding.embedding
|
||||
SELECT uce."contentId" as "storeListingVersionId"
|
||||
FROM (
|
||||
SELECT uce."contentId", uce.embedding
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND {where_clause}
|
||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||
LIMIT 200
|
||||
) uce
|
||||
),
|
||||
search_scores AS (
|
||||
SELECT
|
||||
sa.slug,
|
||||
sa.agent_name,
|
||||
sa.agent_image,
|
||||
sa.creator_username,
|
||||
sa.creator_avatar,
|
||||
sa.sub_heading,
|
||||
sa.description,
|
||||
sa.runs,
|
||||
sa.rating,
|
||||
sa.categories,
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
-- Semantic score
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score (raw, will normalize)
|
||||
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(sa.categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
|
||||
)
|
||||
THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
-- Recency
|
||||
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score,
|
||||
-- Popularity (raw)
|
||||
sa.runs as popularity_raw
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId"
|
||||
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
),
|
||||
max_vals AS (
|
||||
SELECT
|
||||
GREATEST(MAX(lexical_raw), 0.001) as max_lexical,
|
||||
GREATEST(MAX(popularity_raw), 1) as max_popularity
|
||||
FROM search_scores
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
ss.*,
|
||||
ss.lexical_raw / mv.max_lexical as lexical_score,
|
||||
CASE
|
||||
WHEN ss.popularity_raw > 0
|
||||
THEN LN(1 + ss.popularity_raw) / LN(1 + mv.max_popularity)
|
||||
ELSE 0
|
||||
END as popularity_score
|
||||
FROM search_scores ss
|
||||
CROSS JOIN max_vals mv
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
popularity_score,
|
||||
(
|
||||
{w_semantic} * semantic_score +
|
||||
{w_lexical} * lexical_score +
|
||||
{w_category} * category_score +
|
||||
{w_recency} * recency_score +
|
||||
{w_popularity} * popularity_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
),
|
||||
filtered AS (
|
||||
SELECT *, COUNT(*) OVER () as total_count
|
||||
FROM scored
|
||||
WHERE combined_score >= {min_score_param}
|
||||
)
|
||||
SELECT * FROM filtered
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(
|
||||
sql_query, *params, set_public_search_path=True
|
||||
)
|
||||
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
for result in results:
|
||||
result.pop("total_count", None)
|
||||
|
||||
logger.info(f"Hybrid search (store agents): {len(results)} results, {total} total")
|
||||
|
||||
return results, total
|
||||
|
||||
|
||||
async def hybrid_search_simple(
|
||||
query: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Simplified hybrid search for store agents."""
|
||||
return await hybrid_search(query=query, page=page, page_size=page_size)
|
||||
|
||||
|
||||
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
||||
# for existing code that expects the popularity parameter
|
||||
HybridSearchWeights = StoreAgentSearchWeights
|
||||
@@ -0,0 +1,667 @@
|
||||
"""
|
||||
Integration tests for hybrid search with schema handling.
|
||||
|
||||
These tests verify that hybrid search works correctly across different database schemas.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
from backend.api.features.store.hybrid_search import (
|
||||
HybridSearchWeights,
|
||||
UnifiedSearchWeights,
|
||||
hybrid_search,
|
||||
unified_hybrid_search,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_schema_handling():
|
||||
"""Test that hybrid search correctly handles database schema prefixes."""
|
||||
# Test with a mock query to ensure schema handling works
|
||||
query = "test agent"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Mock the query result
|
||||
mock_query.return_value = [
|
||||
{
|
||||
"slug": "test/agent",
|
||||
"agent_name": "Test Agent",
|
||||
"agent_image": "test.png",
|
||||
"creator_username": "test",
|
||||
"creator_avatar": "avatar.png",
|
||||
"sub_heading": "Test sub-heading",
|
||||
"description": "Test description",
|
||||
"runs": 10,
|
||||
"rating": 4.5,
|
||||
"categories": ["test"],
|
||||
"featured": False,
|
||||
"is_available": True,
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"combined_score": 0.8,
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.6,
|
||||
"category_score": 0.5,
|
||||
"recency_score": 0.4,
|
||||
"total_count": 1,
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM # Mock embedding
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query=query,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_query.called
|
||||
# Verify the SQL template uses schema_prefix placeholder
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
assert "{schema_prefix}" in sql_template
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert total == 1
|
||||
assert results[0]["slug"] == "test/agent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_public_schema():
|
||||
"""Test hybrid search when using public schema (no prefix needed)."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "public"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the mock was set up correctly
|
||||
assert mock_schema.return_value == "public"
|
||||
|
||||
# Results should work even with empty results
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_custom_schema():
|
||||
"""Test hybrid search when using custom schema (e.g., 'platform')."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the mock was set up correctly
|
||||
assert mock_schema.return_value == "platform"
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_without_embeddings():
|
||||
"""Test hybrid search gracefully degrades when embeddings are unavailable."""
|
||||
# Mock database to return some results
|
||||
mock_results = [
|
||||
{
|
||||
"slug": "test-agent",
|
||||
"agent_name": "Test Agent",
|
||||
"agent_image": "test.png",
|
||||
"creator_username": "creator",
|
||||
"creator_avatar": "avatar.png",
|
||||
"sub_heading": "Test heading",
|
||||
"description": "Test description",
|
||||
"runs": 100,
|
||||
"rating": 4.5,
|
||||
"categories": ["AI"],
|
||||
"featured": False,
|
||||
"is_available": True,
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.0, # Zero because no embedding
|
||||
"lexical_score": 0.5,
|
||||
"category_score": 0.0,
|
||||
"recency_score": 0.1,
|
||||
"popularity_score": 0.2,
|
||||
"combined_score": 0.3,
|
||||
"total_count": 1,
|
||||
}
|
||||
]
|
||||
|
||||
with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Simulate embedding failure
|
||||
mock_embed.return_value = None
|
||||
mock_query.return_value = mock_results
|
||||
|
||||
# Should NOT raise - graceful degradation
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify it returns results even without embeddings
|
||||
assert len(results) == 1
|
||||
assert results[0]["slug"] == "test-agent"
|
||||
assert total == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_filters():
|
||||
"""Test hybrid search with various filters."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
# Test with featured filter
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
featured=True,
|
||||
creators=["user1", "user2"],
|
||||
category="productivity",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Verify filters were applied in the query
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0][1:] # Skip SQL template
|
||||
|
||||
# Should have query, query_lower, creators array, category
|
||||
assert len(params) >= 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_weights():
|
||||
"""Test hybrid search with custom weights."""
|
||||
custom_weights = HybridSearchWeights(
|
||||
semantic=0.5,
|
||||
lexical=0.3,
|
||||
category=0.1,
|
||||
recency=0.1,
|
||||
popularity=0.0,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
weights=custom_weights,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify custom weights were used in the query
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
params = call_args[0][1:] # Get all parameters passed
|
||||
|
||||
# Check that SQL uses parameterized weights (not f-string interpolation)
|
||||
assert "$" in sql_template # Verify parameterization is used
|
||||
|
||||
# Check that custom weights are in the params
|
||||
assert 0.5 in params # semantic weight
|
||||
assert 0.3 in params # lexical weight
|
||||
assert 0.1 in params # category and recency weights
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_min_score_filtering():
|
||||
"""Test hybrid search minimum score threshold."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Return results with varying scores
|
||||
mock_query.return_value = [
|
||||
{
|
||||
"slug": "high-score/agent",
|
||||
"agent_name": "High Score Agent",
|
||||
"combined_score": 0.8,
|
||||
"total_count": 1,
|
||||
# ... other fields
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
# Test with custom min_score
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
min_score=0.5, # High threshold
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify min_score was applied in query
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
params = call_args[0][1:] # Get all parameters
|
||||
|
||||
# Check that SQL uses parameterized min_score
|
||||
assert "combined_score >=" in sql_template
|
||||
assert "$" in sql_template # Verify parameterization
|
||||
|
||||
# Check that custom min_score is in the params
|
||||
assert 0.5 in params
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_pagination():
|
||||
"""Test hybrid search pagination."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
# Test page 2 with page_size 10
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=2,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Verify pagination parameters
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0]
|
||||
|
||||
# Last two params should be LIMIT and OFFSET
|
||||
limit = params[-2]
|
||||
offset = params[-1]
|
||||
|
||||
assert limit == 10 # page_size
|
||||
assert offset == 10 # (page - 1) * page_size = (2 - 1) * 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_error_handling():
|
||||
"""Test hybrid search error handling."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Simulate database error
|
||||
mock_query.side_effect = Exception("Database connection error")
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
# Should raise exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
assert "Database connection error" in str(exc_info.value)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Unified Hybrid Search Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_basic():
|
||||
"""Test basic unified hybrid search across all content types."""
|
||||
mock_results = [
|
||||
{
|
||||
"content_type": "STORE_AGENT",
|
||||
"content_id": "agent-1",
|
||||
"searchable_text": "Test Agent Description",
|
||||
"metadata": {"name": "Test Agent"},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.8,
|
||||
"category_score": 0.5,
|
||||
"recency_score": 0.3,
|
||||
"combined_score": 0.6,
|
||||
"total_count": 2,
|
||||
},
|
||||
{
|
||||
"content_type": "BLOCK",
|
||||
"content_id": "block-1",
|
||||
"searchable_text": "Test Block Description",
|
||||
"metadata": {"name": "Test Block"},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.6,
|
||||
"lexical_score": 0.7,
|
||||
"category_score": 0.4,
|
||||
"recency_score": 0.2,
|
||||
"combined_score": 0.5,
|
||||
"total_count": 2,
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = mock_results
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert total == 2
|
||||
assert results[0]["content_type"] == "STORE_AGENT"
|
||||
assert results[1]["content_type"] == "BLOCK"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_filter_by_content_type():
|
||||
"""Test unified search filtering by specific content types."""
|
||||
mock_results = [
|
||||
{
|
||||
"content_type": "BLOCK",
|
||||
"content_id": "block-1",
|
||||
"searchable_text": "Test Block",
|
||||
"metadata": {},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.8,
|
||||
"category_score": 0.0,
|
||||
"recency_score": 0.3,
|
||||
"combined_score": 0.5,
|
||||
"total_count": 1,
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = mock_results
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify content_types parameter was passed correctly
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0][1:]
|
||||
# The content types should be in the params as a list
|
||||
assert ["BLOCK"] in params
|
||||
|
||||
assert len(results) == 1
|
||||
assert total == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_with_user_id():
|
||||
"""Test unified search with user_id for private content."""
|
||||
mock_results = [
|
||||
{
|
||||
"content_type": "STORE_AGENT",
|
||||
"content_id": "agent-1",
|
||||
"searchable_text": "My Private Agent",
|
||||
"metadata": {},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.8,
|
||||
"category_score": 0.0,
|
||||
"recency_score": 0.3,
|
||||
"combined_score": 0.6,
|
||||
"total_count": 1,
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = mock_results
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
user_id="user-123",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify SQL contains user_id filter
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
params = call_args[0][1:]
|
||||
|
||||
assert 'uce."userId"' in sql_template
|
||||
assert "user-123" in params
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_custom_weights():
|
||||
"""Test unified search with custom weights."""
|
||||
custom_weights = UnifiedSearchWeights(
|
||||
semantic=0.6,
|
||||
lexical=0.2,
|
||||
category=0.1,
|
||||
recency=0.1,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = []
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
weights=custom_weights,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify custom weights are in parameters
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0][1:]
|
||||
|
||||
assert 0.6 in params # semantic weight
|
||||
assert 0.2 in params # lexical weight
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_graceful_degradation():
|
||||
"""Test unified search gracefully degrades when embeddings unavailable."""
|
||||
mock_results = [
|
||||
{
|
||||
"content_type": "DOCUMENTATION",
|
||||
"content_id": "doc-1",
|
||||
"searchable_text": "API Documentation",
|
||||
"metadata": {},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.0, # Zero because no embedding
|
||||
"lexical_score": 0.8,
|
||||
"category_score": 0.0,
|
||||
"recency_score": 0.2,
|
||||
"combined_score": 0.5,
|
||||
"total_count": 1,
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = mock_results
|
||||
mock_embed.return_value = None # Embedding failure
|
||||
|
||||
# Should NOT raise - graceful degradation
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert total == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_empty_query():
|
||||
"""Test unified search with empty query returns empty results."""
|
||||
results, total = await unified_hybrid_search(
|
||||
query="",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_pagination():
|
||||
"""Test unified search pagination."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = []
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await unified_hybrid_search(
|
||||
query="test",
|
||||
page=3,
|
||||
page_size=15,
|
||||
)
|
||||
|
||||
# Verify pagination parameters (last two params are LIMIT and OFFSET)
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0]
|
||||
|
||||
limit = params[-2]
|
||||
offset = params[-1]
|
||||
|
||||
assert limit == 15 # page_size
|
||||
assert offset == 30 # (page - 1) * page_size = (3 - 1) * 15
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_schema_prefix():
|
||||
"""Test unified search uses schema_prefix placeholder."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = []
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
await unified_hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
|
||||
# Verify schema_prefix placeholder is used for table references
|
||||
assert "{schema_prefix}" in sql_template
|
||||
assert '"UnifiedContentEmbedding"' in sql_template
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -221,3 +221,23 @@ class ReviewSubmissionRequest(pydantic.BaseModel):
|
||||
is_approved: bool
|
||||
comments: str # External comments visible to creator
|
||||
internal_comments: str | None = None # Private admin notes
|
||||
|
||||
|
||||
class UnifiedSearchResult(pydantic.BaseModel):
|
||||
"""A single result from unified hybrid search across all content types."""
|
||||
|
||||
content_type: str # STORE_AGENT, BLOCK, DOCUMENTATION
|
||||
content_id: str
|
||||
searchable_text: str
|
||||
metadata: dict | None = None
|
||||
updated_at: datetime.datetime | None = None
|
||||
combined_score: float | None = None
|
||||
semantic_score: float | None = None
|
||||
lexical_score: float | None = None
|
||||
|
||||
|
||||
class UnifiedSearchResponse(pydantic.BaseModel):
|
||||
"""Response model for unified search across all content types."""
|
||||
|
||||
results: list[UnifiedSearchResult]
|
||||
pagination: Pagination
|
||||
|
||||
@@ -7,12 +7,15 @@ from typing import Literal
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.data.graph
|
||||
import backend.util.json
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
from . import db as store_db
|
||||
from . import hybrid_search as store_hybrid_search
|
||||
from . import image_gen as store_image_gen
|
||||
from . import media as store_media
|
||||
from . import model as store_model
|
||||
@@ -146,6 +149,102 @@ async def get_agents(
|
||||
return agents
|
||||
|
||||
|
||||
##############################################
|
||||
############### Search Endpoints #############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search",
|
||||
summary="Unified search across all content types",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.UnifiedSearchResponse,
|
||||
)
|
||||
async def unified_search(
|
||||
query: str,
|
||||
content_types: list[str] | None = fastapi.Query(
|
||||
default=None,
|
||||
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
|
||||
),
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
user_id: str | None = fastapi.Security(
|
||||
autogpt_libs.auth.get_optional_user_id, use_cache=False
|
||||
),
|
||||
):
|
||||
"""
|
||||
Search across all content types (store agents, blocks, documentation) using hybrid search.
|
||||
|
||||
Combines semantic (embedding-based) and lexical (text-based) search for best results.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
|
||||
page: Page number for pagination (default 1)
|
||||
page_size: Number of results per page (default 20)
|
||||
user_id: Optional authenticated user ID (for user-scoped content in future)
|
||||
|
||||
Returns:
|
||||
UnifiedSearchResponse: Paginated list of search results with relevance scores
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
# Convert string content types to enum
|
||||
content_type_enums: list[prisma.enums.ContentType] | None = None
|
||||
if content_types:
|
||||
try:
|
||||
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
|
||||
)
|
||||
|
||||
# Perform unified hybrid search
|
||||
results, total = await store_hybrid_search.unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=content_type_enums,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
# Convert results to response model
|
||||
search_results = [
|
||||
store_model.UnifiedSearchResult(
|
||||
content_type=r["content_type"],
|
||||
content_id=r["content_id"],
|
||||
searchable_text=r.get("searchable_text", ""),
|
||||
metadata=r.get("metadata"),
|
||||
updated_at=r.get("updated_at"),
|
||||
combined_score=r.get("combined_score"),
|
||||
semantic_score=r.get("semantic_score"),
|
||||
lexical_score=r.get("lexical_score"),
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
|
||||
|
||||
return store_model.UnifiedSearchResponse(
|
||||
results=search_results,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{username}/{agent_name}",
|
||||
summary="Get specific agent",
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
"""Tests for the semantic_search function."""
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store.embeddings import EMBEDDING_DIM, semantic_search
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_blocks_only(mocker):
|
||||
"""Test searching only BLOCK content type."""
|
||||
# Mock embed_query to return a test embedding
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
# Mock query_raw_with_schema to return test results
|
||||
mock_results = [
|
||||
{
|
||||
"content_id": "block-123",
|
||||
"content_type": "BLOCK",
|
||||
"searchable_text": "Calculator Block - Performs arithmetic operations",
|
||||
"metadata": {"name": "Calculator", "categories": ["Math"]},
|
||||
"similarity": 0.85,
|
||||
}
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_results,
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="calculate numbers",
|
||||
content_types=[ContentType.BLOCK],
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["content_type"] == "BLOCK"
|
||||
assert results[0]["content_id"] == "block-123"
|
||||
assert results[0]["similarity"] == 0.85
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_multiple_content_types(mocker):
|
||||
"""Test searching multiple content types simultaneously."""
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
mock_results = [
|
||||
{
|
||||
"content_id": "block-123",
|
||||
"content_type": "BLOCK",
|
||||
"searchable_text": "Calculator Block",
|
||||
"metadata": {},
|
||||
"similarity": 0.85,
|
||||
},
|
||||
{
|
||||
"content_id": "doc-456",
|
||||
"content_type": "DOCUMENTATION",
|
||||
"searchable_text": "How to use Calculator",
|
||||
"metadata": {},
|
||||
"similarity": 0.75,
|
||||
},
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_results,
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="calculator",
|
||||
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content_type"] == "BLOCK"
|
||||
assert results[1]["content_type"] == "DOCUMENTATION"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_min_similarity_threshold(mocker):
|
||||
"""Test that results below min_similarity are filtered out."""
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
# Only return results above 0.7 similarity
|
||||
mock_results = [
|
||||
{
|
||||
"content_id": "block-123",
|
||||
"content_type": "BLOCK",
|
||||
"searchable_text": "Calculator Block",
|
||||
"metadata": {},
|
||||
"similarity": 0.85,
|
||||
}
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_results,
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="calculate",
|
||||
content_types=[ContentType.BLOCK],
|
||||
min_similarity=0.7,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["similarity"] >= 0.7
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fallback_to_lexical(mocker):
|
||||
"""Test fallback to lexical search when embeddings fail."""
|
||||
# Mock embed_query to return None (embeddings unavailable)
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
mock_lexical_results = [
|
||||
{
|
||||
"content_id": "block-123",
|
||||
"content_type": "BLOCK",
|
||||
"searchable_text": "Calculator Block performs calculations",
|
||||
"metadata": {},
|
||||
"similarity": 0.0,
|
||||
}
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_lexical_results,
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="calculator",
|
||||
content_types=[ContentType.BLOCK],
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["similarity"] == 0.0 # Lexical search returns 0 similarity
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_empty_query():
|
||||
"""Test that empty query returns no results."""
|
||||
results = await semantic_search(query="")
|
||||
assert results == []
|
||||
|
||||
results = await semantic_search(query=" ")
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_user_id_filter(mocker):
|
||||
"""Test searching with user_id filter for private content."""
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
mock_results = [
|
||||
{
|
||||
"content_id": "agent-789",
|
||||
"content_type": "LIBRARY_AGENT",
|
||||
"searchable_text": "My Custom Agent",
|
||||
"metadata": {},
|
||||
"similarity": 0.9,
|
||||
}
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_results,
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="custom agent",
|
||||
content_types=[ContentType.LIBRARY_AGENT],
|
||||
user_id="user-123",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["content_type"] == "LIBRARY_AGENT"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_limit_parameter(mocker):
|
||||
"""Test that limit parameter correctly limits results."""
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
# Return 5 results
|
||||
mock_results = [
|
||||
{
|
||||
"content_id": f"block-{i}",
|
||||
"content_type": "BLOCK",
|
||||
"searchable_text": f"Block {i}",
|
||||
"metadata": {},
|
||||
"similarity": 0.8,
|
||||
}
|
||||
for i in range(5)
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_results,
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="block",
|
||||
content_types=[ContentType.BLOCK],
|
||||
limit=5,
|
||||
)
|
||||
|
||||
assert len(results) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_default_content_types(mocker):
|
||||
"""Test that default content_types includes BLOCK, STORE_AGENT, and DOCUMENTATION."""
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
mock_query_raw = mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
await semantic_search(query="test")
|
||||
|
||||
# Check that the SQL query includes all three default content types
|
||||
call_args = mock_query_raw.call_args
|
||||
assert "BLOCK" in str(call_args)
|
||||
assert "STORE_AGENT" in str(call_args)
|
||||
assert "DOCUMENTATION" in str(call_args)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_handles_database_error(mocker):
|
||||
"""Test that database errors are handled gracefully."""
|
||||
mock_embedding = [0.1] * EMBEDDING_DIM
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.embed_query",
|
||||
return_value=mock_embedding,
|
||||
)
|
||||
|
||||
# Simulate database error
|
||||
mocker.patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
side_effect=Exception("Database connection failed"),
|
||||
)
|
||||
|
||||
results = await semantic_search(
|
||||
query="test",
|
||||
content_types=[ContentType.BLOCK],
|
||||
)
|
||||
|
||||
# Should return empty list on error
|
||||
assert results == []
|
||||
@@ -64,7 +64,6 @@ from backend.data.onboarding import (
|
||||
complete_re_run_agent,
|
||||
get_recommended_agents,
|
||||
get_user_onboarding,
|
||||
increment_runs,
|
||||
onboarding_enabled,
|
||||
reset_user_onboarding,
|
||||
update_user_onboarding,
|
||||
@@ -975,7 +974,6 @@ async def execute_graph(
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
record_graph_operation(operation="execute", status="success")
|
||||
await increment_runs(user_id)
|
||||
await complete_re_run_agent(user_id, graph_id)
|
||||
if source == "library":
|
||||
await complete_onboarding_step(
|
||||
|
||||
@@ -38,6 +38,20 @@ POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
|
||||
if POOL_TIMEOUT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
|
||||
|
||||
# Add public schema to search_path for pgvector type access
|
||||
# The vector extension is in public schema, but search_path is determined by schema parameter
|
||||
# Extract the schema from DATABASE_URL or default to 'public' (matching get_database_schema())
|
||||
parsed_url = urlparse(DATABASE_URL)
|
||||
url_params = dict(parse_qsl(parsed_url.query))
|
||||
db_schema = url_params.get("schema", "public")
|
||||
# Build search_path, avoiding duplicates if db_schema is already 'public'
|
||||
search_path_schemas = list(
|
||||
dict.fromkeys([db_schema, "public"])
|
||||
) # Preserves order, removes duplicates
|
||||
search_path = ",".join(search_path_schemas)
|
||||
# This allows using ::vector without schema qualification
|
||||
DATABASE_URL = add_param(DATABASE_URL, "options", f"-c search_path={search_path}")
|
||||
|
||||
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
|
||||
|
||||
prisma = Prisma(
|
||||
@@ -108,21 +122,102 @@ def get_database_schema() -> str:
|
||||
return query_params.get("schema", "public")
|
||||
|
||||
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
"""Execute raw SQL query with proper schema handling."""
|
||||
async def _raw_with_schema(
|
||||
query_template: str,
|
||||
*args,
|
||||
execute: bool = False,
|
||||
client: Prisma | None = None,
|
||||
set_public_search_path: bool = False,
|
||||
) -> list[dict] | int:
|
||||
"""Internal: Execute raw SQL with proper schema handling.
|
||||
|
||||
Use query_raw_with_schema() or execute_raw_with_schema() instead.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE.
|
||||
client: Optional Prisma client for transactions (only used when execute=True).
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
- list[dict] if execute=False (query results)
|
||||
- int if execute=True (number of affected rows)
|
||||
"""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
result = await prisma_module.get_client().query_raw(
|
||||
formatted_query, *args # type: ignore
|
||||
)
|
||||
db_client = client if client else prisma_module.get_client()
|
||||
|
||||
# Set search_path to include public schema if requested
|
||||
# Prisma doesn't support the 'options' connection parameter, so we set it per-session
|
||||
# This is idempotent and safe to call multiple times
|
||||
if set_public_search_path:
|
||||
await db_client.execute_raw(f"SET search_path = {schema}, public") # type: ignore
|
||||
|
||||
if execute:
|
||||
result = await db_client.execute_raw(formatted_query, *args) # type: ignore
|
||||
else:
|
||||
result = await db_client.query_raw(formatted_query, *args) # type: ignore
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def query_raw_with_schema(
|
||||
query_template: str, *args, set_public_search_path: bool = False
|
||||
) -> list[dict]:
|
||||
"""Execute raw SQL SELECT query with proper schema handling.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
List of result rows as dictionaries
|
||||
|
||||
Example:
|
||||
results = await query_raw_with_schema(
|
||||
'SELECT * FROM {schema_prefix}"User" WHERE id = $1',
|
||||
user_id
|
||||
)
|
||||
"""
|
||||
return await _raw_with_schema(query_template, *args, execute=False, set_public_search_path=set_public_search_path) # type: ignore
|
||||
|
||||
|
||||
async def execute_raw_with_schema(
|
||||
query_template: str,
|
||||
*args,
|
||||
client: Prisma | None = None,
|
||||
set_public_search_path: bool = False,
|
||||
) -> int:
|
||||
"""Execute raw SQL command (INSERT/UPDATE/DELETE) with proper schema handling.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
client: Optional Prisma client for transactions
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
Number of affected rows
|
||||
|
||||
Example:
|
||||
await execute_raw_with_schema(
|
||||
'INSERT INTO {schema_prefix}"User" (id, name) VALUES ($1, $2)',
|
||||
user_id, name,
|
||||
client=tx # Optional transaction client
|
||||
)
|
||||
"""
|
||||
return await _raw_with_schema(query_template, *args, execute=True, client=client, set_public_search_path=set_public_search_path) # type: ignore
|
||||
|
||||
|
||||
class BaseDbModel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import fastapi.exceptions
|
||||
@@ -18,6 +19,17 @@ from backend.usecases.sample import create_test_user
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot):
|
||||
"""
|
||||
|
||||
@@ -334,7 +334,7 @@ async def _get_user_timezone(user_id: str) -> str:
|
||||
return get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
|
||||
async def increment_runs(user_id: str):
|
||||
async def increment_onboarding_runs(user_id: str):
|
||||
"""
|
||||
Increment a user's run counters and trigger any onboarding milestones.
|
||||
"""
|
||||
|
||||
404
autogpt_platform/backend/backend/data/understanding.py
Normal file
404
autogpt_platform/backend/backend/data/understanding.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""Data models and access layer for user business understanding."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import pydantic
|
||||
from prisma.models import CoPilotUnderstanding
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache configuration
|
||||
CACHE_KEY_PREFIX = "understanding"
|
||||
CACHE_TTL_SECONDS = 48 * 60 * 60 # 48 hours
|
||||
|
||||
|
||||
def _cache_key(user_id: str) -> str:
|
||||
"""Generate cache key for user business understanding."""
|
||||
return f"{CACHE_KEY_PREFIX}:{user_id}"
|
||||
|
||||
|
||||
def _json_to_list(value: Any) -> list[str]:
|
||||
"""Convert Json field to list[str], handling None."""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return cast(list[str], value)
|
||||
return []
|
||||
|
||||
|
||||
class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
"""Input model for updating business understanding - all fields optional for incremental updates."""
|
||||
|
||||
# User info
|
||||
user_name: Optional[str] = pydantic.Field(None, description="The user's name")
|
||||
job_title: Optional[str] = pydantic.Field(None, description="The user's job title")
|
||||
|
||||
# Business basics
|
||||
business_name: Optional[str] = pydantic.Field(
|
||||
None, description="Name of the user's business"
|
||||
)
|
||||
industry: Optional[str] = pydantic.Field(None, description="Industry or sector")
|
||||
business_size: Optional[str] = pydantic.Field(
|
||||
None, description="Company size (e.g., '1-10', '11-50')"
|
||||
)
|
||||
user_role: Optional[str] = pydantic.Field(
|
||||
None,
|
||||
description="User's role in the organization (e.g., 'decision maker', 'implementer')",
|
||||
)
|
||||
|
||||
# Processes & activities
|
||||
key_workflows: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Key business workflows"
|
||||
)
|
||||
daily_activities: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Daily activities performed"
|
||||
)
|
||||
|
||||
# Pain points & goals
|
||||
pain_points: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Current pain points"
|
||||
)
|
||||
bottlenecks: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Process bottlenecks"
|
||||
)
|
||||
manual_tasks: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Manual/repetitive tasks"
|
||||
)
|
||||
automation_goals: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Desired automation goals"
|
||||
)
|
||||
|
||||
# Current tools
|
||||
current_software: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Software/tools currently used"
|
||||
)
|
||||
existing_automation: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Existing automations"
|
||||
)
|
||||
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = pydantic.Field(
|
||||
None, description="Any additional context"
|
||||
)
|
||||
|
||||
|
||||
class BusinessUnderstanding(pydantic.BaseModel):
|
||||
"""Full business understanding model returned from database."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# User info
|
||||
user_name: Optional[str] = None
|
||||
job_title: Optional[str] = None
|
||||
|
||||
# Business basics
|
||||
business_name: Optional[str] = None
|
||||
industry: Optional[str] = None
|
||||
business_size: Optional[str] = None
|
||||
user_role: Optional[str] = None
|
||||
|
||||
# Processes & activities
|
||||
key_workflows: list[str] = pydantic.Field(default_factory=list)
|
||||
daily_activities: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Pain points & goals
|
||||
pain_points: list[str] = pydantic.Field(default_factory=list)
|
||||
bottlenecks: list[str] = pydantic.Field(default_factory=list)
|
||||
manual_tasks: list[str] = pydantic.Field(default_factory=list)
|
||||
automation_goals: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Current tools
|
||||
current_software: list[str] = pydantic.Field(default_factory=list)
|
||||
existing_automation: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
|
||||
"""Convert database record to Pydantic model."""
|
||||
data = db_record.data if isinstance(db_record.data, dict) else {}
|
||||
business = (
|
||||
data.get("business", {}) if isinstance(data.get("business"), dict) else {}
|
||||
)
|
||||
return cls(
|
||||
id=db_record.id,
|
||||
user_id=db_record.userId,
|
||||
created_at=db_record.createdAt,
|
||||
updated_at=db_record.updatedAt,
|
||||
user_name=data.get("name"),
|
||||
job_title=business.get("job_title"),
|
||||
business_name=business.get("business_name"),
|
||||
industry=business.get("industry"),
|
||||
business_size=business.get("business_size"),
|
||||
user_role=business.get("user_role"),
|
||||
key_workflows=_json_to_list(business.get("key_workflows")),
|
||||
daily_activities=_json_to_list(business.get("daily_activities")),
|
||||
pain_points=_json_to_list(business.get("pain_points")),
|
||||
bottlenecks=_json_to_list(business.get("bottlenecks")),
|
||||
manual_tasks=_json_to_list(business.get("manual_tasks")),
|
||||
automation_goals=_json_to_list(business.get("automation_goals")),
|
||||
current_software=_json_to_list(business.get("current_software")),
|
||||
existing_automation=_json_to_list(business.get("existing_automation")),
|
||||
additional_notes=business.get("additional_notes"),
|
||||
)
|
||||
|
||||
|
||||
def _merge_lists(existing: list | None, new: list | None) -> list | None:
|
||||
"""Merge two lists, removing duplicates while preserving order."""
|
||||
if new is None:
|
||||
return existing
|
||||
if existing is None:
|
||||
return new
|
||||
# Preserve order, add new items that don't exist
|
||||
merged = list(existing)
|
||||
for item in new:
|
||||
if item not in merged:
|
||||
merged.append(item)
|
||||
return merged
|
||||
|
||||
|
||||
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
|
||||
"""Get business understanding from Redis cache."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
cached_data = await redis.get(_cache_key(user_id))
|
||||
if cached_data:
|
||||
return BusinessUnderstanding.model_validate_json(cached_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get understanding from cache: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _set_cache(user_id: str, understanding: BusinessUnderstanding) -> None:
|
||||
"""Set business understanding in Redis cache with TTL."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.setex(
|
||||
_cache_key(user_id),
|
||||
CACHE_TTL_SECONDS,
|
||||
understanding.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set understanding in cache: {e}")
|
||||
|
||||
|
||||
async def _delete_cache(user_id: str) -> None:
|
||||
"""Delete business understanding from Redis cache."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(_cache_key(user_id))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete understanding from cache: {e}")
|
||||
|
||||
|
||||
async def get_business_understanding(
|
||||
user_id: str,
|
||||
) -> Optional[BusinessUnderstanding]:
|
||||
"""Get the business understanding for a user.
|
||||
|
||||
Checks cache first, falls back to database if not cached.
|
||||
Results are cached for 48 hours.
|
||||
"""
|
||||
# Try cache first
|
||||
cached = await _get_from_cache(user_id)
|
||||
if cached:
|
||||
logger.debug(f"Business understanding cache hit for user {user_id}")
|
||||
return cached
|
||||
|
||||
# Cache miss - load from database
|
||||
logger.debug(f"Business understanding cache miss for user {user_id}")
|
||||
record = await CoPilotUnderstanding.prisma().find_unique(where={"userId": user_id})
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
|
||||
# Store in cache for next time
|
||||
await _set_cache(user_id, understanding)
|
||||
|
||||
return understanding
|
||||
|
||||
|
||||
async def upsert_business_understanding(
|
||||
user_id: str,
|
||||
input_data: BusinessUnderstandingInput,
|
||||
) -> BusinessUnderstanding:
|
||||
"""
|
||||
Create or update business understanding with incremental merge strategy.
|
||||
|
||||
- String fields: new value overwrites if provided (not None)
|
||||
- List fields: new items are appended to existing (deduplicated)
|
||||
|
||||
Data is stored as: {name: ..., business: {version: 1, ...}}
|
||||
"""
|
||||
# Get existing record for merge
|
||||
existing = await CoPilotUnderstanding.prisma().find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# Get existing data structure or start fresh
|
||||
existing_data: dict[str, Any] = {}
|
||||
if existing and isinstance(existing.data, dict):
|
||||
existing_data = dict(existing.data)
|
||||
|
||||
existing_business: dict[str, Any] = {}
|
||||
if isinstance(existing_data.get("business"), dict):
|
||||
existing_business = dict(existing_data["business"])
|
||||
|
||||
# Business fields (stored inside business object)
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
"industry",
|
||||
"business_size",
|
||||
"user_role",
|
||||
"additional_notes",
|
||||
]
|
||||
business_list_fields = [
|
||||
"key_workflows",
|
||||
"daily_activities",
|
||||
"pain_points",
|
||||
"bottlenecks",
|
||||
"manual_tasks",
|
||||
"automation_goals",
|
||||
"current_software",
|
||||
"existing_automation",
|
||||
]
|
||||
|
||||
# Handle top-level name field
|
||||
if input_data.user_name is not None:
|
||||
existing_data["name"] = input_data.user_name
|
||||
|
||||
# Business string fields - overwrite if provided
|
||||
for field in business_string_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_business[field] = value
|
||||
|
||||
# Business list fields - merge with existing
|
||||
for field in business_list_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_list = _json_to_list(existing_business.get(field))
|
||||
merged = _merge_lists(existing_list, value)
|
||||
existing_business[field] = merged
|
||||
|
||||
# Set version and nest business data
|
||||
existing_business["version"] = 1
|
||||
existing_data["business"] = existing_business
|
||||
|
||||
# Upsert with the merged data
|
||||
record = await CoPilotUnderstanding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(existing_data)},
|
||||
"update": {"data": SafeJson(existing_data)},
|
||||
},
|
||||
)
|
||||
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
|
||||
# Update cache with new understanding
|
||||
await _set_cache(user_id, understanding)
|
||||
|
||||
return understanding
|
||||
|
||||
|
||||
async def clear_business_understanding(user_id: str) -> bool:
|
||||
"""Clear/delete business understanding for a user from both DB and cache."""
|
||||
# Delete from cache first
|
||||
await _delete_cache(user_id)
|
||||
|
||||
try:
|
||||
await CoPilotUnderstanding.prisma().delete(where={"userId": user_id})
|
||||
return True
|
||||
except Exception:
|
||||
# Record might not exist
|
||||
return False
|
||||
|
||||
|
||||
def format_understanding_for_prompt(understanding: BusinessUnderstanding) -> str:
|
||||
"""Format business understanding as text for system prompt injection."""
|
||||
sections = []
|
||||
|
||||
# User info section
|
||||
user_info = []
|
||||
if understanding.user_name:
|
||||
user_info.append(f"Name: {understanding.user_name}")
|
||||
if understanding.job_title:
|
||||
user_info.append(f"Job Title: {understanding.job_title}")
|
||||
if user_info:
|
||||
sections.append("## User\n" + "\n".join(user_info))
|
||||
|
||||
# Business section
|
||||
business_info = []
|
||||
if understanding.business_name:
|
||||
business_info.append(f"Company: {understanding.business_name}")
|
||||
if understanding.industry:
|
||||
business_info.append(f"Industry: {understanding.industry}")
|
||||
if understanding.business_size:
|
||||
business_info.append(f"Size: {understanding.business_size}")
|
||||
if understanding.user_role:
|
||||
business_info.append(f"Role Context: {understanding.user_role}")
|
||||
if business_info:
|
||||
sections.append("## Business\n" + "\n".join(business_info))
|
||||
|
||||
# Processes section
|
||||
processes = []
|
||||
if understanding.key_workflows:
|
||||
processes.append(f"Key Workflows: {', '.join(understanding.key_workflows)}")
|
||||
if understanding.daily_activities:
|
||||
processes.append(
|
||||
f"Daily Activities: {', '.join(understanding.daily_activities)}"
|
||||
)
|
||||
if processes:
|
||||
sections.append("## Processes\n" + "\n".join(processes))
|
||||
|
||||
# Pain points section
|
||||
pain_points = []
|
||||
if understanding.pain_points:
|
||||
pain_points.append(f"Pain Points: {', '.join(understanding.pain_points)}")
|
||||
if understanding.bottlenecks:
|
||||
pain_points.append(f"Bottlenecks: {', '.join(understanding.bottlenecks)}")
|
||||
if understanding.manual_tasks:
|
||||
pain_points.append(f"Manual Tasks: {', '.join(understanding.manual_tasks)}")
|
||||
if pain_points:
|
||||
sections.append("## Pain Points\n" + "\n".join(pain_points))
|
||||
|
||||
# Goals section
|
||||
if understanding.automation_goals:
|
||||
sections.append(
|
||||
"## Automation Goals\n"
|
||||
+ "\n".join(f"- {goal}" for goal in understanding.automation_goals)
|
||||
)
|
||||
|
||||
# Current tools section
|
||||
tools_info = []
|
||||
if understanding.current_software:
|
||||
tools_info.append(
|
||||
f"Current Software: {', '.join(understanding.current_software)}"
|
||||
)
|
||||
if understanding.existing_automation:
|
||||
tools_info.append(
|
||||
f"Existing Automation: {', '.join(understanding.existing_automation)}"
|
||||
)
|
||||
if tools_info:
|
||||
sections.append("## Current Tools\n" + "\n".join(tools_info))
|
||||
|
||||
# Additional notes
|
||||
if understanding.additional_notes:
|
||||
sections.append(f"## Additional Context\n{understanding.additional_notes}")
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return "# User Business Context\n\n" + "\n\n".join(sections)
|
||||
@@ -7,6 +7,11 @@ from backend.api.features.library.db import (
|
||||
list_library_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
cleanup_orphaned_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -20,6 +25,7 @@ from backend.data.execution import (
|
||||
get_execution_kv_data,
|
||||
get_execution_outputs_by_node_exec_id,
|
||||
get_frequently_executed_graphs,
|
||||
get_graph_execution,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_graph_executions_count,
|
||||
@@ -57,6 +63,7 @@ from backend.data.notifications import (
|
||||
get_user_notification_oldest_message_in_batch,
|
||||
remove_notifications_from_batch,
|
||||
)
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_by_id,
|
||||
@@ -140,6 +147,7 @@ class DatabaseManager(AppService):
|
||||
get_child_graph_executions = _(get_child_graph_executions)
|
||||
get_graph_executions = _(get_graph_executions)
|
||||
get_graph_executions_count = _(get_graph_executions_count)
|
||||
get_graph_execution = _(get_graph_execution)
|
||||
get_graph_execution_meta = _(get_graph_execution_meta)
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_execution = _(get_node_execution)
|
||||
@@ -204,10 +212,18 @@ class DatabaseManager(AppService):
|
||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||
|
||||
# Onboarding
|
||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
|
||||
# Store Embeddings
|
||||
get_embedding_stats = _(get_embedding_stats)
|
||||
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
||||
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
|
||||
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
@@ -259,6 +275,11 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
get_store_agents = _(d.get_store_agents)
|
||||
get_store_agent_details = _(d.get_store_agent_details)
|
||||
|
||||
# Store Embeddings
|
||||
get_embedding_stats = _(d.get_embedding_stats)
|
||||
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
|
||||
cleanup_orphaned_embeddings = _(d.cleanup_orphaned_embeddings)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -274,6 +295,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_settings = d.get_graph_settings
|
||||
get_graph_execution = d.get_graph_execution
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
@@ -318,6 +340,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
add_store_agent_to_library = d.add_store_agent_to_library
|
||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||
|
||||
# Onboarding
|
||||
increment_onboarding_runs = d.increment_onboarding_runs
|
||||
|
||||
# Store
|
||||
get_store_agents = d.get_store_agents
|
||||
get_store_agent_details = d.get_store_agent_details
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
@@ -19,6 +20,17 @@ from backend.util.test import SpinTestServer, wait_execution
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
|
||||
logger.info(f"Creating graph for user {u.id}")
|
||||
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
@@ -27,7 +28,7 @@ from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.onboarding import increment_runs
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.monitoring import (
|
||||
NotificationJobArgs,
|
||||
@@ -37,7 +38,7 @@ from backend.monitoring import (
|
||||
report_execution_accuracy_alerts,
|
||||
report_late_executions,
|
||||
)
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.clients import get_database_manager_client, get_scheduler_client
|
||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||
from backend.util.exceptions import (
|
||||
GraphNotFoundError,
|
||||
@@ -156,7 +157,7 @@ async def _execute_graph(**kwargs):
|
||||
inputs=args.input_data,
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
)
|
||||
await increment_runs(args.user_id)
|
||||
await increment_onboarding_runs(args.user_id)
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
logger.info(
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||
@@ -254,6 +255,114 @@ def execution_accuracy_alerts():
|
||||
return report_execution_accuracy_alerts()
|
||||
|
||||
|
||||
def ensure_embeddings_coverage():
|
||||
"""
|
||||
Ensure all content types (store agents, blocks, docs) have embeddings for search.
|
||||
|
||||
Processes ALL missing embeddings in batches of 10 per content type until 100% coverage.
|
||||
Missing embeddings = content invisible in hybrid search.
|
||||
|
||||
Schedule: Runs every 6 hours (balanced between coverage and API costs).
|
||||
- Catches new content added between scheduled runs
|
||||
- Batch size 10 per content type: gradual processing to avoid rate limits
|
||||
- Manual trigger available via execute_ensure_embeddings_coverage endpoint
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
stats = db_client.get_embedding_stats()
|
||||
|
||||
# Check for error from get_embedding_stats() first
|
||||
if "error" in stats:
|
||||
logger.error(
|
||||
f"Failed to get embedding stats: {stats['error']} - skipping backfill"
|
||||
)
|
||||
return {
|
||||
"backfill": {"processed": 0, "success": 0, "failed": 0},
|
||||
"cleanup": {"deleted": 0},
|
||||
"error": stats["error"],
|
||||
}
|
||||
|
||||
# Extract totals from new stats structure
|
||||
totals = stats.get("totals", {})
|
||||
without_embeddings = totals.get("without_embeddings", 0)
|
||||
coverage_percent = totals.get("coverage_percent", 0)
|
||||
|
||||
total_processed = 0
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
|
||||
if without_embeddings == 0:
|
||||
logger.info("All content has embeddings, skipping backfill")
|
||||
else:
|
||||
# Log per-content-type stats for visibility
|
||||
by_type = stats.get("by_type", {})
|
||||
for content_type, type_stats in by_type.items():
|
||||
if type_stats.get("without_embeddings", 0) > 0:
|
||||
logger.info(
|
||||
f"{content_type}: {type_stats['without_embeddings']} items without embeddings "
|
||||
f"({type_stats['coverage_percent']}% coverage)"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Total: {without_embeddings} items without embeddings "
|
||||
f"({coverage_percent}% coverage) - processing all"
|
||||
)
|
||||
|
||||
# Process in batches until no more missing embeddings
|
||||
while True:
|
||||
result = db_client.backfill_missing_embeddings(batch_size=10)
|
||||
|
||||
total_processed += result["processed"]
|
||||
total_success += result["success"]
|
||||
total_failed += result["failed"]
|
||||
|
||||
if result["processed"] == 0:
|
||||
# No more missing embeddings
|
||||
break
|
||||
|
||||
if result["success"] == 0 and result["processed"] > 0:
|
||||
# All attempts in this batch failed - stop to avoid infinite loop
|
||||
logger.error(
|
||||
f"All {result['processed']} embedding attempts failed - stopping backfill"
|
||||
)
|
||||
break
|
||||
|
||||
# Small delay between batches to avoid rate limits
|
||||
time.sleep(1)
|
||||
|
||||
logger.info(
|
||||
f"Embedding backfill completed: {total_success}/{total_processed} succeeded, "
|
||||
f"{total_failed} failed"
|
||||
)
|
||||
|
||||
# Clean up orphaned embeddings for blocks and docs
|
||||
logger.info("Running cleanup for orphaned embeddings (blocks/docs)...")
|
||||
cleanup_result = db_client.cleanup_orphaned_embeddings()
|
||||
cleanup_totals = cleanup_result.get("totals", {})
|
||||
cleanup_deleted = cleanup_totals.get("deleted", 0)
|
||||
|
||||
if cleanup_deleted > 0:
|
||||
logger.info(f"Cleanup completed: deleted {cleanup_deleted} orphaned embeddings")
|
||||
by_type = cleanup_result.get("by_type", {})
|
||||
for content_type, type_result in by_type.items():
|
||||
if type_result.get("deleted", 0) > 0:
|
||||
logger.info(
|
||||
f"{content_type}: deleted {type_result['deleted']} orphaned embeddings"
|
||||
)
|
||||
else:
|
||||
logger.info("Cleanup completed: no orphaned embeddings found")
|
||||
|
||||
return {
|
||||
"backfill": {
|
||||
"processed": total_processed,
|
||||
"success": total_success,
|
||||
"failed": total_failed,
|
||||
},
|
||||
"cleanup": {
|
||||
"deleted": cleanup_deleted,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Monitoring functions are now imported from monitoring module
|
||||
|
||||
|
||||
@@ -475,6 +584,19 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# Embedding Coverage - Every 6 hours
|
||||
# Ensures all approved agents have embeddings for hybrid search
|
||||
# Critical: missing embeddings = agents invisible in search
|
||||
self.scheduler.add_job(
|
||||
ensure_embeddings_coverage,
|
||||
id="ensure_embeddings_coverage",
|
||||
trigger="interval",
|
||||
hours=6,
|
||||
replace_existing=True,
|
||||
max_instances=1, # Prevent overlapping runs
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
|
||||
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
||||
@@ -632,6 +754,11 @@ class Scheduler(AppService):
|
||||
"""Manually trigger execution accuracy alert checking."""
|
||||
return execution_accuracy_alerts()
|
||||
|
||||
@expose
|
||||
def execute_ensure_embeddings_coverage(self):
|
||||
"""Manually trigger embedding backfill for approved store agents."""
|
||||
return ensure_embeddings_coverage()
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import onboarding as onboarding_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -31,7 +32,6 @@ from backend.data.execution import (
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
NodesInputMasks,
|
||||
get_graph_execution,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput
|
||||
@@ -809,13 +809,14 @@ async def add_graph_execution(
|
||||
edb = execution_db
|
||||
udb = user_db
|
||||
gdb = graph_db
|
||||
odb = onboarding_db
|
||||
else:
|
||||
edb = udb = gdb = get_database_manager_async_client()
|
||||
edb = udb = gdb = odb = get_database_manager_async_client()
|
||||
|
||||
# Get or create the graph execution
|
||||
if graph_exec_id:
|
||||
# Resume existing execution
|
||||
graph_exec = await get_graph_execution(
|
||||
graph_exec = await edb.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=True,
|
||||
@@ -891,6 +892,7 @@ async def add_graph_execution(
|
||||
)
|
||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||
|
||||
# Publish to execution queue for executor to pick up
|
||||
exec_queue = await get_async_execution_queue()
|
||||
await exec_queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
@@ -899,14 +901,12 @@ async def add_graph_execution(
|
||||
)
|
||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||
|
||||
# Update execution status to QUEUED
|
||||
graph_exec.status = ExecutionStatus.QUEUED
|
||||
await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=graph_exec.status,
|
||||
)
|
||||
await get_async_execution_event_bus().publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except BaseException as e:
|
||||
err = str(e) or type(e).__name__
|
||||
if not graph_exec:
|
||||
@@ -927,6 +927,24 @@ async def add_graph_execution(
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
await get_async_execution_event_bus().publish(graph_exec)
|
||||
logger.info(f"Published update for execution #{graph_exec.id} to event bus")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish execution event for graph exec #{graph_exec.id}: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
await odb.increment_onboarding_runs(user_id)
|
||||
logger.info(
|
||||
f"Incremented user #{user_id} onboarding runs for exec #{graph_exec.id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to increment onboarding runs for user #{user_id}: {e}")
|
||||
|
||||
return graph_exec
|
||||
|
||||
|
||||
# ============ Execution Output Helpers ============ #
|
||||
|
||||
|
||||
@@ -245,6 +245,21 @@ DEFAULT_CREDENTIALS = [
|
||||
webshare_proxy_credentials,
|
||||
]
|
||||
|
||||
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
||||
|
||||
# Set of providers that have system credentials available
|
||||
SYSTEM_PROVIDERS = {cred.provider for cred in DEFAULT_CREDENTIALS}
|
||||
|
||||
|
||||
def is_system_credential(credential_id: str) -> bool:
|
||||
"""Check if a credential ID belongs to a system-managed credential."""
|
||||
return credential_id in SYSTEM_CREDENTIAL_IDS
|
||||
|
||||
|
||||
def is_system_provider(provider: str) -> bool:
|
||||
"""Check if a provider has system-managed credentials available."""
|
||||
return provider in SYSTEM_PROVIDERS
|
||||
|
||||
|
||||
class IntegrationCredentialsStore:
|
||||
def __init__(self):
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.util.settings import Settings
|
||||
settings = Settings()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
from supabase import AClient, Client
|
||||
|
||||
from backend.data.execution import (
|
||||
@@ -139,6 +140,24 @@ async def get_async_supabase() -> "AClient":
|
||||
)
|
||||
|
||||
|
||||
# ============ OpenAI Client ============ #
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_openai_client() -> "AsyncOpenAI | None":
|
||||
"""
|
||||
Get a process-cached async OpenAI client for embeddings.
|
||||
|
||||
Returns None if API key is not configured.
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
api_key = settings.secrets.openai_internal_api_key
|
||||
if not api_key:
|
||||
return None
|
||||
return AsyncOpenAI(api_key=api_key)
|
||||
|
||||
|
||||
# ============ Notification Queue Helpers ============ #
|
||||
|
||||
|
||||
|
||||
@@ -658,6 +658,14 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
|
||||
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
|
||||
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
|
||||
|
||||
# Langfuse prompt management
|
||||
langfuse_public_key: str = Field(default="", description="Langfuse public key")
|
||||
langfuse_secret_key: str = Field(default="", description="Langfuse secret key")
|
||||
langfuse_host: str = Field(
|
||||
default="https://cloud.langfuse.com", description="Langfuse host URL"
|
||||
)
|
||||
|
||||
# Add more secret fields as needed
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
-- CreateExtension
|
||||
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
||||
-- Create in public schema so vector type is available across all schemas
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "vector" WITH SCHEMA "public";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
||||
END $$;
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "ContentType" AS ENUM ('STORE_AGENT', 'BLOCK', 'INTEGRATION', 'DOCUMENTATION', 'LIBRARY_AGENT');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "UnifiedContentEmbedding" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"contentType" "ContentType" NOT NULL,
|
||||
"contentId" TEXT NOT NULL,
|
||||
"userId" TEXT,
|
||||
"embedding" public.vector(1536) NOT NULL,
|
||||
"searchableText" TEXT NOT NULL,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "UnifiedContentEmbedding_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UnifiedContentEmbedding_contentType_idx" ON "UnifiedContentEmbedding"("contentType");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UnifiedContentEmbedding_userId_idx" ON "UnifiedContentEmbedding"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UnifiedContentEmbedding_contentType_userId_idx" ON "UnifiedContentEmbedding"("contentType", "userId");
|
||||
|
||||
-- CreateIndex
|
||||
-- NULLS NOT DISTINCT ensures only one public (NULL userId) embedding per contentType+contentId
|
||||
-- Requires PostgreSQL 15+. Supabase uses PostgreSQL 15+.
|
||||
CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" ON "UnifiedContentEmbedding"("contentType", "contentId", "userId") NULLS NOT DISTINCT;
|
||||
|
||||
-- CreateIndex
|
||||
-- HNSW index for fast vector similarity search on embeddings
|
||||
-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py
|
||||
-- Note: Drop first in case Prisma created a btree index (Prisma doesn't support HNSW)
|
||||
DROP INDEX IF EXISTS "UnifiedContentEmbedding_embedding_idx";
|
||||
CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops);
|
||||
@@ -0,0 +1,71 @@
|
||||
-- Acknowledge Supabase-managed extensions to prevent drift warnings
|
||||
-- These extensions are pre-installed by Supabase in specific schemas
|
||||
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
|
||||
|
||||
-- Create schemas (safe in both CI and Supabase)
|
||||
CREATE SCHEMA IF NOT EXISTS "extensions";
|
||||
|
||||
-- Extensions that exist in both CI and Supabase
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgcrypto extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'uuid-ossp extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
-- Supabase-specific extensions (skip gracefully in CI)
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_net extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgjwt extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "graphql";
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_graphql extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "pgsodium";
|
||||
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgsodium extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "vault";
|
||||
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'supabase_vault extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
|
||||
-- Return to platform
|
||||
CREATE SCHEMA IF NOT EXISTS "platform";
|
||||
@@ -0,0 +1,64 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "CoPilotUnderstanding" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT NOT NULL,
|
||||
"data" JSONB,
|
||||
|
||||
CONSTRAINT "CoPilotUnderstanding_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ChatSession" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT NOT NULL,
|
||||
"title" TEXT,
|
||||
"credentials" JSONB NOT NULL DEFAULT '{}',
|
||||
"successfulAgentRuns" JSONB NOT NULL DEFAULT '{}',
|
||||
"successfulAgentSchedules" JSONB NOT NULL DEFAULT '{}',
|
||||
"totalPromptTokens" INTEGER NOT NULL DEFAULT 0,
|
||||
"totalCompletionTokens" INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
CONSTRAINT "ChatSession_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ChatMessage" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"sessionId" TEXT NOT NULL,
|
||||
"role" TEXT NOT NULL,
|
||||
"content" TEXT,
|
||||
"name" TEXT,
|
||||
"toolCallId" TEXT,
|
||||
"refusal" TEXT,
|
||||
"toolCalls" JSONB,
|
||||
"functionCall" JSONB,
|
||||
"sequence" INTEGER NOT NULL,
|
||||
|
||||
CONSTRAINT "ChatMessage_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "CoPilotUnderstanding_userId_key" ON "CoPilotUnderstanding"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "CoPilotUnderstanding_userId_idx" ON "CoPilotUnderstanding"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ChatSession_userId_updatedAt_idx" ON "ChatSession"("userId", "updatedAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "ChatMessage_sessionId_sequence_key" ON "ChatMessage"("sessionId", "sequence");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "CoPilotUnderstanding" ADD CONSTRAINT "CoPilotUnderstanding_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ChatSession" ADD CONSTRAINT "ChatSession_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ChatMessage" ADD CONSTRAINT "ChatMessage_sessionId_fkey" FOREIGN KEY ("sessionId") REFERENCES "ChatSession"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,35 @@
|
||||
-- Add tsvector search column to UnifiedContentEmbedding for unified full-text search
|
||||
-- This enables hybrid search (semantic + lexical) across all content types
|
||||
|
||||
-- Add search column (IF NOT EXISTS for idempotency)
|
||||
ALTER TABLE "UnifiedContentEmbedding" ADD COLUMN IF NOT EXISTS "search" tsvector DEFAULT ''::tsvector;
|
||||
|
||||
-- Create GIN index for fast full-text search
|
||||
-- No @@index in schema.prisma - Prisma may generate DROP INDEX on migrate dev
|
||||
-- If that happens, just let it drop and this migration will recreate it, or manually re-run:
|
||||
-- CREATE INDEX IF NOT EXISTS "UnifiedContentEmbedding_search_idx" ON "UnifiedContentEmbedding" USING GIN ("search");
|
||||
DROP INDEX IF EXISTS "UnifiedContentEmbedding_search_idx";
|
||||
CREATE INDEX "UnifiedContentEmbedding_search_idx" ON "UnifiedContentEmbedding" USING GIN ("search");
|
||||
|
||||
-- Drop existing trigger/function if exists
|
||||
DROP TRIGGER IF EXISTS "update_unified_tsvector" ON "UnifiedContentEmbedding";
|
||||
DROP FUNCTION IF EXISTS update_unified_tsvector_column();
|
||||
|
||||
-- Create function to auto-update tsvector from searchableText
|
||||
CREATE OR REPLACE FUNCTION update_unified_tsvector_column() RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.search := to_tsvector('english', COALESCE(NEW."searchableText", ''));
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql SECURITY DEFINER SET search_path = platform, pg_temp;
|
||||
|
||||
-- Create trigger to auto-update search column on insert/update
|
||||
CREATE TRIGGER "update_unified_tsvector"
|
||||
BEFORE INSERT OR UPDATE ON "UnifiedContentEmbedding"
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_unified_tsvector_column();
|
||||
|
||||
-- Backfill existing rows
|
||||
UPDATE "UnifiedContentEmbedding"
|
||||
SET search = to_tsvector('english', COALESCE("searchableText", ''))
|
||||
WHERE search IS NULL OR search = ''::tsvector;
|
||||
@@ -0,0 +1,90 @@
|
||||
-- Remove the old search column from StoreListingVersion
|
||||
-- This column has been replaced by UnifiedContentEmbedding.search
|
||||
-- which provides unified hybrid search across all content types
|
||||
|
||||
-- First drop the dependent view
|
||||
DROP VIEW IF EXISTS "StoreAgent";
|
||||
|
||||
-- Drop the trigger and function for old search column
|
||||
-- The original trigger was created in 20251016093049_add_full_text_search
|
||||
DROP TRIGGER IF EXISTS "update_tsvector" ON "StoreListingVersion";
|
||||
DROP FUNCTION IF EXISTS update_tsvector_column();
|
||||
|
||||
-- Drop the index
|
||||
DROP INDEX IF EXISTS "StoreListingVersion_search_idx";
|
||||
|
||||
-- NOTE: Keeping search column for now to allow easy revert if needed
|
||||
-- Uncomment to fully remove once migration is verified in production:
|
||||
-- ALTER TABLE "StoreListingVersion" DROP COLUMN IF EXISTS "search";
|
||||
|
||||
-- Recreate the StoreAgent view WITHOUT the search column
|
||||
-- (Search now handled by UnifiedContentEmbedding)
|
||||
CREATE OR REPLACE VIEW "StoreAgent" AS
|
||||
WITH latest_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
MAX(version) AS max_version
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
),
|
||||
agent_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
),
|
||||
agent_graph_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
array_agg(DISTINCT "agentGraphVersion"::text ORDER BY "agentGraphVersion"::text) AS graph_versions
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
)
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
slv.id AS "storeListingVersionId",
|
||||
slv."createdAt" AS updated_at,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS agent_name,
|
||||
slv."videoUrl" AS agent_video,
|
||||
slv."agentOutputDemoUrl" AS agent_output_demo,
|
||||
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
|
||||
slv."isFeatured" AS featured,
|
||||
p.username AS creator_username,
|
||||
p."avatarUrl" AS creator_avatar,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.categories,
|
||||
COALESCE(ar.run_count, 0::bigint) AS runs,
|
||||
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
|
||||
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions,
|
||||
COALESCE(agv.graph_versions, ARRAY[slv."agentGraphVersion"::text]) AS "agentGraphVersions",
|
||||
slv."agentGraphId",
|
||||
slv."isAvailable" AS is_available,
|
||||
COALESCE(sl."useForOnboarding", false) AS "useForOnboarding"
|
||||
FROM "StoreListing" sl
|
||||
JOIN latest_versions lv
|
||||
ON sl.id = lv."storeListingId"
|
||||
JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = lv."storeListingId"
|
||||
AND slv.version = lv.max_version
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
JOIN "AgentGraph" a
|
||||
ON slv."agentGraphId" = a.id
|
||||
AND slv."agentGraphVersion" = a.version
|
||||
LEFT JOIN "Profile" p
|
||||
ON sl."owningUserId" = p."userId"
|
||||
LEFT JOIN "mv_review_stats" rs
|
||||
ON sl.id = rs."storeListingId"
|
||||
LEFT JOIN "mv_agent_run_counts" ar
|
||||
ON a.id = ar."agentGraphId"
|
||||
LEFT JOIN agent_versions av
|
||||
ON sl.id = av."storeListingId"
|
||||
LEFT JOIN agent_graph_versions agv
|
||||
ON sl.id = agv."storeListingId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true;
|
||||
201
autogpt_platform/backend/poetry.lock
generated
201
autogpt_platform/backend/poetry.lock
generated
@@ -2777,6 +2777,30 @@ enabler = ["pytest-enabler (>=2.2)"]
|
||||
test = ["pyfakefs", "pytest (>=6,!=8.1.*)"]
|
||||
type = ["pygobject-stubs", "pytest-mypy", "shtab", "types-pywin32"]
|
||||
|
||||
[[package]]
|
||||
name = "langfuse"
|
||||
version = "3.11.2"
|
||||
description = "A client library for accessing langfuse"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "langfuse-3.11.2-py3-none-any.whl", hash = "sha256:84faea9f909694023cc7f0eb45696be190248c8790424f22af57ca4cd7a29f2d"},
|
||||
{file = "langfuse-3.11.2.tar.gz", hash = "sha256:ab5f296a8056815b7288c7f25bc308a5e79f82a8634467b25daffdde99276e09"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
backoff = ">=1.10.0"
|
||||
httpx = ">=0.15.4,<1.0"
|
||||
openai = ">=0.27.8"
|
||||
opentelemetry-api = ">=1.33.1,<2.0.0"
|
||||
opentelemetry-exporter-otlp-proto-http = ">=1.33.1,<2.0.0"
|
||||
opentelemetry-sdk = ">=1.33.1,<2.0.0"
|
||||
packaging = ">=23.2,<26.0"
|
||||
pydantic = ">=1.10.7,<3.0"
|
||||
requests = ">=2,<3"
|
||||
wrapt = ">=1.14,<2.0"
|
||||
|
||||
[[package]]
|
||||
name = "launchdarkly-eventsource"
|
||||
version = "1.3.0"
|
||||
@@ -3468,6 +3492,90 @@ files = [
|
||||
importlib-metadata = ">=6.0,<8.8.0"
|
||||
typing-extensions = ">=4.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-exporter-otlp-proto-common"
|
||||
version = "1.35.0"
|
||||
description = "OpenTelemetry Protobuf encoding"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "opentelemetry_exporter_otlp_proto_common-1.35.0-py3-none-any.whl", hash = "sha256:863465de697ae81279ede660f3918680b4480ef5f69dcdac04f30722ed7b74cc"},
|
||||
{file = "opentelemetry_exporter_otlp_proto_common-1.35.0.tar.gz", hash = "sha256:6f6d8c39f629b9fa5c79ce19a2829dbd93034f8ac51243cdf40ed2196f00d7eb"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-proto = "1.35.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-exporter-otlp-proto-http"
|
||||
version = "1.35.0"
|
||||
description = "OpenTelemetry Collector Protobuf over HTTP Exporter"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "opentelemetry_exporter_otlp_proto_http-1.35.0-py3-none-any.whl", hash = "sha256:9a001e3df3c7f160fb31056a28ed7faa2de7df68877ae909516102ae36a54e1d"},
|
||||
{file = "opentelemetry_exporter_otlp_proto_http-1.35.0.tar.gz", hash = "sha256:cf940147f91b450ef5f66e9980d40eb187582eed399fa851f4a7a45bb880de79"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
googleapis-common-protos = ">=1.52,<2.0"
|
||||
opentelemetry-api = ">=1.15,<2.0"
|
||||
opentelemetry-exporter-otlp-proto-common = "1.35.0"
|
||||
opentelemetry-proto = "1.35.0"
|
||||
opentelemetry-sdk = ">=1.35.0,<1.36.0"
|
||||
requests = ">=2.7,<3.0"
|
||||
typing-extensions = ">=4.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-proto"
|
||||
version = "1.35.0"
|
||||
description = "OpenTelemetry Python Proto"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "opentelemetry_proto-1.35.0-py3-none-any.whl", hash = "sha256:98fffa803164499f562718384e703be8d7dfbe680192279a0429cb150a2f8809"},
|
||||
{file = "opentelemetry_proto-1.35.0.tar.gz", hash = "sha256:532497341bd3e1c074def7c5b00172601b28bb83b48afc41a4b779f26eb4ee05"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
protobuf = ">=5.0,<7.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-sdk"
|
||||
version = "1.35.0"
|
||||
description = "OpenTelemetry Python SDK"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "opentelemetry_sdk-1.35.0-py3-none-any.whl", hash = "sha256:223d9e5f5678518f4842311bb73966e0b6db5d1e0b74e35074c052cd2487f800"},
|
||||
{file = "opentelemetry_sdk-1.35.0.tar.gz", hash = "sha256:2a400b415ab68aaa6f04e8a6a9f6552908fb3090ae2ff78d6ae0c597ac581954"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = "1.35.0"
|
||||
opentelemetry-semantic-conventions = "0.56b0"
|
||||
typing-extensions = ">=4.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-semantic-conventions"
|
||||
version = "0.56b0"
|
||||
description = "OpenTelemetry Semantic Conventions"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "opentelemetry_semantic_conventions-0.56b0-py3-none-any.whl", hash = "sha256:df44492868fd6b482511cc43a942e7194be64e94945f572db24df2e279a001a2"},
|
||||
{file = "opentelemetry_semantic_conventions-0.56b0.tar.gz", hash = "sha256:c114c2eacc8ff6d3908cb328c811eaf64e6d68623840be9224dc829c4fd6c2ea"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = "1.35.0"
|
||||
typing-extensions = ">=4.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "orjson"
|
||||
version = "3.11.3"
|
||||
@@ -6922,6 +7030,97 @@ files = [
|
||||
{file = "websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wrapt"
|
||||
version = "1.17.3"
|
||||
description = "Module for decorators, wrappers and monkey patching."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "wrapt-1.17.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:88bbae4d40d5a46142e70d58bf664a89b6b4befaea7b2ecc14e03cedb8e06c04"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6b13af258d6a9ad602d57d889f83b9d5543acd471eee12eb51f5b01f8eb1bc2"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd341868a4b6714a5962c1af0bd44f7c404ef78720c7de4892901e540417111c"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f9b2601381be482f70e5d1051a5965c25fb3625455a2bf520b5a077b22afb775"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343e44b2a8e60e06a7e0d29c1671a0d9951f59174f3709962b5143f60a2a98bd"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:33486899acd2d7d3066156b03465b949da3fd41a5da6e394ec49d271baefcf05"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e6f40a8aa5a92f150bdb3e1c44b7e98fb7113955b2e5394122fa5532fec4b418"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-win32.whl", hash = "sha256:a36692b8491d30a8c75f1dfee65bef119d6f39ea84ee04d9f9311f83c5ad9390"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-win_amd64.whl", hash = "sha256:afd964fd43b10c12213574db492cb8f73b2f0826c8df07a68288f8f19af2ebe6"},
|
||||
{file = "wrapt-1.17.3-cp310-cp310-win_arm64.whl", hash = "sha256:af338aa93554be859173c39c85243970dc6a289fa907402289eeae7543e1ae18"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:273a736c4645e63ac582c60a56b0acb529ef07f78e08dc6bfadf6a46b19c0da7"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5531d911795e3f935a9c23eb1c8c03c211661a5060aab167065896bbf62a5f85"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0610b46293c59a3adbae3dee552b648b984176f8562ee0dba099a56cfbe4df1f"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b32888aad8b6e68f83a8fdccbf3165f5469702a7544472bdf41f582970ed3311"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cccf4f81371f257440c88faed6b74f1053eef90807b77e31ca057b2db74edb1"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8a210b158a34164de8bb68b0e7780041a903d7b00c87e906fb69928bf7890d5"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:79573c24a46ce11aab457b472efd8d125e5a51da2d1d24387666cd85f54c05b2"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-win32.whl", hash = "sha256:c31eebe420a9a5d2887b13000b043ff6ca27c452a9a22fa71f35f118e8d4bf89"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-win_amd64.whl", hash = "sha256:0b1831115c97f0663cb77aa27d381237e73ad4f721391a9bfb2fe8bc25fa6e77"},
|
||||
{file = "wrapt-1.17.3-cp311-cp311-win_arm64.whl", hash = "sha256:5a7b3c1ee8265eb4c8f1b7d29943f195c00673f5ab60c192eba2d4a7eae5f46a"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ab232e7fdb44cdfbf55fc3afa31bcdb0d8980b9b95c38b6405df2acb672af0e0"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9baa544e6acc91130e926e8c802a17f3b16fbea0fd441b5a60f5cf2cc5c3deba"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b538e31eca1a7ea4605e44f81a48aa24c4632a277431a6ed3f328835901f4fd"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:042ec3bb8f319c147b1301f2393bc19dba6e176b7da446853406d041c36c7828"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3af60380ba0b7b5aeb329bc4e402acd25bd877e98b3727b0135cb5c2efdaefe9"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b02e424deef65c9f7326d8c19220a2c9040c51dc165cddb732f16198c168396"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:74afa28374a3c3a11b3b5e5fca0ae03bef8450d6aa3ab3a1e2c30e3a75d023dc"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-win32.whl", hash = "sha256:4da9f45279fff3543c371d5ababc57a0384f70be244de7759c85a7f989cb4ebe"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:e71d5c6ebac14875668a1e90baf2ea0ef5b7ac7918355850c0908ae82bcb297c"},
|
||||
{file = "wrapt-1.17.3-cp312-cp312-win_arm64.whl", hash = "sha256:604d076c55e2fdd4c1c03d06dc1a31b95130010517b5019db15365ec4a405fc6"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a47681378a0439215912ef542c45a783484d4dd82bac412b71e59cf9c0e1cea0"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:54a30837587c6ee3cd1a4d1c2ec5d24e77984d44e2f34547e2323ddb4e22eb77"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ecf15d6af39246fe33e507105d67e4b81d8f8d2c6598ff7e3ca1b8a37213f7"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6fd1ad24dc235e4ab88cda009e19bf347aabb975e44fd5c2fb22a3f6e4141277"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ed61b7c2d49cee3c027372df5809a59d60cf1b6c2f81ee980a091f3afed6a2d"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:423ed5420ad5f5529db9ce89eac09c8a2f97da18eb1c870237e84c5a5c2d60aa"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e01375f275f010fcbf7f643b4279896d04e571889b8a5b3f848423d91bf07050"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-win32.whl", hash = "sha256:53e5e39ff71b3fc484df8a522c933ea2b7cdd0d5d15ae82e5b23fde87d44cbd8"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-win_amd64.whl", hash = "sha256:1f0b2f40cf341ee8cc1a97d51ff50dddb9fcc73241b9143ec74b30fc4f44f6cb"},
|
||||
{file = "wrapt-1.17.3-cp313-cp313-win_arm64.whl", hash = "sha256:7425ac3c54430f5fc5e7b6f41d41e704db073309acfc09305816bc6a0b26bb16"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:cf30f6e3c077c8e6a9a7809c94551203c8843e74ba0c960f4a98cd80d4665d39"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e228514a06843cae89621384cfe3a80418f3c04aadf8a3b14e46a7be704e4235"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:5ea5eb3c0c071862997d6f3e02af1d055f381b1d25b286b9d6644b79db77657c"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:281262213373b6d5e4bb4353bc36d1ba4084e6d6b5d242863721ef2bf2c2930b"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dc4a8d2b25efb6681ecacad42fca8859f88092d8732b170de6a5dddd80a1c8fa"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:373342dd05b1d07d752cecbec0c41817231f29f3a89aa8b8843f7b95992ed0c7"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d40770d7c0fd5cbed9d84b2c3f2e156431a12c9a37dc6284060fb4bec0b7ffd4"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-win32.whl", hash = "sha256:fbd3c8319de8e1dc79d346929cd71d523622da527cca14e0c1d257e31c2b8b10"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-win_amd64.whl", hash = "sha256:e1a4120ae5705f673727d3253de3ed0e016f7cd78dc463db1b31e2463e1f3cf6"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314-win_arm64.whl", hash = "sha256:507553480670cab08a800b9463bdb881b2edeed77dc677b0a5915e6106e91a58"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:ed7c635ae45cfbc1a7371f708727bf74690daedc49b4dba310590ca0bd28aa8a"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:249f88ed15503f6492a71f01442abddd73856a0032ae860de6d75ca62eed8067"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5a03a38adec8066d5a37bea22f2ba6bbf39fcdefbe2d91419ab864c3fb515454"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5d4478d72eb61c36e5b446e375bbc49ed002430d17cdec3cecb36993398e1a9e"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223db574bb38637e8230eb14b185565023ab624474df94d2af18f1cdb625216f"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e405adefb53a435f01efa7ccdec012c016b5a1d3f35459990afc39b6be4d5056"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:88547535b787a6c9ce4086917b6e1d291aa8ed914fdd3a838b3539dc95c12804"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-win32.whl", hash = "sha256:41b1d2bc74c2cac6f9074df52b2efbef2b30bdfe5f40cb78f8ca22963bc62977"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-win_amd64.whl", hash = "sha256:73d496de46cd2cdbdbcce4ae4bcdb4afb6a11234a1df9c085249d55166b95116"},
|
||||
{file = "wrapt-1.17.3-cp314-cp314t-win_arm64.whl", hash = "sha256:f38e60678850c42461d4202739f9bf1e3a737c7ad283638251e79cc49effb6b6"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:70d86fa5197b8947a2fa70260b48e400bf2ccacdcab97bb7de47e3d1e6312225"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:df7d30371a2accfe4013e90445f6388c570f103d61019b6b7c57e0265250072a"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:caea3e9c79d5f0d2c6d9ab96111601797ea5da8e6d0723f77eabb0d4068d2b2f"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:758895b01d546812d1f42204bd443b8c433c44d090248bf22689df673ccafe00"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:02b551d101f31694fc785e58e0720ef7d9a10c4e62c1c9358ce6f63f23e30a56"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:656873859b3b50eeebe6db8b1455e99d90c26ab058db8e427046dbc35c3140a5"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:a9a2203361a6e6404f80b99234fe7fb37d1fc73487b5a78dc1aa5b97201e0f22"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-win32.whl", hash = "sha256:55cbbc356c2842f39bcc553cf695932e8b30e30e797f961860afb308e6b1bb7c"},
|
||||
{file = "wrapt-1.17.3-cp38-cp38-win_amd64.whl", hash = "sha256:ad85e269fe54d506b240d2d7b9f5f2057c2aa9a2ea5b32c66f8902f768117ed2"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:30ce38e66630599e1193798285706903110d4f057aab3168a34b7fdc85569afc"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:65d1d00fbfb3ea5f20add88bbc0f815150dbbde3b026e6c24759466c8b5a9ef9"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a7c06742645f914f26c7f1fa47b8bc4c91d222f76ee20116c43d5ef0912bba2d"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7e18f01b0c3e4a07fe6dfdb00e29049ba17eadbc5e7609a2a3a4af83ab7d710a"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0f5f51a6466667a5a356e6381d362d259125b57f059103dd9fdc8c0cf1d14139"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:59923aa12d0157f6b82d686c3fd8e1166fa8cdfb3e17b42ce3b6147ff81528df"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:46acc57b331e0b3bcb3e1ca3b421d65637915cfcd65eb783cb2f78a511193f9b"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-win32.whl", hash = "sha256:3e62d15d3cfa26e3d0788094de7b64efa75f3a53875cdbccdf78547aed547a81"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-win_amd64.whl", hash = "sha256:1f23fa283f51c890eda8e34e4937079114c74b4c81d2b2f1f1d94948f5cc3d7f"},
|
||||
{file = "wrapt-1.17.3-cp39-cp39-win_arm64.whl", hash = "sha256:24c2ed34dc222ed754247a2702b1e1e89fdbaa4016f324b4b8f1a802d4ffe87f"},
|
||||
{file = "wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22"},
|
||||
{file = "wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xattr"
|
||||
version = "1.2.0"
|
||||
@@ -7295,4 +7494,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "a93ba0cea3b465cb6ec3e3f258b383b09f84ea352ccfdbfa112902cde5653fc6"
|
||||
content-hash = "86838b5ae40d606d6e01a14dad8a56c389d890d7a6a0c274a6602cca80f0df84"
|
||||
|
||||
@@ -33,6 +33,7 @@ html2text = "^2024.2.26"
|
||||
jinja2 = "^3.1.6"
|
||||
jsonref = "^1.1.0"
|
||||
jsonschema = "^4.25.0"
|
||||
langfuse = "^3.11.0"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
mem0ai = "^0.1.115"
|
||||
moviepy = "^2.1.2"
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
datasource db {
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
extensions = [pgvector(map: "vector")]
|
||||
}
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = -1
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views", "fullTextSearch"]
|
||||
previewFeatures = ["views", "fullTextSearch", "postgresqlExtensions"]
|
||||
partial_type_generator = "backend/data/partial_types.py"
|
||||
}
|
||||
|
||||
@@ -47,12 +48,13 @@ model User {
|
||||
AnalyticsMetrics AnalyticsMetrics[]
|
||||
CreditTransactions CreditTransaction[]
|
||||
UserBalance UserBalance?
|
||||
|
||||
AgentPresets AgentPreset[]
|
||||
LibraryAgents LibraryAgent[]
|
||||
ChatSessions ChatSession[]
|
||||
AgentPresets AgentPreset[]
|
||||
LibraryAgents LibraryAgent[]
|
||||
|
||||
Profile Profile[]
|
||||
UserOnboarding UserOnboarding?
|
||||
CoPilotUnderstanding CoPilotUnderstanding?
|
||||
BuilderSearchHistory BuilderSearchHistory[]
|
||||
StoreListings StoreListing[]
|
||||
StoreListingReviews StoreListingReview[]
|
||||
@@ -121,19 +123,84 @@ model UserOnboarding {
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
}
|
||||
|
||||
model CoPilotUnderstanding {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
userId String @unique
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
data Json?
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
model BuilderSearchHistory {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
searchQuery String
|
||||
filter String[] @default([])
|
||||
byCreator String[] @default([])
|
||||
filter String[] @default([])
|
||||
byCreator String[] @default([])
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
//////////////// CHAT SESSION TABLES ///////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
model ChatSession {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
// Session metadata
|
||||
title String?
|
||||
credentials Json @default("{}") // Map of provider -> credential metadata
|
||||
|
||||
// Rate limiting counters (stored as JSON maps)
|
||||
successfulAgentRuns Json @default("{}") // Map of graph_id -> count
|
||||
successfulAgentSchedules Json @default("{}") // Map of graph_id -> count
|
||||
|
||||
// Usage tracking
|
||||
totalPromptTokens Int @default(0)
|
||||
totalCompletionTokens Int @default(0)
|
||||
|
||||
Messages ChatMessage[]
|
||||
|
||||
@@index([userId, updatedAt])
|
||||
}
|
||||
|
||||
model ChatMessage {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
sessionId String
|
||||
Session ChatSession @relation(fields: [sessionId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Message content
|
||||
role String // "user", "assistant", "system", "tool", "function"
|
||||
content String?
|
||||
name String?
|
||||
toolCallId String?
|
||||
refusal String?
|
||||
toolCalls Json? // List of tool calls for assistant messages
|
||||
functionCall Json? // Deprecated but kept for compatibility
|
||||
|
||||
// Ordering within session
|
||||
sequence Int
|
||||
|
||||
@@unique([sessionId, sequence])
|
||||
}
|
||||
|
||||
// This model describes the Agent Graph/Flow (Multi Agent System).
|
||||
model AgentGraph {
|
||||
id String @default(uuid())
|
||||
@@ -721,26 +788,25 @@ view StoreAgent {
|
||||
storeListingVersionId String
|
||||
updated_at DateTime
|
||||
|
||||
slug String
|
||||
agent_name String
|
||||
agent_video String?
|
||||
agent_output_demo String?
|
||||
agent_image String[]
|
||||
slug String
|
||||
agent_name String
|
||||
agent_video String?
|
||||
agent_output_demo String?
|
||||
agent_image String[]
|
||||
|
||||
featured Boolean @default(false)
|
||||
creator_username String?
|
||||
creator_avatar String?
|
||||
sub_heading String
|
||||
description String
|
||||
categories String[]
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
agentGraphVersions String[]
|
||||
agentGraphId String
|
||||
is_available Boolean @default(true)
|
||||
useForOnboarding Boolean @default(false)
|
||||
featured Boolean @default(false)
|
||||
creator_username String?
|
||||
creator_avatar String?
|
||||
sub_heading String
|
||||
description String
|
||||
categories String[]
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
agentGraphVersions String[]
|
||||
agentGraphId String
|
||||
is_available Boolean @default(true)
|
||||
useForOnboarding Boolean @default(false)
|
||||
|
||||
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
||||
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
||||
@@ -856,14 +922,14 @@ model StoreListingVersion {
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
|
||||
|
||||
// Content fields
|
||||
name String
|
||||
subHeading String
|
||||
videoUrl String?
|
||||
agentOutputDemoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
name String
|
||||
subHeading String
|
||||
videoUrl String?
|
||||
agentOutputDemoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
|
||||
isFeatured Boolean @default(false)
|
||||
|
||||
@@ -871,7 +937,7 @@ model StoreListingVersion {
|
||||
// Old versions can be made unavailable by the author if desired
|
||||
isAvailable Boolean @default(true)
|
||||
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
||||
// Note: search column removed - now using UnifiedContentEmbedding.search
|
||||
|
||||
// Version workflow state
|
||||
submissionStatus SubmissionStatus @default(DRAFT)
|
||||
@@ -899,6 +965,9 @@ model StoreListingVersion {
|
||||
// Reviews for this specific version
|
||||
Reviews StoreListingReview[]
|
||||
|
||||
// Note: Embeddings now stored in UnifiedContentEmbedding table
|
||||
// Use contentType=STORE_AGENT and contentId=storeListingVersionId
|
||||
|
||||
@@unique([storeListingId, version])
|
||||
@@index([storeListingId, submissionStatus, isAvailable])
|
||||
@@index([submissionStatus])
|
||||
@@ -906,6 +975,45 @@ model StoreListingVersion {
|
||||
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
|
||||
}
|
||||
|
||||
// Content type enum for unified search across store agents, blocks, docs
|
||||
// Note: BLOCK/INTEGRATION are file-based (Python classes), not DB records
|
||||
// DOCUMENTATION are file-based (.md files), not DB records
|
||||
// Only STORE_AGENT and LIBRARY_AGENT are stored in database
|
||||
enum ContentType {
|
||||
STORE_AGENT // Database: StoreListingVersion
|
||||
BLOCK // File-based: Python classes in /backend/blocks/
|
||||
INTEGRATION // File-based: Python classes (blocks with credentials)
|
||||
DOCUMENTATION // File-based: .md/.mdx files
|
||||
LIBRARY_AGENT // Database: User's personal agents
|
||||
}
|
||||
|
||||
// Unified embeddings table for all searchable content types
|
||||
// Supports both public content (userId=null) and user-specific content (userId=userID)
|
||||
model UnifiedContentEmbedding {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Content identification
|
||||
contentType ContentType
|
||||
contentId String // DB ID (storeListingVersionId) or file identifier (block.id, file_path)
|
||||
userId String? // NULL for public content (store, blocks, docs), userId for private content (library agents)
|
||||
|
||||
// Search data
|
||||
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
|
||||
searchableText String // Combined text for search and fallback
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) // Full-text search (auto-populated by trigger)
|
||||
metadata Json @default("{}") // Content-specific metadata
|
||||
|
||||
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
|
||||
@@index([contentType])
|
||||
@@index([userId])
|
||||
@@index([contentType, userId])
|
||||
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
|
||||
// NO @@index for search - GIN index "UnifiedContentEmbedding_search_idx" created via SQL migration
|
||||
// Prisma may generate DROP INDEX on migrate dev - that's okay, migration recreates it
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
@@ -998,16 +1106,16 @@ model OAuthApplication {
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Application metadata
|
||||
name String
|
||||
description String?
|
||||
logoUrl String? // URL to app logo stored in GCS
|
||||
clientId String @unique
|
||||
clientSecret String // Hashed with Scrypt (same as API keys)
|
||||
clientSecretSalt String // Salt for Scrypt hashing
|
||||
name String
|
||||
description String?
|
||||
logoUrl String? // URL to app logo stored in GCS
|
||||
clientId String @unique
|
||||
clientSecret String // Hashed with Scrypt (same as API keys)
|
||||
clientSecretSalt String // Salt for Scrypt hashing
|
||||
|
||||
// OAuth configuration
|
||||
redirectUris String[] // Allowed callback URLs
|
||||
grantTypes String[] @default(["authorization_code", "refresh_token"])
|
||||
grantTypes String[] @default(["authorization_code", "refresh_token"])
|
||||
scopes APIKeyPermission[] // Which permissions the app can request
|
||||
|
||||
// Application management
|
||||
|
||||
81
autogpt_platform/frontend/CLAUDE.md
Normal file
81
autogpt_platform/frontend/CLAUDE.md
Normal file
@@ -0,0 +1,81 @@
|
||||
# CLAUDE.md - Frontend
|
||||
|
||||
This file provides guidance to Claude Code when working with the frontend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
|
||||
# Run E2E tests
|
||||
pnpm test
|
||||
|
||||
# Run Storybook for component development
|
||||
pnpm storybook
|
||||
|
||||
# Build production
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
```
|
||||
|
||||
**📖 Complete Guide**: See @CONTRIBUTING.md and @.cursorrules for comprehensive frontend patterns.
|
||||
|
||||
## Key Conventions
|
||||
|
||||
- Separate render logic from data/behavior in components
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Only use Phosphor Icons
|
||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
||||
|
||||
### Code Style
|
||||
|
||||
- Fully capitalize acronyms in symbols, e.g. `graphID`, `useBackendAPI`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
`.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Feature Development
|
||||
|
||||
See @CONTRIBUTING.md for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Extract component logic into custom hooks grouped by concern, not by component. Each hook should represent a cohesive domain of functionality (e.g., useSearch, useFilters, usePagination) rather than bundling all state into one useComponentState hook.
|
||||
- Put each hook in its own `.ts` file
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
@@ -708,10 +708,7 @@ export function CreateButton() {
|
||||
|
||||
## 🧪 Testing & Storybook
|
||||
|
||||
- End-to-end: [Playwright](https://playwright.dev/docs/intro) (`pnpm test`, `pnpm test-ui`)
|
||||
- [Storybook](https://storybook.js.org/docs) for isolated UI development (`pnpm storybook` / `pnpm build-storybook`)
|
||||
- For Storybook tests in CI, see [`@storybook/test-runner`](https://storybook.js.org/docs/writing-tests/test-runner) (`test-storybook:ci`)
|
||||
- When changing components in `src/components`, update or add stories and visually verify in Storybook/Chromatic
|
||||
- See `TESTING.md` for Playwright setup, E2E data seeding, and Storybook usage.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ This is the frontend for AutoGPT's next generation
|
||||
This project uses [**pnpm**](https://pnpm.io/) as the package manager via **corepack**. [Corepack](https://github.com/nodejs/corepack) is a Node.js tool that automatically manages package managers without requiring global installations.
|
||||
|
||||
For architecture, conventions, data fetching, feature flags, design system usage, state management, and PR process, see [CONTRIBUTING.md](./CONTRIBUTING.md).
|
||||
For Playwright and Storybook testing setup, see [TESTING.md](./TESTING.md).
|
||||
|
||||
### Prerequisites
|
||||
|
||||
|
||||
57
autogpt_platform/frontend/TESTING.md
Normal file
57
autogpt_platform/frontend/TESTING.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Frontend Testing 🧪
|
||||
|
||||
## Quick Start (local) 🚀
|
||||
|
||||
1. Start the backend + Supabase stack:
|
||||
- From `autogpt_platform`: `docker compose --profile local up deps_backend -d`
|
||||
- Or run the full stack: `docker compose up -d`
|
||||
2. Seed rich E2E data (creates `test123@gmail.com` with library agents):
|
||||
- From `autogpt_platform/backend`: `poetry run python test/e2e_test_data.py`
|
||||
3. Run Playwright:
|
||||
- From `autogpt_platform/frontend`: `pnpm test` or `pnpm test-ui`
|
||||
|
||||
## How Playwright setup works 🎭
|
||||
|
||||
- Playwright runs from `frontend/playwright.config.ts` with a global setup step.
|
||||
- The global setup creates a user pool via the real signup UI and stores it in `frontend/.auth/user-pool.json`.
|
||||
- Most tests call `getTestUser()` (from `src/tests/utils/auth.ts`) which pulls a random user from that pool.
|
||||
- these users do not contain library agents, it's user that just "signed up" on the platform, hence some tests to make use of users created via script (see below) with more data
|
||||
|
||||
## Test users 👤
|
||||
|
||||
- **User pool (basic users)**
|
||||
Created automatically by the Playwright global setup through `/signup`.
|
||||
Used by `getTestUser()` in `src/tests/utils/auth.ts`.
|
||||
|
||||
- **Rich user with library agents**
|
||||
Created by `backend/test/e2e_test_data.py`.
|
||||
Accessed via `getTestUserWithLibraryAgents()` in `src/tests/credentials/index.ts`.
|
||||
|
||||
Use the rich user when a test needs existing library agents (e.g. `library.spec.ts`).
|
||||
|
||||
## Resetting or wiping the DB 🔁
|
||||
|
||||
If you reset the Docker DB and logins start failing:
|
||||
|
||||
1. Delete `frontend/.auth/user-pool.json` so the pool is regenerated.
|
||||
2. Re-run the E2E data script to recreate the rich user + library agents:
|
||||
- `poetry run python test/e2e_test_data.py`
|
||||
|
||||
## Storybook 📚
|
||||
|
||||
## Flow diagram 🗺️
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Start Docker stack] --> B[Run e2e_test_data.py]
|
||||
B --> C[Run Playwright tests]
|
||||
C --> D[Global setup creates user pool]
|
||||
D --> E{Test needs rich data?}
|
||||
E -->|No| F[getTestUser from user pool]
|
||||
E -->|Yes| G[getTestUserWithLibraryAgents]
|
||||
```
|
||||
|
||||
- `pnpm storybook` – Run Storybook locally
|
||||
- `pnpm build-storybook` – Build a static Storybook
|
||||
- CI runner: `pnpm test-storybook`
|
||||
- When changing components in `src/components`, update or add stories and verify in Storybook/Chromatic.
|
||||
@@ -3,6 +3,13 @@ import { withSentryConfig } from "@sentry/nextjs";
|
||||
/** @type {import('next').NextConfig} */
|
||||
const nextConfig = {
|
||||
productionBrowserSourceMaps: true,
|
||||
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
|
||||
serverExternalPackages: [
|
||||
"@opentelemetry/instrumentation",
|
||||
"@opentelemetry/sdk-node",
|
||||
"import-in-the-middle",
|
||||
"require-in-the-middle",
|
||||
],
|
||||
experimental: {
|
||||
serverActions: {
|
||||
bodySizeLimit: "256mb",
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
"@hookform/resolvers": "5.2.2",
|
||||
"@next/third-parties": "15.4.6",
|
||||
"@phosphor-icons/react": "2.1.10",
|
||||
"@radix-ui/react-accordion": "1.2.12",
|
||||
"@radix-ui/react-alert-dialog": "1.1.15",
|
||||
"@radix-ui/react-avatar": "1.1.10",
|
||||
"@radix-ui/react-checkbox": "1.3.3",
|
||||
@@ -117,6 +118,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "4.1.2",
|
||||
"@opentelemetry/instrumentation": "0.209.0",
|
||||
"@playwright/test": "1.56.1",
|
||||
"@storybook/addon-a11y": "9.1.5",
|
||||
"@storybook/addon-docs": "9.1.5",
|
||||
@@ -140,6 +142,7 @@
|
||||
"eslint": "8.57.1",
|
||||
"eslint-config-next": "15.5.7",
|
||||
"eslint-plugin-storybook": "9.1.5",
|
||||
"import-in-the-middle": "2.0.2",
|
||||
"msw": "2.11.6",
|
||||
"msw-storybook-addon": "2.0.6",
|
||||
"orval": "7.13.0",
|
||||
@@ -147,7 +150,7 @@
|
||||
"postcss": "8.5.6",
|
||||
"prettier": "3.6.2",
|
||||
"prettier-plugin-tailwindcss": "0.7.1",
|
||||
"require-in-the-middle": "7.5.2",
|
||||
"require-in-the-middle": "8.0.1",
|
||||
"storybook": "9.1.5",
|
||||
"tailwindcss": "3.4.17",
|
||||
"typescript": "5.9.3"
|
||||
@@ -157,5 +160,10 @@
|
||||
"public"
|
||||
]
|
||||
},
|
||||
"pnpm": {
|
||||
"overrides": {
|
||||
"@opentelemetry/instrumentation": "0.209.0"
|
||||
}
|
||||
},
|
||||
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
||||
}
|
||||
|
||||
140
autogpt_platform/frontend/pnpm-lock.yaml
generated
140
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -4,6 +4,9 @@ settings:
|
||||
autoInstallPeers: true
|
||||
excludeLinksFromLockfile: false
|
||||
|
||||
overrides:
|
||||
'@opentelemetry/instrumentation': 0.209.0
|
||||
|
||||
importers:
|
||||
|
||||
.:
|
||||
@@ -20,6 +23,9 @@ importers:
|
||||
'@phosphor-icons/react':
|
||||
specifier: 2.1.10
|
||||
version: 2.1.10(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@radix-ui/react-accordion':
|
||||
specifier: 1.2.12
|
||||
version: 1.2.12(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@radix-ui/react-alert-dialog':
|
||||
specifier: 1.1.15
|
||||
version: 1.1.15(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
@@ -270,6 +276,9 @@ importers:
|
||||
'@chromatic-com/storybook':
|
||||
specifier: 4.1.2
|
||||
version: 4.1.2(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))
|
||||
'@opentelemetry/instrumentation':
|
||||
specifier: 0.209.0
|
||||
version: 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@playwright/test':
|
||||
specifier: 1.56.1
|
||||
version: 1.56.1
|
||||
@@ -339,6 +348,9 @@ importers:
|
||||
eslint-plugin-storybook:
|
||||
specifier: 9.1.5
|
||||
version: 9.1.5(eslint@8.57.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))(typescript@5.9.3)
|
||||
import-in-the-middle:
|
||||
specifier: 2.0.2
|
||||
version: 2.0.2
|
||||
msw:
|
||||
specifier: 2.11.6
|
||||
version: 2.11.6(@types/node@24.10.0)(typescript@5.9.3)
|
||||
@@ -361,8 +373,8 @@ importers:
|
||||
specifier: 0.7.1
|
||||
version: 0.7.1(prettier@3.6.2)
|
||||
require-in-the-middle:
|
||||
specifier: 7.5.2
|
||||
version: 7.5.2
|
||||
specifier: 8.0.1
|
||||
version: 8.0.1
|
||||
storybook:
|
||||
specifier: 9.1.5
|
||||
version: 9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)
|
||||
@@ -1543,8 +1555,8 @@ packages:
|
||||
'@open-draft/until@2.1.0':
|
||||
resolution: {integrity: sha512-U69T3ItWHvLwGg5eJ0n3I62nWuE6ilHlmz7zM0npLBRvPRd7e6NYmg54vvRtP5mZG7kZqZCFVdsTWo7BPtBujg==}
|
||||
|
||||
'@opentelemetry/api-logs@0.208.0':
|
||||
resolution: {integrity: sha512-CjruKY9V6NMssL/T1kAFgzosF1v9o6oeN+aX5JB/C/xPNtmgIJqcXHG7fA82Ou1zCpWGl4lROQUKwUNE1pMCyg==}
|
||||
'@opentelemetry/api-logs@0.209.0':
|
||||
resolution: {integrity: sha512-xomnUNi7TiAGtOgs0tb54LyrjRZLu9shJGGwkcN7NgtiPYOpNnKLkRJtzZvTjD/w6knSZH9sFZcUSUovYOPg6A==}
|
||||
engines: {node: '>=8.0.0'}
|
||||
|
||||
'@opentelemetry/api@1.9.0':
|
||||
@@ -1695,8 +1707,8 @@ packages:
|
||||
peerDependencies:
|
||||
'@opentelemetry/api': ^1.7.0
|
||||
|
||||
'@opentelemetry/instrumentation@0.208.0':
|
||||
resolution: {integrity: sha512-Eju0L4qWcQS+oXxi6pgh7zvE2byogAkcsVv0OjHF/97iOz1N/aKE6etSGowYkie+YA1uo6DNwdSxaaNnLvcRlA==}
|
||||
'@opentelemetry/instrumentation@0.209.0':
|
||||
resolution: {integrity: sha512-Cwe863ojTCnFlxVuuhG7s6ODkAOzKsAEthKAcI4MDRYz1OmGWYnmSl4X2pbyS+hBxVTdvfZePfoEA01IjqcEyw==}
|
||||
engines: {node: ^18.19.0 || >=20.6.0}
|
||||
peerDependencies:
|
||||
'@opentelemetry/api': ^1.3.0
|
||||
@@ -1810,6 +1822,19 @@ packages:
|
||||
'@radix-ui/primitive@1.1.3':
|
||||
resolution: {integrity: sha512-JTF99U/6XIjCBo0wqkU5sK10glYe27MRRsfwoiq5zzOEZLHU3A3KCMa5X/azekYRCJ0HlwI0crAXS/5dEHTzDg==}
|
||||
|
||||
'@radix-ui/react-accordion@1.2.12':
|
||||
resolution: {integrity: sha512-T4nygeh9YE9dLRPhAHSeOZi7HBXo+0kYIPJXayZfvWOWA0+n3dESrZbjfDPUABkUNym6Hd+f2IR113To8D2GPA==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
'@types/react-dom': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc
|
||||
react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
'@types/react-dom':
|
||||
optional: true
|
||||
|
||||
'@radix-ui/react-alert-dialog@1.1.15':
|
||||
resolution: {integrity: sha512-oTVLkEw5GpdRe29BqJ0LSDFWI3qu0vR1M0mUkOQWDIUnY/QIkLpgDMWuKxP94c2NAC2LGcgVhG1ImF3jkZ5wXw==}
|
||||
peerDependencies:
|
||||
@@ -2631,7 +2656,7 @@ packages:
|
||||
'@opentelemetry/api': ^1.9.0
|
||||
'@opentelemetry/context-async-hooks': ^1.30.1 || ^2.1.0 || ^2.2.0
|
||||
'@opentelemetry/core': ^1.30.1 || ^2.1.0 || ^2.2.0
|
||||
'@opentelemetry/instrumentation': '>=0.57.1 <1'
|
||||
'@opentelemetry/instrumentation': 0.209.0
|
||||
'@opentelemetry/resources': ^1.30.1 || ^2.1.0 || ^2.2.0
|
||||
'@opentelemetry/sdk-trace-base': ^1.30.1 || ^2.1.0 || ^2.2.0
|
||||
'@opentelemetry/semantic-conventions': ^1.37.0
|
||||
@@ -4957,8 +4982,8 @@ packages:
|
||||
resolution: {integrity: sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==}
|
||||
engines: {node: '>=6'}
|
||||
|
||||
import-in-the-middle@2.0.1:
|
||||
resolution: {integrity: sha512-bruMpJ7xz+9jwGzrwEhWgvRrlKRYCRDBrfU+ur3FcasYXLJDxTruJ//8g2Noj+QFyRBeqbpj8Bhn4Fbw6HjvhA==}
|
||||
import-in-the-middle@2.0.2:
|
||||
resolution: {integrity: sha512-qet/hkGt3EbNGVtbDfPu0BM+tCqBS8wT1SYrstPaDKoWtshsC6licOemz7DVtpBEyvDNzo8UTBf9/GwWuSDZ9w==}
|
||||
|
||||
imurmurhash@0.1.4:
|
||||
resolution: {integrity: sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==}
|
||||
@@ -6502,10 +6527,6 @@ packages:
|
||||
resolution: {integrity: sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==}
|
||||
engines: {node: '>=0.10.0'}
|
||||
|
||||
require-in-the-middle@7.5.2:
|
||||
resolution: {integrity: sha512-gAZ+kLqBdHarXB64XpAe2VCjB7rIRv+mU8tfRWziHRJ5umKsIHN2tLLv6EtMw7WCdP19S0ERVMldNvxYCHnhSQ==}
|
||||
engines: {node: '>=8.6.0'}
|
||||
|
||||
require-in-the-middle@8.0.1:
|
||||
resolution: {integrity: sha512-QT7FVMXfWOYFbeRBF6nu+I6tr2Tf3u0q8RIEjNob/heKY/nh7drD/k7eeMFmSQgnTtCzLDcCu/XEnpW2wk4xCQ==}
|
||||
engines: {node: '>=9.3.0 || >=8.10.0 <9.0.0'}
|
||||
@@ -8716,7 +8737,7 @@ snapshots:
|
||||
|
||||
'@open-draft/until@2.1.0': {}
|
||||
|
||||
'@opentelemetry/api-logs@0.208.0':
|
||||
'@opentelemetry/api-logs@0.209.0':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
|
||||
@@ -8735,7 +8756,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -8743,7 +8764,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
'@types/connect': 3.4.38
|
||||
transitivePeerDependencies:
|
||||
@@ -8752,7 +8773,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-dataloader@0.26.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -8760,7 +8781,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8769,21 +8790,21 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@opentelemetry/instrumentation-generic-pool@0.52.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@opentelemetry/instrumentation-graphql@0.56.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -8791,7 +8812,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8800,7 +8821,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
forwarded-parse: 2.1.2
|
||||
transitivePeerDependencies:
|
||||
@@ -8809,7 +8830,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-ioredis@0.56.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/redis-common': 0.38.2
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8817,7 +8838,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-kafkajs@0.18.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8825,7 +8846,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-knex@0.53.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8834,7 +8855,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8842,14 +8863,14 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-lru-memoizer@0.53.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@opentelemetry/instrumentation-mongodb@0.61.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -8857,14 +8878,14 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@opentelemetry/instrumentation-mysql2@0.55.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
'@opentelemetry/sql-common': 0.41.2(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
@@ -8873,7 +8894,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-mysql@0.54.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@types/mysql': 2.15.27
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8882,7 +8903,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
'@opentelemetry/sql-common': 0.41.2(@opentelemetry/api@1.9.0)
|
||||
'@types/pg': 8.15.6
|
||||
@@ -8893,7 +8914,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-redis@0.57.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/redis-common': 0.38.2
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
@@ -8902,7 +8923,7 @@ snapshots:
|
||||
'@opentelemetry/instrumentation-tedious@0.27.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@types/tedious': 4.0.14
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8911,16 +8932,16 @@ snapshots:
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@opentelemetry/instrumentation@0.208.0(@opentelemetry/api@1.9.0)':
|
||||
'@opentelemetry/instrumentation@0.209.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/api-logs': 0.208.0
|
||||
import-in-the-middle: 2.0.1
|
||||
'@opentelemetry/api-logs': 0.209.0
|
||||
import-in-the-middle: 2.0.2
|
||||
require-in-the-middle: 8.0.1
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -9100,7 +9121,7 @@ snapshots:
|
||||
'@prisma/instrumentation@6.19.0(@opentelemetry/api@1.9.0)':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -9108,6 +9129,23 @@ snapshots:
|
||||
|
||||
'@radix-ui/primitive@1.1.3': {}
|
||||
|
||||
'@radix-ui/react-accordion@1.2.12(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
|
||||
dependencies:
|
||||
'@radix-ui/primitive': 1.1.3
|
||||
'@radix-ui/react-collapsible': 1.1.12(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@radix-ui/react-collection': 1.1.7(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.17)(react@18.3.1)
|
||||
'@radix-ui/react-context': 1.1.2(@types/react@18.3.17)(react@18.3.1)
|
||||
'@radix-ui/react-direction': 1.1.1(@types/react@18.3.17)(react@18.3.1)
|
||||
'@radix-ui/react-id': 1.1.1(@types/react@18.3.17)(react@18.3.1)
|
||||
'@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@radix-ui/react-use-controllable-state': 1.2.2(@types/react@18.3.17)(react@18.3.1)
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
optionalDependencies:
|
||||
'@types/react': 18.3.17
|
||||
'@types/react-dom': 18.3.5(@types/react@18.3.17)
|
||||
|
||||
'@radix-ui/react-alert-dialog@1.1.15(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
|
||||
dependencies:
|
||||
'@radix-ui/primitive': 1.1.3
|
||||
@@ -9932,19 +9970,19 @@ snapshots:
|
||||
- supports-color
|
||||
- webpack
|
||||
|
||||
'@sentry/node-core@10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/instrumentation@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)':
|
||||
'@sentry/node-core@10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/instrumentation@0.209.0(@opentelemetry/api@1.9.0))(@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)':
|
||||
dependencies:
|
||||
'@apm-js-collab/tracing-hooks': 0.3.1
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/context-async-hooks': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/resources': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/sdk-trace-base': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
'@sentry/core': 10.27.0
|
||||
'@sentry/opentelemetry': 10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)
|
||||
import-in-the-middle: 2.0.1
|
||||
import-in-the-middle: 2.0.2
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -9953,7 +9991,7 @@ snapshots:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/context-async-hooks': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.208.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation': 0.209.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation-amqplib': 0.55.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation-connect': 0.52.0(@opentelemetry/api@1.9.0)
|
||||
'@opentelemetry/instrumentation-dataloader': 0.26.0(@opentelemetry/api@1.9.0)
|
||||
@@ -9981,9 +10019,9 @@ snapshots:
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
'@prisma/instrumentation': 6.19.0(@opentelemetry/api@1.9.0)
|
||||
'@sentry/core': 10.27.0
|
||||
'@sentry/node-core': 10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/instrumentation@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)
|
||||
'@sentry/node-core': 10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/instrumentation@0.209.0(@opentelemetry/api@1.9.0))(@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)
|
||||
'@sentry/opentelemetry': 10.27.0(@opentelemetry/api@1.9.0)(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/semantic-conventions@1.38.0)
|
||||
import-in-the-middle: 2.0.1
|
||||
import-in-the-middle: 2.0.2
|
||||
minimatch: 9.0.5
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -12792,7 +12830,7 @@ snapshots:
|
||||
parent-module: 1.0.1
|
||||
resolve-from: 4.0.0
|
||||
|
||||
import-in-the-middle@2.0.1:
|
||||
import-in-the-middle@2.0.2:
|
||||
dependencies:
|
||||
acorn: 8.15.0
|
||||
acorn-import-attributes: 1.9.5(acorn@8.15.0)
|
||||
@@ -14631,14 +14669,6 @@ snapshots:
|
||||
|
||||
require-from-string@2.0.2: {}
|
||||
|
||||
require-in-the-middle@7.5.2:
|
||||
dependencies:
|
||||
debug: 4.4.3
|
||||
module-details-from-path: 1.0.4
|
||||
resolve: 1.22.11
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
require-in-the-middle@8.0.1:
|
||||
dependencies:
|
||||
debug: 4.4.3
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput";
|
||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||
import { useState } from "react";
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useState, useMemo, useRef } from "react";
|
||||
import { AuthCard } from "@/components/auth/AuthCard";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInput";
|
||||
import { useGetOauthGetOauthAppInfo } from "@/app/api/__generated__/endpoints/oauth/oauth";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { AuthCard } from "@/components/auth/AuthCard";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs";
|
||||
import type {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
CredentialsType,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { CheckIcon, CircleIcon } from "@phosphor-icons/react";
|
||||
import { useGetOauthGetOauthAppInfo } from "@/app/api/__generated__/endpoints/oauth/oauth";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useMemo, useRef, useState } from "react";
|
||||
|
||||
// All credential types - we accept any type of credential
|
||||
const ALL_CREDENTIAL_TYPES: CredentialsType[] = [
|
||||
|
||||
@@ -10,7 +10,10 @@ export const BuilderActions = memo(() => {
|
||||
flowID: parseAsString,
|
||||
});
|
||||
return (
|
||||
<div className="absolute bottom-4 left-[50%] z-[100] flex -translate-x-1/2 items-center gap-4 rounded-full bg-white p-2 px-2 shadow-lg">
|
||||
<div
|
||||
data-id="builder-actions"
|
||||
className="absolute bottom-4 left-[50%] z-[100] flex -translate-x-1/2 items-center gap-4 rounded-full bg-white p-2 px-2 shadow-lg"
|
||||
>
|
||||
<AgentOutputs flowID={flowID} />
|
||||
<RunGraph flowID={flowID} />
|
||||
<ScheduleGraph flowID={flowID} />
|
||||
|
||||
@@ -79,6 +79,7 @@ export const AgentOutputs = ({ flowID }: { flowID: string | null }) => {
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
data-id="agent-outputs-button"
|
||||
disabled={!flowID || !hasOutputs()}
|
||||
>
|
||||
<BookOpenIcon className="size-4" />
|
||||
|
||||
@@ -31,6 +31,7 @@ export const RunGraph = ({ flowID }: { flowID: string | null }) => {
|
||||
<Button
|
||||
size="icon"
|
||||
variant={isGraphRunning ? "destructive" : "primary"}
|
||||
data-id={isGraphRunning ? "stop-graph-button" : "run-graph-button"}
|
||||
onClick={isGraphRunning ? handleStopGraph : handleRunGraph}
|
||||
disabled={!flowID || isExecutingGraph || isTerminatingGraph}
|
||||
loading={isExecutingGraph || isTerminatingGraph || isSaving}
|
||||
|
||||
@@ -7,10 +7,11 @@ import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
|
||||
import { GraphExecutionMeta } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/use-agent-runs";
|
||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useSaveGraph } from "@/app/(platform)/build/hooks/useSaveGraph";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { ApiError } from "@/lib/autogpt-server-api/helpers"; // Check if this exists
|
||||
import { useTutorialStore } from "@/app/(platform)/build/stores/tutorialStore";
|
||||
|
||||
export const useRunGraph = () => {
|
||||
const { saveGraph, isSaving } = useSaveGraph({
|
||||
@@ -33,6 +34,29 @@ export const useRunGraph = () => {
|
||||
useShallow((state) => state.clearAllNodeErrors),
|
||||
);
|
||||
|
||||
// Tutorial integration - force open dialog when tutorial requests it
|
||||
const forceOpenRunInputDialog = useTutorialStore(
|
||||
(state) => state.forceOpenRunInputDialog,
|
||||
);
|
||||
const setForceOpenRunInputDialog = useTutorialStore(
|
||||
(state) => state.setForceOpenRunInputDialog,
|
||||
);
|
||||
|
||||
// Sync tutorial state with dialog state
|
||||
useEffect(() => {
|
||||
if (forceOpenRunInputDialog && !openRunInputDialog) {
|
||||
setOpenRunInputDialog(true);
|
||||
}
|
||||
}, [forceOpenRunInputDialog, openRunInputDialog]);
|
||||
|
||||
// Reset tutorial state when dialog closes
|
||||
const handleSetOpenRunInputDialog = (isOpen: boolean) => {
|
||||
setOpenRunInputDialog(isOpen);
|
||||
if (!isOpen && forceOpenRunInputDialog) {
|
||||
setForceOpenRunInputDialog(false);
|
||||
}
|
||||
};
|
||||
|
||||
const [{ flowID, flowVersion, flowExecutionID }, setQueryStates] =
|
||||
useQueryStates({
|
||||
flowID: parseAsString,
|
||||
@@ -138,6 +162,6 @@ export const useRunGraph = () => {
|
||||
isExecutingGraph,
|
||||
isTerminatingGraph,
|
||||
openRunInputDialog,
|
||||
setOpenRunInputDialog,
|
||||
setOpenRunInputDialog: handleSetOpenRunInputDialog,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -8,6 +8,8 @@ import { Text } from "@/components/atoms/Text/Text";
|
||||
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
|
||||
import { useRunInputDialog } from "./useRunInputDialog";
|
||||
import { CronSchedulerDialog } from "../CronSchedulerDialog/CronSchedulerDialog";
|
||||
import { useTutorialStore } from "@/app/(platform)/build/stores/tutorialStore";
|
||||
import { useEffect } from "react";
|
||||
|
||||
export const RunInputDialog = ({
|
||||
isOpen,
|
||||
@@ -37,6 +39,21 @@ export const RunInputDialog = ({
|
||||
isExecutingGraph,
|
||||
} = useRunInputDialog({ setIsOpen });
|
||||
|
||||
// Tutorial integration - track input values for the tutorial
|
||||
const setTutorialInputValues = useTutorialStore(
|
||||
(state) => state.setTutorialInputValues,
|
||||
);
|
||||
const isTutorialRunning = useTutorialStore(
|
||||
(state) => state.isTutorialRunning,
|
||||
);
|
||||
|
||||
// Update tutorial store when input values change
|
||||
useEffect(() => {
|
||||
if (isTutorialRunning) {
|
||||
setTutorialInputValues(inputValues);
|
||||
}
|
||||
}, [inputValues, isTutorialRunning, setTutorialInputValues]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Dialog
|
||||
@@ -48,16 +65,16 @@ export const RunInputDialog = ({
|
||||
styling={{ maxWidth: "600px", minWidth: "600px" }}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-6 p-1">
|
||||
<div className="space-y-6 p-1" data-id="run-input-dialog-content">
|
||||
{/* Credentials Section */}
|
||||
{hasCredentials() && (
|
||||
<div>
|
||||
<div data-id="run-input-credentials-section">
|
||||
<div className="mb-4">
|
||||
<Text variant="h4" className="text-gray-900">
|
||||
Credentials
|
||||
</Text>
|
||||
</div>
|
||||
<div className="px-2">
|
||||
<div className="px-2" data-id="run-input-credentials-form">
|
||||
<FormRenderer
|
||||
jsonSchema={credentialsSchema as RJSFSchema}
|
||||
handleChange={(v) => handleCredentialChange(v.formData)}
|
||||
@@ -75,13 +92,13 @@ export const RunInputDialog = ({
|
||||
|
||||
{/* Inputs Section */}
|
||||
{hasInputs() && (
|
||||
<div>
|
||||
<div data-id="run-input-inputs-section">
|
||||
<div className="mb-4">
|
||||
<Text variant="h4" className="text-gray-900">
|
||||
Inputs
|
||||
</Text>
|
||||
</div>
|
||||
<div className="px-2">
|
||||
<div data-id="run-input-inputs-form">
|
||||
<FormRenderer
|
||||
jsonSchema={inputSchema as RJSFSchema}
|
||||
handleChange={(v) => handleInputChange(v.formData)}
|
||||
@@ -97,7 +114,10 @@ export const RunInputDialog = ({
|
||||
)}
|
||||
|
||||
{/* Action Button */}
|
||||
<div className="flex justify-end pt-2">
|
||||
<div
|
||||
className="flex justify-end pt-2"
|
||||
data-id="run-input-actions-section"
|
||||
>
|
||||
{purpose === "run" && (
|
||||
<Button
|
||||
variant="primary"
|
||||
@@ -105,6 +125,7 @@ export const RunInputDialog = ({
|
||||
className="group h-fit min-w-0 gap-2"
|
||||
onClick={handleManualRun}
|
||||
loading={isExecutingGraph}
|
||||
data-id="run-input-manual-run-button"
|
||||
>
|
||||
{!isExecutingGraph && (
|
||||
<PlayIcon className="size-5 transition-transform group-hover:scale-110" />
|
||||
@@ -118,6 +139,7 @@ export const RunInputDialog = ({
|
||||
size="large"
|
||||
className="group h-fit min-w-0 gap-2"
|
||||
onClick={() => setOpenCronSchedulerDialog(true)}
|
||||
data-id="run-input-schedule-button"
|
||||
>
|
||||
<ClockIcon className="size-5 transition-transform group-hover:scale-110" />
|
||||
<span className="font-semibold">Schedule Run</span>
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import { usePostV1ExecuteGraphAgent } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
|
||||
import {
|
||||
ApiError,
|
||||
CredentialsMetaInput,
|
||||
GraphExecutionMeta,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
@@ -9,6 +10,9 @@ import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
|
||||
import { useMemo, useState } from "react";
|
||||
import { uiSchema } from "../../../FlowEditor/nodes/uiSchema";
|
||||
import { isCredentialFieldSchema } from "@/components/renderers/InputRenderer/custom/CredentialField/helpers";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useReactFlow } from "@xyflow/react";
|
||||
|
||||
export const useRunInputDialog = ({
|
||||
setIsOpen,
|
||||
@@ -31,6 +35,7 @@ export const useRunInputDialog = ({
|
||||
flowVersion: parseAsInteger,
|
||||
});
|
||||
const { toast } = useToast();
|
||||
const { setViewport } = useReactFlow();
|
||||
|
||||
const { mutateAsync: executeGraph, isPending: isExecutingGraph } =
|
||||
usePostV1ExecuteGraphAgent({
|
||||
@@ -42,13 +47,63 @@ export const useRunInputDialog = ({
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
// Reset running state on error
|
||||
if (error instanceof ApiError && error.isGraphValidationError()) {
|
||||
const errorData = error.response?.detail;
|
||||
Object.entries(errorData.node_errors).forEach(
|
||||
([nodeId, nodeErrors]) => {
|
||||
useNodeStore
|
||||
.getState()
|
||||
.updateNodeErrors(
|
||||
nodeId,
|
||||
nodeErrors as { [key: string]: string },
|
||||
);
|
||||
},
|
||||
);
|
||||
toast({
|
||||
title: errorData?.message || "Graph validation failed",
|
||||
description:
|
||||
"Please fix the validation errors on the highlighted nodes and try again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
setIsOpen(false);
|
||||
|
||||
const firstBackendId = Object.keys(errorData.node_errors)[0];
|
||||
|
||||
if (firstBackendId) {
|
||||
const firstErrorNode = useNodeStore
|
||||
.getState()
|
||||
.nodes.find(
|
||||
(n) =>
|
||||
n.data.metadata?.backend_id === firstBackendId ||
|
||||
n.id === firstBackendId,
|
||||
);
|
||||
|
||||
if (firstErrorNode) {
|
||||
setTimeout(() => {
|
||||
setViewport(
|
||||
{
|
||||
x:
|
||||
-firstErrorNode.position.x * 0.8 +
|
||||
window.innerWidth / 2 -
|
||||
150,
|
||||
y: -firstErrorNode.position.y * 0.8 + 50,
|
||||
zoom: 0.8,
|
||||
},
|
||||
{ duration: 500 },
|
||||
);
|
||||
}, 50);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
toast({
|
||||
title: "Error running graph",
|
||||
description:
|
||||
(error as Error).message || "An unexpected error occurred.",
|
||||
variant: "destructive",
|
||||
});
|
||||
setIsOpen(false);
|
||||
}
|
||||
setIsGraphRunning(false);
|
||||
toast({
|
||||
title: (error.detail as string) ?? "An unexpected error occurred.",
|
||||
description: "An unexpected error occurred.",
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -26,6 +26,7 @@ export const ScheduleGraph = ({ flowID }: { flowID: string | null }) => {
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
data-id="schedule-graph-button"
|
||||
onClick={handleScheduleGraph}
|
||||
disabled={!flowID}
|
||||
>
|
||||
|
||||
@@ -6,12 +6,17 @@ import {
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import {
|
||||
ChalkboardIcon,
|
||||
CircleNotchIcon,
|
||||
FrameCornersIcon,
|
||||
MinusIcon,
|
||||
PlusIcon,
|
||||
} from "@phosphor-icons/react/dist/ssr";
|
||||
import { LockIcon, LockOpenIcon } from "lucide-react";
|
||||
import { memo } from "react";
|
||||
import { memo, useEffect, useState } from "react";
|
||||
import { useSearchParams, useRouter } from "next/navigation";
|
||||
import { useTutorialStore } from "@/app/(platform)/build/stores/tutorialStore";
|
||||
import { startTutorial, setTutorialLoadingCallback } from "../../tutorial";
|
||||
|
||||
export const CustomControls = memo(
|
||||
({
|
||||
@@ -22,27 +27,65 @@ export const CustomControls = memo(
|
||||
setIsLocked: (isLocked: boolean) => void;
|
||||
}) => {
|
||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||
const { isTutorialRunning, setIsTutorialRunning } = useTutorialStore();
|
||||
const [isTutorialLoading, setIsTutorialLoading] = useState(false);
|
||||
const searchParams = useSearchParams();
|
||||
const router = useRouter();
|
||||
|
||||
useEffect(() => {
|
||||
setTutorialLoadingCallback(setIsTutorialLoading);
|
||||
return () => setTutorialLoadingCallback(() => {});
|
||||
}, []);
|
||||
|
||||
const handleTutorialClick = () => {
|
||||
if (isTutorialLoading) return;
|
||||
|
||||
const flowId = searchParams.get("flowID");
|
||||
if (flowId) {
|
||||
router.push("/build?view=new");
|
||||
return;
|
||||
}
|
||||
|
||||
startTutorial();
|
||||
setIsTutorialRunning(true);
|
||||
};
|
||||
|
||||
const controls = [
|
||||
{
|
||||
id: "zoom-in-button",
|
||||
icon: <PlusIcon className="size-4" />,
|
||||
label: "Zoom In",
|
||||
onClick: () => zoomIn(),
|
||||
className: "h-10 w-10 border-none",
|
||||
},
|
||||
{
|
||||
id: "zoom-out-button",
|
||||
icon: <MinusIcon className="size-4" />,
|
||||
label: "Zoom Out",
|
||||
onClick: () => zoomOut(),
|
||||
className: "h-10 w-10 border-none",
|
||||
},
|
||||
{
|
||||
id: "tutorial-button",
|
||||
icon: isTutorialLoading ? (
|
||||
<CircleNotchIcon className="size-4 animate-spin" />
|
||||
) : (
|
||||
<ChalkboardIcon className="size-4" />
|
||||
),
|
||||
label: isTutorialLoading ? "Loading Tutorial..." : "Start Tutorial",
|
||||
onClick: handleTutorialClick,
|
||||
className: `h-10 w-10 border-none ${isTutorialRunning || isTutorialLoading ? "bg-zinc-100" : "bg-white"}`,
|
||||
disabled: isTutorialLoading,
|
||||
},
|
||||
{
|
||||
id: "fit-view-button",
|
||||
icon: <FrameCornersIcon className="size-4" />,
|
||||
label: "Fit View",
|
||||
onClick: () => fitView({ padding: 0.2, duration: 800, maxZoom: 1 }),
|
||||
className: "h-10 w-10 border-none",
|
||||
},
|
||||
{
|
||||
id: "lock-button",
|
||||
icon: !isLocked ? (
|
||||
<LockOpenIcon className="size-4" />
|
||||
) : (
|
||||
@@ -55,15 +98,20 @@ export const CustomControls = memo(
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="absolute bottom-4 left-4 z-10 flex flex-col items-center gap-2 rounded-full bg-white px-1 py-2 shadow-lg">
|
||||
{controls.map((control, index) => (
|
||||
<Tooltip key={index} delayDuration={300}>
|
||||
<div
|
||||
data-id="custom-controls"
|
||||
className="absolute bottom-4 left-4 z-10 flex flex-col items-center gap-2 rounded-full bg-white px-1 py-2 shadow-lg"
|
||||
>
|
||||
{controls.map((control) => (
|
||||
<Tooltip key={control.id} delayDuration={0}>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="icon"
|
||||
size={"small"}
|
||||
onClick={control.onClick}
|
||||
className={control.className}
|
||||
data-id={control.id}
|
||||
disabled={"disabled" in control ? control.disabled : false}
|
||||
>
|
||||
{control.icon}
|
||||
<span className="sr-only">{control.label}</span>
|
||||
|
||||
@@ -3,6 +3,7 @@ import { useGetV2GetSpecificBlocks } from "@/app/api/__generated__/endpoints/def
|
||||
import {
|
||||
useGetV1GetExecutionDetails,
|
||||
useGetV1GetSpecificGraph,
|
||||
useGetV1ListUserGraphs,
|
||||
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||
@@ -17,6 +18,7 @@ import { useReactFlow } from "@xyflow/react";
|
||||
import { useControlPanelStore } from "../../../stores/controlPanelStore";
|
||||
import { useHistoryStore } from "../../../stores/historyStore";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
|
||||
export const useFlow = () => {
|
||||
const [isLocked, setIsLocked] = useState(false);
|
||||
@@ -36,6 +38,9 @@ export const useFlow = () => {
|
||||
const setGraphExecutionStatus = useGraphStore(
|
||||
useShallow((state) => state.setGraphExecutionStatus),
|
||||
);
|
||||
const setAvailableSubGraphs = useGraphStore(
|
||||
useShallow((state) => state.setAvailableSubGraphs),
|
||||
);
|
||||
const updateEdgeBeads = useEdgeStore(
|
||||
useShallow((state) => state.updateEdgeBeads),
|
||||
);
|
||||
@@ -62,6 +67,11 @@ export const useFlow = () => {
|
||||
},
|
||||
);
|
||||
|
||||
// Fetch all available graphs for sub-agent update detection
|
||||
const { data: availableGraphs } = useGetV1ListUserGraphs({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const { data: graph, isLoading: isGraphLoading } = useGetV1GetSpecificGraph(
|
||||
flowID ?? "",
|
||||
flowVersion !== null ? { version: flowVersion } : {},
|
||||
@@ -116,10 +126,18 @@ export const useFlow = () => {
|
||||
}
|
||||
}, [graph]);
|
||||
|
||||
// Update available sub-graphs in store for sub-agent update detection
|
||||
useEffect(() => {
|
||||
if (availableGraphs) {
|
||||
setAvailableSubGraphs(availableGraphs);
|
||||
}
|
||||
}, [availableGraphs, setAvailableSubGraphs]);
|
||||
|
||||
// adding nodes
|
||||
useEffect(() => {
|
||||
if (customNodes.length > 0) {
|
||||
useNodeStore.getState().setNodes([]);
|
||||
useNodeStore.getState().clearResolutionState();
|
||||
addNodes(customNodes);
|
||||
|
||||
// Sync hardcoded values with handle IDs.
|
||||
@@ -203,6 +221,7 @@ export const useFlow = () => {
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
useNodeStore.getState().setNodes([]);
|
||||
useNodeStore.getState().clearResolutionState();
|
||||
useEdgeStore.getState().setEdges([]);
|
||||
useGraphStore.getState().reset();
|
||||
useEdgeStore.getState().resetEdgeBeads();
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
getBezierPath,
|
||||
} from "@xyflow/react";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { XIcon } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { NodeExecutionResult } from "@/lib/autogpt-server-api";
|
||||
@@ -35,6 +36,8 @@ const CustomEdge = ({
|
||||
selected,
|
||||
}: EdgeProps<CustomEdge>) => {
|
||||
const removeConnection = useEdgeStore((state) => state.removeEdge);
|
||||
// Subscribe to the brokenEdgeIDs map and check if this edge is broken across any node
|
||||
const isBroken = useNodeStore((state) => state.isEdgeBroken(id));
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
@@ -50,6 +53,12 @@ const CustomEdge = ({
|
||||
const beadUp = data?.beadUp ?? 0;
|
||||
const beadDown = data?.beadDown ?? 0;
|
||||
|
||||
const handleRemoveEdge = () => {
|
||||
removeConnection(id);
|
||||
// Note: broken edge tracking is cleaned up automatically by useSubAgentUpdateState
|
||||
// when it detects the edge no longer exists
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<BaseEdge
|
||||
@@ -57,9 +66,11 @@ const CustomEdge = ({
|
||||
markerEnd={markerEnd}
|
||||
className={cn(
|
||||
isStatic && "!stroke-[1.5px] [stroke-dasharray:6]",
|
||||
selected
|
||||
? "stroke-zinc-800"
|
||||
: "stroke-zinc-500/50 hover:stroke-zinc-500",
|
||||
isBroken
|
||||
? "!stroke-red-500 !stroke-[2px] [stroke-dasharray:4]"
|
||||
: selected
|
||||
? "stroke-zinc-800"
|
||||
: "stroke-zinc-500/50 hover:stroke-zinc-500",
|
||||
)}
|
||||
/>
|
||||
<JSBeads
|
||||
@@ -70,12 +81,16 @@ const CustomEdge = ({
|
||||
/>
|
||||
<EdgeLabelRenderer>
|
||||
<Button
|
||||
onClick={() => removeConnection(id)}
|
||||
onClick={handleRemoveEdge}
|
||||
className={cn(
|
||||
"absolute h-fit min-w-0 p-1 transition-opacity",
|
||||
isHovered ? "opacity-100" : "opacity-0",
|
||||
isBroken
|
||||
? "bg-red-500 opacity-100 hover:bg-red-600"
|
||||
: isHovered
|
||||
? "opacity-100"
|
||||
: "opacity-0",
|
||||
)}
|
||||
variant="secondary"
|
||||
variant={isBroken ? "primary" : "secondary"}
|
||||
style={{
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
||||
pointerEvents: "all",
|
||||
|
||||
@@ -3,6 +3,7 @@ import { Handle, Position } from "@xyflow/react";
|
||||
import { useEdgeStore } from "../../../stores/edgeStore";
|
||||
import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
|
||||
const InputNodeHandle = ({
|
||||
handleId,
|
||||
@@ -15,6 +16,9 @@ const InputNodeHandle = ({
|
||||
const isInputConnected = useEdgeStore((state) =>
|
||||
state.isInputConnected(nodeId ?? "", cleanedHandleId),
|
||||
);
|
||||
const isInputBroken = useNodeStore((state) =>
|
||||
state.isInputBroken(nodeId, cleanedHandleId),
|
||||
);
|
||||
|
||||
return (
|
||||
<Handle
|
||||
@@ -22,12 +26,16 @@ const InputNodeHandle = ({
|
||||
position={Position.Left}
|
||||
id={cleanedHandleId}
|
||||
className={"-ml-6 mr-2"}
|
||||
data-tutorial-id={`input-handler-${nodeId}-${cleanedHandleId}`}
|
||||
>
|
||||
<div className="pointer-events-none">
|
||||
<CircleIcon
|
||||
size={16}
|
||||
weight={isInputConnected ? "fill" : "duotone"}
|
||||
className={"text-gray-400 opacity-100"}
|
||||
className={cn(
|
||||
"text-gray-400 opacity-100",
|
||||
isInputBroken && "text-red-500",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</Handle>
|
||||
@@ -38,27 +46,34 @@ const OutputNodeHandle = ({
|
||||
field_name,
|
||||
nodeId,
|
||||
hexColor,
|
||||
isBroken,
|
||||
}: {
|
||||
field_name: string;
|
||||
nodeId: string;
|
||||
hexColor: string;
|
||||
isBroken: boolean;
|
||||
}) => {
|
||||
const isOutputConnected = useEdgeStore((state) =>
|
||||
state.isOutputConnected(nodeId, field_name),
|
||||
);
|
||||
|
||||
return (
|
||||
<Handle
|
||||
type={"source"}
|
||||
position={Position.Right}
|
||||
id={field_name}
|
||||
className={"-mr-2 ml-2"}
|
||||
data-tutorial-id={`output-handler-${nodeId}-${field_name}`}
|
||||
>
|
||||
<div className="pointer-events-none">
|
||||
<CircleIcon
|
||||
size={16}
|
||||
weight={"duotone"}
|
||||
color={isOutputConnected ? hexColor : "gray"}
|
||||
className={cn("text-gray-400 opacity-100")}
|
||||
className={cn(
|
||||
"text-gray-400 opacity-100",
|
||||
isBroken && "text-red-500",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</Handle>
|
||||
|
||||
@@ -20,6 +20,8 @@ import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput";
|
||||
import { NodeRightClickMenu } from "./components/NodeRightClickMenu";
|
||||
import { StickyNoteBlock } from "./components/StickyNoteBlock";
|
||||
import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
|
||||
import { SubAgentUpdateFeature } from "./components/SubAgentUpdate/SubAgentUpdateFeature";
|
||||
import { useCustomNode } from "./useCustomNode";
|
||||
|
||||
export type CustomNodeData = {
|
||||
hardcodedValues: {
|
||||
@@ -45,6 +47,10 @@ export type CustomNode = XYNode<CustomNodeData, "custom">;
|
||||
|
||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
({ data, id: nodeId, selected }) => {
|
||||
const { inputSchema, outputSchema } = useCustomNode({ data, nodeId });
|
||||
|
||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||
|
||||
if (data.uiType === BlockUIType.NOTE) {
|
||||
return (
|
||||
<StickyNoteBlock data={data} selected={selected} nodeId={nodeId} />
|
||||
@@ -63,16 +69,6 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
|
||||
const isAyrshare = data.uiType === BlockUIType.AYRSHARE;
|
||||
|
||||
const inputSchema =
|
||||
data.uiType === BlockUIType.AGENT
|
||||
? (data.hardcodedValues.input_schema ?? {})
|
||||
: data.inputSchema;
|
||||
|
||||
const outputSchema =
|
||||
data.uiType === BlockUIType.AGENT
|
||||
? (data.hardcodedValues.output_schema ?? {})
|
||||
: data.outputSchema;
|
||||
|
||||
const hasConfigErrors =
|
||||
data.errors &&
|
||||
Object.values(data.errors).some(
|
||||
@@ -87,12 +83,11 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
|
||||
const hasErrors = hasConfigErrors || hasOutputError;
|
||||
|
||||
// Currently all blockTypes design are similar - that's why i am using the same component for all of them
|
||||
// If in future - if we need some drastic change in some blockTypes design - we can create separate components for them
|
||||
const node = (
|
||||
<NodeContainer selected={selected} nodeId={nodeId} hasErrors={hasErrors}>
|
||||
<div className="rounded-xlarge bg-white">
|
||||
<NodeHeader data={data} nodeId={nodeId} />
|
||||
{isAgent && <SubAgentUpdateFeature nodeID={nodeId} nodeData={data} />}
|
||||
{isWebhook && <WebhookDisclaimer nodeId={nodeId} />}
|
||||
{isAyrshare && <AyrshareConnectButton />}
|
||||
<FormCreator
|
||||
|
||||
@@ -27,6 +27,7 @@ export const NodeContainer = ({
|
||||
status && nodeStyleBasedOnStatus[status],
|
||||
hasErrors ? nodeStyleBasedOnStatus[AgentExecutionStatus.FAILED] : "",
|
||||
)}
|
||||
data-id={`custom-node-${nodeId}`}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
|
||||
@@ -20,11 +20,13 @@ type Props = {
|
||||
|
||||
export const NodeHeader = ({ data, nodeId }: Props) => {
|
||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||
const title = (data.metadata?.customized_name as string) || data.title;
|
||||
const title =
|
||||
(data.metadata?.customized_name as string) ||
|
||||
data.hardcodedValues.agent_name ||
|
||||
data.title;
|
||||
|
||||
const [isEditingTitle, setIsEditingTitle] = useState(false);
|
||||
const [editedTitle, setEditedTitle] = useState(
|
||||
beautifyString(title).replace("Block", "").trim(),
|
||||
);
|
||||
const [editedTitle, setEditedTitle] = useState(title);
|
||||
|
||||
const handleTitleEdit = () => {
|
||||
updateNodeData(nodeId, {
|
||||
|
||||
@@ -23,7 +23,10 @@ export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-3 rounded-b-xl border-t border-zinc-200 px-4 py-4">
|
||||
<div
|
||||
data-tutorial-id={`node-output`}
|
||||
className="flex flex-col gap-3 rounded-b-xl border-t border-zinc-200 px-4 py-4"
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<Text variant="body-medium" className="!font-semibold text-slate-700">
|
||||
Node Output
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
import React from "react";
|
||||
import { ArrowUpIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { cn, beautifyString } from "@/lib/utils";
|
||||
import { CustomNodeData } from "../../CustomNode";
|
||||
import { useSubAgentUpdateState } from "./useSubAgentUpdateState";
|
||||
import { IncompatibleUpdateDialog } from "./components/IncompatibleUpdateDialog";
|
||||
import { ResolutionModeBar } from "./components/ResolutionModeBar";
|
||||
|
||||
/**
|
||||
* Inline component for the update bar that can be placed after the header.
|
||||
* Use this inside the node content where you want the bar to appear.
|
||||
*/
|
||||
type SubAgentUpdateFeatureProps = {
|
||||
nodeID: string;
|
||||
nodeData: CustomNodeData;
|
||||
};
|
||||
|
||||
export function SubAgentUpdateFeature({
|
||||
nodeID,
|
||||
nodeData,
|
||||
}: SubAgentUpdateFeatureProps) {
|
||||
const {
|
||||
updateInfo,
|
||||
isInResolutionMode,
|
||||
handleUpdateClick,
|
||||
showIncompatibilityDialog,
|
||||
setShowIncompatibilityDialog,
|
||||
handleConfirmIncompatibleUpdate,
|
||||
} = useSubAgentUpdateState({ nodeID: nodeID, nodeData: nodeData });
|
||||
|
||||
const agentName = nodeData.title || "Agent";
|
||||
|
||||
if (!updateInfo.hasUpdate && !isInResolutionMode) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{isInResolutionMode ? (
|
||||
<ResolutionModeBar incompatibilities={updateInfo.incompatibilities} />
|
||||
) : (
|
||||
<SubAgentUpdateAvailableBar
|
||||
currentVersion={updateInfo.currentVersion}
|
||||
latestVersion={updateInfo.latestVersion}
|
||||
isCompatible={updateInfo.isCompatible}
|
||||
onUpdate={handleUpdateClick}
|
||||
/>
|
||||
)}
|
||||
{/* Incompatibility dialog - rendered here since this component owns the state */}
|
||||
{updateInfo.incompatibilities && (
|
||||
<IncompatibleUpdateDialog
|
||||
isOpen={showIncompatibilityDialog}
|
||||
onClose={() => setShowIncompatibilityDialog(false)}
|
||||
onConfirm={handleConfirmIncompatibleUpdate}
|
||||
currentVersion={updateInfo.currentVersion}
|
||||
latestVersion={updateInfo.latestVersion}
|
||||
agentName={beautifyString(agentName)}
|
||||
incompatibilities={updateInfo.incompatibilities}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
type SubAgentUpdateAvailableBarProps = {
|
||||
currentVersion: number;
|
||||
latestVersion: number;
|
||||
isCompatible: boolean;
|
||||
onUpdate: () => void;
|
||||
};
|
||||
|
||||
function SubAgentUpdateAvailableBar({
|
||||
currentVersion,
|
||||
latestVersion,
|
||||
isCompatible,
|
||||
onUpdate,
|
||||
}: SubAgentUpdateAvailableBarProps): React.ReactElement {
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 rounded-t-xl bg-blue-50 px-3 py-2 dark:bg-blue-900/30">
|
||||
<div className="flex items-center gap-2">
|
||||
<ArrowUpIcon className="h-4 w-4 text-blue-600 dark:text-blue-400" />
|
||||
<span className="text-sm text-blue-700 dark:text-blue-300">
|
||||
Update available (v{currentVersion} → v{latestVersion})
|
||||
</span>
|
||||
{!isCompatible && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<WarningIcon className="h-4 w-4 text-amber-500" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="max-w-xs">
|
||||
<p className="font-medium">Incompatible changes detected</p>
|
||||
<p className="text-xs text-gray-400">
|
||||
Click Update to see details
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
size="small"
|
||||
variant={isCompatible ? "primary" : "outline"}
|
||||
onClick={onUpdate}
|
||||
className={cn(
|
||||
"h-7 text-xs",
|
||||
!isCompatible && "border-amber-500 text-amber-600 hover:bg-amber-50",
|
||||
)}
|
||||
>
|
||||
Update
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
import React from "react";
|
||||
import {
|
||||
WarningIcon,
|
||||
XCircleIcon,
|
||||
PlusCircleIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import { IncompatibilityInfo } from "@/app/(platform)/build/hooks/useSubAgentUpdate/types";
|
||||
|
||||
type IncompatibleUpdateDialogProps = {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
onConfirm: () => void;
|
||||
currentVersion: number;
|
||||
latestVersion: number;
|
||||
agentName: string;
|
||||
incompatibilities: IncompatibilityInfo;
|
||||
};
|
||||
|
||||
export function IncompatibleUpdateDialog({
|
||||
isOpen,
|
||||
onClose,
|
||||
onConfirm,
|
||||
currentVersion,
|
||||
latestVersion,
|
||||
agentName,
|
||||
incompatibilities,
|
||||
}: IncompatibleUpdateDialogProps) {
|
||||
const hasMissingInputs = incompatibilities.missingInputs.length > 0;
|
||||
const hasMissingOutputs = incompatibilities.missingOutputs.length > 0;
|
||||
const hasNewInputs = incompatibilities.newInputs.length > 0;
|
||||
const hasNewOutputs = incompatibilities.newOutputs.length > 0;
|
||||
const hasNewRequired = incompatibilities.newRequiredInputs.length > 0;
|
||||
const hasTypeMismatches = incompatibilities.inputTypeMismatches.length > 0;
|
||||
|
||||
const hasInputChanges = hasMissingInputs || hasNewInputs;
|
||||
const hasOutputChanges = hasMissingOutputs || hasNewOutputs;
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title={
|
||||
<div className="flex items-center gap-2">
|
||||
<WarningIcon className="h-5 w-5 text-amber-500" weight="fill" />
|
||||
Incompatible Update
|
||||
</div>
|
||||
}
|
||||
controlled={{
|
||||
isOpen,
|
||||
set: async (open) => {
|
||||
if (!open) onClose();
|
||||
},
|
||||
}}
|
||||
onClose={onClose}
|
||||
styling={{ maxWidth: "32rem" }}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-4">
|
||||
<p className="text-sm text-gray-600 dark:text-gray-400">
|
||||
Updating <strong>{beautifyString(agentName)}</strong> from v
|
||||
{currentVersion} to v{latestVersion} will break some connections.
|
||||
</p>
|
||||
|
||||
{/* Input changes - two column layout */}
|
||||
{hasInputChanges && (
|
||||
<TwoColumnSection
|
||||
title="Input Changes"
|
||||
leftIcon={
|
||||
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
||||
}
|
||||
leftTitle="Removed"
|
||||
leftItems={incompatibilities.missingInputs}
|
||||
rightIcon={
|
||||
<PlusCircleIcon
|
||||
className="h-4 w-4 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
}
|
||||
rightTitle="Added"
|
||||
rightItems={incompatibilities.newInputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Output changes - two column layout */}
|
||||
{hasOutputChanges && (
|
||||
<TwoColumnSection
|
||||
title="Output Changes"
|
||||
leftIcon={
|
||||
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
||||
}
|
||||
leftTitle="Removed"
|
||||
leftItems={incompatibilities.missingOutputs}
|
||||
rightIcon={
|
||||
<PlusCircleIcon
|
||||
className="h-4 w-4 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
}
|
||||
rightTitle="Added"
|
||||
rightItems={incompatibilities.newOutputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasTypeMismatches && (
|
||||
<SingleColumnSection
|
||||
icon={
|
||||
<XCircleIcon className="h-4 w-4 text-red-500" weight="fill" />
|
||||
}
|
||||
title="Type Changed"
|
||||
description="These connected inputs have a different type:"
|
||||
items={incompatibilities.inputTypeMismatches.map(
|
||||
(m) => `${m.name} (${m.oldType} → ${m.newType})`,
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasNewRequired && (
|
||||
<SingleColumnSection
|
||||
icon={
|
||||
<PlusCircleIcon
|
||||
className="h-4 w-4 text-amber-500"
|
||||
weight="fill"
|
||||
/>
|
||||
}
|
||||
title="New Required Inputs"
|
||||
description="These inputs are now required:"
|
||||
items={incompatibilities.newRequiredInputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Alert variant="warning">
|
||||
<AlertDescription>
|
||||
If you proceed, you'll need to remove the broken connections
|
||||
before you can save or run your agent.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button variant="ghost" size="small" onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onConfirm}
|
||||
className="border-amber-700 bg-amber-600 hover:bg-amber-700"
|
||||
>
|
||||
Update Anyway
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
type TwoColumnSectionProps = {
|
||||
title: string;
|
||||
leftIcon: React.ReactNode;
|
||||
leftTitle: string;
|
||||
leftItems: string[];
|
||||
rightIcon: React.ReactNode;
|
||||
rightTitle: string;
|
||||
rightItems: string[];
|
||||
};
|
||||
|
||||
function TwoColumnSection({
|
||||
title,
|
||||
leftIcon,
|
||||
leftTitle,
|
||||
leftItems,
|
||||
rightIcon,
|
||||
rightTitle,
|
||||
rightItems,
|
||||
}: TwoColumnSectionProps) {
|
||||
return (
|
||||
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||
<span className="font-medium">{title}</span>
|
||||
<div className="mt-2 grid grid-cols-2 items-start gap-4">
|
||||
{/* Left column - Breaking changes */}
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||
{leftIcon}
|
||||
<span>{leftTitle}</span>
|
||||
</div>
|
||||
<ul className="mt-1.5 space-y-1">
|
||||
{leftItems.length > 0 ? (
|
||||
leftItems.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-red-50 px-1 py-0.5 font-mono text-xs text-red-700 dark:bg-red-900/30 dark:text-red-300">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))
|
||||
) : (
|
||||
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||
None
|
||||
</li>
|
||||
)}
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
{/* Right column - Possible solutions */}
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||
{rightIcon}
|
||||
<span>{rightTitle}</span>
|
||||
</div>
|
||||
<ul className="mt-1.5 space-y-1">
|
||||
{rightItems.length > 0 ? (
|
||||
rightItems.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-green-50 px-1 py-0.5 font-mono text-xs text-green-700 dark:bg-green-900/30 dark:text-green-300">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))
|
||||
) : (
|
||||
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||
None
|
||||
</li>
|
||||
)}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type SingleColumnSectionProps = {
|
||||
icon: React.ReactNode;
|
||||
title: string;
|
||||
description: string;
|
||||
items: string[];
|
||||
};
|
||||
|
||||
function SingleColumnSection({
|
||||
icon,
|
||||
title,
|
||||
description,
|
||||
items,
|
||||
}: SingleColumnSectionProps) {
|
||||
return (
|
||||
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||
<div className="flex items-center gap-2">
|
||||
{icon}
|
||||
<span className="font-medium">{title}</span>
|
||||
</div>
|
||||
<p className="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||
{description}
|
||||
</p>
|
||||
<ul className="mt-2 space-y-1">
|
||||
{items.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="ml-4 list-disc text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-gray-100 px-1 py-0.5 font-mono text-xs dark:bg-gray-800">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
import React from "react";
|
||||
import { InfoIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { IncompatibilityInfo } from "@/app/(platform)/build/hooks/useSubAgentUpdate/types";
|
||||
|
||||
type ResolutionModeBarProps = {
|
||||
incompatibilities: IncompatibilityInfo | null;
|
||||
};
|
||||
|
||||
export function ResolutionModeBar({
|
||||
incompatibilities,
|
||||
}: ResolutionModeBarProps): React.ReactElement {
|
||||
const renderIncompatibilities = () => {
|
||||
if (!incompatibilities) return <span>No incompatibilities</span>;
|
||||
|
||||
const sections: React.ReactNode[] = [];
|
||||
|
||||
if (incompatibilities.missingInputs.length > 0) {
|
||||
sections.push(
|
||||
<div key="missing-inputs" className="mb-1">
|
||||
<span className="font-semibold">Missing inputs: </span>
|
||||
{incompatibilities.missingInputs.map((name, i) => (
|
||||
<React.Fragment key={name}>
|
||||
<code className="font-mono">{name}</code>
|
||||
{i < incompatibilities.missingInputs.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.missingOutputs.length > 0) {
|
||||
sections.push(
|
||||
<div key="missing-outputs" className="mb-1">
|
||||
<span className="font-semibold">Missing outputs: </span>
|
||||
{incompatibilities.missingOutputs.map((name, i) => (
|
||||
<React.Fragment key={name}>
|
||||
<code className="font-mono">{name}</code>
|
||||
{i < incompatibilities.missingOutputs.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.newRequiredInputs.length > 0) {
|
||||
sections.push(
|
||||
<div key="new-required" className="mb-1">
|
||||
<span className="font-semibold">New required inputs: </span>
|
||||
{incompatibilities.newRequiredInputs.map((name, i) => (
|
||||
<React.Fragment key={name}>
|
||||
<code className="font-mono">{name}</code>
|
||||
{i < incompatibilities.newRequiredInputs.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.inputTypeMismatches.length > 0) {
|
||||
sections.push(
|
||||
<div key="type-mismatches" className="mb-1">
|
||||
<span className="font-semibold">Type changed: </span>
|
||||
{incompatibilities.inputTypeMismatches.map((m, i) => (
|
||||
<React.Fragment key={m.name}>
|
||||
<code className="font-mono">{m.name}</code>
|
||||
<span className="text-gray-400">
|
||||
{" "}
|
||||
({m.oldType} → {m.newType})
|
||||
</span>
|
||||
{i < incompatibilities.inputTypeMismatches.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
|
||||
return <>{sections}</>;
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 rounded-t-xl bg-amber-50 px-3 py-2 dark:bg-amber-900/30">
|
||||
<div className="flex items-center gap-2">
|
||||
<WarningIcon className="h-4 w-4 text-amber-600 dark:text-amber-400" />
|
||||
<span className="text-sm text-amber-700 dark:text-amber-300">
|
||||
Remove incompatible connections
|
||||
</span>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<InfoIcon className="h-4 w-4 cursor-help text-amber-500" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="max-w-sm">
|
||||
<p className="mb-2 font-semibold">Incompatible changes:</p>
|
||||
<div className="text-xs">{renderIncompatibilities()}</div>
|
||||
<p className="mt-2 text-xs text-gray-400">
|
||||
{(incompatibilities?.newRequiredInputs.length ?? 0) > 0
|
||||
? "Replace / delete"
|
||||
: "Delete"}{" "}
|
||||
the red connections to continue
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,194 @@
|
||||
import { useState, useCallback, useEffect } from "react";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import {
|
||||
useNodeStore,
|
||||
NodeResolutionData,
|
||||
} from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import {
|
||||
useSubAgentUpdate,
|
||||
createUpdatedAgentNodeInputs,
|
||||
getBrokenEdgeIDs,
|
||||
} from "@/app/(platform)/build/hooks/useSubAgentUpdate";
|
||||
import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api";
|
||||
import { CustomNodeData } from "../../CustomNode";
|
||||
|
||||
// Stable empty set to avoid creating new references in selectors
|
||||
const EMPTY_SET: Set<string> = new Set();
|
||||
|
||||
type UseSubAgentUpdateParams = {
|
||||
nodeID: string;
|
||||
nodeData: CustomNodeData;
|
||||
};
|
||||
|
||||
export function useSubAgentUpdateState({
|
||||
nodeID,
|
||||
nodeData,
|
||||
}: UseSubAgentUpdateParams) {
|
||||
const [showIncompatibilityDialog, setShowIncompatibilityDialog] =
|
||||
useState(false);
|
||||
|
||||
// Get store actions
|
||||
const updateNodeData = useNodeStore(
|
||||
useShallow((state) => state.updateNodeData),
|
||||
);
|
||||
const setNodeResolutionMode = useNodeStore(
|
||||
useShallow((state) => state.setNodeResolutionMode),
|
||||
);
|
||||
const isNodeInResolutionMode = useNodeStore(
|
||||
useShallow((state) => state.isNodeInResolutionMode),
|
||||
);
|
||||
const setBrokenEdgeIDs = useNodeStore(
|
||||
useShallow((state) => state.setBrokenEdgeIDs),
|
||||
);
|
||||
// Get this node's broken edge IDs from the per-node map
|
||||
// Use EMPTY_SET as fallback to maintain referential stability
|
||||
const brokenEdgeIDs = useNodeStore(
|
||||
(state) => state.brokenEdgeIDs.get(nodeID) || EMPTY_SET,
|
||||
);
|
||||
const getNodeResolutionData = useNodeStore(
|
||||
useShallow((state) => state.getNodeResolutionData),
|
||||
);
|
||||
const connectedEdges = useEdgeStore(
|
||||
useShallow((state) => state.getNodeEdges(nodeID)),
|
||||
);
|
||||
const availableSubGraphs = useGraphStore(
|
||||
useShallow((state) => state.availableSubGraphs),
|
||||
);
|
||||
|
||||
// Extract agent-specific data
|
||||
const graphID = nodeData.hardcodedValues?.graph_id as string | undefined;
|
||||
const graphVersion = nodeData.hardcodedValues?.graph_version as
|
||||
| number
|
||||
| undefined;
|
||||
const currentInputSchema = nodeData.hardcodedValues?.input_schema as
|
||||
| GraphInputSchema
|
||||
| undefined;
|
||||
const currentOutputSchema = nodeData.hardcodedValues?.output_schema as
|
||||
| GraphOutputSchema
|
||||
| undefined;
|
||||
|
||||
// Use the sub-agent update hook
|
||||
const updateInfo = useSubAgentUpdate(
|
||||
nodeID,
|
||||
graphID,
|
||||
graphVersion,
|
||||
currentInputSchema,
|
||||
currentOutputSchema,
|
||||
connectedEdges,
|
||||
availableSubGraphs,
|
||||
);
|
||||
|
||||
const isInResolutionMode = isNodeInResolutionMode(nodeID);
|
||||
|
||||
// Handle update button click
|
||||
const handleUpdateClick = useCallback(() => {
|
||||
if (!updateInfo.hasUpdate || !updateInfo.latestGraph) return;
|
||||
|
||||
if (updateInfo.isCompatible) {
|
||||
// Compatible update - apply directly
|
||||
const newHardcodedValues = createUpdatedAgentNodeInputs(
|
||||
nodeData.hardcodedValues,
|
||||
updateInfo.latestGraph,
|
||||
);
|
||||
updateNodeData(nodeID, { hardcodedValues: newHardcodedValues });
|
||||
} else {
|
||||
// Incompatible update - show dialog
|
||||
setShowIncompatibilityDialog(true);
|
||||
}
|
||||
}, [
|
||||
updateInfo.hasUpdate,
|
||||
updateInfo.latestGraph,
|
||||
updateInfo.isCompatible,
|
||||
nodeData.hardcodedValues,
|
||||
updateNodeData,
|
||||
nodeID,
|
||||
]);
|
||||
|
||||
// Handle confirming an incompatible update
|
||||
function handleConfirmIncompatibleUpdate() {
|
||||
if (!updateInfo.latestGraph || !updateInfo.incompatibilities) return;
|
||||
|
||||
const latestGraph = updateInfo.latestGraph;
|
||||
|
||||
// Get the new schemas from the latest graph version
|
||||
const newInputSchema =
|
||||
(latestGraph.input_schema as Record<string, unknown>) || {};
|
||||
const newOutputSchema =
|
||||
(latestGraph.output_schema as Record<string, unknown>) || {};
|
||||
|
||||
// Create the updated hardcoded values but DON'T apply them yet
|
||||
// We'll apply them when resolution is complete
|
||||
const pendingHardcodedValues = createUpdatedAgentNodeInputs(
|
||||
nodeData.hardcodedValues,
|
||||
latestGraph,
|
||||
);
|
||||
|
||||
// Get broken edge IDs and store them for this node
|
||||
const brokenIds = getBrokenEdgeIDs(
|
||||
connectedEdges,
|
||||
updateInfo.incompatibilities,
|
||||
nodeID,
|
||||
);
|
||||
setBrokenEdgeIDs(nodeID, brokenIds);
|
||||
|
||||
// Enter resolution mode with both old and new schemas
|
||||
// DON'T apply the update yet - keep old schema so connections remain visible
|
||||
const resolutionData: NodeResolutionData = {
|
||||
incompatibilities: updateInfo.incompatibilities,
|
||||
pendingUpdate: {
|
||||
input_schema: newInputSchema,
|
||||
output_schema: newOutputSchema,
|
||||
},
|
||||
currentSchema: {
|
||||
input_schema: (currentInputSchema as Record<string, unknown>) || {},
|
||||
output_schema: (currentOutputSchema as Record<string, unknown>) || {},
|
||||
},
|
||||
pendingHardcodedValues,
|
||||
};
|
||||
setNodeResolutionMode(nodeID, true, resolutionData);
|
||||
|
||||
setShowIncompatibilityDialog(false);
|
||||
}
|
||||
|
||||
// Check if resolution is complete (all broken edges removed)
|
||||
const resolutionData = getNodeResolutionData(nodeID);
|
||||
|
||||
// Auto-check resolution on edge changes
|
||||
useEffect(() => {
|
||||
if (!isInResolutionMode) return;
|
||||
|
||||
// Check if any broken edges still exist
|
||||
const remainingBroken = Array.from(brokenEdgeIDs).filter((edgeId) =>
|
||||
connectedEdges.some((e) => e.id === edgeId),
|
||||
);
|
||||
|
||||
if (remainingBroken.length === 0) {
|
||||
// Resolution complete - now apply the pending update
|
||||
if (resolutionData?.pendingHardcodedValues) {
|
||||
updateNodeData(nodeID, {
|
||||
hardcodedValues: resolutionData.pendingHardcodedValues,
|
||||
});
|
||||
}
|
||||
// setNodeResolutionMode will clean up this node's broken edges automatically
|
||||
setNodeResolutionMode(nodeID, false);
|
||||
}
|
||||
}, [
|
||||
isInResolutionMode,
|
||||
brokenEdgeIDs,
|
||||
connectedEdges,
|
||||
resolutionData,
|
||||
nodeID,
|
||||
]);
|
||||
|
||||
return {
|
||||
updateInfo,
|
||||
isInResolutionMode,
|
||||
resolutionData,
|
||||
showIncompatibilityDialog,
|
||||
setShowIncompatibilityDialog,
|
||||
handleUpdateClick,
|
||||
handleConfirmIncompatibleUpdate,
|
||||
};
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { NodeResolutionData } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
|
||||
export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
||||
INCOMPLETE: "ring-slate-300 bg-slate-300",
|
||||
@@ -9,3 +11,48 @@ export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
||||
TERMINATED: "ring-orange-300 bg-orange-300 ",
|
||||
FAILED: "ring-red-300 bg-red-300",
|
||||
};
|
||||
|
||||
/**
|
||||
* Merges schemas during resolution mode to include removed inputs/outputs
|
||||
* that still have connections, so users can see and delete them.
|
||||
*/
|
||||
export function mergeSchemaForResolution(
|
||||
currentSchema: Record<string, unknown>,
|
||||
newSchema: Record<string, unknown>,
|
||||
resolutionData: NodeResolutionData,
|
||||
type: "input" | "output",
|
||||
): Record<string, unknown> {
|
||||
const newProps = (newSchema.properties as RJSFSchema) || {};
|
||||
const currentProps = (currentSchema.properties as RJSFSchema) || {};
|
||||
const mergedProps = { ...newProps };
|
||||
const incomp = resolutionData.incompatibilities;
|
||||
|
||||
if (type === "input") {
|
||||
// Add back missing inputs that have connections
|
||||
incomp.missingInputs.forEach((inputName: string) => {
|
||||
if (currentProps[inputName]) {
|
||||
mergedProps[inputName] = currentProps[inputName];
|
||||
}
|
||||
});
|
||||
// Add back inputs with type mismatches (keep old type so connection works visually)
|
||||
incomp.inputTypeMismatches.forEach(
|
||||
(mismatch: { name: string; oldType: string; newType: string }) => {
|
||||
if (currentProps[mismatch.name]) {
|
||||
mergedProps[mismatch.name] = currentProps[mismatch.name];
|
||||
}
|
||||
},
|
||||
);
|
||||
} else {
|
||||
// Add back missing outputs that have connections
|
||||
incomp.missingOutputs.forEach((outputName: string) => {
|
||||
if (currentProps[outputName]) {
|
||||
mergedProps[outputName] = currentProps[outputName];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
...newSchema,
|
||||
properties: mergedProps,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { CustomNodeData } from "./CustomNode";
|
||||
import { BlockUIType } from "../../../types";
|
||||
import { useMemo } from "react";
|
||||
import { mergeSchemaForResolution } from "./helpers";
|
||||
|
||||
export const useCustomNode = ({
|
||||
data,
|
||||
nodeId,
|
||||
}: {
|
||||
data: CustomNodeData;
|
||||
nodeId: string;
|
||||
}) => {
|
||||
const isInResolutionMode = useNodeStore((state) =>
|
||||
state.nodesInResolutionMode.has(nodeId),
|
||||
);
|
||||
const resolutionData = useNodeStore((state) =>
|
||||
state.nodeResolutionData.get(nodeId),
|
||||
);
|
||||
|
||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||
|
||||
const currentInputSchema = isAgent
|
||||
? (data.hardcodedValues.input_schema ?? {})
|
||||
: data.inputSchema;
|
||||
const currentOutputSchema = isAgent
|
||||
? (data.hardcodedValues.output_schema ?? {})
|
||||
: data.outputSchema;
|
||||
|
||||
const inputSchema = useMemo(() => {
|
||||
if (isAgent && isInResolutionMode && resolutionData) {
|
||||
return mergeSchemaForResolution(
|
||||
resolutionData.currentSchema.input_schema,
|
||||
resolutionData.pendingUpdate.input_schema,
|
||||
resolutionData,
|
||||
"input",
|
||||
);
|
||||
}
|
||||
return currentInputSchema;
|
||||
}, [isAgent, isInResolutionMode, resolutionData, currentInputSchema]);
|
||||
|
||||
const outputSchema = useMemo(() => {
|
||||
if (isAgent && isInResolutionMode && resolutionData) {
|
||||
return mergeSchemaForResolution(
|
||||
resolutionData.currentSchema.output_schema,
|
||||
resolutionData.pendingUpdate.output_schema,
|
||||
resolutionData,
|
||||
"output",
|
||||
);
|
||||
}
|
||||
return currentOutputSchema;
|
||||
}, [isAgent, isInResolutionMode, resolutionData, currentOutputSchema]);
|
||||
|
||||
return {
|
||||
inputSchema,
|
||||
outputSchema,
|
||||
};
|
||||
};
|
||||
@@ -5,20 +5,16 @@ import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { BlockUIType } from "../../types";
|
||||
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
|
||||
|
||||
export const FormCreator = React.memo(
|
||||
({
|
||||
jsonSchema,
|
||||
nodeId,
|
||||
uiType,
|
||||
showHandles = true,
|
||||
className,
|
||||
}: {
|
||||
jsonSchema: RJSFSchema;
|
||||
nodeId: string;
|
||||
uiType: BlockUIType;
|
||||
showHandles?: boolean;
|
||||
className?: string;
|
||||
}) => {
|
||||
interface FormCreatorProps {
|
||||
jsonSchema: RJSFSchema;
|
||||
nodeId: string;
|
||||
uiType: BlockUIType;
|
||||
showHandles?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
|
||||
({ jsonSchema, nodeId, uiType, showHandles = true, className }) => {
|
||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||
|
||||
const getHardCodedValues = useNodeStore(
|
||||
@@ -48,7 +44,10 @@ export const FormCreator = React.memo(
|
||||
: hardcodedValues;
|
||||
|
||||
return (
|
||||
<div className={className}>
|
||||
<div
|
||||
className={className}
|
||||
data-id={`form-creator-container-${nodeId}-node`}
|
||||
>
|
||||
<FormRenderer
|
||||
jsonSchema={jsonSchema}
|
||||
handleChange={handleChange}
|
||||
|
||||
@@ -14,6 +14,8 @@ import {
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { getTypeDisplayInfo } from "./helpers";
|
||||
import { BlockUIType } from "../../types";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useBrokenOutputs } from "./useBrokenOutputs";
|
||||
|
||||
export const OutputHandler = ({
|
||||
outputSchema,
|
||||
@@ -27,6 +29,7 @@ export const OutputHandler = ({
|
||||
const { isOutputConnected } = useEdgeStore();
|
||||
const properties = outputSchema?.properties || {};
|
||||
const [isOutputVisible, setIsOutputVisible] = useState(true);
|
||||
const brokenOutputs = useBrokenOutputs(nodeId);
|
||||
|
||||
const showHandles = uiType !== BlockUIType.OUTPUT;
|
||||
|
||||
@@ -44,9 +47,14 @@ export const OutputHandler = ({
|
||||
const shouldShow = isConnected || isOutputVisible;
|
||||
const { displayType, colorClass, hexColor } =
|
||||
getTypeDisplayInfo(fieldSchema);
|
||||
const isBroken = brokenOutputs.has(fullKey);
|
||||
|
||||
return shouldShow ? (
|
||||
<div key={fullKey} className="flex flex-col items-end gap-2">
|
||||
<div
|
||||
key={fullKey}
|
||||
className="flex flex-col items-end gap-2"
|
||||
data-tutorial-id={`output-handler-${nodeId}-${fieldTitle}`}
|
||||
>
|
||||
<div className="relative flex items-center gap-2">
|
||||
{fieldSchema?.description && (
|
||||
<TooltipProvider>
|
||||
@@ -64,15 +72,29 @@ export const OutputHandler = ({
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
<Text variant="body" className="text-slate-700">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"text-slate-700",
|
||||
isBroken && "text-red-500 line-through",
|
||||
)}
|
||||
>
|
||||
{fieldTitle}
|
||||
</Text>
|
||||
<Text variant="small" as="span" className={colorClass}>
|
||||
<Text
|
||||
variant="small"
|
||||
as="span"
|
||||
className={cn(
|
||||
colorClass,
|
||||
isBroken && "!text-red-500 line-through",
|
||||
)}
|
||||
>
|
||||
({displayType})
|
||||
</Text>
|
||||
|
||||
{showHandles && (
|
||||
<OutputNodeHandle
|
||||
isBroken={isBroken}
|
||||
field_name={fullKey}
|
||||
nodeId={nodeId}
|
||||
hexColor={hexColor}
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import { useMemo } from "react";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
|
||||
/**
|
||||
* Hook to get the set of broken output names for a node in resolution mode.
|
||||
*/
|
||||
export function useBrokenOutputs(nodeID: string): Set<string> {
|
||||
// Subscribe to the actual state values, not just methods
|
||||
const isInResolution = useNodeStore((state) =>
|
||||
state.nodesInResolutionMode.has(nodeID),
|
||||
);
|
||||
const resolutionData = useNodeStore((state) =>
|
||||
state.nodeResolutionData.get(nodeID),
|
||||
);
|
||||
|
||||
return useMemo(() => {
|
||||
if (!isInResolution || !resolutionData) {
|
||||
return new Set<string>();
|
||||
}
|
||||
|
||||
return new Set(resolutionData.incompatibilities.missingOutputs);
|
||||
}, [isInResolution, resolutionData]);
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
// Block IDs for tutorial blocks
|
||||
export const BLOCK_IDS = {
|
||||
CALCULATOR: "b1ab9b19-67a6-406d-abf5-2dba76d00c79",
|
||||
AGENT_INPUT: "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
AGENT_OUTPUT: "363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
} as const;
|
||||
|
||||
export const TUTORIAL_SELECTORS = {
|
||||
// Custom nodes - These are all before saving
|
||||
INPUT_NODE: '[data-id="custom-node-2"]',
|
||||
OUTPUT_NODE: '[data-id="custom-node-3 "]',
|
||||
CALCULATOR_NODE: '[data-id="custom-node-1"]',
|
||||
|
||||
// Paricular field selector
|
||||
NAME_FIELD_OUTPUT_NODE: '[data-id="field-3-root_name"]',
|
||||
|
||||
// Output Handlers
|
||||
SECOND_CALCULATOR_RESULT_OUTPUT_HANDLER:
|
||||
'[data-tutorial-id="output-handler-2-result"]',
|
||||
FIRST_CALCULATOR_RESULT_OUTPUT_HANDLER:
|
||||
'[data-tutorial-id="output-handler-1-result"]',
|
||||
|
||||
// Input Handler
|
||||
SECOND_CALCULATOR_NUMBER_A_INPUT_HANDLER:
|
||||
'[data-tutorial-id="input-handler-2-a"]',
|
||||
OUTPUT_VALUE_INPUT_HANDLEER: '[data-tutorial-id="label-3-root_value"]',
|
||||
|
||||
// Block Menu
|
||||
BLOCKS_TRIGGER: '[data-id="blocks-control-popover-trigger"]',
|
||||
BLOCKS_CONTENT: '[data-id="blocks-control-popover-content"]',
|
||||
BLOCKS_SEARCH_INPUT:
|
||||
'[data-id="blocks-control-search-bar"] input[type="text"]',
|
||||
BLOCKS_SEARCH_INPUT_BOX: '[data-id="blocks-control-search-bar"]',
|
||||
|
||||
// Add a new selector that checks within search results
|
||||
|
||||
// Block Menu Sidebar
|
||||
MENU_ITEM_INPUT_BLOCKS: '[data-id="menu-item-input_blocks"]',
|
||||
MENU_ITEM_ALL_BLOCKS: '[data-id="menu-item-all_blocks"]',
|
||||
MENU_ITEM_ACTION_BLOCKS: '[data-id="menu-item-action_blocks"]',
|
||||
MENU_ITEM_OUTPUT_BLOCKS: '[data-id="menu-item-output_blocks"]',
|
||||
MENU_ITEM_INTEGRATIONS: '[data-id="menu-item-integrations"]',
|
||||
MENU_ITEM_MY_AGENTS: '[data-id="menu-item-my_agents"]',
|
||||
MENU_ITEM_MARKETPLACE: '[data-id="menu-item-marketplace_agents"]',
|
||||
MENU_ITEM_SUGGESTION: '[data-id="menu-item-suggestion"]',
|
||||
|
||||
// Block Cards
|
||||
BLOCK_CARD_PREFIX: '[data-id^="block-card-"]',
|
||||
BLOCK_CARD_AGENT_INPUT: '[data-id="block-card-AgentInputBlock"]',
|
||||
// Calculator block - legacy ID used in old tutorial
|
||||
BLOCK_CARD_CALCULATOR:
|
||||
'[data-id="block-card-b1ab9b1967a6406dabf52dba76d00c79"]',
|
||||
BLOCK_CARD_CALCULATOR_IN_SEARCH:
|
||||
'[data-id="blocks-control-search-results"] [data-id="block-card-b1ab9b1967a6406dabf52dba76d00c79"]',
|
||||
|
||||
// Save Control
|
||||
SAVE_TRIGGER: '[data-id="save-control-popover-trigger"]',
|
||||
SAVE_CONTENT: '[data-id="save-control-popover-content"]',
|
||||
SAVE_AGENT_BUTTON: '[data-id="save-control-save-agent"]',
|
||||
SAVE_NAME_INPUT: '[data-id="save-control-name-input"]',
|
||||
SAVE_DESCRIPTION_INPUT: '[data-id="save-control-description-input"]',
|
||||
|
||||
// Builder Actions (Run, Schedule, Outputs)
|
||||
BUILDER_ACTIONS: '[data-id="builder-actions"]',
|
||||
RUN_BUTTON: '[data-id="run-graph-button"]',
|
||||
STOP_BUTTON: '[data-id="stop-graph-button"]',
|
||||
SCHEDULE_BUTTON: '[data-id="schedule-graph-button"]',
|
||||
AGENT_OUTPUTS_BUTTON: '[data-id="agent-outputs-button"]',
|
||||
|
||||
// Run Input Dialog
|
||||
RUN_INPUT_DIALOG_CONTENT: '[data-id="run-input-dialog-content"]',
|
||||
RUN_INPUT_CREDENTIALS_SECTION: '[data-id="run-input-credentials-section"]',
|
||||
RUN_INPUT_CREDENTIALS_FORM: '[data-id="run-input-credentials-form"]',
|
||||
RUN_INPUT_INPUTS_SECTION: '[data-id="run-input-inputs-section"]',
|
||||
RUN_INPUT_INPUTS_FORM: '[data-id="run-input-inputs-form"]',
|
||||
RUN_INPUT_ACTIONS_SECTION: '[data-id="run-input-actions-section"]',
|
||||
RUN_INPUT_MANUAL_RUN_BUTTON: '[data-id="run-input-manual-run-button"]',
|
||||
RUN_INPUT_SCHEDULE_BUTTON: '[data-id="run-input-schedule-button"]',
|
||||
|
||||
// Custom Controls (bottom left)
|
||||
CUSTOM_CONTROLS: '[data-id="custom-controls"]',
|
||||
ZOOM_IN_BUTTON: '[data-id="zoom-in-button"]',
|
||||
ZOOM_OUT_BUTTON: '[data-id="zoom-out-button"]',
|
||||
FIT_VIEW_BUTTON: '[data-id="fit-view-button"]',
|
||||
LOCK_BUTTON: '[data-id="lock-button"]',
|
||||
TUTORIAL_BUTTON: '[data-id="tutorial-button"]',
|
||||
|
||||
// Canvas
|
||||
REACT_FLOW_CANVAS: ".react-flow__pane",
|
||||
REACT_FLOW_NODE: ".react-flow__node",
|
||||
REACT_FLOW_NODE_FIRST: '[data-testid^="rf__node-"]:first-child',
|
||||
REACT_FLOW_EDGE: '[data-testid^="rf__edge-"]',
|
||||
|
||||
// Node elements
|
||||
NODE_CONTAINER: '[data-id^="custom-node-"]',
|
||||
NODE_HEADER: '[data-id^="node-header-"]',
|
||||
NODE_INPUT_HANDLES: '[data-tutorial-id="input-handles"]',
|
||||
NODE_OUTPUT_HANDLE: '[data-handlepos="right"]',
|
||||
NODE_INPUT_HANDLE: "[data-nodeid]",
|
||||
FIRST_CALCULATOR_NODE_OUTPUT: '[data-tutorial-id="node-output"]',
|
||||
// These are the Id's of the nodes before saving
|
||||
CALCULATOR_NODE_FORM_CONTAINER: '[data-id^="form-creator-container-1-node"]', // <-- Add this line
|
||||
AGENT_INPUT_NODE_FORM_CONTAINER: '[data-id^="form-creator-container-2-node"]', // <-- Add this line
|
||||
AGENT_OUTPUT_NODE_FORM_CONTAINER:
|
||||
'[data-id^="form-creator-container-3-node"]', // <-- Add this line
|
||||
|
||||
// Execution badges
|
||||
BADGE_QUEUED: '[data-id^="badge-"][data-id$="-QUEUED"]',
|
||||
BADGE_COMPLETED: '[data-id^="badge-"][data-id$="-COMPLETED"]',
|
||||
|
||||
// Undo/Redo
|
||||
UNDO_BUTTON: '[data-id="undo-button"]',
|
||||
REDO_BUTTON: '[data-id="redo-button"]',
|
||||
} as const;
|
||||
|
||||
export const CSS_CLASSES = {
|
||||
DISABLE: "new-builder-tutorial-disable",
|
||||
HIGHLIGHT: "new-builder-tutorial-highlight",
|
||||
PULSE: "new-builder-tutorial-pulse",
|
||||
} as const;
|
||||
|
||||
export const TUTORIAL_CONFIG = {
|
||||
ELEMENT_CHECK_INTERVAL: 50, // ms
|
||||
INPUT_CHECK_INTERVAL: 100, // ms
|
||||
USE_MODAL_OVERLAY: true,
|
||||
SCROLL_BEHAVIOR: "smooth" as const,
|
||||
SCROLL_BLOCK: "center" as const,
|
||||
SEARCH_TERM_CALCULATOR: "Calculator",
|
||||
} as const;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user