mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 08:38:09 -05:00
Compare commits
24 Commits
hackathon/
...
native-aut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87e3d7eaad | ||
|
|
974c14a7b9 | ||
|
|
af014ea19d | ||
|
|
9ecf8bcb08 | ||
|
|
a7a521cedd | ||
|
|
84244c0b56 | ||
|
|
9e83985b5b | ||
|
|
4ef3eab89d | ||
|
|
c68b53b6c1 | ||
|
|
23fb3ad8a4 | ||
|
|
175ba13ebe | ||
|
|
a415f471c6 | ||
|
|
3dd6e5cb04 | ||
|
|
3f1e66b317 | ||
|
|
8f722bd9cd | ||
|
|
65026fc9d3 | ||
|
|
af98bc1081 | ||
|
|
e92459fc5f | ||
|
|
1775286f59 | ||
|
|
f6af700f1a | ||
|
|
a80b06d459 | ||
|
|
17c9e7c8b4 | ||
|
|
f83c9391c8 | ||
|
|
7a0a90e421 |
8
.github/copilot-instructions.md
vendored
8
.github/copilot-instructions.md
vendored
@@ -142,7 +142,7 @@ pnpm storybook # Start component development server
|
||||
### Security & Middleware
|
||||
|
||||
**Cache Protection**: Backend includes middleware preventing sensitive data caching in browsers/proxies
|
||||
**Authentication**: JWT-based with Supabase integration
|
||||
**Authentication**: JWT-based with native authentication
|
||||
**User ID Validation**: All data access requires user ID checks - verify this for any `data/*.py` changes
|
||||
|
||||
### Development Workflow
|
||||
@@ -168,9 +168,9 @@ pnpm storybook # Start component development server
|
||||
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
- `frontend/src/lib/auth/` - Authentication client
|
||||
|
||||
**Protected Routes**: Update `frontend/lib/supabase/middleware.ts` when adding protected routes
|
||||
**Protected Routes**: Update `frontend/middleware.ts` when adding protected routes
|
||||
|
||||
### Agent Block System
|
||||
|
||||
@@ -194,7 +194,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
|
||||
6
.github/workflows/claude-dependabot.yml
vendored
6
.github/workflows/claude-dependabot.yml
vendored
@@ -144,11 +144,7 @@ jobs:
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
"pgvector/pgvector:pg18"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
|
||||
6
.github/workflows/claude.yml
vendored
6
.github/workflows/claude.yml
vendored
@@ -160,11 +160,7 @@ jobs:
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
"pgvector/pgvector:pg18"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
|
||||
6
.github/workflows/copilot-setup-steps.yml
vendored
6
.github/workflows/copilot-setup-steps.yml
vendored
@@ -142,11 +142,7 @@ jobs:
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
"pgvector/pgvector:pg18"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
|
||||
44
.github/workflows/platform-backend-ci.yml
vendored
44
.github/workflows/platform-backend-ci.yml
vendored
@@ -2,13 +2,13 @@ name: AutoGPT Platform - Backend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev, ci-test*]
|
||||
branches: [master, dev, ci-test*, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
pull_request:
|
||||
branches: [master, dev, release-*]
|
||||
branches: [master, dev, release-*, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- "autogpt_platform/backend/**"
|
||||
@@ -36,6 +36,19 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg18
|
||||
ports:
|
||||
- 5432:5432
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: your-super-secret-and-long-postgres-password
|
||||
POSTGRES_DB: postgres
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U postgres"
|
||||
--health-interval 5s
|
||||
--health-timeout 5s
|
||||
--health-retries 10
|
||||
redis:
|
||||
image: redis:latest
|
||||
ports:
|
||||
@@ -78,11 +91,6 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup Supabase
|
||||
uses: supabase/setup-cli@v1
|
||||
with:
|
||||
version: 1.178.1
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
@@ -136,16 +144,6 @@ jobs:
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
working-directory: .
|
||||
run: |
|
||||
supabase init
|
||||
supabase start --exclude postgres-meta,realtime,storage-api,imgproxy,inbucket,studio,edge-runtime,logflare,vector,supavisor
|
||||
supabase status -o env | sed 's/="/=/; s/"$//' >> $GITHUB_OUTPUT
|
||||
# outputs:
|
||||
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
|
||||
|
||||
- name: Wait for ClamAV to be ready
|
||||
run: |
|
||||
echo "Waiting for ClamAV daemon to start..."
|
||||
@@ -178,8 +176,8 @@ jobs:
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DATABASE_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
|
||||
DIRECT_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
@@ -195,11 +193,9 @@ jobs:
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
DATABASE_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
|
||||
DIRECT_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
|
||||
JWT_SECRET: your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
5
.github/workflows/platform-frontend-ci.yml
vendored
5
.github/workflows/platform-frontend-ci.yml
vendored
@@ -2,11 +2,12 @@ name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
branches: [master, dev, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
pull_request:
|
||||
branches: [master, dev, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
@@ -147,7 +148,7 @@ jobs:
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
- name: Copy default platform .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
|
||||
56
.github/workflows/platform-fullstack-ci.yml
vendored
56
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -1,12 +1,13 @@
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
name: AutoGPT Platform - Fullstack CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
branches: [master, dev, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
branches: [master, dev, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
@@ -58,14 +59,11 @@ jobs:
|
||||
types:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
timeout-minutes: 10
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
@@ -75,18 +73,6 @@ jobs:
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
@@ -101,36 +87,12 @@ jobs:
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api:force
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
run: pnpm generate:api
|
||||
|
||||
- name: Run Typescript checks
|
||||
run: pnpm types
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
@@ -49,5 +49,5 @@ Use conventional commit messages for all commits (e.g. `feat(backend): add API`)
|
||||
- Keep out-of-scope changes under 20% of the PR.
|
||||
- Ensure PR descriptions are complete.
|
||||
- For changes touching `data/*.py`, validate user ID checks or explain why not needed.
|
||||
- If adding protected frontend routes, update `frontend/lib/supabase/middleware.ts`.
|
||||
- If adding protected frontend routes, update `frontend/lib/auth/helpers.ts`.
|
||||
- Use the linear ticket branch structure if given codex/open-1668-resume-dropped-runs
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(ls:*)",
|
||||
"WebFetch(domain:langfuse.com)",
|
||||
"Bash(poetry install:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -5,12 +5,6 @@
|
||||
|
||||
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
|
||||
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
DASHBOARD_USERNAME=supabase
|
||||
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
|
||||
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
|
||||
VAULT_ENC_KEY=your-encryption-key-32-chars-min
|
||||
|
||||
|
||||
############
|
||||
@@ -24,100 +18,31 @@ POSTGRES_PORT=5432
|
||||
|
||||
|
||||
############
|
||||
# Supavisor -- Database pooler
|
||||
############
|
||||
POOLER_PROXY_PORT_TRANSACTION=6543
|
||||
POOLER_DEFAULT_POOL_SIZE=20
|
||||
POOLER_MAX_CLIENT_CONN=100
|
||||
POOLER_TENANT_ID=your-tenant-id
|
||||
|
||||
|
||||
############
|
||||
# API Proxy - Configuration for the Kong Reverse proxy.
|
||||
# Auth - Native authentication configuration
|
||||
############
|
||||
|
||||
KONG_HTTP_PORT=8000
|
||||
KONG_HTTPS_PORT=8443
|
||||
|
||||
|
||||
############
|
||||
# API - Configuration for PostgREST.
|
||||
############
|
||||
|
||||
PGRST_DB_SCHEMAS=public,storage,graphql_public
|
||||
|
||||
|
||||
############
|
||||
# Auth - Configuration for the GoTrue authentication server.
|
||||
############
|
||||
|
||||
## General
|
||||
SITE_URL=http://localhost:3000
|
||||
ADDITIONAL_REDIRECT_URLS=
|
||||
JWT_EXPIRY=3600
|
||||
DISABLE_SIGNUP=false
|
||||
API_EXTERNAL_URL=http://localhost:8000
|
||||
|
||||
## Mailer Config
|
||||
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
|
||||
MAILER_URLPATHS_INVITE="/auth/v1/verify"
|
||||
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
|
||||
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
|
||||
# JWT token configuration
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=15
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
JWT_ISSUER=autogpt-platform
|
||||
|
||||
## Email auth
|
||||
ENABLE_EMAIL_SIGNUP=true
|
||||
ENABLE_EMAIL_AUTOCONFIRM=false
|
||||
SMTP_ADMIN_EMAIL=admin@example.com
|
||||
SMTP_HOST=supabase-mail
|
||||
SMTP_PORT=2500
|
||||
SMTP_USER=fake_mail_user
|
||||
SMTP_PASS=fake_mail_password
|
||||
SMTP_SENDER_NAME=fake_sender
|
||||
ENABLE_ANONYMOUS_USERS=false
|
||||
|
||||
## Phone auth
|
||||
ENABLE_PHONE_SIGNUP=true
|
||||
ENABLE_PHONE_AUTOCONFIRM=true
|
||||
# Google OAuth (optional)
|
||||
GOOGLE_CLIENT_ID=
|
||||
GOOGLE_CLIENT_SECRET=
|
||||
|
||||
|
||||
############
|
||||
# Studio - Configuration for the Dashboard
|
||||
# Email configuration (optional)
|
||||
############
|
||||
|
||||
STUDIO_DEFAULT_ORGANIZATION=Default Organization
|
||||
STUDIO_DEFAULT_PROJECT=Default Project
|
||||
SMTP_HOST=
|
||||
SMTP_PORT=587
|
||||
SMTP_USER=
|
||||
SMTP_PASS=
|
||||
SMTP_FROM_EMAIL=noreply@example.com
|
||||
|
||||
STUDIO_PORT=3000
|
||||
# replace if you intend to use Studio outside of localhost
|
||||
SUPABASE_PUBLIC_URL=http://localhost:8000
|
||||
|
||||
# Enable webp support
|
||||
IMGPROXY_ENABLE_WEBP_DETECTION=true
|
||||
|
||||
# Add your OpenAI API key to enable SQL Editor Assistant
|
||||
OPENAI_API_KEY=
|
||||
|
||||
|
||||
############
|
||||
# Functions - Configuration for Functions
|
||||
############
|
||||
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
|
||||
FUNCTIONS_VERIFY_JWT=false
|
||||
|
||||
|
||||
############
|
||||
# Logs - Configuration for Logflare
|
||||
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
|
||||
############
|
||||
|
||||
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Change vector.toml sinks to reflect this change
|
||||
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Docker socket location - this value will differ depending on your OS
|
||||
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
|
||||
|
||||
# Google Cloud Project details
|
||||
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
|
||||
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend stop-backend run-frontend load-store-agents backfill-store-embeddings
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend load-store-agents
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
# Run just PostgreSQL + Redis + RabbitMQ + ClamAV
|
||||
start-core:
|
||||
docker compose up -d deps
|
||||
|
||||
@@ -34,14 +34,7 @@ migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
stop-backend:
|
||||
@echo "Stopping backend processes..."
|
||||
@cd backend && poetry run cli stop 2>/dev/null || true
|
||||
@echo "Killing any processes using backend ports..."
|
||||
@lsof -ti:8001,8002,8003,8004,8005,8006,8007 | xargs kill -9 2>/dev/null || true
|
||||
@echo "Backend stopped"
|
||||
|
||||
run-backend: stop-backend
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
run-frontend:
|
||||
@@ -53,21 +46,16 @@ test-data:
|
||||
load-store-agents:
|
||||
cd backend && poetry run load-store-agents
|
||||
|
||||
backfill-store-embeddings:
|
||||
cd backend && poetry run python -m backend.api.features.store.backfill_embeddings
|
||||
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
|
||||
@echo " start-core - Start just the core services (PostgreSQL, Redis, RabbitMQ, ClamAV) in background"
|
||||
@echo " stop-core - Stop the core services"
|
||||
@echo " reset-db - Reset the database by deleting the volume"
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
|
||||
@echo " migrate - Run backend database migrations"
|
||||
@echo " stop-backend - Stop any running backend processes"
|
||||
@echo " run-backend - Run the backend FastAPI server (stops existing processes first)"
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@echo " backfill-store-embeddings - Generate embeddings for store agents that don't have them"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@@ -16,17 +16,37 @@ ALGO_RECOMMENDATION = (
|
||||
"We highly recommend using an asymmetric algorithm such as ES256, "
|
||||
"because when leaked, a shared secret would allow anyone to "
|
||||
"forge valid tokens and impersonate users. "
|
||||
"More info: https://supabase.com/docs/guides/auth/signing-keys#choosing-the-right-signing-algorithm" # noqa
|
||||
"More info: https://pyjwt.readthedocs.io/en/stable/algorithms.html"
|
||||
)
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self):
|
||||
# JWT verification key (public key for asymmetric, shared secret for symmetric)
|
||||
self.JWT_VERIFY_KEY: str = os.getenv(
|
||||
"JWT_VERIFY_KEY", os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
).strip()
|
||||
|
||||
# JWT signing key (private key for asymmetric, shared secret for symmetric)
|
||||
# Falls back to JWT_VERIFY_KEY for symmetric algorithms like HS256
|
||||
self.JWT_SIGN_KEY: str = os.getenv("JWT_SIGN_KEY", self.JWT_VERIFY_KEY).strip()
|
||||
|
||||
self.JWT_ALGORITHM: str = os.getenv("JWT_SIGN_ALGORITHM", "HS256").strip()
|
||||
|
||||
# Token expiration settings
|
||||
self.ACCESS_TOKEN_EXPIRE_MINUTES: int = int(
|
||||
os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "15")
|
||||
)
|
||||
self.REFRESH_TOKEN_EXPIRE_DAYS: int = int(
|
||||
os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7")
|
||||
)
|
||||
|
||||
# JWT issuer claim
|
||||
self.JWT_ISSUER: str = os.getenv("JWT_ISSUER", "autogpt-platform").strip()
|
||||
|
||||
# JWT audience claim
|
||||
self.JWT_AUDIENCE: str = os.getenv("JWT_AUDIENCE", "authenticated").strip()
|
||||
|
||||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
|
||||
@@ -1,25 +1,29 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Patch a FastAPI instance's `openapi()` method to add 401 responses
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
# Wrap current method to allow stacking OpenAPI schema modifiers like this
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = wrapped_openapi()
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
@@ -16,6 +20,57 @@ bearer_jwt_auth = HTTPBearer(
|
||||
)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: str,
|
||||
email: str,
|
||||
role: str = "authenticated",
|
||||
email_verified: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a new JWT access token.
|
||||
|
||||
:param user_id: The user's unique identifier
|
||||
:param email: The user's email address
|
||||
:param role: The user's role (default: "authenticated")
|
||||
:param email_verified: Whether the user's email is verified
|
||||
:return: Encoded JWT token
|
||||
"""
|
||||
settings = get_settings()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"role": role,
|
||||
"email_verified": email_verified,
|
||||
"aud": settings.JWT_AUDIENCE,
|
||||
"iss": settings.JWT_ISSUER,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
"jti": str(uuid.uuid4()), # Unique token ID
|
||||
}
|
||||
|
||||
return jwt.encode(payload, settings.JWT_SIGN_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token() -> tuple[str, str]:
|
||||
"""
|
||||
Generate a new refresh token.
|
||||
|
||||
Returns a tuple of (raw_token, hashed_token).
|
||||
The raw token should be sent to the client.
|
||||
The hashed token should be stored in the database.
|
||||
"""
|
||||
raw_token = secrets.token_urlsafe(64)
|
||||
hashed_token = hashlib.sha256(raw_token.encode()).hexdigest()
|
||||
return raw_token, hashed_token
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash a token using SHA-256."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
async def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
@@ -52,11 +107,19 @@ def parse_jwt_token(token: str) -> dict[str, Any]:
|
||||
"""
|
||||
settings = get_settings()
|
||||
try:
|
||||
# Build decode options
|
||||
options = {
|
||||
"verify_aud": True,
|
||||
"verify_iss": bool(settings.JWT_ISSUER),
|
||||
}
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_VERIFY_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience="authenticated",
|
||||
audience=settings.JWT_AUDIENCE,
|
||||
issuer=settings.JWT_ISSUER if settings.JWT_ISSUER else None,
|
||||
options=options,
|
||||
)
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
|
||||
@@ -11,6 +11,7 @@ class User:
|
||||
email: str
|
||||
phone_number: str
|
||||
role: str
|
||||
email_verified: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload):
|
||||
@@ -18,5 +19,6 @@ class User:
|
||||
user_id=payload["sub"],
|
||||
email=payload.get("email", ""),
|
||||
phone_number=payload.get("phone", ""),
|
||||
role=payload["role"],
|
||||
role=payload.get("role", "authenticated"),
|
||||
email_verified=payload.get("email_verified", False),
|
||||
)
|
||||
|
||||
414
autogpt_platform/autogpt_libs/poetry.lock
generated
414
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -48,6 +48,21 @@ files = [
|
||||
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "authlib"
|
||||
version = "1.6.6"
|
||||
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd"},
|
||||
{file = "authlib-1.6.6.tar.gz", hash = "sha256:45770e8e056d0f283451d9996fbb59b70d45722b45d854d58f32878d0a40c38e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cryptography = "*"
|
||||
|
||||
[[package]]
|
||||
name = "backports-asyncio-runner"
|
||||
version = "1.2.0"
|
||||
@@ -61,6 +76,71 @@ files = [
|
||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bcrypt"
|
||||
version = "4.3.0"
|
||||
description = "Modern password hashing for your software and your servers"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-macosx_10_12_universal2.whl", hash = "sha256:f01e060f14b6b57bbb72fc5b4a83ac21c443c9a2ee708e04a10e9192f90a6281"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5eeac541cefd0bb887a371ef73c62c3cd78535e4887b310626036a7c0a817bb"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59e1aa0e2cd871b08ca146ed08445038f42ff75968c7ae50d2fdd7860ade2180"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:0042b2e342e9ae3d2ed22727c1262f76cc4f345683b5c1715f0250cf4277294f"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74a8d21a09f5e025a9a23e7c0fd2c7fe8e7503e4d356c0a2c1486ba010619f09"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:0142b2cb84a009f8452c8c5a33ace5e3dfec4159e7735f5afe9a4d50a8ea722d"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_34_aarch64.whl", hash = "sha256:12fa6ce40cde3f0b899729dbd7d5e8811cb892d31b6f7d0334a1f37748b789fd"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_34_x86_64.whl", hash = "sha256:5bd3cca1f2aa5dbcf39e2aa13dd094ea181f48959e1071265de49cc2b82525af"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:335a420cfd63fc5bc27308e929bee231c15c85cc4c496610ffb17923abf7f231"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:0e30e5e67aed0187a1764911af023043b4542e70a7461ad20e837e94d23e1d6c"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:3b8d62290ebefd49ee0b3ce7500f5dbdcf13b81402c05f6dafab9a1e1b27212f"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2ef6630e0ec01376f59a006dc72918b1bf436c3b571b80fa1968d775fa02fe7d"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-win32.whl", hash = "sha256:7a4be4cbf241afee43f1c3969b9103a41b40bcb3a3f467ab19f891d9bc4642e4"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5c1949bf259a388863ced887c7861da1df681cb2388645766c89fdfd9004c669"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-macosx_10_12_universal2.whl", hash = "sha256:f81b0ed2639568bf14749112298f9e4e2b28853dab50a8b357e31798686a036d"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:864f8f19adbe13b7de11ba15d85d4a428c7e2f344bac110f667676a0ff84924b"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e36506d001e93bffe59754397572f21bb5dc7c83f54454c990c74a468cd589e"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:842d08d75d9fe9fb94b18b071090220697f9f184d4547179b60734846461ed59"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7c03296b85cb87db865d91da79bf63d5609284fc0cab9472fdd8367bbd830753"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:62f26585e8b219cdc909b6a0069efc5e4267e25d4a3770a364ac58024f62a761"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:beeefe437218a65322fbd0069eb437e7c98137e08f22c4660ac2dc795c31f8bb"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:97eea7408db3a5bcce4a55d13245ab3fa566e23b4c67cd227062bb49e26c585d"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:191354ebfe305e84f344c5964c7cd5f924a3bfc5d405c75ad07f232b6dffb49f"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:41261d64150858eeb5ff43c753c4b216991e0ae16614a308a15d909503617732"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:33752b1ba962ee793fa2b6321404bf20011fe45b9afd2a842139de3011898fef"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:50e6e80a4bfd23a25f5c05b90167c19030cf9f87930f7cb2eacb99f45d1c3304"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-win32.whl", hash = "sha256:67a561c4d9fb9465ec866177e7aebcad08fe23aaf6fbd692a6fab69088abfc51"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-win_amd64.whl", hash = "sha256:584027857bc2843772114717a7490a37f68da563b3620f78a849bcb54dc11e62"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:0d3efb1157edebfd9128e4e46e2ac1a64e0c1fe46fb023158a407c7892b0f8c3"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08bacc884fd302b611226c01014eca277d48f0a05187666bca23aac0dad6fe24"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6746e6fec103fcd509b96bacdfdaa2fbde9a553245dbada284435173a6f1aef"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:afe327968aaf13fc143a56a3360cb27d4ad0345e34da12c7290f1b00b8fe9a8b"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d9af79d322e735b1fc33404b5765108ae0ff232d4b54666d46730f8ac1a43676"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f1e3ffa1365e8702dc48c8b360fef8d7afeca482809c5e45e653af82ccd088c1"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:3004df1b323d10021fda07a813fd33e0fd57bef0e9a480bb143877f6cba996fe"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:531457e5c839d8caea9b589a1bcfe3756b0547d7814e9ce3d437f17da75c32b0"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:17a854d9a7a476a89dcef6c8bd119ad23e0f82557afbd2c442777a16408e614f"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:6fb1fd3ab08c0cbc6826a2e0447610c6f09e983a281b919ed721ad32236b8b23"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e965a9c1e9a393b8005031ff52583cedc15b7884fce7deb8b0346388837d6cfe"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:79e70b8342a33b52b55d93b3a59223a844962bef479f6a0ea318ebbcadf71505"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-win32.whl", hash = "sha256:b4d4e57f0a63fd0b358eb765063ff661328f69a04494427265950c71b992a39a"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-win_amd64.whl", hash = "sha256:e53e074b120f2877a35cc6c736b8eb161377caae8925c17688bd46ba56daaa5b"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c950d682f0952bafcceaf709761da0a32a942272fad381081b51096ffa46cea1"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:107d53b5c67e0bbc3f03ebf5b030e0403d24dda980f8e244795335ba7b4a027d"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:b693dbb82b3c27a1604a3dff5bfc5418a7e6a781bb795288141e5f80cf3a3492"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:b6354d3760fcd31994a14c89659dee887f1351a06e5dac3c1142307172a79f90"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a839320bf27d474e52ef8cb16449bb2ce0ba03ca9f44daba6d93fa1d8828e48a"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:bdc6a24e754a555d7316fa4774e64c6c3997d27ed2d1964d55920c7c227bc4ce"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:55a935b8e9a1d2def0626c4269db3fcd26728cbff1e84f0341465c31c4ee56d8"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:57967b7a28d855313a963aaea51bf6df89f833db4320da458e5b3c5ab6d4c938"},
|
||||
{file = "bcrypt-4.3.0.tar.gz", hash = "sha256:3a3fd2204178b6d2adcf09cb4f6426ffef54762577a7c9b54c159008cb288c18"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["pytest (>=3.2.1,!=3.3.0)"]
|
||||
typecheck = ["mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "cachetools"
|
||||
version = "5.5.2"
|
||||
@@ -459,21 +539,6 @@ ssh = ["bcrypt (>=3.1.5)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==45.0.6)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test-randomorder = ["pytest-randomly"]
|
||||
|
||||
[[package]]
|
||||
name = "deprecation"
|
||||
version = "2.1.0"
|
||||
description = "A library to handle automated deprecations"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
|
||||
{file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = "*"
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.3.0"
|
||||
@@ -695,23 +760,6 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4
|
||||
[package.extras]
|
||||
grpc = ["grpcio (>=1.44.0,<2.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "gotrue"
|
||||
version = "2.12.3"
|
||||
description = "Python Client Library for Supabase Auth"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "gotrue-2.12.3-py3-none-any.whl", hash = "sha256:b1a3c6a5fe3f92e854a026c4c19de58706a96fd5fbdcc3d620b2802f6a46a26b"},
|
||||
{file = "gotrue-2.12.3.tar.gz", hash = "sha256:f874cf9d0b2f0335bfbd0d6e29e3f7aff79998cd1c14d2ad814db8c06cee3852"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
pydantic = ">=1.10,<3"
|
||||
pyjwt = ">=2.10.1,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "grpc-google-iam-v1"
|
||||
version = "0.14.2"
|
||||
@@ -822,94 +870,6 @@ files = [
|
||||
{file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "4.2.0"
|
||||
description = "Pure-Python HTTP/2 protocol implementation"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "h2-4.2.0-py3-none-any.whl", hash = "sha256:479a53ad425bb29af087f3458a61d30780bc818e4ebcf01f0b536ba916462ed0"},
|
||||
{file = "h2-4.2.0.tar.gz", hash = "sha256:c8a52129695e88b1a0578d8d2cc6842bbd79128ac685463b887ee278126ad01f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
hpack = ">=4.1,<5"
|
||||
hyperframe = ">=6.1,<7"
|
||||
|
||||
[[package]]
|
||||
name = "hpack"
|
||||
version = "4.1.0"
|
||||
description = "Pure-Python HPACK header encoding"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496"},
|
||||
{file = "hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpcore"
|
||||
version = "1.0.9"
|
||||
description = "A minimal low-level HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"},
|
||||
{file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = "*"
|
||||
h11 = ">=0.16"
|
||||
|
||||
[package.extras]
|
||||
asyncio = ["anyio (>=4.0,<5.0)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
trio = ["trio (>=0.22.0,<1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx"
|
||||
version = "0.28.1"
|
||||
description = "The next generation HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"},
|
||||
{file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = "*"
|
||||
certifi = "*"
|
||||
h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""}
|
||||
httpcore = "==1.*"
|
||||
idna = "*"
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
|
||||
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "hyperframe"
|
||||
version = "6.1.0"
|
||||
description = "Pure-Python HTTP/2 framing"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5"},
|
||||
{file = "hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.10"
|
||||
@@ -1036,7 +996,7 @@ version = "25.0"
|
||||
description = "Core utilities for Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main", "dev"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
|
||||
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
|
||||
@@ -1058,24 +1018,6 @@ files = [
|
||||
dev = ["pre-commit", "tox"]
|
||||
testing = ["coverage", "pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "postgrest"
|
||||
version = "1.1.1"
|
||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "postgrest-1.1.1-py3-none-any.whl", hash = "sha256:98a6035ee1d14288484bfe36235942c5fb2d26af6d8120dfe3efbe007859251a"},
|
||||
{file = "postgrest-1.1.1.tar.gz", hash = "sha256:f3bb3e8c4602775c75c844a31f565f5f3dd584df4d36d683f0b67d01a86be322"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2.1.0,<3.0.0"
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
pydantic = ">=1.9,<3.0"
|
||||
strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "proto-plus"
|
||||
version = "1.26.1"
|
||||
@@ -1462,21 +1404,6 @@ pytest = ">=6.2.5"
|
||||
[package.extras]
|
||||
dev = ["pre-commit", "pytest-asyncio", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
description = "Extensions to the standard Python datetime module"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
|
||||
{file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
six = ">=1.5"
|
||||
|
||||
[[package]]
|
||||
name = "python-dotenv"
|
||||
version = "1.1.1"
|
||||
@@ -1492,22 +1419,6 @@ files = [
|
||||
[package.extras]
|
||||
cli = ["click (>=5.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "realtime"
|
||||
version = "2.5.3"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "realtime-2.5.3-py3-none-any.whl", hash = "sha256:eb0994636946eff04c4c7f044f980c8c633c7eb632994f549f61053a474ac970"},
|
||||
{file = "realtime-2.5.3.tar.gz", hash = "sha256:0587594f3bc1c84bf007ff625075b86db6528843e03250dc84f4f2808be3d99a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
typing-extensions = ">=4.14.0,<5.0.0"
|
||||
websockets = ">=11,<16"
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "6.2.0"
|
||||
@@ -1606,18 +1517,6 @@ files = [
|
||||
{file = "semver-3.0.4.tar.gz", hash = "sha256:afc7d8c584a5ed0a11033af086e8af226a9c0b206f313e0301f8dd7b6b589602"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "six"
|
||||
version = "1.17.0"
|
||||
description = "Python 2 and 3 compatibility utilities"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"},
|
||||
{file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.1"
|
||||
@@ -1649,76 +1548,6 @@ typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""
|
||||
[package.extras]
|
||||
full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"]
|
||||
|
||||
[[package]]
|
||||
name = "storage3"
|
||||
version = "0.12.0"
|
||||
description = "Supabase Storage client for Python."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "storage3-0.12.0-py3-none-any.whl", hash = "sha256:1c4585693ca42243ded1512b58e54c697111e91a20916cd14783eebc37e7c87d"},
|
||||
{file = "storage3-0.12.0.tar.gz", hash = "sha256:94243f20922d57738bf42e96b9f5582b4d166e8bf209eccf20b146909f3f71b0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2.1.0,<3.0.0"
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
python-dateutil = ">=2.8.2,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "strenum"
|
||||
version = "0.4.15"
|
||||
description = "An Enum that inherits from str."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"},
|
||||
{file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"]
|
||||
release = ["twine"]
|
||||
test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"]
|
||||
|
||||
[[package]]
|
||||
name = "supabase"
|
||||
version = "2.16.0"
|
||||
description = "Supabase client for Python."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase-2.16.0-py3-none-any.whl", hash = "sha256:99065caab3d90a56650bf39fbd0e49740995da3738ab28706c61bd7f2401db55"},
|
||||
{file = "supabase-2.16.0.tar.gz", hash = "sha256:98f3810158012d4ec0e3083f2e5515f5e10b32bd71e7d458662140e963c1d164"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
gotrue = ">=2.11.0,<3.0.0"
|
||||
httpx = ">=0.26,<0.29"
|
||||
postgrest = ">0.19,<1.2"
|
||||
realtime = ">=2.4.0,<2.6.0"
|
||||
storage3 = ">=0.10,<0.13"
|
||||
supafunc = ">=0.9,<0.11"
|
||||
|
||||
[[package]]
|
||||
name = "supafunc"
|
||||
version = "0.10.1"
|
||||
description = "Library for Supabase Functions"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supafunc-0.10.1-py3-none-any.whl", hash = "sha256:26df9bd25ff2ef56cb5bfb8962de98f43331f7f8ff69572bac3ed9c3a9672040"},
|
||||
{file = "supafunc-0.10.1.tar.gz", hash = "sha256:a5b33c8baecb6b5297d25da29a2503e2ec67ee6986f3d44c137e651b8a59a17d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
strenum = ">=0.4.15,<0.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.2.1"
|
||||
@@ -1827,85 +1656,6 @@ typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "websockets"
|
||||
version = "15.0.1"
|
||||
description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "websockets-15.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d63efaa0cd96cf0c5fe4d581521d9fa87744540d4bc999ae6e08595a1014b45b"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac60e3b188ec7574cb761b08d50fcedf9d77f1530352db4eef1707fe9dee7205"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5756779642579d902eed757b21b0164cd6fe338506a8083eb58af5c372e39d9a"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fdfe3e2a29e4db3659dbd5bbf04560cea53dd9610273917799f1cde46aa725e"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c2529b320eb9e35af0fa3016c187dffb84a3ecc572bcee7c3ce302bfeba52bf"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac1e5c9054fe23226fb11e05a6e630837f074174c4c2f0fe442996112a6de4fb"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5df592cd503496351d6dc14f7cdad49f268d8e618f80dce0cd5a36b93c3fc08d"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0a34631031a8f05657e8e90903e656959234f3a04552259458aac0b0f9ae6fd9"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3d00075aa65772e7ce9e990cab3ff1de702aa09be3940d1dc88d5abf1ab8a09c"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-win32.whl", hash = "sha256:1234d4ef35db82f5446dca8e35a7da7964d02c127b095e172e54397fb6a6c256"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:39c1fec2c11dc8d89bba6b2bf1556af381611a173ac2b511cf7231622058af41"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:823c248b690b2fd9303ba00c4f66cd5e2d8c3ba4aa968b2779be9532a4dad431"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678999709e68425ae2593acf2e3ebcbcf2e69885a5ee78f9eb80e6e371f1bf57"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d50fd1ee42388dcfb2b3676132c78116490976f1300da28eb629272d5d93e905"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d99e5546bf73dbad5bf3547174cd6cb8ba7273062a23808ffea025ecb1cf8562"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66dd88c918e3287efc22409d426c8f729688d89a0c587c88971a0faa2c2f3792"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8dd8327c795b3e3f219760fa603dcae1dcc148172290a8ab15158cf85a953413"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8fdc51055e6ff4adeb88d58a11042ec9a5eae317a0a53d12c062c8a8865909e8"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:693f0192126df6c2327cce3baa7c06f2a117575e32ab2308f7f8216c29d9e2e3"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:54479983bd5fb469c38f2f5c7e3a24f9a4e70594cd68cd1fa6b9340dadaff7cf"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-win32.whl", hash = "sha256:16b6c1b3e57799b9d38427dda63edcbe4926352c47cf88588c0be4ace18dac85"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:27ccee0071a0e75d22cb35849b1db43f2ecd3e161041ac1ee9d2352ddf72f065"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3e90baa811a5d73f3ca0bcbf32064d663ed81318ab225ee4f427ad4e26e5aff3"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:592f1a9fe869c778694f0aa806ba0374e97648ab57936f092fd9d87f8bc03665"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0701bc3cfcb9164d04a14b149fd74be7347a530ad3bbf15ab2c678a2cd3dd9a2"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8b56bdcdb4505c8078cb6c7157d9811a85790f2f2b3632c7d1462ab5783d215"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0af68c55afbd5f07986df82831c7bff04846928ea8d1fd7f30052638788bc9b5"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64dee438fed052b52e4f98f76c5790513235efaa1ef7f3f2192c392cd7c91b65"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d5f6b181bb38171a8ad1d6aa58a67a6aa9d4b38d0f8c5f496b9e42561dfc62fe"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5d54b09eba2bada6011aea5375542a157637b91029687eb4fdb2dab11059c1b4"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3be571a8b5afed347da347bfcf27ba12b069d9d7f42cb8c7028b5e98bbb12597"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-win32.whl", hash = "sha256:c338ffa0520bdb12fbc527265235639fb76e7bc7faafbb93f6ba80d9c06578a9"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcd5cf9e305d7b8338754470cf69cf81f420459dbae8a3b40cee57417f4614a7"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee443ef070bb3b6ed74514f5efaa37a252af57c90eb33b956d35c8e9c10a1931"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5a939de6b7b4e18ca683218320fc67ea886038265fd1ed30173f5ce3f8e85675"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:746ee8dba912cd6fc889a8147168991d50ed70447bf18bcda7039f7d2e3d9151"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:595b6c3969023ecf9041b2936ac3827e4623bfa3ccf007575f04c5a6aa318c22"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c714d2fc58b5ca3e285461a4cc0c9a66bd0e24c5da9911e30158286c9b5be7f"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f3c1e2ab208db911594ae5b4f79addeb3501604a165019dd221c0bdcabe4db8"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:229cf1d3ca6c1804400b0a9790dc66528e08a6a1feec0d5040e8b9eb14422375"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:756c56e867a90fb00177d530dca4b097dd753cde348448a1012ed6c5131f8b7d"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:558d023b3df0bffe50a04e710bc87742de35060580a293c2a984299ed83bc4e4"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-win32.whl", hash = "sha256:ba9e56e8ceeeedb2e080147ba85ffcd5cd0711b89576b83784d8605a7df455fa"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:e09473f095a819042ecb2ab9465aee615bd9c2028e4ef7d933600a8401c79561"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5f4c04ead5aed67c8a1a20491d54cdfba5884507a48dd798ecaf13c74c4489f5"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abdc0c6c8c648b4805c5eacd131910d2a7f6455dfd3becab248ef108e89ab16a"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a625e06551975f4b7ea7102bc43895b90742746797e2e14b70ed61c43a90f09b"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d591f8de75824cbb7acad4e05d2d710484f15f29d4a915092675ad3456f11770"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:47819cea040f31d670cc8d324bb6435c6f133b8c7a19ec3d61634e62f8d8f9eb"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac017dd64572e5c3bd01939121e4d16cf30e5d7e110a119399cf3133b63ad054"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4a9fac8e469d04ce6c25bb2610dc535235bd4aa14996b4e6dbebf5e007eba5ee"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:363c6f671b761efcb30608d24925a382497c12c506b51661883c3e22337265ed"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:2034693ad3097d5355bfdacfffcbd3ef5694f9718ab7f29c29689a9eae841880"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-win32.whl", hash = "sha256:3b1ac0d3e594bf121308112697cf4b32be538fb1444468fb0a6ae4feebc83411"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7643a03db5c95c799b89b31c036d5f27eeb4d259c798e878d6937d71832b1e4"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0c9e74d766f2818bb95f84c25be4dea09841ac0f734d1966f415e4edfc4ef1c3"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1009ee0c7739c08a0cd59de430d6de452a55e42d6b522de7aa15e6f67db0b8e1"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d1f20b1c7a2fa82367e04982e708723ba0e7b8d43aa643d3dcd404d74f1475"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f29d80eb9a9263b8d109135351caf568cc3f80b9928bccde535c235de55c22d9"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b359ed09954d7c18bbc1680f380c7301f92c60bf924171629c5db97febb12f04"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7f493881579c90fc262d9cdbaa05a6b54b3811c2f300766748db79f098db9940"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:47b099e1f4fbc95b701b6e85768e1fcdaf1630f3cbe4765fa216596f12310e2e"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67f2b6de947f8c757db2db9c71527933ad0019737ec374a8a6be9a956786aaf9"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d08eb4c2b7d6c41da6ca0600c077e93f5adcfd979cd777d747e9ee624556da4b"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b826973a4a2ae47ba357e4e82fa44a463b8f168e1ca775ac64521442b19e87f"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:21c1fa28a6a7e3cbdc171c694398b6df4744613ce9b36b1a498e816787e28123"},
|
||||
{file = "websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f"},
|
||||
{file = "websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zipp"
|
||||
version = "3.23.0"
|
||||
@@ -1929,4 +1679,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
|
||||
content-hash = "de209c97aa0feb29d669a20e4422d51bdf3a0872ec37e85ce9b88ce726fcee7a"
|
||||
|
||||
@@ -18,7 +18,8 @@ pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.16.0"
|
||||
bcrypt = "^4.1.0"
|
||||
authlib = "^1.3.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -27,10 +27,15 @@ REDIS_PORT=6379
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
|
||||
# Supabase Authentication
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
# JWT Authentication
|
||||
# Generate a secure random key: python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
JWT_SIGN_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
JWT_SIGN_ALGORITHM=HS256
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=15
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
JWT_ISSUER=autogpt-platform
|
||||
JWT_AUDIENCE=authenticated
|
||||
|
||||
## ===== REQUIRED SECURITY KEYS ===== ##
|
||||
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
|
||||
@@ -58,13 +63,6 @@ V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Langfuse Prompt Management
|
||||
# Used for managing the CoPilot system prompt externally
|
||||
# Get credentials from https://cloud.langfuse.com or your self-hosted instance
|
||||
LANGFUSE_PUBLIC_KEY=
|
||||
LANGFUSE_SECRET_KEY=
|
||||
LANGFUSE_HOST=https://cloud.langfuse.com
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -18,3 +18,6 @@ load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
|
||||
# Migration backups (contain user data)
|
||||
migration_backups/
|
||||
|
||||
@@ -108,7 +108,7 @@ import fastapi.testclient
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.features.myroute import router
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
@@ -149,7 +149,7 @@ These provide the easiest way to set up authentication mocking in test modules:
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.api.features.myroute import router
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
from .v1.routes import v1_router
|
||||
|
||||
external_api = FastAPI(
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_api.add_middleware(SecurityHeadersMiddleware)
|
||||
external_api.include_router(v1_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_api,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
@@ -1,340 +0,0 @@
|
||||
"""Tests for analytics API endpoints."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from .analytics import router as analytics_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_metric endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_metric_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging."""
|
||||
mock_result = Mock(id="metric-123-uuid")
|
||||
mock_log_metric = mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "page_load_time",
|
||||
"metric_value": 2.5,
|
||||
"data_string": "/dashboard",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "metric-123-uuid"
|
||||
|
||||
mock_log_metric.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_metric_success",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_value,metric_name,data_string,test_id",
|
||||
[
|
||||
(100, "api_calls_count", "external_api", "integer_value"),
|
||||
(0, "error_count", "no_errors", "zero_value"),
|
||||
(-5.2, "temperature_delta", "cooling", "negative_value"),
|
||||
(1.23456789, "precision_test", "float_precision", "float_precision"),
|
||||
(999999999, "large_number", "max_value", "large_number"),
|
||||
(0.0000001, "tiny_number", "min_value", "tiny_number"),
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_various_values(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
metric_value: float,
|
||||
metric_name: str,
|
||||
data_string: str,
|
||||
test_id: str,
|
||||
) -> None:
|
||||
"""Test raw metric logging with various metric values."""
|
||||
mock_result = Mock(id=f"metric-{test_id}-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": metric_name,
|
||||
"metric_value": metric_value,
|
||||
"data_string": data_string,
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"metric_id": response.json(), "test_case": test_id},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
f"analytics_metric_{test_id}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"metric_name": "test"}, "Field required"),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": "not_a_number", "data_string": "x"},
|
||||
"Input should be a valid number",
|
||||
),
|
||||
(
|
||||
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": 1.0, "data_string": ""},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_metric_value_and_data_string",
|
||||
"invalid_metric_value_type",
|
||||
"empty_metric_name",
|
||||
"empty_data_string",
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid metric requests."""
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_metric_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database connection failed"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "test_metric",
|
||||
"metric_value": 1.0,
|
||||
"data_string": "test",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Database connection failed" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_analytics endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_analytics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw analytics logging."""
|
||||
mock_result = Mock(id="analytics-789-uuid")
|
||||
mock_log_analytics = mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "user_action",
|
||||
"data": {
|
||||
"action": "button_click",
|
||||
"button_id": "submit_form",
|
||||
"timestamp": "2023-01-01T00:00:00Z",
|
||||
"metadata": {"form_type": "registration", "fields_filled": 5},
|
||||
},
|
||||
"data_index": "button_click_submit_form",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "analytics-789-uuid"
|
||||
|
||||
mock_log_analytics.assert_called_once_with(
|
||||
test_user_id,
|
||||
"user_action",
|
||||
request_data["data"],
|
||||
"button_click_submit_form",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"analytics_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_analytics_success",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_analytics_complex_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test raw analytics logging with complex nested data structures."""
|
||||
mock_result = Mock(id="analytics-complex-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "agent_execution",
|
||||
"data": {
|
||||
"agent_id": "agent_123",
|
||||
"execution_id": "exec_456",
|
||||
"status": "completed",
|
||||
"duration_ms": 3500,
|
||||
"nodes_executed": 15,
|
||||
"blocks_used": [
|
||||
{"block_id": "llm_block", "count": 3},
|
||||
{"block_id": "http_block", "count": 5},
|
||||
{"block_id": "code_block", "count": 2},
|
||||
],
|
||||
"errors": [],
|
||||
"metadata": {
|
||||
"trigger": "manual",
|
||||
"user_tier": "premium",
|
||||
"environment": "production",
|
||||
},
|
||||
},
|
||||
"data_index": "agent_123_exec_456",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"analytics_id": response.json(), "logged_data": request_data["data"]},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
"analytics_log_analytics_complex_data",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"type": "test"}, "Field required"),
|
||||
(
|
||||
{"type": "test", "data": "not_a_dict", "data_index": "test"},
|
||||
"Input should be a valid dictionary",
|
||||
),
|
||||
({"type": "test", "data": {"key": "value"}}, "Field required"),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_data_and_data_index",
|
||||
"invalid_data_type",
|
||||
"missing_data_index",
|
||||
],
|
||||
)
|
||||
def test_log_raw_analytics_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid analytics requests."""
|
||||
response = client.post("/log_raw_analytics", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_analytics_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Analytics DB unreachable"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "test_event",
|
||||
"data": {"key": "value"},
|
||||
"data_index": "test_index",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Analytics DB unreachable" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
@@ -1,195 +0,0 @@
|
||||
"""Database operations for chat sessions."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import ChatSessionUpdateInput
|
||||
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort messages by sequence in Python since Prisma doesn't support order_by in include
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> PrismaChatSession:
|
||||
"""Create a new chat session in the database."""
|
||||
data = {
|
||||
"id": session_id,
|
||||
"userId": user_id,
|
||||
"credentials": SafeJson({}),
|
||||
"successfulAgentRuns": SafeJson({}),
|
||||
"successfulAgentSchedules": SafeJson({}),
|
||||
}
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
session_id: str,
|
||||
credentials: dict[str, Any] | None = None,
|
||||
successful_agent_runs: dict[str, Any] | None = None,
|
||||
successful_agent_schedules: dict[str, Any] | None = None,
|
||||
total_prompt_tokens: int | None = None,
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> PrismaChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
if credentials is not None:
|
||||
data["credentials"] = SafeJson(credentials)
|
||||
if successful_agent_runs is not None:
|
||||
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
|
||||
if successful_agent_schedules is not None:
|
||||
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
|
||||
if total_prompt_tokens is not None:
|
||||
data["totalPromptTokens"] = total_prompt_tokens
|
||||
if total_completion_tokens is not None:
|
||||
data["totalCompletionTokens"] = total_completion_tokens
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
|
||||
session = await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
sequence: int,
|
||||
content: str | None = None,
|
||||
name: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
refusal: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if tool_call_id is not None:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = refusal
|
||||
if tool_calls is not None:
|
||||
data["toolCalls"] = SafeJson(tool_calls)
|
||||
if function_call is not None:
|
||||
data["functionCall"] = SafeJson(function_call)
|
||||
|
||||
# Update session's updatedAt timestamp
|
||||
await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return await PrismaChatMessage.prisma().create(data=data)
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[PrismaChatMessage]:
|
||||
"""Add multiple messages to a chat session in a batch."""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
created_messages = []
|
||||
for i, msg in enumerate(messages):
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma().create(data=data)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp
|
||||
await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return created_messages
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[PrismaChatSession]:
|
||||
"""Get chat sessions for a user, ordered by most recent."""
|
||||
return await PrismaChatSession.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
|
||||
async def get_user_session_count(user_id: str) -> int:
|
||||
"""Get the total number of chat sessions for a user."""
|
||||
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str) -> bool:
|
||||
"""Delete a chat session and all its messages."""
|
||||
try:
|
||||
await PrismaChatSession.prisma().delete(where={"id": session_id})
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
@@ -1,473 +0,0 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionDeveloperMessageParam,
|
||||
ChatCompletionFunctionMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_assistant_message_param import FunctionCall
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str | None = None
|
||||
name: str | None = None
|
||||
tool_call_id: str | None = None
|
||||
refusal: str | None = None
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
title: str | None = None
|
||||
messages: list[ChatMessage]
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
started_at: datetime
|
||||
updated_at: datetime
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def new(user_id: str | None) -> "ChatSession":
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_prisma(
|
||||
prisma_session: PrismaChatSession,
|
||||
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||
) -> "ChatSession":
|
||||
"""Convert Prisma models to Pydantic ChatSession."""
|
||||
messages = []
|
||||
if prisma_messages:
|
||||
for msg in prisma_messages:
|
||||
tool_calls = None
|
||||
if msg.toolCalls:
|
||||
tool_calls = (
|
||||
json.loads(msg.toolCalls)
|
||||
if isinstance(msg.toolCalls, str)
|
||||
else msg.toolCalls
|
||||
)
|
||||
|
||||
function_call = None
|
||||
if msg.functionCall:
|
||||
function_call = (
|
||||
json.loads(msg.functionCall)
|
||||
if isinstance(msg.functionCall, str)
|
||||
else msg.functionCall
|
||||
)
|
||||
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
name=msg.name,
|
||||
tool_call_id=msg.toolCallId,
|
||||
refusal=msg.refusal,
|
||||
tool_calls=tool_calls,
|
||||
function_call=function_call,
|
||||
)
|
||||
)
|
||||
|
||||
# Parse JSON fields from Prisma
|
||||
credentials = (
|
||||
json.loads(prisma_session.credentials)
|
||||
if isinstance(prisma_session.credentials, str)
|
||||
else prisma_session.credentials or {}
|
||||
)
|
||||
successful_agent_runs = (
|
||||
json.loads(prisma_session.successfulAgentRuns)
|
||||
if isinstance(prisma_session.successfulAgentRuns, str)
|
||||
else prisma_session.successfulAgentRuns or {}
|
||||
)
|
||||
successful_agent_schedules = (
|
||||
json.loads(prisma_session.successfulAgentSchedules)
|
||||
if isinstance(prisma_session.successfulAgentSchedules, str)
|
||||
else prisma_session.successfulAgentSchedules or {}
|
||||
)
|
||||
|
||||
# Calculate usage from token counts
|
||||
usage = []
|
||||
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
|
||||
usage.append(
|
||||
Usage(
|
||||
prompt_tokens=prisma_session.totalPromptTokens or 0,
|
||||
completion_tokens=prisma_session.totalCompletionTokens or 0,
|
||||
total_tokens=(prisma_session.totalPromptTokens or 0)
|
||||
+ (prisma_session.totalCompletionTokens or 0),
|
||||
)
|
||||
)
|
||||
|
||||
return ChatSession(
|
||||
session_id=prisma_session.id,
|
||||
user_id=prisma_session.userId,
|
||||
title=prisma_session.title,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
credentials=credentials,
|
||||
started_at=prisma_session.createdAt,
|
||||
updated_at=prisma_session.updatedAt,
|
||||
successful_agent_runs=successful_agent_runs,
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
if message.role == "developer":
|
||||
m = ChatCompletionDeveloperMessageParam(
|
||||
role="developer",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "system":
|
||||
m = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "user":
|
||||
m = ChatCompletionUserMessageParam(
|
||||
role="user",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "assistant":
|
||||
m = ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.function_call:
|
||||
m["function_call"] = FunctionCall(
|
||||
arguments=message.function_call["arguments"],
|
||||
name=message.function_call["name"],
|
||||
)
|
||||
if message.refusal:
|
||||
m["refusal"] = message.refusal
|
||||
if message.tool_calls:
|
||||
t: list[ChatCompletionMessageToolCallParam] = []
|
||||
for tool_call in message.tool_calls:
|
||||
# Tool calls are stored with nested structure: {id, type, function: {name, arguments}}
|
||||
function_data = tool_call.get("function", {})
|
||||
|
||||
# Skip tool calls that are missing required fields
|
||||
if "id" not in tool_call or "name" not in function_data:
|
||||
logger.warning(
|
||||
f"Skipping invalid tool call: missing required fields. "
|
||||
f"Got: {tool_call.keys()}, function keys: {function_data.keys()}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Arguments are stored as a JSON string
|
||||
arguments_str = function_data.get("arguments", "{}")
|
||||
|
||||
t.append(
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call["id"],
|
||||
type="function",
|
||||
function=Function(
|
||||
arguments=arguments_str,
|
||||
name=function_data["name"],
|
||||
),
|
||||
)
|
||||
)
|
||||
m["tool_calls"] = t
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "tool":
|
||||
messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
content=message.content or "",
|
||||
tool_call_id=message.tool_call_id or "",
|
||||
)
|
||||
)
|
||||
elif message.role == "function":
|
||||
messages.append(
|
||||
ChatCompletionFunctionMessageParam(
|
||||
role="function",
|
||||
content=message.content,
|
||||
name=message.name or "",
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = f"chat:session:{session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = f"chat:session:{session.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_prisma(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> ChatSession | None:
|
||||
"""Get a chat session by ID.
|
||||
|
||||
Checks Redis cache first, falls back to database if not found.
|
||||
Caches database results back to Redis.
|
||||
"""
|
||||
# Try cache first
|
||||
try:
|
||||
session = await _get_session_from_cache(session_id)
|
||||
if session:
|
||||
# Verify user ownership
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
return session
|
||||
except RedisError:
|
||||
logger.warning(f"Cache error for session {session_id}, trying database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||
|
||||
# Fall back to database
|
||||
logger.info(f"Session {session_id} not in cache, checking database")
|
||||
session = await _get_session_from_db(session_id)
|
||||
|
||||
if session is None:
|
||||
logger.warning(f"Session {session_id} not found in cache or database")
|
||||
return None
|
||||
|
||||
# Verify user ownership
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
"""Update a chat session in both cache and database."""
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
# Save to database
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save session {session.session_id} to database: {e}")
|
||||
# Continue to cache even if DB fails
|
||||
|
||||
# Save to cache
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
raise RedisError(
|
||||
f"Failed to persist chat session {session.session_id} to Redis: {e}"
|
||||
) from e
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str | None) -> ChatSession:
|
||||
"""Create a new chat session and persist it."""
|
||||
session = ChatSession.new(user_id)
|
||||
|
||||
# Create in database first
|
||||
try:
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session in database: {e}")
|
||||
# Continue even if DB fails - cache will still work
|
||||
|
||||
# Cache the session
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def get_user_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[ChatSession]:
|
||||
"""Get all chat sessions for a user from the database."""
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
# Convert without messages for listing (lighter weight)
|
||||
sessions.append(ChatSession.from_prisma(prisma_session, None))
|
||||
|
||||
return sessions
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str) -> bool:
|
||||
"""Delete a chat session from both cache and database."""
|
||||
# Delete from cache
|
||||
try:
|
||||
redis_key = f"chat:session:{session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
|
||||
|
||||
# Delete from database
|
||||
return await chat_db.delete_chat_session(session_id)
|
||||
@@ -1,117 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
|
||||
messages = [
|
||||
ChatMessage(content="Hello, how are you?", role="user"),
|
||||
ChatMessage(
|
||||
content="I'm fine, thank you!",
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "New York"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
content="I'm using the tool to get the weather",
|
||||
role="tool",
|
||||
tool_call_id="t123",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_serialization_deserialization():
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s.messages = messages
|
||||
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
|
||||
serialized = s.model_dump_json()
|
||||
s2 = ChatSession.model_validate_json(serialized)
|
||||
assert s2.model_dump() == s.model_dump()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage():
|
||||
|
||||
s = ChatSession.new(user_id=None)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 == s
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage_user_id_mismatch():
|
||||
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(s.session_id, None)
|
||||
|
||||
assert s2 is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_db_storage():
|
||||
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=None)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
|
||||
# Upsert to save to both cache and DB
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
# Clear the Redis cache to force DB load
|
||||
redis_key = f"chat:session:{s.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
|
||||
# Load from DB (cache was cleared)
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 is not None, "Session not found after loading from DB"
|
||||
assert len(s2.messages) == len(
|
||||
s.messages
|
||||
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
|
||||
|
||||
# Verify all roles are present
|
||||
roles = [m.role for m in s2.messages]
|
||||
assert "user" in roles, f"User message missing. Roles found: {roles}"
|
||||
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
|
||||
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
|
||||
|
||||
# Verify message content
|
||||
for orig, loaded in zip(s.messages, s2.messages):
|
||||
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
|
||||
assert (
|
||||
orig.content == loaded.content
|
||||
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
|
||||
if orig.tool_calls:
|
||||
assert (
|
||||
loaded.tool_calls is not None
|
||||
), f"Tool calls missing for {orig.role} message"
|
||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||
@@ -1,192 +0,0 @@
|
||||
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find, create, and set up AutoGPT agents to solve their business problems.
|
||||
|
||||
Here are the functions available to you:
|
||||
|
||||
<functions>
|
||||
**Understanding & Discovery:**
|
||||
1. **add_understanding** - Save information about the user's business context (use this as you learn about them)
|
||||
2. **find_agent** - Search the marketplace for pre-built agents that solve the user's problem
|
||||
3. **find_library_agent** - Search the user's personal library of saved agents
|
||||
4. **find_block** - Search for individual blocks (building components for agents)
|
||||
5. **search_platform_docs** - Search AutoGPT documentation for help
|
||||
|
||||
**Agent Creation & Editing:**
|
||||
6. **create_agent** - Create a new custom agent from scratch based on user requirements
|
||||
7. **edit_agent** - Modify an existing agent (add/remove blocks, change configuration)
|
||||
|
||||
**Execution & Output:**
|
||||
8. **run_agent** - Run or schedule an agent (automatically handles setup)
|
||||
9. **run_block** - Run a single block directly without creating an agent
|
||||
10. **agent_output** - Get the output/results from a running or completed agent execution
|
||||
</functions>
|
||||
|
||||
## ALWAYS GET THE USER'S NAME
|
||||
|
||||
**This is critical:** If you don't know the user's name, ask for it in your first response. Use a friendly, natural approach:
|
||||
- "Hi! I'm Otto. What's your name?"
|
||||
- "Hey there! Before we dive in, what should I call you?"
|
||||
|
||||
Once you have their name, immediately save it with `add_understanding(user_name="...")` and use it throughout the conversation.
|
||||
|
||||
## BUILDING USER UNDERSTANDING
|
||||
|
||||
**If no User Business Context is provided below**, gather information naturally during conversation - don't interrogate them.
|
||||
|
||||
**Key information to gather (in priority order):**
|
||||
1. Their name (ALWAYS first if unknown)
|
||||
2. Their job title and role
|
||||
3. Their business/company and industry
|
||||
4. Pain points and what they want to automate
|
||||
5. Tools they currently use
|
||||
|
||||
**How to gather this information:**
|
||||
- Ask naturally as part of helping them (e.g., "What's your role?" or "What industry are you in?")
|
||||
- When they share information, immediately save it using `add_understanding`
|
||||
- Don't ask all questions at once - spread them across the conversation
|
||||
- Prioritize understanding their immediate problem first
|
||||
|
||||
**Example:**
|
||||
```
|
||||
User: "I need help automating my social media"
|
||||
Otto: I can help with that! I'm Otto - what's your name?
|
||||
User: "I'm Sarah"
|
||||
Otto: [calls add_understanding with user_name="Sarah"]
|
||||
Nice to meet you, Sarah! What's your role - are you a social media manager or business owner?
|
||||
User: "I'm the marketing director at a fintech startup"
|
||||
Otto: [calls add_understanding with job_title="Marketing Director", industry="fintech", business_size="startup"]
|
||||
Great! Let me find social media automation agents for you.
|
||||
[calls find_agent with query="social media automation marketing"]
|
||||
```
|
||||
|
||||
## WHEN TO USE WHICH TOOL
|
||||
|
||||
**Finding existing agents:**
|
||||
- `find_agent` - Search the marketplace for pre-built agents others have created
|
||||
- `find_library_agent` - Search agents the user has already saved to their library
|
||||
|
||||
**Creating/editing agents:**
|
||||
- `create_agent` - When user wants a custom agent that doesn't exist, or has specific requirements
|
||||
- `edit_agent` - When user wants to modify an existing agent (change inputs, add blocks, etc.)
|
||||
|
||||
**Running agents:**
|
||||
- `run_agent` - To execute an agent (handles credentials and inputs automatically)
|
||||
- `agent_output` - To check the results of a running or completed agent execution
|
||||
|
||||
**Direct execution:**
|
||||
- `run_block` - Run a single block directly without needing a full agent
|
||||
|
||||
## HOW run_agent WORKS
|
||||
|
||||
The `run_agent` tool automatically handles the entire setup flow:
|
||||
|
||||
1. **First call** (no inputs) → Returns available inputs so user can decide what values to use
|
||||
2. **Credentials check** → If missing, UI automatically prompts user to add them (you don't need to mention this)
|
||||
3. **Execution** → Runs when you provide `inputs` OR set `use_defaults=true`
|
||||
|
||||
Parameters:
|
||||
- `username_agent_slug` (required): Agent identifier like "creator/agent-name"
|
||||
- `inputs`: Object with input values for the agent
|
||||
- `use_defaults`: Set to `true` to run with default values (only after user confirms)
|
||||
- `schedule_name` + `cron`: For scheduled execution
|
||||
|
||||
## HOW create_agent WORKS
|
||||
|
||||
Use `create_agent` when the user wants to build a custom automation:
|
||||
- Describe what the agent should do
|
||||
- The tool will create the agent structure with appropriate blocks
|
||||
- Returns the agent ID for further editing or running
|
||||
|
||||
## HOW agent_output WORKS
|
||||
|
||||
Use `agent_output` to get results from agent executions:
|
||||
- Pass the execution_id from a run_agent response
|
||||
- Returns the current status and any outputs produced
|
||||
- Useful for checking if an agent has completed and what it produced
|
||||
|
||||
## WORKFLOW
|
||||
|
||||
1. **Get their name** - If unknown, ask for it first
|
||||
2. **Understand context** - Ask 1-2 questions about their problem while helping
|
||||
3. **Find or create** - Use find_agent for existing solutions, create_agent for custom needs
|
||||
4. **Set up and run** - Use run_agent to execute, agent_output to get results
|
||||
|
||||
## YOUR APPROACH
|
||||
|
||||
**Step 1: Greet and Identify**
|
||||
- If you don't know their name, ask for it
|
||||
- Be friendly and conversational
|
||||
|
||||
**Step 2: Understand the Problem**
|
||||
- Ask maximum 1-2 targeted questions
|
||||
- Focus on: What business problem are they solving?
|
||||
- If they want to create/edit an agent, understand what it should do
|
||||
|
||||
**Step 3: Find or Create**
|
||||
- For existing solutions: Use `find_agent` with relevant keywords
|
||||
- For custom needs: Use `create_agent` with their requirements
|
||||
- For modifications: Use `edit_agent` on an existing agent
|
||||
|
||||
**Step 4: Execute**
|
||||
- Call `run_agent` without inputs first to see what's available
|
||||
- Ask user what values they want or if defaults are okay
|
||||
- Call `run_agent` again with inputs or `use_defaults=true`
|
||||
- Use `agent_output` to check results when needed
|
||||
|
||||
## USING add_understanding
|
||||
|
||||
Call `add_understanding` whenever you learn something about the user:
|
||||
|
||||
**User info:** `user_name`, `job_title`
|
||||
**Business:** `business_name`, `industry`, `business_size` (1-10, 11-50, 51-200, 201-1000, 1000+), `user_role` (decision maker, implementer, end user)
|
||||
**Processes:** `key_workflows` (array), `daily_activities` (array)
|
||||
**Pain points:** `pain_points` (array), `bottlenecks` (array), `manual_tasks` (array), `automation_goals` (array)
|
||||
**Tools:** `current_software` (array), `existing_automation` (array)
|
||||
**Other:** `additional_notes`
|
||||
|
||||
Example: `add_understanding(user_name="Sarah", job_title="Marketing Director", industry="fintech")`
|
||||
|
||||
## KEY RULES
|
||||
|
||||
**What You DON'T Do:**
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't mention or explain credentials to the user (frontend handles this automatically)
|
||||
- Don't run agents without first showing available inputs to the user
|
||||
- Don't use `use_defaults=true` without user explicitly confirming
|
||||
- Don't write responses longer than 3 sentences
|
||||
- Don't interrogate users with many questions - gather info naturally
|
||||
|
||||
**What You DO:**
|
||||
- ALWAYS ask for user's name if you don't have it
|
||||
- Save user information with `add_understanding` as you learn it
|
||||
- Use their name when addressing them
|
||||
- Always call run_agent first without inputs to see what's available
|
||||
- Ask user what values they want OR if they want to use defaults
|
||||
- Keep all responses to maximum 3 sentences
|
||||
- Include the agent link in your response after successful execution
|
||||
|
||||
**Error Handling:**
|
||||
- Authentication needed → "Please sign in via the interface"
|
||||
- Credentials missing → The UI handles this automatically. Focus on asking the user about input values instead.
|
||||
|
||||
## RESPONSE STRUCTURE
|
||||
|
||||
Before responding, wrap your analysis in <thinking> tags to systematically plan your approach:
|
||||
- Check if you know the user's name - if not, ask for it
|
||||
- Check if you have user context - if not, plan to gather some naturally
|
||||
- Extract the key business problem or request from the user's message
|
||||
- Determine what function call (if any) you need to make next
|
||||
- Plan your response to stay under the 3-sentence maximum
|
||||
|
||||
Example interaction:
|
||||
```
|
||||
User: "Hi, I want to build an agent that monitors my competitors"
|
||||
Otto: <thinking>I don't know this user's name. I should ask for it while acknowledging their request.</thinking>
|
||||
Hi! I'm Otto and I'd love to help you build a competitor monitoring agent. What's your name?
|
||||
User: "I'm Mike"
|
||||
Otto: [calls add_understanding with user_name="Mike"]
|
||||
<thinking>Now I know Mike wants competitor monitoring. I should search for existing agents first.</thinking>
|
||||
Great to meet you, Mike! Let me search for competitor monitoring agents.
|
||||
[calls find_agent with query="competitor monitoring analysis"]
|
||||
```
|
||||
|
||||
KEEP ANSWERS TO 3 SENTENCES
|
||||
@@ -1,155 +0,0 @@
|
||||
You are Otto, an AI Co-Pilot helping new users get started with AutoGPT, an AI Business Automation platform. Your mission is to welcome them, learn about their needs, and help them run their first successful agent.
|
||||
|
||||
Here are the functions available to you:
|
||||
|
||||
<functions>
|
||||
**Understanding & Discovery:**
|
||||
1. **add_understanding** - Save information about the user's business context (use this as you learn about them)
|
||||
2. **find_agent** - Search the marketplace for pre-built agents that solve the user's problem
|
||||
3. **find_library_agent** - Search the user's personal library of saved agents
|
||||
4. **find_block** - Search for individual blocks (building components for agents)
|
||||
5. **search_platform_docs** - Search AutoGPT documentation for help
|
||||
|
||||
**Agent Creation & Editing:**
|
||||
6. **create_agent** - Create a new custom agent from scratch based on user requirements
|
||||
7. **edit_agent** - Modify an existing agent (add/remove blocks, change configuration)
|
||||
|
||||
**Execution & Output:**
|
||||
8. **run_agent** - Run or schedule an agent (automatically handles setup)
|
||||
9. **run_block** - Run a single block directly without creating an agent
|
||||
10. **agent_output** - Get the output/results from a running or completed agent execution
|
||||
</functions>
|
||||
|
||||
## YOUR ONBOARDING MISSION
|
||||
|
||||
You are guiding a new user through their first experience with AutoGPT. Your goal is to:
|
||||
1. Welcome them warmly and get their name
|
||||
2. Learn about them and their business
|
||||
3. Find or create an agent that solves a real problem for them
|
||||
4. Get that agent running successfully
|
||||
5. Celebrate their success and point them to next steps
|
||||
|
||||
## PHASE 1: WELCOME & INTRODUCTION
|
||||
|
||||
**Start every conversation by:**
|
||||
- Giving a warm, friendly greeting
|
||||
- Introducing yourself as Otto, their AI assistant
|
||||
- Asking for their name immediately
|
||||
|
||||
**Example opening:**
|
||||
```
|
||||
Hi! I'm Otto, your AI assistant. Welcome to AutoGPT! I'm here to help you set up your first automation. What's your name?
|
||||
```
|
||||
|
||||
Once you have their name, save it immediately with `add_understanding(user_name="...")` and use it throughout.
|
||||
|
||||
## PHASE 2: DISCOVERY
|
||||
|
||||
**After getting their name, learn about them:**
|
||||
- What's their role/job title?
|
||||
- What industry/business are they in?
|
||||
- What's one thing they'd love to automate?
|
||||
|
||||
**Keep it conversational - don't interrogate. Example:**
|
||||
```
|
||||
Nice to meet you, Sarah! What do you do for work, and what's one task you wish you could automate?
|
||||
```
|
||||
|
||||
Save everything you learn with `add_understanding`.
|
||||
|
||||
## PHASE 3: FIND OR CREATE AN AGENT
|
||||
|
||||
**Once you understand their need:**
|
||||
- Search for existing agents with `find_agent`
|
||||
- Present the best match and explain how it helps them
|
||||
- If nothing fits, offer to create a custom agent with `create_agent`
|
||||
|
||||
**Be enthusiastic about the solution:**
|
||||
```
|
||||
I found a great agent for you! The "Social Media Scheduler" can automatically post to your accounts on a schedule. Want to try it?
|
||||
```
|
||||
|
||||
## PHASE 4: SETUP & RUN
|
||||
|
||||
**Guide them through running the agent:**
|
||||
1. Call `run_agent` without inputs first to see what's needed
|
||||
2. Explain each input in simple terms
|
||||
3. Ask what values they want to use
|
||||
4. Run the agent with their inputs or defaults
|
||||
|
||||
**Don't mention credentials** - the UI handles that automatically.
|
||||
|
||||
## PHASE 5: CELEBRATE & HANDOFF
|
||||
|
||||
**After successful execution:**
|
||||
- Congratulate them on their first automation!
|
||||
- Tell them where to find this agent (their Library)
|
||||
- Mention they can explore more agents in the Marketplace
|
||||
- Offer to help with anything else
|
||||
|
||||
**Example:**
|
||||
```
|
||||
You did it! Your first agent is running. You can find it anytime in your Library. Ready to explore more automations?
|
||||
```
|
||||
|
||||
## KEY RULES
|
||||
|
||||
**What You DON'T Do:**
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't mention credentials (UI handles automatically)
|
||||
- Don't run agents without showing inputs first
|
||||
- Don't use `use_defaults=true` without explicit confirmation
|
||||
- Don't write responses longer than 3 sentences
|
||||
- Don't overwhelm with too many questions at once
|
||||
|
||||
**What You DO:**
|
||||
- ALWAYS get the user's name first
|
||||
- Be warm, encouraging, and celebratory
|
||||
- Save info with `add_understanding` as you learn it
|
||||
- Use their name when addressing them
|
||||
- Keep responses to maximum 3 sentences
|
||||
- Make them feel successful at each step
|
||||
|
||||
## USING add_understanding
|
||||
|
||||
Save information as you learn it:
|
||||
|
||||
**User info:** `user_name`, `job_title`
|
||||
**Business:** `business_name`, `industry`, `business_size`, `user_role`
|
||||
**Pain points:** `pain_points`, `manual_tasks`, `automation_goals`
|
||||
**Tools:** `current_software`
|
||||
|
||||
Example: `add_understanding(user_name="Sarah", job_title="Marketing Manager", automation_goals=["social media scheduling"])`
|
||||
|
||||
## HOW run_agent WORKS
|
||||
|
||||
1. **First call** (no inputs) → Shows available inputs
|
||||
2. **Credentials** → UI handles automatically (don't mention)
|
||||
3. **Execution** → Run with `inputs={...}` or `use_defaults=true`
|
||||
|
||||
## RESPONSE STRUCTURE
|
||||
|
||||
Before responding, plan your approach in <thinking> tags:
|
||||
- What phase am I in? (Welcome/Discovery/Find/Setup/Celebrate)
|
||||
- Do I know their name? If not, ask for it
|
||||
- What's the next step to move them forward?
|
||||
- Keep response under 3 sentences
|
||||
|
||||
**Example flow:**
|
||||
```
|
||||
User: "Hi"
|
||||
Otto: <thinking>Phase 1 - I need to welcome them and get their name.</thinking>
|
||||
Hi! I'm Otto, welcome to AutoGPT! I'm here to help you set up your first automation - what's your name?
|
||||
|
||||
User: "I'm Alex"
|
||||
Otto: [calls add_understanding with user_name="Alex"]
|
||||
<thinking>Got their name. Phase 2 - learn about them.</thinking>
|
||||
Great to meet you, Alex! What do you do for work, and what's one task you'd love to automate?
|
||||
|
||||
User: "I run an e-commerce store and spend hours on customer support emails"
|
||||
Otto: [calls add_understanding with industry="e-commerce", pain_points=["customer support emails"]]
|
||||
<thinking>Phase 3 - search for agents.</thinking>
|
||||
[calls find_agent with query="customer support email automation"]
|
||||
```
|
||||
|
||||
KEEP ANSWERS TO 3 SENTENCES - Be warm, helpful, and focused on their success!
|
||||
@@ -1,472 +0,0 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from .config import ChatConfig
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class StreamChatRequest(BaseModel):
|
||||
"""Request model for streaming chat with optional context."""
|
||||
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""Response model containing information on a newly created chat session."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
"""Response model providing complete details for a chat session, including messages."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
"""Response model for a session summary (without messages)."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
"""Response model for listing chat sessions."""
|
||||
|
||||
sessions: list[SessionSummaryResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def list_sessions(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> ListSessionsResponse:
|
||||
"""
|
||||
List chat sessions for the authenticated user.
|
||||
|
||||
Returns a paginated list of chat sessions belonging to the current user,
|
||||
ordered by most recently updated.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of sessions to return (1-100).
|
||||
offset: Number of sessions to skip for pagination.
|
||||
|
||||
Returns:
|
||||
ListSessionsResponse: List of session summaries and total count.
|
||||
"""
|
||||
sessions = await chat_service.get_user_sessions(user_id, limit, offset)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=None, # TODO: Add title support
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
total=len(sessions),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
|
||||
Initiates a new chat session for either an authenticated or anonymous user.
|
||||
|
||||
Args:
|
||||
user_id: The optional authenticated user ID parsed from the JWT. If missing, creates an anonymous session.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
|
||||
)
|
||||
|
||||
session = await chat_service.create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
|
||||
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
f"Returning session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
# Validate session exists before starting the stream
|
||||
# This prevents errors after the response has already started
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found. ")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat_get(
|
||||
session_id: str,
|
||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
is_user_message: bool = Query(default=True),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (GET - legacy endpoint).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
message: The user's new message to process.
|
||||
user_id: Optional authenticated user ID.
|
||||
is_user_message: Whether the message is a user message.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
# Validate session exists before starting the stream
|
||||
# This prevents errors after the response has already started
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found. ")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/assign-user",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
)
|
||||
async def session_assign_user(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""
|
||||
Assign an authenticated user to a chat session.
|
||||
|
||||
Used (typically post-login) to claim an existing anonymous session as the current authenticated user.
|
||||
|
||||
Args:
|
||||
session_id: The identifier for the (previously anonymous) session.
|
||||
user_id: The authenticated user's ID to associate with the session.
|
||||
|
||||
Returns:
|
||||
dict: Status of the assignment.
|
||||
|
||||
"""
|
||||
await chat_service.assign_user_to_session(session_id, user_id)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Onboarding Routes ==========
|
||||
# These routes use a specialized onboarding system prompt
|
||||
|
||||
|
||||
@router.post(
|
||||
"/onboarding/sessions",
|
||||
)
|
||||
async def create_onboarding_session(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new onboarding chat session.
|
||||
|
||||
Initiates a new chat session specifically for user onboarding,
|
||||
using a specialized prompt that guides users through their first
|
||||
experience with AutoGPT.
|
||||
|
||||
Args:
|
||||
user_id: The optional authenticated user ID parsed from the JWT.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created onboarding session.
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating onboarding session with user_id: "
|
||||
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
|
||||
)
|
||||
|
||||
session = await chat_service.create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/onboarding/sessions/{session_id}",
|
||||
)
|
||||
async def get_onboarding_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of an onboarding chat session.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the onboarding session.
|
||||
user_id: The optional authenticated user ID.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session.
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
f"Returning onboarding session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/onboarding/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_onboarding_chat(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream onboarding chat responses for a session.
|
||||
|
||||
Uses the specialized onboarding system prompt to guide new users
|
||||
through their first experience with AutoGPT. Streams AI responses
|
||||
in real time over Server-Sent Events (SSE).
|
||||
|
||||
Args:
|
||||
session_id: The onboarding session identifier.
|
||||
request: Request body containing message and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
context=request.context,
|
||||
prompt_type="onboarding", # Use onboarding system prompt
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
@router.get("/health", status_code=200)
|
||||
async def health_check() -> dict:
|
||||
"""
|
||||
Health check endpoint for the chat service.
|
||||
|
||||
Performs a full cycle test of session creation, assignment, and retrieval. Should always return healthy
|
||||
if the service and data layer are operational.
|
||||
|
||||
Returns:
|
||||
dict: A status dictionary indicating health, service name, and API version.
|
||||
|
||||
"""
|
||||
session = await chat_service.create_chat_session(None)
|
||||
await chat_service.assign_user_to_session(session.session_id, "test_user")
|
||||
await chat_service.get_session(session.session_id, "test_user")
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "chat",
|
||||
"version": "0.1.0",
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .search_docs import SearchDocsTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
|
||||
# Initialize tool instances
|
||||
add_understanding_tool = AddUnderstandingTool()
|
||||
create_agent_tool = CreateAgentTool()
|
||||
edit_agent_tool = EditAgentTool()
|
||||
find_agent_tool = FindAgentTool()
|
||||
find_block_tool = FindBlockTool()
|
||||
find_library_agent_tool = FindLibraryAgentTool()
|
||||
run_agent_tool = RunAgentTool()
|
||||
run_block_tool = RunBlockTool()
|
||||
search_docs_tool = SearchDocsTool()
|
||||
agent_output_tool = AgentOutputTool()
|
||||
|
||||
# Export tools as OpenAI format
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
add_understanding_tool.as_openai_tool(),
|
||||
create_agent_tool.as_openai_tool(),
|
||||
edit_agent_tool.as_openai_tool(),
|
||||
find_agent_tool.as_openai_tool(),
|
||||
find_block_tool.as_openai_tool(),
|
||||
find_library_agent_tool.as_openai_tool(),
|
||||
run_agent_tool.as_openai_tool(),
|
||||
run_block_tool.as_openai_tool(),
|
||||
search_docs_tool.as_openai_tool(),
|
||||
agent_output_tool.as_openai_tool(),
|
||||
]
|
||||
|
||||
|
||||
async def execute_tool(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
) -> "StreamToolExecutionResult":
|
||||
|
||||
tool_map: dict[str, BaseTool] = {
|
||||
"add_understanding": add_understanding_tool,
|
||||
"create_agent": create_agent_tool,
|
||||
"edit_agent": edit_agent_tool,
|
||||
"find_agent": find_agent_tool,
|
||||
"find_block": find_block_tool,
|
||||
"find_library_agent": find_library_agent_tool,
|
||||
"run_agent": run_agent_tool,
|
||||
"run_block": run_block_tool,
|
||||
"search_platform_docs": search_docs_tool,
|
||||
"agent_output": agent_output_tool,
|
||||
}
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
return await tool_map[tool_name].execute(
|
||||
user_id, session, tool_call_id, **parameters
|
||||
)
|
||||
@@ -1,206 +0,0 @@
|
||||
"""Tool for capturing user business understanding incrementally."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AddUnderstandingTool(BaseTool):
|
||||
"""Tool for capturing user's business understanding incrementally."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "add_understanding"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Capture and store information about the user's business context,
|
||||
workflows, pain points, and automation goals. Call this tool whenever the user
|
||||
shares information about their business. Each call incrementally adds to the
|
||||
existing understanding - you don't need to provide all fields at once.
|
||||
|
||||
Use this to build a comprehensive profile that helps recommend better agents
|
||||
and automations for the user's specific needs."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user_name": {
|
||||
"type": "string",
|
||||
"description": "The user's name",
|
||||
},
|
||||
"job_title": {
|
||||
"type": "string",
|
||||
"description": "The user's job title (e.g., 'Marketing Manager', 'CEO', 'Software Engineer')",
|
||||
},
|
||||
"business_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the user's business or organization",
|
||||
},
|
||||
"industry": {
|
||||
"type": "string",
|
||||
"description": "Industry or sector (e.g., 'e-commerce', 'healthcare', 'finance')",
|
||||
},
|
||||
"business_size": {
|
||||
"type": "string",
|
||||
"description": "Company size: '1-10', '11-50', '51-200', '201-1000', or '1000+'",
|
||||
},
|
||||
"user_role": {
|
||||
"type": "string",
|
||||
"description": "User's role in organization context (e.g., 'decision maker', 'implementer', 'end user')",
|
||||
},
|
||||
"key_workflows": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Key business workflows (e.g., 'lead qualification', 'content publishing')",
|
||||
},
|
||||
"daily_activities": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Regular daily activities the user performs",
|
||||
},
|
||||
"pain_points": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Current pain points or challenges",
|
||||
},
|
||||
"bottlenecks": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Process bottlenecks slowing things down",
|
||||
},
|
||||
"manual_tasks": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Manual or repetitive tasks that could be automated",
|
||||
},
|
||||
"automation_goals": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Desired automation outcomes or goals",
|
||||
},
|
||||
"current_software": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Software and tools currently in use",
|
||||
},
|
||||
"existing_automation": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Any existing automations or integrations",
|
||||
},
|
||||
"additional_notes": {
|
||||
"type": "string",
|
||||
"description": "Any other relevant context or notes",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""Requires authentication to store user-specific data."""
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Capture and store business understanding incrementally.
|
||||
|
||||
Each call merges new data with existing understanding:
|
||||
- String fields are overwritten if provided
|
||||
- List fields are appended (with deduplication)
|
||||
"""
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required to save business understanding.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if any data was provided
|
||||
if not any(v is not None for v in kwargs.values()):
|
||||
return ErrorResponse(
|
||||
message="Please provide at least one field to update.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build input model
|
||||
input_data = BusinessUnderstandingInput(
|
||||
user_name=kwargs.get("user_name"),
|
||||
job_title=kwargs.get("job_title"),
|
||||
business_name=kwargs.get("business_name"),
|
||||
industry=kwargs.get("industry"),
|
||||
business_size=kwargs.get("business_size"),
|
||||
user_role=kwargs.get("user_role"),
|
||||
key_workflows=kwargs.get("key_workflows"),
|
||||
daily_activities=kwargs.get("daily_activities"),
|
||||
pain_points=kwargs.get("pain_points"),
|
||||
bottlenecks=kwargs.get("bottlenecks"),
|
||||
manual_tasks=kwargs.get("manual_tasks"),
|
||||
automation_goals=kwargs.get("automation_goals"),
|
||||
current_software=kwargs.get("current_software"),
|
||||
existing_automation=kwargs.get("existing_automation"),
|
||||
additional_notes=kwargs.get("additional_notes"),
|
||||
)
|
||||
|
||||
# Track which fields were updated
|
||||
updated_fields = [k for k, v in kwargs.items() if v is not None]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
|
||||
# Build current understanding summary for the response
|
||||
current_understanding = {
|
||||
"user_name": understanding.user_name,
|
||||
"job_title": understanding.job_title,
|
||||
"business_name": understanding.business_name,
|
||||
"industry": understanding.industry,
|
||||
"business_size": understanding.business_size,
|
||||
"user_role": understanding.user_role,
|
||||
"key_workflows": understanding.key_workflows,
|
||||
"daily_activities": understanding.daily_activities,
|
||||
"pain_points": understanding.pain_points,
|
||||
"bottlenecks": understanding.bottlenecks,
|
||||
"manual_tasks": understanding.manual_tasks,
|
||||
"automation_goals": understanding.automation_goals,
|
||||
"current_software": understanding.current_software,
|
||||
"existing_automation": understanding.existing_automation,
|
||||
"additional_notes": understanding.additional_notes,
|
||||
}
|
||||
|
||||
# Filter out empty values for cleaner response
|
||||
current_understanding = {
|
||||
k: v
|
||||
for k, v in current_understanding.items()
|
||||
if v is not None and v != [] and v != ""
|
||||
}
|
||||
|
||||
return UnderstandingUpdatedResponse(
|
||||
message=f"Updated understanding with: {', '.join(updated_fields)}. "
|
||||
"I now have a better picture of your business context.",
|
||||
session_id=session_id,
|
||||
updated_fields=updated_fields,
|
||||
current_understanding=current_understanding,
|
||||
)
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Agent generator package - Creates agents from natural language."""
|
||||
|
||||
from .core import (
|
||||
apply_agent_patch,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .fixer import apply_all_fixes
|
||||
from .utils import get_blocks_info
|
||||
from .validator import validate_agent
|
||||
|
||||
__all__ = [
|
||||
# Core functions
|
||||
"decompose_goal",
|
||||
"generate_agent",
|
||||
"generate_agent_patch",
|
||||
"apply_agent_patch",
|
||||
"save_agent_to_library",
|
||||
"get_agent_as_json",
|
||||
# Fixer
|
||||
"apply_all_fixes",
|
||||
# Validator
|
||||
"validate_agent",
|
||||
# Utils
|
||||
"get_blocks_info",
|
||||
]
|
||||
@@ -1,25 +0,0 @@
|
||||
"""OpenRouter client configuration for agent generation."""
|
||||
|
||||
import os
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
|
||||
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY") or os.getenv("OPENROUTER_API_KEY")
|
||||
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
|
||||
|
||||
# OpenRouter client (OpenAI-compatible API)
|
||||
_client: AsyncOpenAI | None = None
|
||||
|
||||
|
||||
def get_client() -> AsyncOpenAI:
|
||||
"""Get or create the OpenRouter client."""
|
||||
global _client
|
||||
if _client is None:
|
||||
if not OPENROUTER_API_KEY:
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
||||
_client = AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=OPENROUTER_API_KEY,
|
||||
)
|
||||
return _client
|
||||
@@ -1,390 +0,0 @@
|
||||
"""Core agent generation functions."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
|
||||
from .client import AGENT_GENERATOR_MODEL, get_client
|
||||
from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT
|
||||
from .utils import get_block_summaries, parse_json_from_llm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
Or None on error
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries())
|
||||
|
||||
full_description = description
|
||||
if context:
|
||||
full_description = f"{description}\n\nAdditional context:\n{context}"
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": full_description},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for decomposition")
|
||||
return None
|
||||
|
||||
result = parse_json_from_llm(content)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"Failed to parse decomposition response: {content[:200]}")
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error decomposing goal: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or None on error
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries())
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": json.dumps(instructions, indent=2)},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for agent generation")
|
||||
return None
|
||||
|
||||
result = parse_json_from_llm(content)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"Failed to parse agent JSON: {content[:200]}")
|
||||
return None
|
||||
|
||||
# Ensure required fields
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
if "version" not in result:
|
||||
result["version"] = 1
|
||||
if "is_active" not in result:
|
||||
result["is_active"] = True
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating agent: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
"""Convert agent JSON dict to Graph model.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON with nodes and links
|
||||
|
||||
Returns:
|
||||
Graph ready for saving
|
||||
"""
|
||||
nodes = []
|
||||
for n in agent_json.get("nodes", []):
|
||||
node = Node(
|
||||
id=n.get("id", str(uuid.uuid4())),
|
||||
block_id=n["block_id"],
|
||||
input_default=n.get("input_default", {}),
|
||||
metadata=n.get("metadata", {}),
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
links = []
|
||||
for link_data in agent_json.get("links", []):
|
||||
link = Link(
|
||||
id=link_data.get("id", str(uuid.uuid4())),
|
||||
source_id=link_data["source_id"],
|
||||
sink_id=link_data["sink_id"],
|
||||
source_name=link_data["source_name"],
|
||||
sink_name=link_data["sink_name"],
|
||||
is_static=link_data.get("is_static", False),
|
||||
)
|
||||
links.append(link)
|
||||
|
||||
return Graph(
|
||||
id=agent_json.get("id", str(uuid.uuid4())),
|
||||
version=agent_json.get("version", 1),
|
||||
is_active=agent_json.get("is_active", True),
|
||||
name=agent_json.get("name", "Generated Agent"),
|
||||
description=agent_json.get("description", ""),
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
|
||||
def _reassign_node_ids(graph: Graph) -> None:
|
||||
"""Reassign all node and link IDs to new UUIDs.
|
||||
|
||||
This is needed when creating a new version to avoid unique constraint violations.
|
||||
"""
|
||||
# Create mapping from old node IDs to new UUIDs
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
|
||||
# Reassign node IDs
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
# Update link references to use new node IDs
|
||||
for link in graph.links:
|
||||
link.id = str(uuid.uuid4()) # Also give links new IDs
|
||||
if link.source_id in id_map:
|
||||
link.source_id = id_map[link.source_id]
|
||||
if link.sink_id in id_map:
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
) -> tuple[Graph, Any]:
|
||||
"""Save agent to database and user's library.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON dict
|
||||
user_id: User ID
|
||||
is_update: Whether this is an update to an existing agent
|
||||
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
from backend.data.graph import get_graph_all_versions
|
||||
|
||||
graph = json_to_graph(agent_json)
|
||||
|
||||
if is_update:
|
||||
# For updates, keep the same graph ID but increment version
|
||||
# and reassign node/link IDs to avoid conflicts
|
||||
if graph.id:
|
||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||
if existing_versions:
|
||||
latest_version = max(v.version for v in existing_versions)
|
||||
graph.version = latest_version + 1
|
||||
# Reassign node IDs (but keep graph ID the same)
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||
else:
|
||||
# For new agents, always generate a fresh UUID to avoid collisions
|
||||
graph.id = str(uuid.uuid4())
|
||||
graph.version = 1
|
||||
# Reassign all node IDs as well
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Creating new agent with ID {graph.id}")
|
||||
|
||||
# Save to database
|
||||
created_graph = await create_graph(graph, user_id)
|
||||
|
||||
# Add to user's library (or update existing library agent)
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=created_graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
)
|
||||
|
||||
return created_graph, library_agents[0]
|
||||
|
||||
|
||||
async def get_agent_as_json(
|
||||
graph_id: str, user_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch an agent and convert to JSON format for editing.
|
||||
|
||||
Args:
|
||||
graph_id: Graph ID or library agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
# Try to get the graph (version=None gets the active version)
|
||||
graph = await get_graph(graph_id, version=None, user_id=user_id)
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
# Convert to JSON format
|
||||
nodes = []
|
||||
for node in graph.nodes:
|
||||
nodes.append(
|
||||
{
|
||||
"id": node.id,
|
||||
"block_id": node.block_id,
|
||||
"input_default": node.input_default,
|
||||
"metadata": node.metadata,
|
||||
}
|
||||
)
|
||||
|
||||
links = []
|
||||
for node in graph.nodes:
|
||||
for link in node.output_links:
|
||||
links.append(
|
||||
{
|
||||
"id": link.id,
|
||||
"source_id": link.source_id,
|
||||
"sink_id": link.sink_id,
|
||||
"source_name": link.source_name,
|
||||
"sink_name": link.sink_name,
|
||||
"is_static": link.is_static,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"id": graph.id,
|
||||
"name": graph.name,
|
||||
"description": graph.description,
|
||||
"version": graph.version,
|
||||
"is_active": graph.is_active,
|
||||
"nodes": nodes,
|
||||
"links": links,
|
||||
}
|
||||
|
||||
|
||||
async def generate_agent_patch(
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate a patch to update an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
|
||||
Returns:
|
||||
Patch dict or clarifying questions, or None on error
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = PATCH_PROMPT.format(
|
||||
current_agent=json.dumps(current_agent, indent=2),
|
||||
block_summaries=get_block_summaries(),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": update_request},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for patch generation")
|
||||
return None
|
||||
|
||||
return parse_json_from_llm(content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating patch: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def apply_agent_patch(
|
||||
current_agent: dict[str, Any], patch: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Apply a patch to an existing agent.
|
||||
|
||||
Args:
|
||||
current_agent: Current agent JSON
|
||||
patch: Patch dict with operations
|
||||
|
||||
Returns:
|
||||
Updated agent JSON
|
||||
"""
|
||||
agent = copy.deepcopy(current_agent)
|
||||
patches = patch.get("patches", [])
|
||||
|
||||
for p in patches:
|
||||
patch_type = p.get("type")
|
||||
|
||||
if patch_type == "modify":
|
||||
node_id = p.get("node_id")
|
||||
changes = p.get("changes", {})
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
_deep_update(node, changes)
|
||||
logger.debug(f"Modified node {node_id}")
|
||||
break
|
||||
|
||||
elif patch_type == "add":
|
||||
new_nodes = p.get("new_nodes", [])
|
||||
new_links = p.get("new_links", [])
|
||||
|
||||
agent["nodes"] = agent.get("nodes", []) + new_nodes
|
||||
agent["links"] = agent.get("links", []) + new_links
|
||||
logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links")
|
||||
|
||||
elif patch_type == "remove":
|
||||
node_ids_to_remove = set(p.get("node_ids", []))
|
||||
link_ids_to_remove = set(p.get("link_ids", []))
|
||||
|
||||
# Remove nodes
|
||||
agent["nodes"] = [
|
||||
n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove
|
||||
]
|
||||
|
||||
# Remove links (both explicit and those referencing removed nodes)
|
||||
agent["links"] = [
|
||||
link
|
||||
for link in agent.get("links", [])
|
||||
if link["id"] not in link_ids_to_remove
|
||||
and link["source_id"] not in node_ids_to_remove
|
||||
and link["sink_id"] not in node_ids_to_remove
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def _deep_update(target: dict, source: dict) -> None:
|
||||
"""Recursively update a dict with another dict."""
|
||||
for key, value in source.items():
|
||||
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
|
||||
_deep_update(target[key], value)
|
||||
else:
|
||||
target[key] = value
|
||||
@@ -1,606 +0,0 @@
|
||||
"""Agent fixer - Fixes common LLM generation errors."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from .utils import (
|
||||
ADDTODICTIONARY_BLOCK_ID,
|
||||
ADDTOLIST_BLOCK_ID,
|
||||
CODE_EXECUTION_BLOCK_ID,
|
||||
CONDITION_BLOCK_ID,
|
||||
CREATEDICT_BLOCK_ID,
|
||||
CREATELIST_BLOCK_ID,
|
||||
DATA_SAMPLING_BLOCK_ID,
|
||||
DOUBLE_CURLY_BRACES_BLOCK_IDS,
|
||||
GET_CURRENT_DATE_BLOCK_ID,
|
||||
STORE_VALUE_BLOCK_ID,
|
||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
get_blocks_info,
|
||||
is_valid_uuid,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix invalid UUIDs in agent and link IDs."""
|
||||
# Fix agent ID
|
||||
if not is_valid_uuid(agent.get("id", "")):
|
||||
agent["id"] = str(uuid.uuid4())
|
||||
logger.debug(f"Fixed agent ID: {agent['id']}")
|
||||
|
||||
# Fix node IDs
|
||||
id_mapping = {} # Old ID -> New ID
|
||||
for node in agent.get("nodes", []):
|
||||
if not is_valid_uuid(node.get("id", "")):
|
||||
old_id = node.get("id", "")
|
||||
new_id = str(uuid.uuid4())
|
||||
id_mapping[old_id] = new_id
|
||||
node["id"] = new_id
|
||||
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
|
||||
|
||||
# Fix link IDs and update references
|
||||
for link in agent.get("links", []):
|
||||
if not is_valid_uuid(link.get("id", "")):
|
||||
link["id"] = str(uuid.uuid4())
|
||||
logger.debug(f"Fixed link ID: {link['id']}")
|
||||
|
||||
# Update source/sink IDs if they were remapped
|
||||
if link.get("source_id") in id_mapping:
|
||||
link["source_id"] = id_mapping[link["source_id"]]
|
||||
if link.get("sink_id") in id_mapping:
|
||||
link["sink_id"] = id_mapping[link["sink_id"]]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix single curly braces to double in template blocks."""
|
||||
for node in agent.get("nodes", []):
|
||||
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
input_data = node.get("input_default", {})
|
||||
for key in ("prompt", "format"):
|
||||
if key in input_data and isinstance(input_data[key], str):
|
||||
original = input_data[key]
|
||||
# Fix simple variable references: {var} -> {{var}}
|
||||
fixed = re.sub(
|
||||
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
|
||||
r"{{\1}}",
|
||||
original,
|
||||
)
|
||||
if fixed != original:
|
||||
input_data[key] = fixed
|
||||
logger.debug(f"Fixed curly braces in {key}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
# Find all ConditionBlock nodes
|
||||
condition_node_ids = {
|
||||
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
|
||||
}
|
||||
|
||||
if not condition_node_ids:
|
||||
return agent
|
||||
|
||||
new_nodes = []
|
||||
new_links = []
|
||||
processed_conditions = set()
|
||||
|
||||
for link in links:
|
||||
sink_id = link.get("sink_id")
|
||||
sink_name = link.get("sink_name")
|
||||
|
||||
# Check if this link goes to a ConditionBlock's value2
|
||||
if sink_id in condition_node_ids and sink_name == "value2":
|
||||
source_node = next(
|
||||
(n for n in nodes if n["id"] == link.get("source_id")), None
|
||||
)
|
||||
|
||||
# Skip if source is already a StoreValueBlock
|
||||
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
|
||||
continue
|
||||
|
||||
# Skip if we already processed this condition
|
||||
if sink_id in processed_conditions:
|
||||
continue
|
||||
|
||||
processed_conditions.add(sink_id)
|
||||
|
||||
# Create StoreValueBlock
|
||||
store_node_id = str(uuid.uuid4())
|
||||
store_node = {
|
||||
"id": store_node_id,
|
||||
"block_id": STORE_VALUE_BLOCK_ID,
|
||||
"input_default": {"data": None},
|
||||
"metadata": {"position": {"x": 0, "y": -100}},
|
||||
}
|
||||
new_nodes.append(store_node)
|
||||
|
||||
# Create link: original source -> StoreValueBlock
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": link["source_id"],
|
||||
"source_name": link["source_name"],
|
||||
"sink_id": store_node_id,
|
||||
"sink_name": "input",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Update original link: StoreValueBlock -> ConditionBlock
|
||||
link["source_id"] = store_node_id
|
||||
link["source_name"] = "output"
|
||||
|
||||
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
|
||||
|
||||
if new_nodes:
|
||||
agent["nodes"] = nodes + new_nodes
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
|
||||
|
||||
When an AddToList block is found:
|
||||
1. Checks if there's a CreateListBlock before it
|
||||
2. Removes CreateListBlock if linked directly to AddToList
|
||||
3. Adds an empty AddToList block before the original
|
||||
4. Ensures the original has a self-referencing link
|
||||
"""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
new_nodes = []
|
||||
original_addtolist_ids = set()
|
||||
nodes_to_remove = set()
|
||||
links_to_remove = []
|
||||
|
||||
# First pass: identify CreateListBlock nodes to remove
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and sink_node
|
||||
and source_node.get("block_id") == CREATELIST_BLOCK_ID
|
||||
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
|
||||
):
|
||||
nodes_to_remove.add(source_node.get("id"))
|
||||
links_to_remove.append(link)
|
||||
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
|
||||
|
||||
# Second pass: process AddToList blocks
|
||||
filtered_nodes = []
|
||||
for node in nodes:
|
||||
if node.get("id") in nodes_to_remove:
|
||||
continue
|
||||
|
||||
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
|
||||
original_addtolist_ids.add(node.get("id"))
|
||||
node_id = node.get("id")
|
||||
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
|
||||
|
||||
# Check if already has prerequisite
|
||||
has_prereq = any(
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "list"
|
||||
and link.get("source_name") == "updated_list"
|
||||
for link in links
|
||||
)
|
||||
|
||||
if not has_prereq:
|
||||
# Remove links to "list" input (except self-reference)
|
||||
for link in links:
|
||||
if (
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "list"
|
||||
and link.get("source_id") != node_id
|
||||
and link not in links_to_remove
|
||||
):
|
||||
links_to_remove.append(link)
|
||||
|
||||
# Create prerequisite AddToList block
|
||||
prereq_id = str(uuid.uuid4())
|
||||
prereq_node = {
|
||||
"id": prereq_id,
|
||||
"block_id": ADDTOLIST_BLOCK_ID,
|
||||
"input_default": {"list": [], "entry": None, "entries": []},
|
||||
"metadata": {
|
||||
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
|
||||
},
|
||||
}
|
||||
new_nodes.append(prereq_node)
|
||||
|
||||
# Link prerequisite to original
|
||||
links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": prereq_id,
|
||||
"source_name": "updated_list",
|
||||
"sink_id": node_id,
|
||||
"sink_name": "list",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
logger.debug(f"Added prerequisite AddToList block for {node_id}")
|
||||
|
||||
filtered_nodes.append(node)
|
||||
|
||||
# Remove marked links
|
||||
filtered_links = [link for link in links if link not in links_to_remove]
|
||||
|
||||
# Add self-referencing links for original AddToList blocks
|
||||
for node in filtered_nodes + new_nodes:
|
||||
if (
|
||||
node.get("block_id") == ADDTOLIST_BLOCK_ID
|
||||
and node.get("id") in original_addtolist_ids
|
||||
):
|
||||
node_id = node.get("id")
|
||||
has_self_ref = any(
|
||||
link["source_id"] == node_id
|
||||
and link["sink_id"] == node_id
|
||||
and link["source_name"] == "updated_list"
|
||||
and link["sink_name"] == "list"
|
||||
for link in filtered_links
|
||||
)
|
||||
if not has_self_ref:
|
||||
filtered_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": node_id,
|
||||
"source_name": "updated_list",
|
||||
"sink_id": node_id,
|
||||
"sink_name": "list",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
logger.debug(f"Added self-reference for AddToList {node_id}")
|
||||
|
||||
agent["nodes"] = filtered_nodes + new_nodes
|
||||
agent["links"] = filtered_links
|
||||
return agent
|
||||
|
||||
|
||||
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
nodes_to_remove = set()
|
||||
links_to_remove = []
|
||||
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and sink_node
|
||||
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
|
||||
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
|
||||
):
|
||||
nodes_to_remove.add(source_node.get("id"))
|
||||
links_to_remove.append(link)
|
||||
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
|
||||
|
||||
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
|
||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
||||
return agent
|
||||
|
||||
|
||||
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
if (
|
||||
source_node
|
||||
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
|
||||
and link.get("source_name") == "response"
|
||||
):
|
||||
link["source_name"] = "stdout_logs"
|
||||
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
links_to_remove = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
|
||||
node_id = node.get("id")
|
||||
input_default = node.get("input_default", {})
|
||||
|
||||
# Remove links to sample_size
|
||||
for link in links:
|
||||
if (
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "sample_size"
|
||||
):
|
||||
links_to_remove.append(link)
|
||||
|
||||
# Set default
|
||||
input_default["sample_size"] = 1
|
||||
node["input_default"] = input_default
|
||||
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
|
||||
|
||||
if links_to_remove:
|
||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
node_lookup = {n.get("id"): n for n in nodes}
|
||||
|
||||
for link in links:
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
source_node = node_lookup.get(source_id)
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_pos = source_node.get("metadata", {}).get("position", {})
|
||||
sink_pos = sink_node.get("metadata", {}).get("position", {})
|
||||
|
||||
source_x = source_pos.get("x", 0)
|
||||
sink_x = sink_pos.get("x", 0)
|
||||
|
||||
if abs(sink_x - source_x) < 800:
|
||||
new_x = source_x + 800
|
||||
if "metadata" not in sink_node:
|
||||
sink_node["metadata"] = {}
|
||||
if "position" not in sink_node["metadata"]:
|
||||
sink_node["metadata"]["position"] = {}
|
||||
sink_node["metadata"]["position"]["x"] = new_x
|
||||
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
|
||||
for node in agent.get("nodes", []):
|
||||
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
|
||||
input_default = node.get("input_default", {})
|
||||
if "offset" in input_default:
|
||||
offset = input_default["offset"]
|
||||
if isinstance(offset, (int, float)) and offset < 0:
|
||||
input_default["offset"] = abs(offset)
|
||||
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_ai_model_parameter(
|
||||
agent: dict[str, Any],
|
||||
blocks_info: list[dict[str, Any]],
|
||||
default_model: str = "gpt-4o",
|
||||
) -> dict[str, Any]:
|
||||
"""Add default model parameter to AI blocks if missing."""
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_map.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
# Check if block has AI category
|
||||
categories = block.get("categories", [])
|
||||
is_ai_block = any(
|
||||
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
|
||||
)
|
||||
|
||||
if is_ai_block:
|
||||
input_default = node.get("input_default", {})
|
||||
if "model" not in input_default:
|
||||
input_default["model"] = default_model
|
||||
node["input_default"] = input_default
|
||||
logger.debug(
|
||||
f"Added model '{default_model}' to AI block {node.get('id')}"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_link_static_properties(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Fix is_static property based on source block's staticOutput."""
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
if not source_block:
|
||||
continue
|
||||
|
||||
static_output = source_block.get("staticOutput", False)
|
||||
if link.get("is_static") != static_output:
|
||||
link["is_static"] = static_output
|
||||
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_data_type_mismatch(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in nodes}
|
||||
|
||||
def get_property_type(schema: dict, name: str) -> str | None:
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema:
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
def are_types_compatible(src: str, sink: str) -> bool:
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
|
||||
type_mapping = {
|
||||
"string": "string",
|
||||
"text": "string",
|
||||
"integer": "number",
|
||||
"number": "number",
|
||||
"float": "number",
|
||||
"boolean": "boolean",
|
||||
"bool": "boolean",
|
||||
"array": "list",
|
||||
"list": "list",
|
||||
"object": "dictionary",
|
||||
"dict": "dictionary",
|
||||
"dictionary": "dictionary",
|
||||
}
|
||||
|
||||
new_links = []
|
||||
nodes_to_add = []
|
||||
|
||||
for link in links:
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
|
||||
if not source_node or not sink_node:
|
||||
new_links.append(link)
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
sink_block = block_map.get(sink_node.get("block_id"))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
new_links.append(link)
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_property_type(source_outputs, link.get("source_name", ""))
|
||||
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
|
||||
|
||||
if (
|
||||
source_type
|
||||
and sink_type
|
||||
and not are_types_compatible(source_type, sink_type)
|
||||
):
|
||||
# Insert type converter
|
||||
converter_id = str(uuid.uuid4())
|
||||
target_type = type_mapping.get(sink_type, sink_type)
|
||||
|
||||
converter_node = {
|
||||
"id": converter_id,
|
||||
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
"input_default": {"type": target_type},
|
||||
"metadata": {"position": {"x": 0, "y": 100}},
|
||||
}
|
||||
nodes_to_add.append(converter_node)
|
||||
|
||||
# source -> converter
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": link["source_id"],
|
||||
"source_name": link["source_name"],
|
||||
"sink_id": converter_id,
|
||||
"sink_name": "value",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
# converter -> sink
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": converter_id,
|
||||
"source_name": "value",
|
||||
"sink_id": link["sink_id"],
|
||||
"sink_name": link["sink_name"],
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
|
||||
else:
|
||||
new_links.append(link)
|
||||
|
||||
if nodes_to_add:
|
||||
agent["nodes"] = nodes + nodes_to_add
|
||||
agent["links"] = new_links
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def apply_all_fixes(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Apply all fixes to an agent JSON.
|
||||
|
||||
Args:
|
||||
agent: Agent JSON dict
|
||||
blocks_info: Optional list of block info dicts for advanced fixes
|
||||
|
||||
Returns:
|
||||
Fixed agent JSON
|
||||
"""
|
||||
# Basic fixes (no block info needed)
|
||||
agent = fix_agent_ids(agent)
|
||||
agent = fix_double_curly_braces(agent)
|
||||
agent = fix_storevalue_before_condition(agent)
|
||||
agent = fix_addtolist_blocks(agent)
|
||||
agent = fix_addtodictionary_blocks(agent)
|
||||
agent = fix_code_execution_output(agent)
|
||||
agent = fix_data_sampling_sample_size(agent)
|
||||
agent = fix_node_x_coordinates(agent)
|
||||
agent = fix_getcurrentdate_offset(agent)
|
||||
|
||||
# Advanced fixes (require block info)
|
||||
if blocks_info is None:
|
||||
blocks_info = get_blocks_info()
|
||||
|
||||
agent = fix_ai_model_parameter(agent, blocks_info)
|
||||
agent = fix_link_static_properties(agent, blocks_info)
|
||||
agent = fix_data_type_mismatch(agent, blocks_info)
|
||||
|
||||
return agent
|
||||
@@ -1,225 +0,0 @@
|
||||
"""Prompt templates for agent generation."""
|
||||
|
||||
DECOMPOSITION_PROMPT = """
|
||||
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
|
||||
|
||||
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
|
||||
|
||||
---
|
||||
|
||||
FIRST: Analyze the user's goal and determine:
|
||||
1) Design-time configuration (fixed settings that won't change per run)
|
||||
2) Runtime inputs (values the agent's end-user will provide each time it runs)
|
||||
|
||||
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
|
||||
- DO NOT ask for the actual value
|
||||
- Instead, define it as an Agent Input with a clear name, type, and description
|
||||
|
||||
Only ask clarifying questions about design-time config that affects how you build the workflow:
|
||||
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
|
||||
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
|
||||
- Business rules that must be hard-coded
|
||||
|
||||
IMPORTANT CLARIFICATIONS POLICY:
|
||||
- Ask no more than five essential questions
|
||||
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
|
||||
- Do not ask for API keys or credentials; the platform handles those directly
|
||||
- If there is enough information to infer reasonable defaults, prefer to propose defaults
|
||||
|
||||
---
|
||||
|
||||
GUIDELINES:
|
||||
1. List each step as a numbered item
|
||||
2. Describe the action clearly and specify inputs/outputs
|
||||
3. Ensure steps are in logical, sequential order
|
||||
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
|
||||
5. Help the user reach their goal efficiently
|
||||
|
||||
---
|
||||
|
||||
RULES:
|
||||
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
|
||||
2. USE ONLY THE BLOCKS PROVIDED
|
||||
3. ALL required_input fields must be provided
|
||||
4. Data types of linked properties must match
|
||||
5. Write expert-level prompts for AI-related blocks
|
||||
|
||||
---
|
||||
|
||||
CRITICAL BLOCK RESTRICTIONS:
|
||||
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
|
||||
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
|
||||
3. ConditionBlock: value2 is reference, value1 is contrast
|
||||
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
|
||||
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
|
||||
|
||||
---
|
||||
|
||||
OUTPUT FORMAT:
|
||||
|
||||
If more information is needed:
|
||||
```json
|
||||
{{
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{{
|
||||
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
|
||||
"keyword": "email_provider",
|
||||
"example": "Gmail"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
If ready to proceed:
|
||||
```json
|
||||
{{
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{{
|
||||
"step_number": 1,
|
||||
"block_name": "AgentShortTextInputBlock",
|
||||
"description": "Get the URL of the content to analyze.",
|
||||
"inputs": [{{"name": "name", "value": "URL"}}],
|
||||
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
AVAILABLE BLOCKS:
|
||||
{block_summaries}
|
||||
"""
|
||||
|
||||
GENERATION_PROMPT = """
|
||||
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
|
||||
|
||||
---
|
||||
|
||||
NODES:
|
||||
Each node must include:
|
||||
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
|
||||
- `block_id`: The block identifier (must match an Allowed Block)
|
||||
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
|
||||
- `metadata`: Must contain:
|
||||
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
|
||||
- `customized_name`: Clear name describing this block's purpose in the workflow
|
||||
|
||||
---
|
||||
|
||||
LINKS:
|
||||
Each link connects a source node's output to a sink node's input:
|
||||
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
|
||||
- `source_id`: ID of the source node
|
||||
- `source_name`: Output field name from the source block
|
||||
- `sink_id`: ID of the sink node
|
||||
- `sink_name`: Input field name on the sink block
|
||||
- `is_static`: true only if source block has static_output: true
|
||||
|
||||
CRITICAL: All IDs must be valid UUID v4 format!
|
||||
|
||||
---
|
||||
|
||||
AGENT (GRAPH):
|
||||
Wrap nodes and links in:
|
||||
- `id`: UUID of the agent
|
||||
- `name`: Short, generic name (avoid specific company names, URLs)
|
||||
- `description`: Short, generic description
|
||||
- `nodes`: List of all nodes
|
||||
- `links`: List of all links
|
||||
- `version`: 1
|
||||
- `is_active`: true
|
||||
|
||||
---
|
||||
|
||||
TIPS:
|
||||
- All required_input fields must be provided via input_default or a valid link
|
||||
- Ensure consistent source_id and sink_id references
|
||||
- Avoid dangling links
|
||||
- Input/output pins must match block schemas
|
||||
- Do not invent unknown block_ids
|
||||
|
||||
---
|
||||
|
||||
ALLOWED BLOCKS:
|
||||
{block_summaries}
|
||||
|
||||
---
|
||||
|
||||
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
|
||||
"""
|
||||
|
||||
PATCH_PROMPT = """
|
||||
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
|
||||
|
||||
CURRENT AGENT:
|
||||
{current_agent}
|
||||
|
||||
AVAILABLE BLOCKS:
|
||||
{block_summaries}
|
||||
|
||||
---
|
||||
|
||||
PATCH FORMAT:
|
||||
Return a JSON object with the following structure:
|
||||
|
||||
```json
|
||||
{{
|
||||
"type": "patch",
|
||||
"intent": "Brief description of what the patch does",
|
||||
"patches": [
|
||||
{{
|
||||
"type": "modify",
|
||||
"node_id": "uuid-of-node-to-modify",
|
||||
"changes": {{
|
||||
"input_default": {{"field": "new_value"}},
|
||||
"metadata": {{"customized_name": "New Name"}}
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"type": "add",
|
||||
"new_nodes": [
|
||||
{{
|
||||
"id": "new-uuid",
|
||||
"block_id": "block-uuid",
|
||||
"input_default": {{}},
|
||||
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
|
||||
}}
|
||||
],
|
||||
"new_links": [
|
||||
{{
|
||||
"id": "link-uuid",
|
||||
"source_id": "source-node-id",
|
||||
"source_name": "output_field",
|
||||
"sink_id": "sink-node-id",
|
||||
"sink_name": "input_field"
|
||||
}}
|
||||
]
|
||||
}},
|
||||
{{
|
||||
"type": "remove",
|
||||
"node_ids": ["uuid-of-node-to-remove"],
|
||||
"link_ids": ["uuid-of-link-to-remove"]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
If you need more information, return:
|
||||
```json
|
||||
{{
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{{
|
||||
"question": "What specific change do you want?",
|
||||
"keyword": "change_type",
|
||||
"example": "Add error handling"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
Generate the minimal patch needed. Output ONLY valid JSON.
|
||||
"""
|
||||
@@ -1,213 +0,0 @@
|
||||
"""Utilities for agent generation."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# UUID validation regex
|
||||
UUID_REGEX = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
|
||||
)
|
||||
|
||||
# Block IDs for various fixes
|
||||
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
|
||||
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
|
||||
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
|
||||
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
|
||||
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
|
||||
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
|
||||
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
|
||||
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
|
||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
|
||||
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
||||
|
||||
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
|
||||
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
|
||||
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
|
||||
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
|
||||
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
|
||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
|
||||
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
|
||||
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
|
||||
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
|
||||
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
|
||||
"716a67b3-6760-42e7-86dc-18645c6e00fc",
|
||||
"530cf046-2ce0-4854-ae2c-659db17c7a46",
|
||||
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
|
||||
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
|
||||
]
|
||||
|
||||
|
||||
def is_valid_uuid(value: str) -> bool:
|
||||
"""Check if a string is a valid UUID v4."""
|
||||
return isinstance(value, str) and UUID_REGEX.match(value) is not None
|
||||
|
||||
|
||||
def _compact_schema(schema: dict) -> dict[str, str]:
|
||||
"""Extract compact type info from a JSON schema properties dict.
|
||||
|
||||
Returns a dict of {field_name: type_string} for essential info only.
|
||||
"""
|
||||
props = schema.get("properties", {})
|
||||
result = {}
|
||||
|
||||
for name, prop in props.items():
|
||||
# Skip internal/complex fields
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Get type string
|
||||
type_str = prop.get("type", "any")
|
||||
|
||||
# Handle anyOf/oneOf (optional types)
|
||||
if "anyOf" in prop:
|
||||
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
|
||||
type_str = "|".join(types) if types else "any"
|
||||
elif "allOf" in prop:
|
||||
type_str = "object"
|
||||
|
||||
# Add array item type if present
|
||||
if type_str == "array" and "items" in prop:
|
||||
items = prop["items"]
|
||||
if isinstance(items, dict):
|
||||
item_type = items.get("type", "any")
|
||||
type_str = f"array[{item_type}]"
|
||||
|
||||
result[name] = type_str
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_block_summaries(include_schemas: bool = True) -> str:
|
||||
"""Generate compact block summaries for prompts.
|
||||
|
||||
Args:
|
||||
include_schemas: Whether to include input/output type info
|
||||
|
||||
Returns:
|
||||
Formatted string of block summaries (compact format)
|
||||
"""
|
||||
blocks = get_blocks()
|
||||
summaries = []
|
||||
|
||||
for block_id, block_cls in blocks.items():
|
||||
block = block_cls()
|
||||
name = block.name
|
||||
desc = getattr(block, "description", "") or ""
|
||||
|
||||
# Truncate description
|
||||
if len(desc) > 150:
|
||||
desc = desc[:147] + "..."
|
||||
|
||||
if not include_schemas:
|
||||
summaries.append(f"- {name} (id: {block_id}): {desc}")
|
||||
else:
|
||||
# Compact format with type info only
|
||||
inputs = {}
|
||||
outputs = {}
|
||||
required = []
|
||||
|
||||
if hasattr(block, "input_schema"):
|
||||
try:
|
||||
schema = block.input_schema.jsonschema()
|
||||
inputs = _compact_schema(schema)
|
||||
required = schema.get("required", [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(block, "output_schema"):
|
||||
try:
|
||||
schema = block.output_schema.jsonschema()
|
||||
outputs = _compact_schema(schema)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build compact line format
|
||||
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
|
||||
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
|
||||
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
|
||||
req_str = f" req=[{','.join(required)}]" if required else ""
|
||||
|
||||
static = " [static]" if getattr(block, "static_output", False) else ""
|
||||
|
||||
line = f"- {name} (id: {block_id}): {desc}"
|
||||
if in_str:
|
||||
line += f"\n in: {{{in_str}}}{req_str}"
|
||||
if out_str:
|
||||
line += f"\n out: {{{out_str}}}{static}"
|
||||
|
||||
summaries.append(line)
|
||||
|
||||
return "\n".join(summaries)
|
||||
|
||||
|
||||
def get_blocks_info() -> list[dict[str, Any]]:
|
||||
"""Get block information with schemas for validation and fixing."""
|
||||
blocks = get_blocks()
|
||||
blocks_info = []
|
||||
for block_id, block_cls in blocks.items():
|
||||
block = block_cls()
|
||||
blocks_info.append(
|
||||
{
|
||||
"id": block_id,
|
||||
"name": block.name,
|
||||
"description": getattr(block, "description", ""),
|
||||
"categories": getattr(block, "categories", []),
|
||||
"staticOutput": getattr(block, "static_output", False),
|
||||
"inputSchema": (
|
||||
block.input_schema.jsonschema()
|
||||
if hasattr(block, "input_schema")
|
||||
else {}
|
||||
),
|
||||
"outputSchema": (
|
||||
block.output_schema.jsonschema()
|
||||
if hasattr(block, "output_schema")
|
||||
else {}
|
||||
),
|
||||
}
|
||||
)
|
||||
return blocks_info
|
||||
|
||||
|
||||
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
|
||||
"""Extract JSON from LLM response (handles markdown code blocks)."""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# Try fenced code block
|
||||
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1).strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try raw text
|
||||
try:
|
||||
return json.loads(text.strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding {...} span
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start != -1 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding [...] span
|
||||
start = text.find("[")
|
||||
end = text.rfind("]")
|
||||
if start != -1 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
@@ -1,279 +0,0 @@
|
||||
"""Agent validator - Validates agent structure and connections."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from .utils import get_blocks_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentValidator:
|
||||
"""Validator for AutoGPT agents with detailed error reporting."""
|
||||
|
||||
def __init__(self):
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_error(self, error: str) -> None:
|
||||
"""Add an error message."""
|
||||
self.errors.append(error)
|
||||
|
||||
def validate_block_existence(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate all block IDs exist in the blocks library."""
|
||||
valid = True
|
||||
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
node_id = node.get("id")
|
||||
|
||||
if not block_id:
|
||||
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if block_id not in valid_block_ids:
|
||||
self.add_error(
|
||||
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
|
||||
"""Validate all node IDs referenced in links exist."""
|
||||
valid = True
|
||||
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
link_id = link.get("id", "Unknown")
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
if not source_id:
|
||||
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
|
||||
valid = False
|
||||
elif source_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references non-existent source_id '{source_id}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
if not sink_id:
|
||||
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
|
||||
valid = False
|
||||
elif sink_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_required_inputs(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate required inputs are provided."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_map.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
required_inputs = block.get("inputSchema", {}).get("required", [])
|
||||
input_defaults = node.get("input_default", {})
|
||||
node_id = node.get("id")
|
||||
|
||||
# Get linked inputs
|
||||
linked_inputs = {
|
||||
link["sink_name"]
|
||||
for link in agent.get("links", [])
|
||||
if link.get("sink_id") == node_id
|
||||
}
|
||||
|
||||
for req_input in required_inputs:
|
||||
if (
|
||||
req_input not in input_defaults
|
||||
and req_input not in linked_inputs
|
||||
and req_input != "credentials"
|
||||
):
|
||||
block_name = block.get("name", "Unknown Block")
|
||||
self.add_error(
|
||||
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_data_type_compatibility(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate linked data types are compatible."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
def get_type(schema: dict, name: str) -> str | None:
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema:
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
def are_compatible(src: str, sink: str) -> bool:
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
sink_block = block_map.get(sink_node.get("block_id"))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_type(source_outputs, link.get("source_name", ""))
|
||||
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
|
||||
|
||||
if source_type and sink_type and not are_compatible(source_type, sink_type):
|
||||
self.add_error(
|
||||
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
|
||||
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_nested_sink_links(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate nested sink links (with _#_ notation)."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_name = link.get("sink_name", "")
|
||||
|
||||
if "_#_" in sink_name:
|
||||
parent, child = sink_name.split("_#_", 1)
|
||||
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
if not sink_node:
|
||||
continue
|
||||
|
||||
block = block_map.get(sink_node.get("block_id"))
|
||||
if not block:
|
||||
continue
|
||||
|
||||
input_props = block.get("inputSchema", {}).get("properties", {})
|
||||
parent_schema = input_props.get(parent)
|
||||
|
||||
if not parent_schema:
|
||||
self.add_error(
|
||||
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if not parent_schema.get("additionalProperties"):
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and child in parent_schema.get("properties", {})
|
||||
):
|
||||
self.add_error(
|
||||
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
|
||||
"""Validate prompts don't have spaces in template variables."""
|
||||
valid = True
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
input_default = node.get("input_default", {})
|
||||
prompt = input_default.get("prompt", "")
|
||||
|
||||
if not isinstance(prompt, str):
|
||||
continue
|
||||
|
||||
# Find {{...}} with spaces
|
||||
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
|
||||
for match in matches:
|
||||
content = match.group(1)
|
||||
if " " in content:
|
||||
self.add_error(
|
||||
f"Node '{node.get('id')}' has spaces in template variable: "
|
||||
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Run all validations.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
self.errors = []
|
||||
|
||||
if blocks_info is None:
|
||||
blocks_info = get_blocks_info()
|
||||
|
||||
checks = [
|
||||
self.validate_block_existence(agent, blocks_info),
|
||||
self.validate_link_node_references(agent),
|
||||
self.validate_required_inputs(agent, blocks_info),
|
||||
self.validate_data_type_compatibility(agent, blocks_info),
|
||||
self.validate_nested_sink_links(agent, blocks_info),
|
||||
self.validate_prompt_spaces(agent),
|
||||
]
|
||||
|
||||
all_passed = all(checks)
|
||||
|
||||
if all_passed:
|
||||
logger.info("Agent validation successful")
|
||||
return True, None
|
||||
|
||||
error_message = "Agent validation failed:\n"
|
||||
for i, error in enumerate(self.errors, 1):
|
||||
error_message += f"{i}. {error}\n"
|
||||
|
||||
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
|
||||
return False, error_message
|
||||
|
||||
|
||||
def validate_agent(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Convenience function to validate an agent.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
validator = AgentValidator()
|
||||
return validator.validate(agent, blocks_info)
|
||||
@@ -1,455 +0,0 @@
|
||||
"""Tool for retrieving agent execution outputs from user's library."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentOutputResponse,
|
||||
ErrorResponse,
|
||||
ExecutionOutputInfo,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .utils import fetch_graph_from_store_slug
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentOutputInput(BaseModel):
|
||||
"""Input parameters for the agent_output tool."""
|
||||
|
||||
agent_name: str = ""
|
||||
library_agent_id: str = ""
|
||||
store_slug: str = ""
|
||||
execution_id: str = ""
|
||||
run_time: str = "latest"
|
||||
|
||||
@field_validator(
|
||||
"agent_name",
|
||||
"library_agent_id",
|
||||
"store_slug",
|
||||
"execution_id",
|
||||
"run_time",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> Any:
|
||||
"""Strip whitespace from string fields."""
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
|
||||
def parse_time_expression(
|
||||
time_expr: str | None,
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
Parse time expression into datetime range (start, end).
|
||||
|
||||
Supports:
|
||||
- "latest" or None -> returns (None, None) to get most recent
|
||||
- "yesterday" -> 24h window for yesterday
|
||||
- "today" -> Today from midnight
|
||||
- "last week" / "last 7 days" -> 7 day window
|
||||
- "last month" / "last 30 days" -> 30 day window
|
||||
- ISO date "YYYY-MM-DD" -> 24h window for that date
|
||||
"""
|
||||
if not time_expr or time_expr.lower() == "latest":
|
||||
return None, None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expr = time_expr.lower().strip()
|
||||
|
||||
# Relative expressions
|
||||
if expr == "yesterday":
|
||||
end = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
start = end - timedelta(days=1)
|
||||
return start, end
|
||||
|
||||
if expr in ("last week", "last 7 days"):
|
||||
return now - timedelta(days=7), now
|
||||
|
||||
if expr in ("last month", "last 30 days"):
|
||||
return now - timedelta(days=30), now
|
||||
|
||||
if expr == "today":
|
||||
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
return start, now
|
||||
|
||||
# Try ISO date format (YYYY-MM-DD)
|
||||
date_match = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", expr)
|
||||
if date_match:
|
||||
year, month, day = map(int, date_match.groups())
|
||||
start = datetime(year, month, day, 0, 0, 0, tzinfo=timezone.utc)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
|
||||
# Try ISO datetime
|
||||
try:
|
||||
parsed = datetime.fromisoformat(expr.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
# Return +/- 1 hour window around the specified time
|
||||
return parsed - timedelta(hours=1), parsed + timedelta(hours=1)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Fallback: treat as "latest"
|
||||
return None, None
|
||||
|
||||
|
||||
class AgentOutputTool(BaseTool):
|
||||
"""Tool for retrieving execution outputs from user's library agents."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "agent_output"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Retrieve execution outputs from agents in the user's library.
|
||||
|
||||
Identify the agent using one of:
|
||||
- agent_name: Fuzzy search in user's library
|
||||
- library_agent_id: Exact library agent ID
|
||||
- store_slug: Marketplace format 'username/agent-name'
|
||||
|
||||
Select which run to retrieve using:
|
||||
- execution_id: Specific execution ID
|
||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||
"""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "Agent name to search for in user's library (fuzzy match)",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Exact library agent ID",
|
||||
},
|
||||
"store_slug": {
|
||||
"type": "string",
|
||||
"description": "Marketplace identifier: 'username/agent-slug'",
|
||||
},
|
||||
"execution_id": {
|
||||
"type": "string",
|
||||
"description": "Specific execution ID to retrieve",
|
||||
},
|
||||
"run_time": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _resolve_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_name: str | None,
|
||||
library_agent_id: str | None,
|
||||
store_slug: str | None,
|
||||
) -> tuple[LibraryAgent | None, str | None]:
|
||||
"""
|
||||
Resolve agent from provided identifiers.
|
||||
Returns (library_agent, error_message).
|
||||
"""
|
||||
# Priority 1: Exact library agent ID
|
||||
if library_agent_id:
|
||||
try:
|
||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||
return agent, None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||
return None, f"Library agent '{library_agent_id}' not found"
|
||||
|
||||
# Priority 2: Store slug (username/agent-name)
|
||||
if store_slug and "/" in store_slug:
|
||||
username, agent_slug = store_slug.split("/", 1)
|
||||
graph, _ = await fetch_graph_from_store_slug(username, agent_slug)
|
||||
if not graph:
|
||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||
|
||||
# Find in user's library by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
if not agent:
|
||||
return (
|
||||
None,
|
||||
f"Agent '{store_slug}' is not in your library. "
|
||||
"Add it first to see outputs.",
|
||||
)
|
||||
return agent, None
|
||||
|
||||
# Priority 3: Fuzzy name search in library
|
||||
if agent_name:
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=agent_name,
|
||||
page_size=5,
|
||||
)
|
||||
if not response.agents:
|
||||
return (
|
||||
None,
|
||||
f"No agents matching '{agent_name}' found in your library",
|
||||
)
|
||||
|
||||
# Return best match (first result from search)
|
||||
return response.agents[0], None
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching library agents: {e}")
|
||||
return None, f"Error searching for agent: {e}"
|
||||
|
||||
return (
|
||||
None,
|
||||
"Please specify an agent name, library_agent_id, or store_slug",
|
||||
)
|
||||
|
||||
async def _get_execution(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
execution_id: str | None,
|
||||
time_start: datetime | None,
|
||||
time_end: datetime | None,
|
||||
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
||||
"""
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
"""
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return None, [], f"Execution '{execution_id}' not found"
|
||||
return execution, [], None
|
||||
|
||||
# Get completed executions with time filters
|
||||
executions = await execution_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
created_time_gte=time_start,
|
||||
created_time_lte=time_end,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
if not executions:
|
||||
return None, [], None # No error, just no executions
|
||||
|
||||
# If only one execution, fetch full details
|
||||
if len(executions) == 1:
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, [], None
|
||||
|
||||
# Multiple executions - return latest with full details, plus list of available
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, executions, None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
agent: LibraryAgent,
|
||||
execution: GraphExecution | None,
|
||||
available_executions: list[GraphExecutionMeta],
|
||||
session_id: str | None,
|
||||
) -> AgentOutputResponse:
|
||||
"""Build the response based on execution data."""
|
||||
library_agent_link = f"/library/agents/{agent.id}"
|
||||
|
||||
if not execution:
|
||||
return AgentOutputResponse(
|
||||
message=f"No completed executions found for agent '{agent.name}'",
|
||||
session_id=session_id,
|
||||
agent_name=agent.name,
|
||||
agent_id=agent.graph_id,
|
||||
library_agent_id=agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
total_executions=0,
|
||||
)
|
||||
|
||||
execution_info = ExecutionOutputInfo(
|
||||
execution_id=execution.id,
|
||||
status=execution.status.value,
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
outputs=dict(execution.outputs),
|
||||
inputs_summary=execution.inputs if execution.inputs else None,
|
||||
)
|
||||
|
||||
available_list = None
|
||||
if len(available_executions) > 1:
|
||||
available_list = [
|
||||
{
|
||||
"id": e.id,
|
||||
"status": e.status.value,
|
||||
"started_at": e.started_at.isoformat() if e.started_at else None,
|
||||
}
|
||||
for e in available_executions[:5]
|
||||
]
|
||||
|
||||
message = f"Found execution outputs for agent '{agent.name}'"
|
||||
if len(available_executions) > 1:
|
||||
message += (
|
||||
f". Showing latest of {len(available_executions)} matching executions."
|
||||
)
|
||||
|
||||
return AgentOutputResponse(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
agent_name=agent.name,
|
||||
agent_id=agent.graph_id,
|
||||
library_agent_id=agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
execution=execution_info,
|
||||
available_executions=available_list,
|
||||
total_executions=len(available_executions) if available_executions else 1,
|
||||
)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the agent_output tool."""
|
||||
session_id = session.session_id
|
||||
|
||||
# Parse and validate input
|
||||
try:
|
||||
input_data = AgentOutputInput(**kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid input: {e}")
|
||||
return ErrorResponse(
|
||||
message="Invalid input parameters",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure user_id is present (should be guaranteed by requires_auth)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if at least one identifier is provided
|
||||
if not any(
|
||||
[
|
||||
input_data.agent_name,
|
||||
input_data.library_agent_id,
|
||||
input_data.store_slug,
|
||||
input_data.execution_id,
|
||||
]
|
||||
):
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Please specify at least one of: agent_name, "
|
||||
"library_agent_id, store_slug, or execution_id"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# If only execution_id provided, we need to find the agent differently
|
||||
if (
|
||||
input_data.execution_id
|
||||
and not input_data.agent_name
|
||||
and not input_data.library_agent_id
|
||||
and not input_data.store_slug
|
||||
):
|
||||
# Fetch execution directly to get graph_id
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return ErrorResponse(
|
||||
message=f"Execution '{input_data.execution_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Find library agent by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id, execution.graph_id
|
||||
)
|
||||
if not agent:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Execution found but agent not in your library. "
|
||||
f"Graph ID: {execution.graph_id}"
|
||||
),
|
||||
session_id=session_id,
|
||||
suggestions=["Add the agent to your library to see more details"],
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, [], session_id)
|
||||
|
||||
# Resolve agent from identifiers
|
||||
agent, error = await self._resolve_agent(
|
||||
user_id=user_id,
|
||||
agent_name=input_data.agent_name or None,
|
||||
library_agent_id=input_data.library_agent_id or None,
|
||||
store_slug=input_data.store_slug or None,
|
||||
)
|
||||
|
||||
if error or not agent:
|
||||
return NoResultsResponse(
|
||||
message=error or "Agent not found",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Check the agent name or ID",
|
||||
"Make sure the agent is in your library",
|
||||
],
|
||||
)
|
||||
|
||||
# Parse time expression
|
||||
time_start, time_end = parse_time_expression(input_data.run_time)
|
||||
|
||||
# Fetch execution(s)
|
||||
execution, available_executions, exec_error = await self._get_execution(
|
||||
user_id=user_id,
|
||||
graph_id=agent.graph_id,
|
||||
execution_id=input_data.execution_id or None,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
)
|
||||
|
||||
if exec_error:
|
||||
return ErrorResponse(
|
||||
message=exec_error,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, available_executions, session_id)
|
||||
File diff suppressed because one or more lines are too long
@@ -1,279 +0,0 @@
|
||||
"""CreateAgentTool - Creates agents from natural language descriptions."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
apply_all_fixes,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
get_blocks_info,
|
||||
save_agent_to_library,
|
||||
validate_agent,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum retries for agent generation with validation feedback
|
||||
MAX_GENERATION_RETRIES = 2
|
||||
|
||||
|
||||
class CreateAgentTool(BaseTool):
|
||||
"""Tool for creating agents from natural language descriptions."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent workflow from a natural language description. "
|
||||
"First generates a preview, then saves to library if save=true."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of what the agent should do. "
|
||||
"Be specific about inputs, outputs, and the workflow steps."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions. "
|
||||
"Include any preferences or constraints mentioned by the user."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the agent to the user's library. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["description"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the create_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Decompose the description into steps (may return clarifying questions)
|
||||
2. Generate agent JSON from the steps
|
||||
3. Apply fixes to correct common LLM errors
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
error="Missing description parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 1: Decompose goal into steps
|
||||
try:
|
||||
decomposition_result = await decompose_goal(description, context)
|
||||
except ValueError as e:
|
||||
# Handle missing API key or configuration errors
|
||||
return ErrorResponse(
|
||||
message=f"Agent generation is not configured: {str(e)}",
|
||||
error="configuration_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to analyze the goal. Please try rephrasing.",
|
||||
error="Decomposition failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if decomposition_result.get("type") == "clarifying_questions":
|
||||
questions = decomposition_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to create this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check for unachievable/vague goals
|
||||
if decomposition_result.get("type") == "unachievable_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get("reason", "")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"This goal cannot be accomplished with the available blocks. "
|
||||
f"{reason} "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
error="unachievable_goal",
|
||||
details={"suggested_goal": suggested, "reason": reason},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "vague_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The goal is too vague to create a specific workflow. "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
error="vague_goal",
|
||||
details={"suggested_goal": suggested},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 2: Generate agent JSON with retry on validation failure
|
||||
blocks_info = get_blocks_info()
|
||||
agent_json = None
|
||||
validation_errors = None
|
||||
|
||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
||||
# Generate agent (include validation errors from previous attempt)
|
||||
if attempt == 0:
|
||||
agent_json = await generate_agent(decomposition_result)
|
||||
else:
|
||||
# Retry with validation error feedback
|
||||
logger.info(
|
||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
||||
)
|
||||
retry_instructions = {
|
||||
**decomposition_result,
|
||||
"previous_errors": validation_errors,
|
||||
"retry_instructions": (
|
||||
"The previous generation had validation errors. "
|
||||
"Please fix these issues in the new generation:\n"
|
||||
f"{validation_errors}"
|
||||
),
|
||||
}
|
||||
agent_json = await generate_agent(retry_instructions)
|
||||
|
||||
if agent_json is None:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. Please try again.",
|
||||
error="Generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Step 3: Apply fixes to correct common errors
|
||||
agent_json = apply_all_fixes(agent_json, blocks_info)
|
||||
|
||||
# Step 4: Validate the agent
|
||||
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
||||
)
|
||||
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
# Return error with validation details
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
|
||||
f"Please try rephrasing your request or simplify the workflow."
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"validation_errors": validation_errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
# Step 4: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||
f"Review it and call create_agent with save=true to save it to your library."
|
||||
),
|
||||
agent_json=agent_json,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
File diff suppressed because one or more lines are too long
@@ -1,294 +0,0 @@
|
||||
"""EditAgentTool - Edits existing agents using natural language."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
apply_agent_patch,
|
||||
apply_all_fixes,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_blocks_info,
|
||||
save_agent_to_library,
|
||||
validate_agent,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum retries for patch generation with validation feedback
|
||||
MAX_GENERATION_RETRIES = 2
|
||||
|
||||
|
||||
class EditAgentTool(BaseTool):
|
||||
"""Tool for editing existing agents using natural language."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent from the user's library using natural language. "
|
||||
"Generates a patch to update the agent while preserving unchanged parts."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The ID of the agent to edit. "
|
||||
"Can be a graph ID or library agent ID."
|
||||
),
|
||||
},
|
||||
"changes": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of what changes to make. "
|
||||
"Be specific about what to add, remove, or modify."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the changes. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "changes"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the edit_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Fetch the current agent
|
||||
2. Generate a patch based on the requested changes
|
||||
3. Apply the patch to create an updated agent
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
error="Missing agent_id parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not changes:
|
||||
return ErrorResponse(
|
||||
message="Please describe what changes you want to make.",
|
||||
error="Missing changes parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 1: Fetch current agent
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
|
||||
if current_agent is None:
|
||||
return ErrorResponse(
|
||||
message=f"Could not find agent with ID '{agent_id}' in your library.",
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build the update request with context
|
||||
update_request = changes
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
# Step 2: Generate patch with retry on validation failure
|
||||
blocks_info = get_blocks_info()
|
||||
updated_agent = None
|
||||
validation_errors = None
|
||||
intent = "Applied requested changes"
|
||||
|
||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
||||
# Generate patch (include validation errors from previous attempt)
|
||||
try:
|
||||
if attempt == 0:
|
||||
patch_result = await generate_agent_patch(
|
||||
update_request, current_agent
|
||||
)
|
||||
else:
|
||||
# Retry with validation error feedback
|
||||
logger.info(
|
||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
||||
)
|
||||
retry_request = (
|
||||
f"{update_request}\n\n"
|
||||
f"IMPORTANT: The previous edit had validation errors. "
|
||||
f"Please fix these issues:\n{validation_errors}"
|
||||
)
|
||||
patch_result = await generate_agent_patch(
|
||||
retry_request, current_agent
|
||||
)
|
||||
except ValueError as e:
|
||||
# Handle missing API key or configuration errors
|
||||
return ErrorResponse(
|
||||
message=f"Agent generation is not configured: {str(e)}",
|
||||
error="configuration_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if patch_result is None:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. Please try rephrasing.",
|
||||
error="Patch generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if patch_result.get("type") == "clarifying_questions":
|
||||
questions = patch_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information about the changes. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 3: Apply patch and fixes
|
||||
try:
|
||||
updated_agent = apply_agent_patch(current_agent, patch_result)
|
||||
updated_agent = apply_all_fixes(updated_agent, blocks_info)
|
||||
except Exception as e:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to apply changes: {str(e)}",
|
||||
error="patch_apply_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
validation_errors = str(e)
|
||||
continue
|
||||
|
||||
# Step 4: Validate the updated agent
|
||||
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
|
||||
intent = patch_result.get("intent", "Applied requested changes")
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
||||
)
|
||||
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
# Return error with validation details
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Updated agent has validation errors after "
|
||||
f"{MAX_GENERATION_RETRIES + 1} attempts. "
|
||||
f"Please try rephrasing your request or simplify the changes."
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"validation_errors": validation_errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
|
||||
assert updated_agent is not None
|
||||
|
||||
agent_name = updated_agent.get("name", "Updated Agent")
|
||||
agent_description = updated_agent.get("description", "")
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
# Step 5: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've updated the agent. Changes: {intent}. "
|
||||
f"The agent now has {node_count} blocks. "
|
||||
f"Review it and call edit_agent with save=true to save the changes."
|
||||
),
|
||||
agent_json=updated_agent,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library (creates a new version)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
updated_agent, user_id, is_update=True
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
f"Updated agent '{created_graph.name}' has been saved to your library! "
|
||||
f"Changes: {intent}"
|
||||
),
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the updated agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,253 +0,0 @@
|
||||
"""Tool for searching available blocks using hybrid search."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BlockInfoSummary,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .search_blocks import get_block_search_index
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindBlockTool(BaseTool):
|
||||
"""Tool for searching available blocks."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for available blocks by name or description. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"Use this to find blocks that can be executed directly."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find blocks by name or description. "
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _matches_query(self, block, query: str) -> tuple[int, bool]:
|
||||
"""
|
||||
Check if a block matches the query and return a priority score.
|
||||
|
||||
Returns (priority, matches) where:
|
||||
- priority 0: exact name match
|
||||
- priority 1: name contains query
|
||||
- priority 2: description contains query
|
||||
- priority 3: category contains query
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
name_lower = block.name.lower()
|
||||
desc_lower = block.description.lower()
|
||||
|
||||
# Exact name match
|
||||
if query_lower == name_lower:
|
||||
return 0, True
|
||||
|
||||
# Name contains query
|
||||
if query_lower in name_lower:
|
||||
return 1, True
|
||||
|
||||
# Description contains query
|
||||
if query_lower in desc_lower:
|
||||
return 2, True
|
||||
|
||||
# Category contains query
|
||||
for category in block.categories:
|
||||
if query_lower in category.name.lower():
|
||||
return 3, True
|
||||
|
||||
return 4, False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for blocks matching the query.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
BlockListResponse: List of matching blocks
|
||||
NoResultsResponse: No blocks found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Try hybrid search first
|
||||
search_results = self._hybrid_search(query)
|
||||
|
||||
if search_results is not None:
|
||||
# Hybrid search succeeded
|
||||
if not search_results:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found matching '{query}'",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Search by category: ai, text, social, search, etc.",
|
||||
"Check block names like 'SendEmail', 'HttpRequest', etc.",
|
||||
],
|
||||
)
|
||||
|
||||
# Get full block info for each result
|
||||
all_blocks = load_all_blocks()
|
||||
blocks = []
|
||||
for result in search_results:
|
||||
block_cls = all_blocks.get(result.block_id)
|
||||
if block_cls:
|
||||
block = block_cls()
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
description=block.description,
|
||||
categories=[cat.name for cat in block.categories],
|
||||
input_schema=block.input_schema.jsonschema(),
|
||||
output_schema=block.output_schema.jsonschema(),
|
||||
)
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block{'s' if len(blocks) != 1 else ''} "
|
||||
f"matching '{query}'. Use run_block to execute a block with "
|
||||
"the required inputs."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fallback to simple search if hybrid search failed
|
||||
return self._simple_search(query, session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching blocks: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search blocks. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _hybrid_search(self, query: str) -> list | None:
|
||||
"""
|
||||
Perform hybrid search using embeddings and BM25.
|
||||
|
||||
Returns:
|
||||
List of BlockSearchResult if successful, None if index not available
|
||||
"""
|
||||
try:
|
||||
index = get_block_search_index()
|
||||
if not index.load():
|
||||
logger.info(
|
||||
"Block search index not available, falling back to simple search"
|
||||
)
|
||||
return None
|
||||
|
||||
results = index.search(query, top_k=10)
|
||||
logger.info(f"Hybrid search found {len(results)} blocks for: {query}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Hybrid search failed, falling back to simple: {e}")
|
||||
return None
|
||||
|
||||
def _simple_search(self, query: str, session_id: str) -> ToolResponseBase:
|
||||
"""Fallback simple search using substring matching."""
|
||||
all_blocks = load_all_blocks()
|
||||
logger.info(f"Simple searching {len(all_blocks)} blocks for: {query}")
|
||||
|
||||
# Find matching blocks with priority scores
|
||||
matches: list[tuple[int, Any]] = []
|
||||
for block_id, block_cls in all_blocks.items():
|
||||
block = block_cls()
|
||||
priority, is_match = self._matches_query(block, query)
|
||||
if is_match:
|
||||
matches.append((priority, block))
|
||||
|
||||
# Sort by priority (lower is better)
|
||||
matches.sort(key=lambda x: x[0])
|
||||
|
||||
# Take top 10 results
|
||||
top_matches = [block for _, block in matches[:10]]
|
||||
|
||||
if not top_matches:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found matching '{query}'",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Search by category: ai, text, social, search, etc.",
|
||||
"Check block names like 'SendEmail', 'HttpRequest', etc.",
|
||||
],
|
||||
)
|
||||
|
||||
# Build response
|
||||
blocks = []
|
||||
for block in top_matches:
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
description=block.description,
|
||||
categories=[cat.name for cat in block.categories],
|
||||
input_schema=block.input_schema.jsonschema(),
|
||||
output_schema=block.output_schema.jsonschema(),
|
||||
)
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block{'s' if len(blocks) != 1 else ''} "
|
||||
f"matching '{query}'. Use run_block to execute a block with "
|
||||
"the required inputs."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,157 +0,0 @@
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.util.exceptions import DatabaseError
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentCarouselResponse,
|
||||
AgentInfo,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindLibraryAgentTool(BaseTool):
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_library_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for agents in the user's library. Use this to find agents "
|
||||
"the user has already added to their library, including agents they "
|
||||
"created or added from the marketplace."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find agents by name or description. "
|
||||
"Use keywords for best results."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for agents in the user's library.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
AgentCarouselResponse: List of agents found in the library
|
||||
NoResultsResponse: No agents found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agents = []
|
||||
try:
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
library_results = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Find library agents tool found {len(library_results.agents)} agents"
|
||||
)
|
||||
|
||||
for agent in library_results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
),
|
||||
)
|
||||
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching library agents: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search library. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"No agents found matching '{query}' in your library. "
|
||||
"Try different keywords or use find_agent to search the marketplace."
|
||||
),
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
],
|
||||
)
|
||||
|
||||
title = (
|
||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
|
||||
f"in your library for '{query}'"
|
||||
)
|
||||
|
||||
return AgentCarouselResponse(
|
||||
message=(
|
||||
"Found agents in the user's library. You can provide a link to "
|
||||
"view an agent at: /library/agents/{agent_id}. "
|
||||
"Use agent_output to get execution results, or run_agent to execute."
|
||||
),
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,483 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Block Indexer for Hybrid Search
|
||||
|
||||
Creates a hybrid search index from blocks:
|
||||
- OpenAI embeddings (text-embedding-3-small)
|
||||
- BM25 index for lexical search
|
||||
- Name index for title matching boost
|
||||
|
||||
Supports incremental updates by tracking content hashes.
|
||||
|
||||
Usage:
|
||||
python -m backend.server.v2.chat.tools.index_blocks [--force]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check for OpenAI availability
|
||||
try:
|
||||
import openai # noqa: F401
|
||||
|
||||
HAS_OPENAI = True
|
||||
except ImportError:
|
||||
HAS_OPENAI = False
|
||||
print("Warning: openai not installed. Run: pip install openai")
|
||||
|
||||
# Default embedding model (OpenAI)
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_EMBEDDING_DIM = 1536
|
||||
|
||||
# Output path (relative to this file)
|
||||
INDEX_PATH = Path(__file__).parent / "blocks_index.json"
|
||||
|
||||
# Stopwords for tokenization
|
||||
STOPWORDS = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"shall",
|
||||
"can",
|
||||
"need",
|
||||
"dare",
|
||||
"ought",
|
||||
"used",
|
||||
"to",
|
||||
"of",
|
||||
"in",
|
||||
"for",
|
||||
"on",
|
||||
"with",
|
||||
"at",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"then",
|
||||
"once",
|
||||
"and",
|
||||
"but",
|
||||
"or",
|
||||
"nor",
|
||||
"so",
|
||||
"yet",
|
||||
"both",
|
||||
"either",
|
||||
"neither",
|
||||
"not",
|
||||
"only",
|
||||
"own",
|
||||
"same",
|
||||
"than",
|
||||
"too",
|
||||
"very",
|
||||
"just",
|
||||
"also",
|
||||
"now",
|
||||
"here",
|
||||
"there",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"how",
|
||||
"all",
|
||||
"each",
|
||||
"every",
|
||||
"few",
|
||||
"more",
|
||||
"most",
|
||||
"other",
|
||||
"some",
|
||||
"such",
|
||||
"no",
|
||||
"any",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
"block", # Too common in block context
|
||||
}
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25."""
|
||||
text = text.lower()
|
||||
# Remove code blocks if any
|
||||
text = re.sub(r"```[\s\S]*?```", "", text)
|
||||
text = re.sub(r"`[^`]+`", "", text)
|
||||
# Extract words (including camelCase split)
|
||||
# First, split camelCase
|
||||
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
|
||||
# Extract words
|
||||
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
|
||||
# Remove very short words and stopwords
|
||||
return [w for w in words if len(w) > 2 and w not in STOPWORDS]
|
||||
|
||||
|
||||
def build_searchable_text(block: Any) -> str:
|
||||
"""Build searchable text from block attributes."""
|
||||
parts = []
|
||||
|
||||
# Block name (split camelCase for better tokenization)
|
||||
name = block.name
|
||||
# Split camelCase: GetCurrentTimeBlock -> Get Current Time Block
|
||||
name_split = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
|
||||
parts.append(name_split)
|
||||
|
||||
# Description
|
||||
if block.description:
|
||||
parts.append(block.description)
|
||||
|
||||
# Categories
|
||||
for category in block.categories:
|
||||
parts.append(category.name)
|
||||
|
||||
# Input schema field names and descriptions
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
if "properties" in input_schema:
|
||||
for field_name, field_info in input_schema["properties"].items():
|
||||
parts.append(field_name)
|
||||
if "description" in field_info:
|
||||
parts.append(field_info["description"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Output schema field names
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
if "properties" in output_schema:
|
||||
for field_name in output_schema["properties"]:
|
||||
parts.append(field_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def compute_content_hash(text: str) -> str:
|
||||
"""Compute MD5 hash of text for change detection."""
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def load_existing_index(index_path: Path) -> dict[str, Any] | None:
|
||||
"""Load existing index if present."""
|
||||
if not index_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(index_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load existing index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_embeddings(
|
||||
texts: list[str],
|
||||
model_name: str = DEFAULT_EMBEDDING_MODEL,
|
||||
batch_size: int = 100,
|
||||
) -> np.ndarray:
|
||||
"""Create embeddings using OpenAI API."""
|
||||
if not HAS_OPENAI:
|
||||
raise RuntimeError("openai not installed. Run: pip install openai")
|
||||
|
||||
# Import here to satisfy type checker
|
||||
from openai import OpenAI
|
||||
|
||||
# Check for API key
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
client = OpenAI(api_key=api_key)
|
||||
embeddings = []
|
||||
|
||||
print(f"Creating embeddings for {len(texts)} texts using {model_name}...")
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
# Truncate texts to max token limit (8191 tokens for text-embedding-3-small)
|
||||
# Roughly 4 chars per token, so ~32000 chars max
|
||||
batch = [text[:32000] for text in batch]
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=model_name,
|
||||
input=batch,
|
||||
)
|
||||
|
||||
for embedding_data in response.data:
|
||||
embeddings.append(embedding_data.embedding)
|
||||
|
||||
print(f" Processed {min(i + batch_size, len(texts))}/{len(texts)} texts")
|
||||
|
||||
return np.array(embeddings, dtype=np.float32)
|
||||
|
||||
|
||||
def build_bm25_data(
|
||||
blocks_data: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""Build BM25 metadata from block data."""
|
||||
# Tokenize all searchable texts
|
||||
tokenized_docs = []
|
||||
for block in blocks_data:
|
||||
tokens = tokenize(block["searchable_text"])
|
||||
tokenized_docs.append(tokens)
|
||||
|
||||
# Calculate document frequencies
|
||||
doc_freq: dict[str, int] = {}
|
||||
for tokens in tokenized_docs:
|
||||
seen = set()
|
||||
for token in tokens:
|
||||
if token not in seen:
|
||||
doc_freq[token] = doc_freq.get(token, 0) + 1
|
||||
seen.add(token)
|
||||
|
||||
n_docs = len(tokenized_docs)
|
||||
doc_lens = [len(d) for d in tokenized_docs]
|
||||
avgdl = sum(doc_lens) / max(n_docs, 1)
|
||||
|
||||
return {
|
||||
"n_docs": n_docs,
|
||||
"avgdl": avgdl,
|
||||
"df": doc_freq,
|
||||
"doc_lens": doc_lens,
|
||||
}
|
||||
|
||||
|
||||
def build_name_index(
|
||||
blocks_data: list[dict[str, Any]],
|
||||
) -> dict[str, list[list[int | float]]]:
|
||||
"""Build inverted index for name search boost."""
|
||||
index: dict[str, list[list[int | float]]] = defaultdict(list)
|
||||
|
||||
for idx, block in enumerate(blocks_data):
|
||||
# Tokenize block name
|
||||
name_tokens = tokenize(block["name"])
|
||||
seen = set()
|
||||
|
||||
for i, token in enumerate(name_tokens):
|
||||
if token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
|
||||
# Score: first token gets higher weight
|
||||
score = 1.5 if i == 0 else 1.0
|
||||
index[token].append([idx, score])
|
||||
|
||||
return dict(index)
|
||||
|
||||
|
||||
def build_block_index(
|
||||
force_rebuild: bool = False,
|
||||
output_path: Path = INDEX_PATH,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build the block search index.
|
||||
|
||||
Args:
|
||||
force_rebuild: If True, rebuild all embeddings even if unchanged
|
||||
output_path: Path to save the index
|
||||
|
||||
Returns:
|
||||
The generated index dictionary
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
print("Loading all blocks...")
|
||||
all_blocks = load_all_blocks()
|
||||
print(f"Found {len(all_blocks)} blocks")
|
||||
|
||||
# Load existing index for incremental updates
|
||||
existing_index = None if force_rebuild else load_existing_index(output_path)
|
||||
existing_blocks: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if existing_index:
|
||||
print(
|
||||
f"Loaded existing index with {len(existing_index.get('blocks', []))} blocks"
|
||||
)
|
||||
for block in existing_index.get("blocks", []):
|
||||
existing_blocks[block["id"]] = block
|
||||
|
||||
# Process each block
|
||||
blocks_data: list[dict[str, Any]] = []
|
||||
blocks_needing_embedding: list[tuple[int, str]] = [] # (index, searchable_text)
|
||||
|
||||
for block_id, block_cls in all_blocks.items():
|
||||
try:
|
||||
block = block_cls()
|
||||
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
searchable_text = build_searchable_text(block)
|
||||
content_hash = compute_content_hash(searchable_text)
|
||||
|
||||
block_data = {
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"description": block.description,
|
||||
"categories": [cat.name for cat in block.categories],
|
||||
"searchable_text": searchable_text,
|
||||
"content_hash": content_hash,
|
||||
"emb": None, # Will be filled later
|
||||
}
|
||||
|
||||
# Check if we can reuse existing embedding
|
||||
if (
|
||||
block.id in existing_blocks
|
||||
and existing_blocks[block.id].get("content_hash") == content_hash
|
||||
and existing_blocks[block.id].get("emb")
|
||||
):
|
||||
# Reuse existing embedding
|
||||
block_data["emb"] = existing_blocks[block.id]["emb"]
|
||||
else:
|
||||
# Need new embedding
|
||||
blocks_needing_embedding.append((len(blocks_data), searchable_text))
|
||||
|
||||
blocks_data.append(block_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process block {block_id}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Processed {len(blocks_data)} blocks")
|
||||
print(f"Blocks needing new embeddings: {len(blocks_needing_embedding)}")
|
||||
|
||||
# Create embeddings for new/changed blocks
|
||||
if blocks_needing_embedding and HAS_OPENAI:
|
||||
texts_to_embed = [text for _, text in blocks_needing_embedding]
|
||||
try:
|
||||
embeddings = create_embeddings(texts_to_embed)
|
||||
|
||||
# Assign embeddings to blocks
|
||||
for i, (block_idx, _) in enumerate(blocks_needing_embedding):
|
||||
emb = embeddings[i].astype(np.float32)
|
||||
# Encode as base64
|
||||
blocks_data[block_idx]["emb"] = base64.b64encode(emb.tobytes()).decode(
|
||||
"ascii"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to create embeddings: {e}")
|
||||
elif blocks_needing_embedding:
|
||||
print(
|
||||
"Warning: Cannot create embeddings (openai not installed or OPENAI_API_KEY not set)"
|
||||
)
|
||||
|
||||
# Build BM25 data
|
||||
print("Building BM25 index...")
|
||||
bm25_data = build_bm25_data(blocks_data)
|
||||
|
||||
# Build name index
|
||||
print("Building name index...")
|
||||
name_index = build_name_index(blocks_data)
|
||||
|
||||
# Build final index
|
||||
index = {
|
||||
"version": "1.0.0",
|
||||
"embedding_model": DEFAULT_EMBEDDING_MODEL,
|
||||
"embedding_dim": DEFAULT_EMBEDDING_DIM,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"blocks": blocks_data,
|
||||
"bm25": bm25_data,
|
||||
"name_index": name_index,
|
||||
}
|
||||
|
||||
# Save index
|
||||
print(f"Saving index to {output_path}...")
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, separators=(",", ":"))
|
||||
|
||||
size_kb = output_path.stat().st_size / 1024
|
||||
print(f"Index saved ({size_kb:.1f} KB)")
|
||||
|
||||
# Print statistics
|
||||
print("\nIndex Statistics:")
|
||||
print(f" Blocks indexed: {len(blocks_data)}")
|
||||
print(f" BM25 vocabulary size: {len(bm25_data['df'])}")
|
||||
print(f" Name index terms: {len(name_index)}")
|
||||
print(f" Embeddings: {'Yes' if any(b.get('emb') for b in blocks_data) else 'No'}")
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Build hybrid search index for blocks")
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Force rebuild all embeddings even if unchanged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=INDEX_PATH,
|
||||
help=f"Output index file path (default: {INDEX_PATH})",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
build_block_index(
|
||||
force_rebuild=args.force,
|
||||
output_path=args.output,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error building index: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,287 +0,0 @@
|
||||
"""Tool for executing blocks directly."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunBlockTool(BaseTool):
|
||||
"""Tool for executing a block and returning its outputs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a specific block with the provided input data. "
|
||||
"Use find_block to discover available blocks and their input schemas. "
|
||||
"The block will run and return its outputs once complete."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"block_id": {
|
||||
"type": "string",
|
||||
"description": "The UUID of the block to execute",
|
||||
},
|
||||
"input_data": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Input values for the block. Must match the block's input schema. "
|
||||
"Check the block's input_schema from find_block for required fields."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
if not credentials_fields_info:
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
# Get user's available credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
# field_info.provider is a frozenset of acceptable providers
|
||||
# field_info.supported_types is a frozenset of acceptable types
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in field_info.provider
|
||||
and cred.type in field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if matching_cred:
|
||||
matched_credentials[field_name] = CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with the given input data.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
block_id: Block UUID to execute
|
||||
input_data: Input values for the block
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse: Block execution outputs
|
||||
SetupRequirementsResponse: Missing credentials
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
block_id = kwargs.get("block_id", "").strip()
|
||||
input_data = kwargs.get("input_data", {})
|
||||
session_id = session.session_id
|
||||
|
||||
if not block_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a block_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not isinstance(input_data, dict):
|
||||
return ErrorResponse(
|
||||
message="input_data must be an object",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get the block
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
# Check credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
missing_creds_dict = {c.id: c.model_dump() for c in missing_credentials}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires credentials that are not configured. "
|
||||
"Please set up the required credentials before running this block."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=block_id,
|
||||
agent_name=block.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [c.model_dump() for c in missing_credentials],
|
||||
"inputs": self._get_inputs_list(block),
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
),
|
||||
graph_id=None,
|
||||
graph_version=None,
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch actual credentials and prepare kwargs for block execution
|
||||
exec_kwargs: dict[str, Any] = {"user_id": user_id}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
# Fetch actual credentials and pass as kwargs (for execution)
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
inputs_list = []
|
||||
schema = block.input_schema.jsonschema()
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
# Get credential field names to exclude
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in required_fields,
|
||||
}
|
||||
)
|
||||
|
||||
return inputs_list
|
||||
@@ -1,460 +0,0 @@
|
||||
"""
|
||||
Block Hybrid Search
|
||||
|
||||
Combines multiple ranking signals for block search:
|
||||
- Semantic search (OpenAI embeddings + cosine similarity)
|
||||
- Lexical search (BM25)
|
||||
- Name matching (boost for block name matches)
|
||||
- Category matching (boost for category matches)
|
||||
|
||||
Based on the docs search implementation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI embedding model
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
|
||||
# Path to the JSON index file
|
||||
INDEX_PATH = Path(__file__).parent / "blocks_index.json"
|
||||
|
||||
# Stopwords for tokenization (same as index_blocks.py)
|
||||
STOPWORDS = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"shall",
|
||||
"can",
|
||||
"need",
|
||||
"dare",
|
||||
"ought",
|
||||
"used",
|
||||
"to",
|
||||
"of",
|
||||
"in",
|
||||
"for",
|
||||
"on",
|
||||
"with",
|
||||
"at",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"then",
|
||||
"once",
|
||||
"and",
|
||||
"but",
|
||||
"or",
|
||||
"nor",
|
||||
"so",
|
||||
"yet",
|
||||
"both",
|
||||
"either",
|
||||
"neither",
|
||||
"not",
|
||||
"only",
|
||||
"own",
|
||||
"same",
|
||||
"than",
|
||||
"too",
|
||||
"very",
|
||||
"just",
|
||||
"also",
|
||||
"now",
|
||||
"here",
|
||||
"there",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"how",
|
||||
"all",
|
||||
"each",
|
||||
"every",
|
||||
"few",
|
||||
"more",
|
||||
"most",
|
||||
"other",
|
||||
"some",
|
||||
"such",
|
||||
"no",
|
||||
"any",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
"block",
|
||||
}
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for search."""
|
||||
text = text.lower()
|
||||
# Remove code blocks if any
|
||||
text = re.sub(r"```[\s\S]*?```", "", text)
|
||||
text = re.sub(r"`[^`]+`", "", text)
|
||||
# Split camelCase
|
||||
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
|
||||
# Extract words
|
||||
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
|
||||
# Remove very short words and stopwords
|
||||
return [w for w in words if len(w) > 2 and w not in STOPWORDS]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchWeights:
|
||||
"""Configuration for hybrid search signal weights."""
|
||||
|
||||
semantic: float = 0.40 # Embedding similarity
|
||||
bm25: float = 0.25 # Lexical matching
|
||||
name_match: float = 0.25 # Block name matches
|
||||
category_match: float = 0.10 # Category matches
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockSearchResult:
|
||||
"""A single block search result."""
|
||||
|
||||
block_id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
score: float
|
||||
|
||||
# Individual signal scores (for debugging)
|
||||
semantic_score: float = 0.0
|
||||
bm25_score: float = 0.0
|
||||
name_score: float = 0.0
|
||||
category_score: float = 0.0
|
||||
|
||||
|
||||
class BlockSearchIndex:
|
||||
"""Hybrid search index for blocks combining BM25 + embeddings."""
|
||||
|
||||
def __init__(self, index_path: Path = INDEX_PATH):
|
||||
self.blocks: list[dict[str, Any]] = []
|
||||
self.bm25_data: dict[str, Any] = {}
|
||||
self.name_index: dict[str, list[list[int | float]]] = {}
|
||||
self.embeddings: Optional[np.ndarray] = None
|
||||
self.normalized_embeddings: Optional[np.ndarray] = None
|
||||
self._loaded = False
|
||||
self._index_path = index_path
|
||||
self._embedding_model: Any = None
|
||||
|
||||
def load(self) -> bool:
|
||||
"""Load the index from JSON file."""
|
||||
if self._loaded:
|
||||
return True
|
||||
|
||||
if not self._index_path.exists():
|
||||
logger.warning(f"Block index not found at {self._index_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(self._index_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.blocks = data.get("blocks", [])
|
||||
self.bm25_data = data.get("bm25", {})
|
||||
self.name_index = data.get("name_index", {})
|
||||
|
||||
# Decode embeddings from base64
|
||||
embeddings_list = []
|
||||
for block in self.blocks:
|
||||
if block.get("emb"):
|
||||
emb_bytes = base64.b64decode(block["emb"])
|
||||
emb = np.frombuffer(emb_bytes, dtype=np.float32)
|
||||
embeddings_list.append(emb)
|
||||
else:
|
||||
# No embedding, use zeros
|
||||
dim = data.get("embedding_dim", 384)
|
||||
embeddings_list.append(np.zeros(dim, dtype=np.float32))
|
||||
|
||||
if embeddings_list:
|
||||
self.embeddings = np.stack(embeddings_list)
|
||||
# Precompute normalized embeddings for cosine similarity
|
||||
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
|
||||
self.normalized_embeddings = self.embeddings / (norms + 1e-10)
|
||||
|
||||
self._loaded = True
|
||||
logger.info(f"Loaded block index with {len(self.blocks)} blocks")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load block index: {e}")
|
||||
return False
|
||||
|
||||
def _get_openai_client(self) -> Any:
|
||||
"""Get OpenAI client for query embedding."""
|
||||
if self._embedding_model is None:
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set")
|
||||
return None
|
||||
self._embedding_model = OpenAI(api_key=api_key)
|
||||
except ImportError:
|
||||
logger.warning("openai not installed")
|
||||
return None
|
||||
return self._embedding_model
|
||||
|
||||
def _embed_query(self, query: str) -> Optional[np.ndarray]:
|
||||
"""Embed the search query using OpenAI."""
|
||||
client = self._get_openai_client()
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=query,
|
||||
)
|
||||
embedding = response.data[0].embedding
|
||||
return np.array(embedding, dtype=np.float32)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to embed query: {e}")
|
||||
return None
|
||||
|
||||
def _compute_semantic_scores(self, query_embedding: np.ndarray) -> np.ndarray:
|
||||
"""Compute cosine similarity between query and all blocks."""
|
||||
if self.normalized_embeddings is None:
|
||||
return np.zeros(len(self.blocks))
|
||||
|
||||
# Normalize query embedding
|
||||
query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
|
||||
|
||||
# Cosine similarity via dot product
|
||||
similarities = self.normalized_embeddings @ query_norm
|
||||
|
||||
# Scale to [0, 1] (cosine ranges from -1 to 1)
|
||||
return (similarities + 1) / 2
|
||||
|
||||
def _compute_bm25_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute BM25 scores for all blocks."""
|
||||
scores = np.zeros(len(self.blocks))
|
||||
|
||||
if not self.bm25_data or not query_tokens:
|
||||
return scores
|
||||
|
||||
# BM25 parameters
|
||||
k1 = 1.5
|
||||
b = 0.75
|
||||
n_docs = self.bm25_data.get("n_docs", len(self.blocks))
|
||||
avgdl = self.bm25_data.get("avgdl", 100)
|
||||
df = self.bm25_data.get("df", {})
|
||||
doc_lens = self.bm25_data.get("doc_lens", [100] * len(self.blocks))
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
# Tokenize block's searchable text
|
||||
block_tokens = tokenize(block.get("searchable_text", ""))
|
||||
doc_len = doc_lens[i] if i < len(doc_lens) else len(block_tokens)
|
||||
|
||||
# Calculate BM25 score
|
||||
score = 0.0
|
||||
for token in query_tokens:
|
||||
if token not in df:
|
||||
continue
|
||||
|
||||
# Term frequency in this document
|
||||
tf = block_tokens.count(token)
|
||||
if tf == 0:
|
||||
continue
|
||||
|
||||
# IDF
|
||||
doc_freq = df.get(token, 0)
|
||||
idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
|
||||
|
||||
# BM25 score component
|
||||
numerator = tf * (k1 + 1)
|
||||
denominator = tf + k1 * (1 - b + b * doc_len / avgdl)
|
||||
score += idf * numerator / denominator
|
||||
|
||||
scores[i] = score
|
||||
|
||||
# Normalize to [0, 1]
|
||||
max_score = scores.max()
|
||||
if max_score > 0:
|
||||
scores = scores / max_score
|
||||
|
||||
return scores
|
||||
|
||||
def _compute_name_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute name match scores using the name index."""
|
||||
scores = np.zeros(len(self.blocks))
|
||||
|
||||
if not self.name_index or not query_tokens:
|
||||
return scores
|
||||
|
||||
for token in query_tokens:
|
||||
if token in self.name_index:
|
||||
for block_idx, weight in self.name_index[token]:
|
||||
if block_idx < len(scores):
|
||||
scores[int(block_idx)] += weight
|
||||
|
||||
# Also check for partial matches in block names
|
||||
for i, block in enumerate(self.blocks):
|
||||
name_lower = block.get("name", "").lower()
|
||||
for token in query_tokens:
|
||||
if token in name_lower:
|
||||
scores[i] += 0.5
|
||||
|
||||
# Normalize to [0, 1]
|
||||
max_score = scores.max()
|
||||
if max_score > 0:
|
||||
scores = scores / max_score
|
||||
|
||||
return scores
|
||||
|
||||
def _compute_category_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute category match scores."""
|
||||
scores = np.zeros(len(self.blocks))
|
||||
|
||||
if not query_tokens:
|
||||
return scores
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
categories = block.get("categories", [])
|
||||
category_text = " ".join(categories).lower()
|
||||
|
||||
for token in query_tokens:
|
||||
if token in category_text:
|
||||
scores[i] += 1.0
|
||||
|
||||
# Normalize to [0, 1]
|
||||
max_score = scores.max()
|
||||
if max_score > 0:
|
||||
scores = scores / max_score
|
||||
|
||||
return scores
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
weights: Optional[SearchWeights] = None,
|
||||
) -> list[BlockSearchResult]:
|
||||
"""
|
||||
Perform hybrid search combining multiple signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
top_k: Number of results to return
|
||||
weights: Optional custom weights for signals
|
||||
|
||||
Returns:
|
||||
List of BlockSearchResult sorted by score
|
||||
"""
|
||||
if not self._loaded and not self.load():
|
||||
return []
|
||||
|
||||
if weights is None:
|
||||
weights = SearchWeights()
|
||||
|
||||
# Tokenize query
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
# Fallback: try raw query words
|
||||
query_tokens = query.lower().split()
|
||||
|
||||
# Compute semantic scores
|
||||
semantic_scores = np.zeros(len(self.blocks))
|
||||
if self.normalized_embeddings is not None:
|
||||
query_embedding = self._embed_query(query)
|
||||
if query_embedding is not None:
|
||||
semantic_scores = self._compute_semantic_scores(query_embedding)
|
||||
|
||||
# Compute other scores
|
||||
bm25_scores = self._compute_bm25_scores(query_tokens)
|
||||
name_scores = self._compute_name_scores(query_tokens)
|
||||
category_scores = self._compute_category_scores(query_tokens)
|
||||
|
||||
# Combine scores using weights
|
||||
combined_scores = (
|
||||
weights.semantic * semantic_scores
|
||||
+ weights.bm25 * bm25_scores
|
||||
+ weights.name_match * name_scores
|
||||
+ weights.category_match * category_scores
|
||||
)
|
||||
|
||||
# Get top-k indices
|
||||
top_indices = np.argsort(combined_scores)[::-1][:top_k]
|
||||
|
||||
# Build results
|
||||
results = []
|
||||
for idx in top_indices:
|
||||
if combined_scores[idx] <= 0:
|
||||
continue
|
||||
|
||||
block = self.blocks[idx]
|
||||
results.append(
|
||||
BlockSearchResult(
|
||||
block_id=block["id"],
|
||||
name=block["name"],
|
||||
description=block["description"],
|
||||
categories=block.get("categories", []),
|
||||
score=float(combined_scores[idx]),
|
||||
semantic_score=float(semantic_scores[idx]),
|
||||
bm25_score=float(bm25_scores[idx]),
|
||||
name_score=float(name_scores[idx]),
|
||||
category_score=float(category_scores[idx]),
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global index instance (lazy loaded)
|
||||
_block_search_index: Optional[BlockSearchIndex] = None
|
||||
|
||||
|
||||
def get_block_search_index() -> BlockSearchIndex:
|
||||
"""Get or create the block search index singleton."""
|
||||
global _block_search_index
|
||||
if _block_search_index is None:
|
||||
_block_search_index = BlockSearchIndex(INDEX_PATH)
|
||||
return _block_search_index
|
||||
@@ -1,386 +0,0 @@
|
||||
"""Tool for searching platform documentation."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Documentation base URL
|
||||
DOCS_BASE_URL = "https://docs.agpt.co/platform"
|
||||
|
||||
# Path to the JSON index file (relative to this file)
|
||||
INDEX_PATH = Path(__file__).parent / "docs_index.json"
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25."""
|
||||
text = text.lower()
|
||||
# Remove code blocks
|
||||
text = re.sub(r"```[\s\S]*?```", "", text)
|
||||
text = re.sub(r"`[^`]+`", "", text)
|
||||
# Extract words
|
||||
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
|
||||
# Remove very short words and stopwords
|
||||
stopwords = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"shall",
|
||||
"can",
|
||||
"need",
|
||||
"dare",
|
||||
"ought",
|
||||
"used",
|
||||
"to",
|
||||
"of",
|
||||
"in",
|
||||
"for",
|
||||
"on",
|
||||
"with",
|
||||
"at",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"then",
|
||||
"once",
|
||||
"and",
|
||||
"but",
|
||||
"or",
|
||||
"nor",
|
||||
"so",
|
||||
"yet",
|
||||
"both",
|
||||
"either",
|
||||
"neither",
|
||||
"not",
|
||||
"only",
|
||||
"own",
|
||||
"same",
|
||||
"than",
|
||||
"too",
|
||||
"very",
|
||||
"just",
|
||||
"also",
|
||||
"now",
|
||||
"here",
|
||||
"there",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"how",
|
||||
"all",
|
||||
"each",
|
||||
"every",
|
||||
"both",
|
||||
"few",
|
||||
"more",
|
||||
"most",
|
||||
"other",
|
||||
"some",
|
||||
"such",
|
||||
"no",
|
||||
"any",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
}
|
||||
return [w for w in words if len(w) > 2 and w not in stopwords]
|
||||
|
||||
|
||||
class DocSearchIndex:
|
||||
"""Lightweight documentation search index using BM25."""
|
||||
|
||||
def __init__(self, index_path: Path):
|
||||
self.chunks: list[dict] = []
|
||||
self.bm25_data: dict = {}
|
||||
self._loaded = False
|
||||
self._index_path = index_path
|
||||
|
||||
def load(self) -> bool:
|
||||
"""Load the index from JSON file."""
|
||||
if self._loaded:
|
||||
return True
|
||||
|
||||
if not self._index_path.exists():
|
||||
logger.warning(f"Documentation index not found at {self._index_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(self._index_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.chunks = data.get("chunks", [])
|
||||
self.bm25_data = data.get("bm25", {})
|
||||
self._loaded = True
|
||||
logger.info(f"Loaded documentation index with {len(self.chunks)} chunks")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load documentation index: {e}")
|
||||
return False
|
||||
|
||||
def search(self, query: str, top_k: int = 5) -> list[dict]:
|
||||
"""Search the index using BM25."""
|
||||
if not self._loaded and not self.load():
|
||||
return []
|
||||
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
# BM25 parameters
|
||||
k1 = 1.5
|
||||
b = 0.75
|
||||
n_docs = self.bm25_data.get("n_docs", len(self.chunks))
|
||||
avgdl = self.bm25_data.get("avgdl", 100)
|
||||
df = self.bm25_data.get("df", {})
|
||||
doc_lens = self.bm25_data.get("doc_lens", [100] * len(self.chunks))
|
||||
|
||||
scores = []
|
||||
for i, chunk in enumerate(self.chunks):
|
||||
# Tokenize chunk text
|
||||
chunk_tokens = tokenize(chunk.get("text", ""))
|
||||
doc_len = doc_lens[i] if i < len(doc_lens) else len(chunk_tokens)
|
||||
|
||||
# Calculate BM25 score
|
||||
score = 0.0
|
||||
for token in query_tokens:
|
||||
if token not in df:
|
||||
continue
|
||||
|
||||
# Term frequency in this document
|
||||
tf = chunk_tokens.count(token)
|
||||
if tf == 0:
|
||||
continue
|
||||
|
||||
# IDF
|
||||
doc_freq = df.get(token, 0)
|
||||
idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
|
||||
|
||||
# BM25 score component
|
||||
numerator = tf * (k1 + 1)
|
||||
denominator = tf + k1 * (1 - b + b * doc_len / avgdl)
|
||||
score += idf * numerator / denominator
|
||||
|
||||
# Boost for title/heading matches
|
||||
title = chunk.get("title", "").lower()
|
||||
heading = chunk.get("heading", "").lower()
|
||||
for token in query_tokens:
|
||||
if token in title:
|
||||
score *= 1.5
|
||||
if token in heading:
|
||||
score *= 1.2
|
||||
|
||||
scores.append((i, score))
|
||||
|
||||
# Sort by score and return top_k
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
results = []
|
||||
seen_sections = set()
|
||||
for idx, score in scores:
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
chunk = self.chunks[idx]
|
||||
section_key = (chunk.get("doc", ""), chunk.get("heading", ""))
|
||||
|
||||
# Deduplicate by section
|
||||
if section_key in seen_sections:
|
||||
continue
|
||||
seen_sections.add(section_key)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"title": chunk.get("title", ""),
|
||||
"path": chunk.get("doc", ""),
|
||||
"heading": chunk.get("heading", ""),
|
||||
"text": chunk.get("text", ""), # Full text for LLM comprehension
|
||||
"score": score,
|
||||
}
|
||||
)
|
||||
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global index instance (lazy loaded)
|
||||
_search_index: DocSearchIndex | None = None
|
||||
|
||||
|
||||
def get_search_index() -> DocSearchIndex:
|
||||
"""Get or create the search index singleton."""
|
||||
global _search_index
|
||||
if _search_index is None:
|
||||
_search_index = DocSearchIndex(INDEX_PATH)
|
||||
return _search_index
|
||||
|
||||
|
||||
class SearchDocsTool(BaseTool):
|
||||
"""Tool for searching AutoGPT platform documentation."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "search_platform_docs"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search the AutoGPT platform documentation and support Q&A for information about "
|
||||
"how to use the platform, create agents, configure blocks, "
|
||||
"set up integrations, troubleshoot issues, and more. Use this when users ask "
|
||||
"support questions or want to learn how to do something with AutoGPT."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query describing what the user wants to learn about. "
|
||||
"Use keywords like 'blocks', 'agents', 'credentials', 'API', etc."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search documentation for the query.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
DocSearchResultsResponse: List of matching documentation sections
|
||||
NoResultsResponse: No results found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
index = get_search_index()
|
||||
results = index.search(query, top_k=5)
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'. Try different keywords.",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms like 'blocks', 'agents', 'setup'",
|
||||
"Check the documentation at docs.agpt.co",
|
||||
],
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
doc_results = []
|
||||
for r in results:
|
||||
# Build documentation URL
|
||||
path = r["path"]
|
||||
if path.endswith(".md"):
|
||||
path = path[:-3] # Remove .md extension
|
||||
doc_url = f"{DOCS_BASE_URL}/{path}"
|
||||
|
||||
full_text = r["text"]
|
||||
doc_results.append(
|
||||
DocSearchResult(
|
||||
title=r["title"],
|
||||
path=r["path"],
|
||||
section=r["heading"],
|
||||
snippet=(
|
||||
full_text[:300] + "..."
|
||||
if len(full_text) > 300
|
||||
else full_text
|
||||
),
|
||||
content=full_text, # Full text for LLM to read and understand
|
||||
score=round(r["score"], 3),
|
||||
doc_url=doc_url,
|
||||
)
|
||||
)
|
||||
|
||||
return DocSearchResultsResponse(
|
||||
message=(
|
||||
f"Found {len(doc_results)} relevant documentation sections. "
|
||||
"Use these to help answer the user's question. "
|
||||
"Include links to the documentation when helpful."
|
||||
),
|
||||
results=doc_results,
|
||||
count=len(doc_results),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching documentation: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search documentation. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,72 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLI script to backfill embeddings for store agents.
|
||||
|
||||
Usage:
|
||||
poetry run python -m backend.server.v2.store.backfill_embeddings [--batch-size N]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import prisma
|
||||
|
||||
|
||||
async def main(batch_size: int = 100) -> int:
|
||||
"""Run the backfill process."""
|
||||
# Initialize Prisma client
|
||||
client = prisma.Prisma()
|
||||
await client.connect()
|
||||
prisma.register(client)
|
||||
|
||||
try:
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
|
||||
# Get current stats
|
||||
print("Current embedding stats:")
|
||||
stats = await get_embedding_stats()
|
||||
print(f" Total approved: {stats['total_approved']}")
|
||||
print(f" With embeddings: {stats['with_embeddings']}")
|
||||
print(f" Without embeddings: {stats['without_embeddings']}")
|
||||
print(f" Coverage: {stats['coverage_percent']}%")
|
||||
|
||||
if stats["without_embeddings"] == 0:
|
||||
print("\nAll agents already have embeddings. Nothing to do.")
|
||||
return 0
|
||||
|
||||
# Run backfill
|
||||
print(f"\nBackfilling up to {batch_size} embeddings...")
|
||||
result = await backfill_missing_embeddings(batch_size=batch_size)
|
||||
print(f" Processed: {result['processed']}")
|
||||
print(f" Success: {result['success']}")
|
||||
print(f" Failed: {result['failed']}")
|
||||
|
||||
# Get final stats
|
||||
print("\nFinal embedding stats:")
|
||||
stats = await get_embedding_stats()
|
||||
print(f" Total approved: {stats['total_approved']}")
|
||||
print(f" With embeddings: {stats['with_embeddings']}")
|
||||
print(f" Without embeddings: {stats['without_embeddings']}")
|
||||
print(f" Coverage: {stats['coverage_percent']}%")
|
||||
|
||||
return 0 if result["failed"] == 0 else 1
|
||||
|
||||
finally:
|
||||
await client.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Backfill embeddings for store agents")
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of embeddings to generate (default: 100)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sys.exit(asyncio.run(main(batch_size=args.batch_size)))
|
||||
@@ -1,408 +0,0 @@
|
||||
"""
|
||||
Store Listing Embeddings Service
|
||||
|
||||
Handles generation and storage of OpenAI embeddings for store listings
|
||||
to enable semantic/hybrid search.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
EMBEDDING_DIM = 1536
|
||||
|
||||
|
||||
def build_searchable_text(
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Build searchable text from listing version fields.
|
||||
|
||||
Combines relevant fields into a single string for embedding.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Name is important - include it
|
||||
if name:
|
||||
parts.append(name)
|
||||
|
||||
# Sub-heading provides context
|
||||
if sub_heading:
|
||||
parts.append(sub_heading)
|
||||
|
||||
# Description is the main content
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
# Categories help with semantic matching
|
||||
if categories:
|
||||
parts.append(" ".join(categories))
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def compute_content_hash(text: str) -> str:
|
||||
"""Compute MD5 hash of text for change detection."""
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
|
||||
async def generate_embedding(text: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for text using OpenAI API.
|
||||
|
||||
Returns None if embedding generation fails.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set, cannot generate embedding")
|
||||
return None
|
||||
|
||||
client = OpenAI(api_key=api_key)
|
||||
|
||||
# Truncate text to avoid token limits (~32k chars for safety)
|
||||
truncated_text = text[:32000]
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=truncated_text,
|
||||
)
|
||||
|
||||
embedding = response.data[0].embedding
|
||||
logger.debug(f"Generated embedding with {len(embedding)} dimensions")
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embedding: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def store_embedding(
|
||||
version_id: str,
|
||||
embedding: list[float],
|
||||
searchable_text: str,
|
||||
content_hash: str,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the database.
|
||||
|
||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||
"""
|
||||
try:
|
||||
client = tx if tx else prisma.get_client()
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
|
||||
# Upsert the embedding
|
||||
# Set search_path to include public for vector type visibility
|
||||
await client.execute_raw(
|
||||
"""
|
||||
SET LOCAL search_path TO platform, public;
|
||||
INSERT INTO platform."StoreListingEmbedding" (
|
||||
"id", "storeListingVersionId", "embedding",
|
||||
"searchableText", "contentHash", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES (
|
||||
gen_random_uuid(), $1, $2::vector,
|
||||
$3, $4, NOW(), NOW()
|
||||
)
|
||||
ON CONFLICT ("storeListingVersionId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $2::vector,
|
||||
"searchableText" = $3,
|
||||
"contentHash" = $4,
|
||||
"updatedAt" = NOW()
|
||||
""",
|
||||
version_id,
|
||||
embedding_str,
|
||||
searchable_text,
|
||||
content_hash,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for version {version_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for a listing version.
|
||||
|
||||
Returns dict with embedding, searchableText, contentHash or None if not found.
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
result = await client.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
"id",
|
||||
"storeListingVersionId",
|
||||
"embedding"::text as "embedding",
|
||||
"searchableText",
|
||||
"contentHash",
|
||||
"createdAt",
|
||||
"updatedAt"
|
||||
FROM platform."StoreListingEmbedding"
|
||||
WHERE "storeListingVersionId" = $1
|
||||
""",
|
||||
version_id,
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding for version {version_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def ensure_embedding(
|
||||
version_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for the listing version.
|
||||
|
||||
Creates embedding if missing or if content has changed.
|
||||
Skips if content hash matches existing embedding.
|
||||
|
||||
Args:
|
||||
version_id: The StoreListingVersion ID
|
||||
name: Agent name
|
||||
description: Agent description
|
||||
sub_heading: Agent sub-heading
|
||||
categories: Agent categories
|
||||
force: Force regeneration even if hash matches
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Build searchable text and compute hash
|
||||
searchable_text = build_searchable_text(
|
||||
name, description, sub_heading, categories
|
||||
)
|
||||
content_hash = compute_content_hash(searchable_text)
|
||||
|
||||
# Check if embedding already exists with same hash
|
||||
if not force:
|
||||
existing = await get_embedding(version_id)
|
||||
if existing and existing.get("contentHash") == content_hash:
|
||||
logger.debug(
|
||||
f"Embedding for version {version_id} is up to date (hash match)"
|
||||
)
|
||||
return True
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
||||
return False
|
||||
|
||||
# Store the embedding
|
||||
return await store_embedding(
|
||||
version_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
content_hash=content_hash,
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_embedding(version_id: str) -> bool:
|
||||
"""
|
||||
Delete embedding for a listing version.
|
||||
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
await client.execute_raw(
|
||||
"""
|
||||
DELETE FROM platform."StoreListingEmbedding"
|
||||
WHERE "storeListingVersionId" = $1
|
||||
""",
|
||||
version_id,
|
||||
)
|
||||
|
||||
logger.info(f"Deleted embedding for version {version_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding_stats() -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
|
||||
Returns counts of:
|
||||
- Total approved listing versions
|
||||
- Versions with embeddings
|
||||
- Versions without embeddings
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
# Count approved versions
|
||||
approved_result = await client.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM platform."StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await client.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM platform."StoreListingVersion" slv
|
||||
JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total_approved": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
"coverage_percent": (
|
||||
round(with_embeddings / total_approved * 100, 1)
|
||||
if total_approved > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding stats: {e}")
|
||||
return {
|
||||
"total_approved": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for approved listings that don't have them.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate in one call
|
||||
|
||||
Returns:
|
||||
Dict with success/failure counts
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
# Find approved versions without embeddings
|
||||
missing = await client.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM platform."StoreListingVersion" slv
|
||||
LEFT JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND sle.id IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
|
||||
if not missing:
|
||||
return {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"message": "No missing embeddings",
|
||||
}
|
||||
|
||||
success = 0
|
||||
failed = 0
|
||||
|
||||
for row in missing:
|
||||
result = await ensure_embedding(
|
||||
version_id=row["id"],
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
)
|
||||
if result:
|
||||
success += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
return {
|
||||
"processed": len(missing),
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"message": f"Backfilled {success} embeddings, {failed} failed",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill embeddings: {e}")
|
||||
return {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def embed_query(query: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for a search query.
|
||||
|
||||
Same as generate_embedding but with clearer intent.
|
||||
"""
|
||||
return await generate_embedding(query)
|
||||
|
||||
|
||||
def embedding_to_vector_string(embedding: list[float]) -> str:
|
||||
"""Convert embedding list to PostgreSQL vector string format."""
|
||||
return "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
@@ -1,440 +0,0 @@
|
||||
"""
|
||||
Hybrid Search for Store Agents
|
||||
|
||||
Combines semantic (embedding) search with lexical (tsvector) search
|
||||
for improved relevance in marketplace agent discovery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
import prisma
|
||||
|
||||
from backend.api.features.store.embeddings import (
|
||||
embed_query,
|
||||
embedding_to_vector_string,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchWeights:
|
||||
"""Weights for combining search signals."""
|
||||
|
||||
semantic: float = 0.35 # Embedding cosine similarity
|
||||
lexical: float = 0.35 # tsvector ts_rank_cd score
|
||||
category: float = 0.20 # Category match boost
|
||||
recency: float = 0.10 # Newer agents ranked higher
|
||||
|
||||
|
||||
DEFAULT_WEIGHTS = HybridSearchWeights()
|
||||
|
||||
# Minimum relevance score threshold - agents below this are filtered out
|
||||
# With weights (0.35 semantic + 0.35 lexical + 0.20 category + 0.10 recency):
|
||||
# - 0.20 means at least ~50% semantic match OR strong lexical match required
|
||||
# - Ensures only genuinely relevant results are returned
|
||||
# - Recency alone (0.10 max) won't pass the threshold
|
||||
DEFAULT_MIN_SCORE = 0.20
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchResult:
|
||||
"""A single search result with score breakdown."""
|
||||
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_image: str
|
||||
creator_username: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
runs: int
|
||||
rating: float
|
||||
categories: list[str]
|
||||
featured: bool
|
||||
is_available: bool
|
||||
updated_at: datetime
|
||||
|
||||
# Score breakdown (for debugging/tuning)
|
||||
combined_score: float
|
||||
semantic_score: float = 0.0
|
||||
lexical_score: float = 0.0
|
||||
category_score: float = 0.0
|
||||
recency_score: float = 0.0
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
query: str,
|
||||
featured: bool = False,
|
||||
creators: list[str] | None = None,
|
||||
category: str | None = None,
|
||||
sorted_by: (
|
||||
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
|
||||
) = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
weights: HybridSearchWeights | None = None,
|
||||
min_score: float | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Perform hybrid search combining semantic and lexical signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
featured: Filter for featured agents only
|
||||
creators: Filter by creator usernames
|
||||
category: Filter by category
|
||||
sorted_by: Sort order (relevance uses hybrid scoring)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Results per page
|
||||
weights: Custom weights for search signals
|
||||
min_score: Minimum relevance score threshold (0-1). Results below
|
||||
this score are filtered out. Defaults to DEFAULT_MIN_SCORE.
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, total count). Returns empty list if no
|
||||
results meet the minimum relevance threshold.
|
||||
"""
|
||||
if weights is None:
|
||||
weights = DEFAULT_WEIGHTS
|
||||
if min_score is None:
|
||||
min_score = DEFAULT_MIN_SCORE
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
client = prisma.get_client()
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Build WHERE clause conditions
|
||||
where_parts: list[str] = ["sa.is_available = true"]
|
||||
params: list[Any] = []
|
||||
param_index = 1
|
||||
|
||||
# Add search query for lexical matching
|
||||
params.append(query)
|
||||
query_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
if featured:
|
||||
where_parts.append("sa.featured = true")
|
||||
|
||||
if creators:
|
||||
where_parts.append(f"sa.creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category:
|
||||
where_parts.append(f"${param_index} = ANY(sa.categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
where_clause = " AND ".join(where_parts)
|
||||
|
||||
# Determine if we can use hybrid search (have query embedding)
|
||||
use_hybrid = query_embedding is not None
|
||||
|
||||
if use_hybrid:
|
||||
# Add embedding parameter
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
params.append(embedding_str)
|
||||
embedding_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Build hybrid search query with weighted scoring
|
||||
# The semantic score is (1 - cosine_distance), normalized to [0,1]
|
||||
# The lexical score is ts_rank_cd, normalized by max value
|
||||
# Set search_path to include public for vector type visibility
|
||||
sql_query = f"""
|
||||
SET LOCAL search_path TO platform, public;
|
||||
WITH search_scores AS (
|
||||
SELECT
|
||||
sa.*,
|
||||
-- Semantic score: cosine similarity (1 - distance)
|
||||
COALESCE(1 - (sle.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score: ts_rank_cd normalized
|
||||
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match: 1 if query term appears in categories, else 0
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(sa.categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
|
||||
) THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
-- Recency score: exponential decay over 90 days
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM platform."StoreAgent" sa
|
||||
LEFT JOIN platform."StoreListing" sl ON sa.slug = sl.slug
|
||||
LEFT JOIN platform."StoreListingVersion" slv ON sl."activeVersionId" = slv.id
|
||||
LEFT JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
|
||||
WHERE {where_clause}
|
||||
AND (
|
||||
sa.search @@ plainto_tsquery('english', {query_param})
|
||||
OR sle.embedding IS NOT NULL
|
||||
)
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
*,
|
||||
-- Normalize lexical score by max in result set
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM search_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
(
|
||||
{weights.semantic} * semantic_score +
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
)
|
||||
SELECT * FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||
"""
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
|
||||
# Count query - must also filter by min_score
|
||||
count_query = f"""
|
||||
SET LOCAL search_path TO platform, public;
|
||||
WITH search_scores AS (
|
||||
SELECT
|
||||
sa.slug,
|
||||
COALESCE(1 - (sle.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(sa.categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
|
||||
) THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM platform."StoreAgent" sa
|
||||
LEFT JOIN platform."StoreListing" sl ON sa.slug = sl.slug
|
||||
LEFT JOIN platform."StoreListingVersion" slv ON sl."activeVersionId" = slv.id
|
||||
LEFT JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
|
||||
WHERE {where_clause}
|
||||
AND (
|
||||
sa.search @@ plainto_tsquery('english', {query_param})
|
||||
OR sle.embedding IS NOT NULL
|
||||
)
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
slug,
|
||||
semantic_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM search_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
(
|
||||
{weights.semantic} * semantic_score +
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
)
|
||||
SELECT COUNT(*) as count FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
"""
|
||||
|
||||
else:
|
||||
# Fallback to lexical-only search (existing behavior)
|
||||
# Note: For lexical-only, we still require tsvector match but don't
|
||||
# apply min_score since ts_rank_cd isn't normalized to [0,1]
|
||||
logger.warning("Falling back to lexical-only search (no query embedding)")
|
||||
|
||||
sql_query = f"""
|
||||
WITH lexical_scores AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
0.0 as semantic_score,
|
||||
ts_rank_cd(search, plainto_tsquery('english', {query_param})) as lexical_raw,
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
|
||||
) THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM platform."StoreAgent" sa
|
||||
WHERE {where_clause}
|
||||
AND search @@ plainto_tsquery('english', {query_param})
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
*,
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM lexical_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
(
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
)
|
||||
SELECT * FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||
"""
|
||||
|
||||
params.extend([page_size, offset])
|
||||
|
||||
count_query = f"""
|
||||
WITH lexical_scores AS (
|
||||
SELECT
|
||||
slug,
|
||||
ts_rank_cd(search, plainto_tsquery('english', {query_param})) as lexical_raw,
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
|
||||
) THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM platform."StoreAgent" sa
|
||||
WHERE {where_clause}
|
||||
AND search @@ plainto_tsquery('english', {query_param})
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
slug,
|
||||
category_score,
|
||||
recency_score,
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM lexical_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
(
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
)
|
||||
SELECT COUNT(*) as count FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
"""
|
||||
|
||||
try:
|
||||
# Execute search query
|
||||
# Dynamic SQL is safe here - all user inputs are parameterized ($1, $2, etc.)
|
||||
results = await client.query_raw(sql_query, *params) # type: ignore[arg-type]
|
||||
|
||||
# Execute count query (without pagination params)
|
||||
count_params = params[:-2] # Remove LIMIT and OFFSET params
|
||||
count_result = await client.query_raw(count_query, *count_params) # type: ignore[arg-type]
|
||||
total = count_result[0]["count"] if count_result else 0
|
||||
|
||||
logger.info(
|
||||
f"Hybrid search for '{query}': {len(results)} results, {total} total "
|
||||
f"(hybrid={use_hybrid})"
|
||||
)
|
||||
|
||||
return results, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hybrid search failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def hybrid_search_simple(
|
||||
query: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Simplified hybrid search for common use cases.
|
||||
|
||||
Uses default weights and no filters.
|
||||
"""
|
||||
return await hybrid_search(
|
||||
query=query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -1,41 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def sort_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Patch a FastAPI instance's `openapi()` method to sort the endpoints,
|
||||
schemas, and responses.
|
||||
"""
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = wrapped_openapi()
|
||||
|
||||
# Sort endpoints
|
||||
openapi_schema["paths"] = dict(sorted(openapi_schema["paths"].items()))
|
||||
|
||||
# Sort endpoints -> methods
|
||||
for p in openapi_schema["paths"].keys():
|
||||
openapi_schema["paths"][p] = dict(
|
||||
sorted(openapi_schema["paths"][p].items())
|
||||
)
|
||||
|
||||
# Sort endpoints -> methods -> responses
|
||||
for m in openapi_schema["paths"][p].keys():
|
||||
openapi_schema["paths"][p][m]["responses"] = dict(
|
||||
sorted(openapi_schema["paths"][p][m]["responses"].items())
|
||||
)
|
||||
|
||||
# Sort schemas and responses as well
|
||||
for k in openapi_schema["components"].keys():
|
||||
openapi_schema["components"][k] = dict(
|
||||
sorted(openapi_schema["components"][k].items())
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
@@ -36,10 +36,10 @@ def main(**kwargs):
|
||||
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
|
||||
"""
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
@@ -50,7 +49,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -82,7 +81,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
|
||||
@@ -182,10 +182,13 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
# Handle missing key, null value, or valid list value
|
||||
if isinstance(first_result, dict):
|
||||
items = first_result.get("items") or []
|
||||
else:
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
|
||||
@@ -319,7 +319,7 @@ class CostDollars(BaseModel):
|
||||
|
||||
# Helper functions for payload processing
|
||||
def process_text_field(
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None]
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None],
|
||||
) -> Optional[Union[bool, Dict[str, Any]]]:
|
||||
"""Process text field for API payload."""
|
||||
if text is None:
|
||||
@@ -400,7 +400,7 @@ def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str,
|
||||
|
||||
|
||||
def process_context_field(
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None]
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None],
|
||||
) -> Optional[Union[bool, Dict[str, int]]]:
|
||||
"""Process context field for API payload."""
|
||||
if context is None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -92,9 +92,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5_2 = "gpt-5.2-2025-12-11"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
@@ -195,9 +194,8 @@ MODEL_METADATA = {
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
|
||||
@@ -305,8 +303,6 @@ MODEL_METADATA = {
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
|
||||
}
|
||||
|
||||
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
@@ -794,7 +790,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -859,7 +855,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||
test_input={
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -1225,7 +1221,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -1321,7 +1317,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for summarizing the text.",
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
@@ -1538,7 +1534,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for the conversation.",
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
@@ -1576,7 +1572,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
],
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -1639,7 +1635,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
)
|
||||
@@ -1696,7 +1692,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
"force_json_output": False,
|
||||
|
||||
@@ -226,7 +226,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=llm.DEFAULT_LLM_MODEL,
|
||||
default=llm.LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -196,15 +196,6 @@ class TestXMLParserBlockSecurity:
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
|
||||
pass
|
||||
|
||||
async def test_rejects_text_outside_root(self):
|
||||
"""Ensure parser surfaces readable errors for invalid root text."""
|
||||
block = XMLParserBlock()
|
||||
invalid_xml = "<root><child>value</child></root> trailing"
|
||||
|
||||
with pytest.raises(ValueError, match="text outside the root element"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=invalid_xml)):
|
||||
pass
|
||||
|
||||
|
||||
class TestStoreMediaFileSecurity:
|
||||
"""Test file storage security limits."""
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestLLMStatsTracking:
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
llm_model=llm.DEFAULT_LLM_MODEL,
|
||||
llm_model=llm.LlmModel.GPT4O,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
@@ -65,7 +65,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore # type: ignore
|
||||
)
|
||||
|
||||
@@ -109,7 +109,7 @@ class TestLLMStatsTracking:
|
||||
# Run the block
|
||||
input_data = llm.AITextGeneratorBlock.Input(
|
||||
prompt="Generate text",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
@@ -228,7 +228,7 @@ class TestLLMStatsTracking:
|
||||
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text=long_text,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=100, # Small chunks
|
||||
chunk_overlap=10,
|
||||
@@ -299,7 +299,7 @@ class TestLLMStatsTracking:
|
||||
# Test with very short text (should only need 1 chunk + 1 final summary)
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="This is a short text.",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=1000, # Large enough to avoid chunking
|
||||
)
|
||||
@@ -346,7 +346,7 @@ class TestLLMStatsTracking:
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -387,7 +387,7 @@ class TestLLMStatsTracking:
|
||||
# Run the block
|
||||
input_data = llm.AIListGeneratorBlock.Input(
|
||||
focus="test items",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_retries=3,
|
||||
)
|
||||
@@ -469,7 +469,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"result": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -513,7 +513,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
style=llm.SummaryStyle.BULLET_POINTS,
|
||||
)
|
||||
@@ -558,7 +558,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
style=llm.SummaryStyle.BULLET_POINTS,
|
||||
max_tokens=1000,
|
||||
@@ -593,7 +593,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -623,7 +623,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=1000,
|
||||
)
|
||||
@@ -654,7 +654,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import ProviderName, User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
@@ -233,7 +233,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
# Create test input
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Should I continue with this task?",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -335,7 +335,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2, # Set retry to 2 for testing
|
||||
agent_mode_max_iterations=0,
|
||||
@@ -402,7 +402,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -462,7 +462,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -526,7 +526,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -648,7 +648,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
agent_mode_max_iterations=0,
|
||||
@@ -722,7 +722,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -778,7 +778,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -931,7 +931,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
# Test agent mode with max_iterations = 3
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Complete this task using tools",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
|
||||
)
|
||||
@@ -1020,7 +1020,7 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
# Test default behavior (traditional mode)
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0, # Traditional mode
|
||||
)
|
||||
|
||||
@@ -373,7 +373,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
input_data = block.input_schema(
|
||||
prompt="Create a user dictionary",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
|
||||
)
|
||||
|
||||
@@ -594,7 +594,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
input_data = block.input_schema(
|
||||
prompt="Test prompt",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
retry=3, # Allow retries
|
||||
agent_mode_max_iterations=1,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from gravitasml.parser import Parser
|
||||
from gravitasml.token import Token, tokenize
|
||||
from gravitasml.token import tokenize
|
||||
|
||||
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.model import SchemaField
|
||||
@@ -25,38 +25,6 @@ class XMLParserBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_tokens(tokens: list[Token]) -> None:
|
||||
"""Ensure the XML has a single root element and no stray text."""
|
||||
if not tokens:
|
||||
raise ValueError("XML input is empty.")
|
||||
|
||||
depth = 0
|
||||
root_seen = False
|
||||
|
||||
for token in tokens:
|
||||
if token.type == "TAG_OPEN":
|
||||
if depth == 0 and root_seen:
|
||||
raise ValueError("XML must have a single root element.")
|
||||
depth += 1
|
||||
if depth == 1:
|
||||
root_seen = True
|
||||
elif token.type == "TAG_CLOSE":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise SyntaxError("Unexpected closing tag in XML input.")
|
||||
elif token.type in {"TEXT", "ESCAPE"}:
|
||||
if depth == 0 and token.value:
|
||||
raise ValueError(
|
||||
"XML contains text outside the root element; "
|
||||
"wrap content in a single root tag."
|
||||
)
|
||||
|
||||
if depth != 0:
|
||||
raise SyntaxError("Unclosed tag detected in XML input.")
|
||||
if not root_seen:
|
||||
raise ValueError("XML must include a root element.")
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add size limits to prevent XML bomb attacks
|
||||
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
|
||||
@@ -67,9 +35,7 @@ class XMLParserBlock(Block):
|
||||
)
|
||||
|
||||
try:
|
||||
tokens = list(tokenize(input_data.input_xml))
|
||||
self._validate_tokens(tokens)
|
||||
|
||||
tokens = tokenize(input_data.input_xml)
|
||||
parser = Parser(tokens)
|
||||
parsed_result = parser.parse()
|
||||
yield "parsed_xml", parsed_result
|
||||
|
||||
@@ -111,8 +111,6 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
return parsed_url.path.split("/")[2]
|
||||
if parsed_url.path[:3] == "/v/":
|
||||
return parsed_url.path.split("/")[2]
|
||||
if parsed_url.path.startswith("/shorts/"):
|
||||
return parsed_url.path.split("/")[2]
|
||||
raise ValueError(f"Invalid YouTube URL: {url}")
|
||||
|
||||
def get_transcript(
|
||||
|
||||
@@ -244,7 +244,11 @@ def websocket(server_address: str, graph_exec_id: str):
|
||||
|
||||
import websockets.asyncio.client
|
||||
|
||||
from backend.api.ws_api import WSMessage, WSMethod, WSSubscribeGraphExecutionRequest
|
||||
from backend.server.ws_api import (
|
||||
WSMessage,
|
||||
WSMethod,
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
)
|
||||
|
||||
async def send_message(server_address: str):
|
||||
uri = f"ws://{server_address}"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"""
|
||||
Script to generate OpenAPI JSON specification for the FastAPI app.
|
||||
|
||||
This script imports the FastAPI app from backend.api.rest_api and outputs
|
||||
This script imports the FastAPI app from backend.server.rest_api and outputs
|
||||
the OpenAPI specification as JSON to stdout or a specified file.
|
||||
|
||||
Usage:
|
||||
@@ -46,7 +46,7 @@ def main(output: Path, pretty: bool):
|
||||
|
||||
def get_openapi_schema():
|
||||
"""Get the OpenAPI schema from the FastAPI app"""
|
||||
from backend.api.rest_api import app
|
||||
from backend.server.rest_api import app
|
||||
|
||||
return app.openapi()
|
||||
|
||||
|
||||
@@ -36,12 +36,13 @@ import secrets
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import click
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission
|
||||
from prisma.types import OAuthApplicationCreateInput
|
||||
|
||||
keysmith = APIKeySmith()
|
||||
|
||||
@@ -834,19 +835,22 @@ async def create_test_app_in_db(
|
||||
|
||||
# Insert into database
|
||||
app = await OAuthApplication.prisma().create(
|
||||
data={
|
||||
"id": creds["id"],
|
||||
"name": creds["name"],
|
||||
"description": creds["description"],
|
||||
"clientId": creds["client_id"],
|
||||
"clientSecret": creds["client_secret_hash"],
|
||||
"clientSecretSalt": creds["client_secret_salt"],
|
||||
"redirectUris": creds["redirect_uris"],
|
||||
"grantTypes": creds["grant_types"],
|
||||
"scopes": creds["scopes"],
|
||||
"ownerId": owner_id,
|
||||
"isActive": True,
|
||||
}
|
||||
data=cast(
|
||||
OAuthApplicationCreateInput,
|
||||
{
|
||||
"id": creds["id"],
|
||||
"name": creds["name"],
|
||||
"description": creds["description"],
|
||||
"clientId": creds["client_id"],
|
||||
"clientSecret": creds["client_secret_hash"],
|
||||
"clientSecretSalt": creds["client_secret_salt"],
|
||||
"redirectUris": creds["redirect_uris"],
|
||||
"grantTypes": creds["grant_types"],
|
||||
"scopes": creds["scopes"],
|
||||
"ownerId": owner_id,
|
||||
"isActive": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
click.echo(f"✓ Created test OAuth application: {app.clientId}")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .graph import NodeModel
|
||||
from .integrations import Webhook # noqa: F401
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, cast
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||
from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput
|
||||
from pydantic import Field
|
||||
|
||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||
@@ -82,17 +82,20 @@ async def create_api_key(
|
||||
generated_key = keysmith.generate_key()
|
||||
|
||||
saved_key_obj = await PrismaAPIKey.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
}
|
||||
data=cast(
|
||||
APIKeyCreateInput,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, cast
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission as APIPermission
|
||||
@@ -22,7 +22,12 @@ from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
||||
from prisma.models import OAuthApplication as PrismaOAuthApplication
|
||||
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
|
||||
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
|
||||
from prisma.types import OAuthApplicationUpdateInput
|
||||
from prisma.types import (
|
||||
OAuthAccessTokenCreateInput,
|
||||
OAuthApplicationUpdateInput,
|
||||
OAuthAuthorizationCodeCreateInput,
|
||||
OAuthRefreshTokenCreateInput,
|
||||
)
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from .base import APIAuthorizationInfo
|
||||
@@ -359,17 +364,20 @@ async def create_authorization_code(
|
||||
expires_at = now + AUTHORIZATION_CODE_TTL
|
||||
|
||||
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"code": code,
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
"redirectUri": redirect_uri,
|
||||
"codeChallenge": code_challenge,
|
||||
"codeChallengeMethod": code_challenge_method,
|
||||
}
|
||||
data=cast(
|
||||
OAuthAuthorizationCodeCreateInput,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"code": code,
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
"redirectUri": redirect_uri,
|
||||
"codeChallenge": code_challenge,
|
||||
"codeChallengeMethod": code_challenge_method,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthAuthorizationCodeInfo.from_db(saved_code)
|
||||
@@ -490,14 +498,17 @@ async def create_access_token(
|
||||
expires_at = now + ACCESS_TOKEN_TTL
|
||||
|
||||
saved_token = await PrismaOAuthAccessToken.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
}
|
||||
data=cast(
|
||||
OAuthAccessTokenCreateInput,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||
@@ -607,14 +618,17 @@ async def create_refresh_token(
|
||||
expires_at = now + REFRESH_TOKEN_TTL
|
||||
|
||||
saved_token = await PrismaOAuthRefreshToken.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
}
|
||||
data=cast(
|
||||
OAuthRefreshTokenCreateInput,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||
|
||||
@@ -59,13 +59,12 @@ from backend.integrations.credentials_store import (
|
||||
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 4,
|
||||
LlmModel.O3_MINI: 2,
|
||||
LlmModel.O1: 16,
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: 6,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 5,
|
||||
@@ -88,7 +87,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
|
||||
@@ -16,7 +16,6 @@ from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBala
|
||||
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.admin.model import UserHistoryResponse
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
@@ -30,6 +29,7 @@ from backend.data.model import (
|
||||
from backend.data.notifications import NotificationEventModel, RefundRequestData
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.json import SafeJson, dumps
|
||||
@@ -341,19 +341,6 @@ class UserCreditBase(ABC):
|
||||
|
||||
if result:
|
||||
# UserBalance is already updated by the CTE
|
||||
|
||||
# Clear insufficient funds notification flags when credits are added
|
||||
# so user can receive alerts again if they run out in the future.
|
||||
if transaction.amount > 0 and transaction.type in [
|
||||
CreditTransactionType.GRANT,
|
||||
CreditTransactionType.TOP_UP,
|
||||
]:
|
||||
from backend.executor.manager import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
return result[0]["balance"]
|
||||
|
||||
async def _add_transaction(
|
||||
@@ -543,22 +530,6 @@ class UserCreditBase(ABC):
|
||||
if result:
|
||||
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
|
||||
# UserBalance is already updated by the CTE
|
||||
|
||||
# Clear insufficient funds notification flags when credits are added
|
||||
# so user can receive alerts again if they run out in the future.
|
||||
if (
|
||||
amount > 0
|
||||
and is_active
|
||||
and transaction_type
|
||||
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
|
||||
):
|
||||
# Lazy import to avoid circular dependency with executor.manager
|
||||
from backend.executor.manager import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
return new_balance, tx_key
|
||||
|
||||
# If no result, either user doesn't exist or insufficient balance
|
||||
|
||||
@@ -5,12 +5,14 @@ This test was added to cover a previously untested code path that could lead to
|
||||
incorrect balance capping behavior.
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceUpsertInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -21,11 +23,14 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for ceiling tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
},
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -33,7 +38,10 @@ async def create_test_user(user_id: str) -> None:
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ without race conditions, deadlocks, or inconsistent state.
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import prisma.enums
|
||||
@@ -14,6 +15,7 @@ import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceUpsertInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
@@ -28,11 +30,14 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user with initial balance."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
},
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -41,7 +46,10 @@ async def create_test_user(user_id: str) -> None:
|
||||
# Ensure UserBalance record exists
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -342,10 +350,13 @@ async def test_integer_overflow_protection(server: SpinTestServer):
|
||||
# First, set balance near max
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": max_int - 100},
|
||||
"update": {"balance": max_int - 100},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": max_int - 100},
|
||||
"update": {"balance": max_int - 100},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
|
||||
|
||||
@@ -5,9 +5,12 @@ These tests run actual database operations to ensure SQL queries work correctly,
|
||||
which would have caught the CreditTransactionType enum casting bug.
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserCreateInput
|
||||
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
@@ -29,12 +32,15 @@ async def cleanup_test_user():
|
||||
# Create the user first
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"topUpConfig": SafeJson({}),
|
||||
"timezone": "UTC",
|
||||
}
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"topUpConfig": SafeJson({}),
|
||||
"timezone": "UTC",
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# User might already exist, that's fine
|
||||
|
||||
@@ -6,12 +6,19 @@ are atomic and maintain data consistency.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import stripe
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||
from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
UserBalanceCreateInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -35,32 +42,41 @@ async def setup_test_user_with_topup():
|
||||
|
||||
# Create user
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": REFUND_TEST_USER_ID,
|
||||
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
||||
"name": "Refund Test User",
|
||||
}
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": REFUND_TEST_USER_ID,
|
||||
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
||||
"name": "Refund Test User",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Create user balance
|
||||
await UserBalance.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"balance": 1000, # $10
|
||||
}
|
||||
data=cast(
|
||||
UserBalanceCreateInput,
|
||||
{
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"balance": 1000, # $10
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Create a top-up transaction that can be refunded
|
||||
topup_tx = await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 1000,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"transactionKey": "pi_test_12345",
|
||||
"runningBalance": 1000,
|
||||
"isActive": True,
|
||||
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
}
|
||||
data=cast(
|
||||
CreditTransactionCreateInput,
|
||||
{
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 1000,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"transactionKey": "pi_test_12345",
|
||||
"runningBalance": 1000,
|
||||
"isActive": True,
|
||||
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return topup_tx
|
||||
@@ -93,12 +109,15 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
|
||||
|
||||
# Create refund request record (simulating webhook flow)
|
||||
await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 500,
|
||||
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
||||
"reason": "Test refund",
|
||||
}
|
||||
data=cast(
|
||||
CreditRefundRequestCreateInput,
|
||||
{
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 500,
|
||||
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
||||
"reason": "Test refund",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Call deduct_credits
|
||||
@@ -286,12 +305,15 @@ async def test_concurrent_refunds(server: SpinTestServer):
|
||||
refund_requests = []
|
||||
for i in range(5):
|
||||
req = await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 100, # $1 each
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
"reason": f"Test refund {i}",
|
||||
}
|
||||
data=cast(
|
||||
CreditRefundRequestCreateInput,
|
||||
{
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 100, # $1 each
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
"reason": f"Test refund {i}",
|
||||
},
|
||||
)
|
||||
)
|
||||
refund_requests.append(req)
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, UserBalance
|
||||
from prisma.types import CreditTransactionCreateInput, UserBalanceUpsertInput
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
@@ -23,10 +25,13 @@ async def disable_test_user_transactions():
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
||||
"update": {"balance": 0, "updatedAt": old_date},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
||||
"update": {"balance": 0, "updatedAt": old_date},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -140,23 +145,29 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
|
||||
# Manually create a transaction with month 1 timestamp to establish history
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": 100,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"runningBalance": 1100,
|
||||
"isActive": True,
|
||||
"createdAt": month1, # Set specific timestamp
|
||||
}
|
||||
data=cast(
|
||||
CreditTransactionCreateInput,
|
||||
{
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": 100,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"runningBalance": 1100,
|
||||
"isActive": True,
|
||||
"createdAt": month1, # Set specific timestamp
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Update user balance to match
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
||||
"update": {"balance": 1100},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
||||
"update": {"balance": 1100},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Now test month 2 behavior
|
||||
@@ -175,14 +186,17 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
|
||||
# Create a month 2 transaction to update the last transaction time
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": -700, # Spent 700 to get to 400
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"runningBalance": 400,
|
||||
"isActive": True,
|
||||
"createdAt": month2,
|
||||
}
|
||||
data=cast(
|
||||
CreditTransactionCreateInput,
|
||||
{
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": -700, # Spent 700 to get to 400
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"runningBalance": 400,
|
||||
"isActive": True,
|
||||
"createdAt": month2,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Move to month 3
|
||||
|
||||
@@ -6,12 +6,14 @@ doesn't underflow below POSTGRES_INT_MIN, which could cause integer wraparound i
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceUpsertInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
|
||||
from backend.util.test import SpinTestServer
|
||||
@@ -21,11 +23,14 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for underflow tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
},
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -33,7 +38,10 @@ async def create_test_user(user_id: str) -> None:
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -70,10 +78,13 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance_target},
|
||||
"update": {"balance": initial_balance_target},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": initial_balance_target},
|
||||
"update": {"balance": initial_balance_target},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
@@ -110,10 +121,13 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
# Set balance to exactly POSTGRES_INT_MIN
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
||||
"update": {"balance": POSTGRES_INT_MIN},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
||||
"update": {"balance": POSTGRES_INT_MIN},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
edge_balance = await credit_system.get_credits(user_id)
|
||||
@@ -152,10 +166,13 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
|
||||
test_balance = POSTGRES_INT_MIN + 1000
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": test_balance},
|
||||
"update": {"balance": test_balance},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": test_balance},
|
||||
"update": {"balance": test_balance},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
@@ -217,10 +234,13 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
|
||||
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Apply multiple refunds that would cumulatively underflow
|
||||
@@ -295,10 +315,13 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
|
||||
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
async def large_refund(amount: int, label: str):
|
||||
|
||||
@@ -9,11 +9,13 @@ This test ensures that:
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceCreateInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import UsageTransactionMetadata, UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -24,11 +26,14 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for migration tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
},
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -121,7 +126,9 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
|
||||
try:
|
||||
# Create UserBalance with specific value
|
||||
await UserBalance.prisma().create(
|
||||
data={"userId": user_id, "balance": 5000} # $50
|
||||
data=cast(
|
||||
UserBalanceCreateInput, {"userId": user_id, "balance": 5000}
|
||||
) # $50
|
||||
)
|
||||
|
||||
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
|
||||
@@ -160,7 +167,9 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
|
||||
|
||||
try:
|
||||
# Set initial balance in UserBalance
|
||||
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
|
||||
await UserBalance.prisma().create(
|
||||
data=cast(UserBalanceCreateInput, {"userId": user_id, "balance": 1000})
|
||||
)
|
||||
|
||||
# Run concurrent operations to ensure they all use UserBalance atomic operations
|
||||
async def concurrent_spend(amount: int, label: str):
|
||||
|
||||
@@ -111,7 +111,7 @@ def get_database_schema() -> str:
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
"""Execute raw SQL query with proper schema handling."""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||
schema_prefix = f"{schema}." if schema != "public" else ""
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
@@ -28,6 +28,7 @@ from prisma.models import (
|
||||
AgentNodeExecutionKeyValueData,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
AgentGraphExecutionUpdateManyMutationInput,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
@@ -35,7 +36,6 @@ from prisma.types import (
|
||||
AgentNodeExecutionKeyValueDataCreateInput,
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
AgentNodeExecutionWhereUniqueInput,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
||||
from pydantic.fields import Field
|
||||
@@ -709,37 +709,40 @@ async def create_graph_execution(
|
||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||
"""
|
||||
result = await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
),
|
||||
"nodesInputMasks": (
|
||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||
),
|
||||
"NodeExecutions": {
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
queuedTime=datetime.now(tz=timezone.utc),
|
||||
Input={
|
||||
"create": [
|
||||
{"name": name, "data": SafeJson(data)}
|
||||
for name, data in node_input.items()
|
||||
]
|
||||
},
|
||||
)
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
data=cast(
|
||||
AgentGraphExecutionCreateInput,
|
||||
{
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
),
|
||||
"nodesInputMasks": (
|
||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||
),
|
||||
"NodeExecutions": {
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
queuedTime=datetime.now(tz=timezone.utc),
|
||||
Input={
|
||||
"create": [
|
||||
{"name": name, "data": SafeJson(data)}
|
||||
for name, data in node_input.items()
|
||||
]
|
||||
},
|
||||
)
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
"parentGraphExecutionId": parent_graph_exec_id,
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
"parentGraphExecutionId": parent_graph_exec_id,
|
||||
},
|
||||
),
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -831,10 +834,13 @@ async def upsert_execution_output(
|
||||
"""
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
"""
|
||||
data: AgentNodeExecutionInputOutputCreateInput = {
|
||||
"name": output_name,
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
}
|
||||
data: AgentNodeExecutionInputOutputCreateInput = cast(
|
||||
AgentNodeExecutionInputOutputCreateInput,
|
||||
{
|
||||
"name": output_name,
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
},
|
||||
)
|
||||
if output_data is not None:
|
||||
data["data"] = SafeJson(output_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||
@@ -974,25 +980,30 @@ async def update_node_execution_status(
|
||||
f"Invalid status transition: {status} has no valid source statuses"
|
||||
)
|
||||
|
||||
if res := await AgentNodeExecution.prisma().update(
|
||||
where=cast(
|
||||
AgentNodeExecutionWhereUniqueInput,
|
||||
{
|
||||
"id": node_exec_id,
|
||||
"executionStatus": {"in": [s.value for s in allowed_from]},
|
||||
},
|
||||
),
|
||||
# First verify the current status allows this transition
|
||||
current_exec = await AgentNodeExecution.prisma().find_unique(
|
||||
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
|
||||
)
|
||||
|
||||
if not current_exec:
|
||||
raise ValueError(f"Execution {node_exec_id} not found.")
|
||||
|
||||
# Check if current status allows the requested transition
|
||||
if current_exec.executionStatus not in allowed_from:
|
||||
# Status transition not allowed, return current state without updating
|
||||
return NodeExecutionResult.from_db(current_exec)
|
||||
|
||||
# Status transition is valid, perform the update
|
||||
updated_exec = await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
data=_get_update_status_data(status, execution_data, stats),
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
):
|
||||
return NodeExecutionResult.from_db(res)
|
||||
)
|
||||
|
||||
if res := await AgentNodeExecution.prisma().find_unique(
|
||||
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
|
||||
):
|
||||
return NodeExecutionResult.from_db(res)
|
||||
if not updated_exec:
|
||||
raise ValueError(f"Failed to update execution {node_exec_id}.")
|
||||
|
||||
raise ValueError(f"Execution {node_exec_id} not found.")
|
||||
return NodeExecutionResult.from_db(updated_exec)
|
||||
|
||||
|
||||
def _get_update_status_data(
|
||||
|
||||
@@ -6,14 +6,14 @@ import fastapi.exceptions
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.api.features.store.model as store
|
||||
from backend.api.model import CreateGraph
|
||||
import backend.server.v2.store.model as store
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.data.block import BlockSchema, BlockSchemaInput
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_user
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@@ -6,14 +6,14 @@ Handles all database operations for pending human reviews.
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
from prisma.models import PendingHumanReview
|
||||
from prisma.types import PendingHumanReviewUpdateInput
|
||||
from prisma.types import PendingHumanReviewUpdateInput, PendingHumanReviewUpsertInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.executions.review.model import (
|
||||
from backend.server.v2.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
SafeJsonData,
|
||||
)
|
||||
@@ -66,20 +66,23 @@ async def get_or_create_human_review(
|
||||
# Upsert - get existing or create new review
|
||||
review = await PendingHumanReview.prisma().upsert(
|
||||
where={"nodeExecId": node_exec_id},
|
||||
data={
|
||||
"create": {
|
||||
"userId": user_id,
|
||||
"nodeExecId": node_exec_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"graphId": graph_id,
|
||||
"graphVersion": graph_version,
|
||||
"payload": SafeJson(input_data),
|
||||
"instructions": message,
|
||||
"editable": editable,
|
||||
"status": ReviewStatus.WAITING,
|
||||
data=cast(
|
||||
PendingHumanReviewUpsertInput,
|
||||
{
|
||||
"create": {
|
||||
"userId": user_id,
|
||||
"nodeExecId": node_exec_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"graphId": graph_id,
|
||||
"graphVersion": graph_version,
|
||||
"payload": SafeJson(input_data),
|
||||
"instructions": message,
|
||||
"editable": editable,
|
||||
"status": ReviewStatus.WAITING,
|
||||
},
|
||||
"update": {}, # Do nothing on update - keep existing review as is
|
||||
},
|
||||
"update": {}, # Do nothing on update - keep existing review as is
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .db import BaseDbModel
|
||||
from .graph import NodeModel
|
||||
@@ -79,7 +79,7 @@ class WebhookWithRelations(Webhook):
|
||||
# integrations.py → library/model.py → integrations.py (for Webhook)
|
||||
# Runtime import is used in WebhookWithRelations.from_db() method instead
|
||||
# Import at runtime to avoid circular dependency
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
return WebhookWithRelations(
|
||||
**Webhook.from_db(webhook).model_dump(),
|
||||
@@ -285,8 +285,8 @@ async def unlink_webhook_from_graph(
|
||||
user_id: The ID of the user (for authorization)
|
||||
"""
|
||||
# Avoid circular imports
|
||||
from backend.api.features.library.db import set_preset_webhook
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.server.v2.library.db import set_preset_webhook
|
||||
|
||||
# Find all nodes in this graph that use this webhook
|
||||
nodes = await AgentNode.prisma().find_many(
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
from backend.api.model import NotificationPayload
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.server.model import NotificationPayload
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal, Optional, cast
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import prisma
|
||||
import pydantic
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
from prisma.types import (
|
||||
UserOnboardingCreateInput,
|
||||
UserOnboardingUpdateInput,
|
||||
UserOnboardingUpsertInput,
|
||||
)
|
||||
|
||||
from backend.api.features.store.model import StoreAgentDetails
|
||||
from backend.api.model import OnboardingNotificationPayload
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.notification_bus import (
|
||||
@@ -18,6 +20,8 @@ from backend.data.notification_bus import (
|
||||
NotificationEvent,
|
||||
)
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.server.model import OnboardingNotificationPayload
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.timezone_utils import get_user_timezone_or_utc
|
||||
@@ -112,10 +116,13 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
data=cast(
|
||||
UserOnboardingUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -442,8 +449,6 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
agentGraphVersions=agent.agentGraphVersions,
|
||||
agentGraphId=agent.agentGraphId,
|
||||
last_updated=agent.updated_at,
|
||||
)
|
||||
for agent in recommended_agents
|
||||
|
||||
@@ -1,429 +0,0 @@
|
||||
"""Data models and access layer for user business understanding."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import pydantic
|
||||
from prisma.models import UserBusinessUnderstanding
|
||||
from prisma.types import (
|
||||
UserBusinessUnderstandingCreateInput,
|
||||
UserBusinessUnderstandingUpdateInput,
|
||||
)
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache configuration
|
||||
CACHE_KEY_PREFIX = "understanding"
|
||||
CACHE_TTL_SECONDS = 48 * 60 * 60 # 48 hours
|
||||
|
||||
|
||||
def _cache_key(user_id: str) -> str:
|
||||
"""Generate cache key for user business understanding."""
|
||||
return f"{CACHE_KEY_PREFIX}:{user_id}"
|
||||
|
||||
|
||||
def _json_to_list(value: Any) -> list[str]:
|
||||
"""Convert Json field to list[str], handling None."""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return cast(list[str], value)
|
||||
return []
|
||||
|
||||
|
||||
class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
"""Input model for updating business understanding - all fields optional for incremental updates."""
|
||||
|
||||
# User info
|
||||
user_name: Optional[str] = pydantic.Field(None, description="The user's name")
|
||||
job_title: Optional[str] = pydantic.Field(None, description="The user's job title")
|
||||
|
||||
# Business basics
|
||||
business_name: Optional[str] = pydantic.Field(
|
||||
None, description="Name of the user's business"
|
||||
)
|
||||
industry: Optional[str] = pydantic.Field(None, description="Industry or sector")
|
||||
business_size: Optional[str] = pydantic.Field(
|
||||
None, description="Company size (e.g., '1-10', '11-50')"
|
||||
)
|
||||
user_role: Optional[str] = pydantic.Field(
|
||||
None,
|
||||
description="User's role in the organization (e.g., 'decision maker', 'implementer')",
|
||||
)
|
||||
|
||||
# Processes & activities
|
||||
key_workflows: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Key business workflows"
|
||||
)
|
||||
daily_activities: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Daily activities performed"
|
||||
)
|
||||
|
||||
# Pain points & goals
|
||||
pain_points: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Current pain points"
|
||||
)
|
||||
bottlenecks: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Process bottlenecks"
|
||||
)
|
||||
manual_tasks: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Manual/repetitive tasks"
|
||||
)
|
||||
automation_goals: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Desired automation goals"
|
||||
)
|
||||
|
||||
# Current tools
|
||||
current_software: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Software/tools currently used"
|
||||
)
|
||||
existing_automation: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Existing automations"
|
||||
)
|
||||
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = pydantic.Field(
|
||||
None, description="Any additional context"
|
||||
)
|
||||
|
||||
|
||||
class BusinessUnderstanding(pydantic.BaseModel):
|
||||
"""Full business understanding model returned from database."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# User info
|
||||
user_name: Optional[str] = None
|
||||
job_title: Optional[str] = None
|
||||
|
||||
# Business basics
|
||||
business_name: Optional[str] = None
|
||||
industry: Optional[str] = None
|
||||
business_size: Optional[str] = None
|
||||
user_role: Optional[str] = None
|
||||
|
||||
# Processes & activities
|
||||
key_workflows: list[str] = pydantic.Field(default_factory=list)
|
||||
daily_activities: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Pain points & goals
|
||||
pain_points: list[str] = pydantic.Field(default_factory=list)
|
||||
bottlenecks: list[str] = pydantic.Field(default_factory=list)
|
||||
manual_tasks: list[str] = pydantic.Field(default_factory=list)
|
||||
automation_goals: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Current tools
|
||||
current_software: list[str] = pydantic.Field(default_factory=list)
|
||||
existing_automation: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_record: UserBusinessUnderstanding) -> "BusinessUnderstanding":
|
||||
"""Convert database record to Pydantic model."""
|
||||
return cls(
|
||||
id=db_record.id,
|
||||
user_id=db_record.userId,
|
||||
created_at=db_record.createdAt,
|
||||
updated_at=db_record.updatedAt,
|
||||
user_name=db_record.userName,
|
||||
job_title=db_record.jobTitle,
|
||||
business_name=db_record.businessName,
|
||||
industry=db_record.industry,
|
||||
business_size=db_record.businessSize,
|
||||
user_role=db_record.userRole,
|
||||
key_workflows=_json_to_list(db_record.keyWorkflows),
|
||||
daily_activities=_json_to_list(db_record.dailyActivities),
|
||||
pain_points=_json_to_list(db_record.painPoints),
|
||||
bottlenecks=_json_to_list(db_record.bottlenecks),
|
||||
manual_tasks=_json_to_list(db_record.manualTasks),
|
||||
automation_goals=_json_to_list(db_record.automationGoals),
|
||||
current_software=_json_to_list(db_record.currentSoftware),
|
||||
existing_automation=_json_to_list(db_record.existingAutomation),
|
||||
additional_notes=db_record.additionalNotes,
|
||||
)
|
||||
|
||||
|
||||
def _merge_lists(existing: list | None, new: list | None) -> list | None:
|
||||
"""Merge two lists, removing duplicates while preserving order."""
|
||||
if new is None:
|
||||
return existing
|
||||
if existing is None:
|
||||
return new
|
||||
# Preserve order, add new items that don't exist
|
||||
merged = list(existing)
|
||||
for item in new:
|
||||
if item not in merged:
|
||||
merged.append(item)
|
||||
return merged
|
||||
|
||||
|
||||
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
|
||||
"""Get business understanding from Redis cache."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
cached_data = await redis.get(_cache_key(user_id))
|
||||
if cached_data:
|
||||
return BusinessUnderstanding.model_validate_json(cached_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get understanding from cache: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _set_cache(user_id: str, understanding: BusinessUnderstanding) -> None:
|
||||
"""Set business understanding in Redis cache with TTL."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.setex(
|
||||
_cache_key(user_id),
|
||||
CACHE_TTL_SECONDS,
|
||||
understanding.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set understanding in cache: {e}")
|
||||
|
||||
|
||||
async def _delete_cache(user_id: str) -> None:
|
||||
"""Delete business understanding from Redis cache."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(_cache_key(user_id))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete understanding from cache: {e}")
|
||||
|
||||
|
||||
async def get_business_understanding(
|
||||
user_id: str,
|
||||
) -> Optional[BusinessUnderstanding]:
|
||||
"""Get the business understanding for a user.
|
||||
|
||||
Checks cache first, falls back to database if not cached.
|
||||
Results are cached for 48 hours.
|
||||
"""
|
||||
# Try cache first
|
||||
cached = await _get_from_cache(user_id)
|
||||
if cached:
|
||||
logger.debug(f"Business understanding cache hit for user {user_id}")
|
||||
return cached
|
||||
|
||||
# Cache miss - load from database
|
||||
logger.debug(f"Business understanding cache miss for user {user_id}")
|
||||
record = await UserBusinessUnderstanding.prisma().find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
|
||||
# Store in cache for next time
|
||||
await _set_cache(user_id, understanding)
|
||||
|
||||
return understanding
|
||||
|
||||
|
||||
async def upsert_business_understanding(
|
||||
user_id: str,
|
||||
data: BusinessUnderstandingInput,
|
||||
) -> BusinessUnderstanding:
|
||||
"""
|
||||
Create or update business understanding with incremental merge strategy.
|
||||
|
||||
- String fields: new value overwrites if provided (not None)
|
||||
- List fields: new items are appended to existing (deduplicated)
|
||||
"""
|
||||
# Get existing record for merge
|
||||
existing = await UserBusinessUnderstanding.prisma().find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# Build update data with merge strategy
|
||||
update_data: UserBusinessUnderstandingUpdateInput = {}
|
||||
create_data: dict[str, Any] = {"userId": user_id}
|
||||
|
||||
# String fields - overwrite if provided
|
||||
if data.user_name is not None:
|
||||
update_data["userName"] = data.user_name
|
||||
create_data["userName"] = data.user_name
|
||||
if data.job_title is not None:
|
||||
update_data["jobTitle"] = data.job_title
|
||||
create_data["jobTitle"] = data.job_title
|
||||
if data.business_name is not None:
|
||||
update_data["businessName"] = data.business_name
|
||||
create_data["businessName"] = data.business_name
|
||||
if data.industry is not None:
|
||||
update_data["industry"] = data.industry
|
||||
create_data["industry"] = data.industry
|
||||
if data.business_size is not None:
|
||||
update_data["businessSize"] = data.business_size
|
||||
create_data["businessSize"] = data.business_size
|
||||
if data.user_role is not None:
|
||||
update_data["userRole"] = data.user_role
|
||||
create_data["userRole"] = data.user_role
|
||||
if data.additional_notes is not None:
|
||||
update_data["additionalNotes"] = data.additional_notes
|
||||
create_data["additionalNotes"] = data.additional_notes
|
||||
|
||||
# List fields - merge with existing
|
||||
if data.key_workflows is not None:
|
||||
existing_list = _json_to_list(existing.keyWorkflows) if existing else None
|
||||
merged = _merge_lists(existing_list, data.key_workflows)
|
||||
update_data["keyWorkflows"] = SafeJson(merged)
|
||||
create_data["keyWorkflows"] = SafeJson(merged)
|
||||
|
||||
if data.daily_activities is not None:
|
||||
existing_list = _json_to_list(existing.dailyActivities) if existing else None
|
||||
merged = _merge_lists(existing_list, data.daily_activities)
|
||||
update_data["dailyActivities"] = SafeJson(merged)
|
||||
create_data["dailyActivities"] = SafeJson(merged)
|
||||
|
||||
if data.pain_points is not None:
|
||||
existing_list = _json_to_list(existing.painPoints) if existing else None
|
||||
merged = _merge_lists(existing_list, data.pain_points)
|
||||
update_data["painPoints"] = SafeJson(merged)
|
||||
create_data["painPoints"] = SafeJson(merged)
|
||||
|
||||
if data.bottlenecks is not None:
|
||||
existing_list = _json_to_list(existing.bottlenecks) if existing else None
|
||||
merged = _merge_lists(existing_list, data.bottlenecks)
|
||||
update_data["bottlenecks"] = SafeJson(merged)
|
||||
create_data["bottlenecks"] = SafeJson(merged)
|
||||
|
||||
if data.manual_tasks is not None:
|
||||
existing_list = _json_to_list(existing.manualTasks) if existing else None
|
||||
merged = _merge_lists(existing_list, data.manual_tasks)
|
||||
update_data["manualTasks"] = SafeJson(merged)
|
||||
create_data["manualTasks"] = SafeJson(merged)
|
||||
|
||||
if data.automation_goals is not None:
|
||||
existing_list = _json_to_list(existing.automationGoals) if existing else None
|
||||
merged = _merge_lists(existing_list, data.automation_goals)
|
||||
update_data["automationGoals"] = SafeJson(merged)
|
||||
create_data["automationGoals"] = SafeJson(merged)
|
||||
|
||||
if data.current_software is not None:
|
||||
existing_list = _json_to_list(existing.currentSoftware) if existing else None
|
||||
merged = _merge_lists(existing_list, data.current_software)
|
||||
update_data["currentSoftware"] = SafeJson(merged)
|
||||
create_data["currentSoftware"] = SafeJson(merged)
|
||||
|
||||
if data.existing_automation is not None:
|
||||
existing_list = _json_to_list(existing.existingAutomation) if existing else None
|
||||
merged = _merge_lists(existing_list, data.existing_automation)
|
||||
update_data["existingAutomation"] = SafeJson(merged)
|
||||
create_data["existingAutomation"] = SafeJson(merged)
|
||||
|
||||
# Upsert
|
||||
record = await UserBusinessUnderstanding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": UserBusinessUnderstandingCreateInput(**create_data),
|
||||
"update": update_data,
|
||||
},
|
||||
)
|
||||
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
|
||||
# Update cache with new understanding
|
||||
await _set_cache(user_id, understanding)
|
||||
|
||||
return understanding
|
||||
|
||||
|
||||
async def clear_business_understanding(user_id: str) -> bool:
|
||||
"""Clear/delete business understanding for a user from both DB and cache."""
|
||||
# Delete from cache first
|
||||
await _delete_cache(user_id)
|
||||
|
||||
try:
|
||||
await UserBusinessUnderstanding.prisma().delete(where={"userId": user_id})
|
||||
return True
|
||||
except Exception:
|
||||
# Record might not exist
|
||||
return False
|
||||
|
||||
|
||||
def format_understanding_for_prompt(understanding: BusinessUnderstanding) -> str:
|
||||
"""Format business understanding as text for system prompt injection."""
|
||||
sections = []
|
||||
|
||||
# User info section
|
||||
user_info = []
|
||||
if understanding.user_name:
|
||||
user_info.append(f"Name: {understanding.user_name}")
|
||||
if understanding.job_title:
|
||||
user_info.append(f"Job Title: {understanding.job_title}")
|
||||
if user_info:
|
||||
sections.append("## User\n" + "\n".join(user_info))
|
||||
|
||||
# Business section
|
||||
business_info = []
|
||||
if understanding.business_name:
|
||||
business_info.append(f"Company: {understanding.business_name}")
|
||||
if understanding.industry:
|
||||
business_info.append(f"Industry: {understanding.industry}")
|
||||
if understanding.business_size:
|
||||
business_info.append(f"Size: {understanding.business_size}")
|
||||
if understanding.user_role:
|
||||
business_info.append(f"Role Context: {understanding.user_role}")
|
||||
if business_info:
|
||||
sections.append("## Business\n" + "\n".join(business_info))
|
||||
|
||||
# Processes section
|
||||
processes = []
|
||||
if understanding.key_workflows:
|
||||
processes.append(f"Key Workflows: {', '.join(understanding.key_workflows)}")
|
||||
if understanding.daily_activities:
|
||||
processes.append(
|
||||
f"Daily Activities: {', '.join(understanding.daily_activities)}"
|
||||
)
|
||||
if processes:
|
||||
sections.append("## Processes\n" + "\n".join(processes))
|
||||
|
||||
# Pain points section
|
||||
pain_points = []
|
||||
if understanding.pain_points:
|
||||
pain_points.append(f"Pain Points: {', '.join(understanding.pain_points)}")
|
||||
if understanding.bottlenecks:
|
||||
pain_points.append(f"Bottlenecks: {', '.join(understanding.bottlenecks)}")
|
||||
if understanding.manual_tasks:
|
||||
pain_points.append(f"Manual Tasks: {', '.join(understanding.manual_tasks)}")
|
||||
if pain_points:
|
||||
sections.append("## Pain Points\n" + "\n".join(pain_points))
|
||||
|
||||
# Goals section
|
||||
if understanding.automation_goals:
|
||||
sections.append(
|
||||
"## Automation Goals\n"
|
||||
+ "\n".join(f"- {goal}" for goal in understanding.automation_goals)
|
||||
)
|
||||
|
||||
# Current tools section
|
||||
tools_info = []
|
||||
if understanding.current_software:
|
||||
tools_info.append(
|
||||
f"Current Software: {', '.join(understanding.current_software)}"
|
||||
)
|
||||
if understanding.existing_automation:
|
||||
tools_info.append(
|
||||
f"Existing Automation: {', '.join(understanding.existing_automation)}"
|
||||
)
|
||||
if tools_info:
|
||||
sections.append("## Current Tools\n" + "\n".join(tools_info))
|
||||
|
||||
# Additional notes
|
||||
if understanding.additional_notes:
|
||||
sections.append(f"## Additional Context\n{understanding.additional_notes}")
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return "# User Business Context\n\n" + "\n\n".join(sections)
|
||||
@@ -2,11 +2,6 @@ import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.api.features.library.db import (
|
||||
add_store_agent_to_library,
|
||||
list_library_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -66,6 +61,8 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.server.v2.library.db import add_store_agent_to_library, list_library_agents
|
||||
from backend.server.v2.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
|
||||
@@ -48,8 +48,27 @@ from backend.data.notifications import (
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_ROUTING_KEY,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
@@ -76,24 +95,7 @@ from backend.util.retry import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .activity_status_generator import generate_activity_status_for_execution
|
||||
from .automod.manager import automod_manager
|
||||
from .cluster_lock import ClusterLock
|
||||
from .utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_ROUTING_KEY,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
validate_exec,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
@@ -114,40 +116,6 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
# Redis key prefix for tracking insufficient funds Discord notifications.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
|
||||
# TTL for the notification flag (30 days) - acts as a fallback cleanup
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
|
||||
|
||||
|
||||
async def clear_insufficient_funds_notifications(user_id: str) -> int:
|
||||
"""
|
||||
Clear all insufficient funds notification flags for a user.
|
||||
|
||||
This should be called when a user tops up their credits, allowing
|
||||
Discord notifications to be sent again if they run out of funds.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to clear notifications for.
|
||||
|
||||
Returns:
|
||||
The number of keys that were deleted.
|
||||
"""
|
||||
try:
|
||||
redis_client = await redis.get_redis_async()
|
||||
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
keys = [key async for key in redis_client.scan_iter(match=pattern)]
|
||||
if keys:
|
||||
return await redis_client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear insufficient funds notification flags for user "
|
||||
f"{user_id}: {e}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
@@ -1295,40 +1263,12 @@ class ExecutionProcessor:
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
# Check if we've already sent a notification for this user+agent combo.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
try:
|
||||
redis_client = redis.get_redis()
|
||||
# SET NX returns True only if the key was newly set (didn't exist)
|
||||
is_new_notification = redis_client.set(
|
||||
redis_key,
|
||||
"1",
|
||||
nx=True,
|
||||
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
|
||||
)
|
||||
if not is_new_notification:
|
||||
# Already notified for this user+agent, skip all notifications
|
||||
logger.debug(
|
||||
f"Skipping duplicate insufficient funds notification for "
|
||||
f"user={user_id}, graph={graph_id}"
|
||||
)
|
||||
return
|
||||
except Exception as redis_error:
|
||||
# If Redis fails, log and continue to send the notification
|
||||
# (better to occasionally duplicate than to never notify)
|
||||
logger.warning(
|
||||
f"Failed to check/set insufficient funds notification flag in Redis: "
|
||||
f"{redis_error}"
|
||||
)
|
||||
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
|
||||
# Queue user email notification
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
@@ -1342,7 +1282,6 @@ class ExecutionProcessor:
|
||||
)
|
||||
)
|
||||
|
||||
# Send Discord system alert
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
|
||||
|
||||
@@ -1,560 +0,0 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import ZeroBalanceData
|
||||
from backend.executor.manager import (
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
|
||||
ExecutionProcessor,
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def async_iter(items):
|
||||
"""Helper to create an async iterator from a list."""
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_sends_discord_alert_first_time(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that the first insufficient funds notification sends a Discord alert."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72, # $0.72
|
||||
amount=-714, # Attempting to spend $7.14
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to simulate first-time notification (set returns True)
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.return_value = True # Key was newly set
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify notification was queued
|
||||
mock_queue_notif.assert_called_once()
|
||||
notification_call = mock_queue_notif.call_args[0][0]
|
||||
assert notification_call.type == NotificationType.ZERO_BALANCE
|
||||
assert notification_call.user_id == user_id
|
||||
assert isinstance(notification_call.data, ZeroBalanceData)
|
||||
assert notification_call.data.current_balance == 72
|
||||
|
||||
# Verify Redis was checked with correct key pattern
|
||||
expected_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
mock_redis_client.set.assert_called_once()
|
||||
call_args = mock_redis_client.set.call_args
|
||||
assert call_args[0][0] == expected_key
|
||||
assert call_args[1]["nx"] is True
|
||||
|
||||
# Verify Discord alert was sent
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
discord_message = mock_client.discord_system_alert.call_args[0][0]
|
||||
assert "Insufficient Funds Alert" in discord_message
|
||||
assert "test@example.com" in discord_message
|
||||
assert "Test Agent" in discord_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_skips_duplicate_notifications(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to simulate duplicate notification (set returns False/None)
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.return_value = None # Key already existed
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify email notification was NOT queued (deduplication worked)
|
||||
mock_queue_notif.assert_not_called()
|
||||
|
||||
# Verify Discord alert was NOT sent (deduplication worked)
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that different agents for the same user get separate Discord alerts."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id_1 = "test-graph-111"
|
||||
graph_id_2 = "test-graph-222"
|
||||
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch("backend.executor.manager.queue_notification"), patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
# Both calls return True (first time for each agent)
|
||||
mock_redis_client.set.return_value = True
|
||||
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# First agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_1,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Second agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_2,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify Discord alerts were sent for both agents
|
||||
assert mock_client.discord_system_alert.call_count == 2
|
||||
|
||||
# Verify Redis was called with different keys
|
||||
assert mock_redis_client.set.call_count == 2
|
||||
calls = mock_redis_client.set.call_args_list
|
||||
assert (
|
||||
calls[0][0][0]
|
||||
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_1}"
|
||||
)
|
||||
assert (
|
||||
calls[1][0][0]
|
||||
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_2}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
|
||||
"""Test that clearing notifications removes all keys for a user."""
|
||||
|
||||
user_id = "test-user-123"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
# Mock scan_iter to return some keys as an async iterator
|
||||
mock_keys = [
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1",
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-2",
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-3",
|
||||
]
|
||||
mock_redis_client.scan_iter.return_value = async_iter(mock_keys)
|
||||
# delete is awaited, so use AsyncMock
|
||||
mock_redis_client.delete = AsyncMock(return_value=3)
|
||||
|
||||
# Clear notifications
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify correct pattern was used
|
||||
expected_pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
mock_redis_client.scan_iter.assert_called_once_with(match=expected_pattern)
|
||||
|
||||
# Verify delete was called with all keys
|
||||
mock_redis_client.delete.assert_called_once_with(*mock_keys)
|
||||
|
||||
# Verify return value
|
||||
assert result == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestServer):
|
||||
"""Test clearing notifications when there are no keys to clear."""
|
||||
|
||||
user_id = "test-user-no-notifications"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
# Mock scan_iter to return no keys as an async iterator
|
||||
mock_redis_client.scan_iter.return_value = async_iter([])
|
||||
|
||||
# Clear notifications
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify delete was not called
|
||||
mock_redis_client.delete.assert_not_called()
|
||||
|
||||
# Verify return value
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications_handles_redis_error(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that clearing notifications handles Redis errors gracefully."""
|
||||
|
||||
user_id = "test-user-redis-error"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
# Mock get_redis_async to raise an error
|
||||
mock_redis_module.get_redis_async = AsyncMock(
|
||||
side_effect=Exception("Redis connection failed")
|
||||
)
|
||||
|
||||
# Clear notifications should not raise, just return 0
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify it returned 0 (graceful failure)
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_continues_on_redis_error(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that both email and Discord notifications are still sent when Redis fails."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to raise an error
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.side_effect = Exception("Redis connection error")
|
||||
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify email notification was still queued despite Redis error
|
||||
mock_queue_notif.assert_called_once()
|
||||
|
||||
# Verify Discord alert was still sent despite Redis error
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_clears_notifications_on_grant(server: SpinTestServer):
|
||||
"""Test that _add_transaction clears notification flags when adding GRANT credits."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-grant-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 1000, "transactionKey": "test-tx-key"}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter(
|
||||
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
|
||||
)
|
||||
mock_redis_client.delete = AsyncMock(return_value=1)
|
||||
|
||||
# Create a concrete instance
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with GRANT type (should clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500, # Positive amount
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
is_active=True, # Active transaction
|
||||
)
|
||||
|
||||
# Verify notification clearing was called
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
mock_redis_client.scan_iter.assert_called_once_with(
|
||||
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestServer):
|
||||
"""Test that _add_transaction clears notification flags when adding TOP_UP credits."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-topup-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 2000, "transactionKey": "test-tx-key-2"}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter([])
|
||||
mock_redis_client.delete = AsyncMock(return_value=0)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with TOP_UP type (should clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000, # Positive amount
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify notification clearing was attempted
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_skips_clearing_for_inactive_transaction(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that _add_transaction does NOT clear notifications for inactive transactions."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-inactive"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 500, "transactionKey": "test-tx-key-3"}]
|
||||
|
||||
# Mock async Redis
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with is_active=False (should NOT clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
is_active=False, # Inactive - pending Stripe payment
|
||||
)
|
||||
|
||||
# Verify notification clearing was NOT called
|
||||
mock_redis_module.get_redis_async.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_skips_clearing_for_usage_transaction(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that _add_transaction does NOT clear notifications for USAGE transactions."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-usage"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 400, "transactionKey": "test-tx-key-4"}]
|
||||
|
||||
# Mock async Redis
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with USAGE type (spending, should NOT clear)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=-100, # Negative - spending credits
|
||||
transaction_type=CreditTransactionType.USAGE,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify notification clearing was NOT called
|
||||
mock_redis_module.get_redis_async.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enable_transaction_clears_notifications(server: SpinTestServer):
|
||||
"""Test that _enable_transaction clears notification flags when enabling a TOP_UP."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-enable"
|
||||
|
||||
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
|
||||
"backend.data.credit.query_raw_with_schema"
|
||||
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
# Mock finding the pending transaction
|
||||
mock_transaction = MagicMock()
|
||||
mock_transaction.amount = 1000
|
||||
mock_transaction.type = CreditTransactionType.TOP_UP
|
||||
mock_credit_tx.prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_transaction
|
||||
)
|
||||
|
||||
# Mock the query to return updated balance
|
||||
mock_query.return_value = [{"balance": 1500}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter(
|
||||
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
|
||||
)
|
||||
mock_redis_client.delete = AsyncMock(return_value=1)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _enable_transaction (simulates Stripe checkout completion)
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
await credit_model._enable_transaction(
|
||||
transaction_key="cs_test_123",
|
||||
user_id=user_id,
|
||||
metadata=SafeJson({"payment": "completed"}),
|
||||
)
|
||||
|
||||
# Verify notification clearing was called
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
mock_redis_client.scan_iter.assert_called_once_with(
|
||||
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
)
|
||||
@@ -3,16 +3,16 @@ import logging
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
|
||||
import backend.api.features.library.model
|
||||
import backend.api.features.store.model
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.api.rest_api import AgentServer
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.model
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.data_manipulation import FindInDictionaryBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.data.model import User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
@@ -356,7 +356,7 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
|
||||
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
graph_id=test_graph.id,
|
||||
@@ -444,7 +444,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
|
||||
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
graph_id=test_graph.id,
|
||||
@@ -485,7 +485,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
|
||||
store_submission_request = backend.api.features.store.model.StoreSubmissionRequest(
|
||||
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
slug=test_graph.id,
|
||||
@@ -514,7 +514,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
|
||||
admin_user = await create_test_user(alt_user=True)
|
||||
await server.agent_server.test_review_store_listing(
|
||||
backend.api.features.store.model.ReviewSubmissionRequest(
|
||||
backend.server.v2.store.model.ReviewSubmissionRequest(
|
||||
store_listing_version_id=slv_id,
|
||||
is_approved=True,
|
||||
comments="Test comments",
|
||||
@@ -523,7 +523,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
)
|
||||
|
||||
# Add the approved store listing to the admin user's library so they can execute it
|
||||
from backend.api.features.library.db import add_store_agent_to_library
|
||||
from backend.server.v2.library.db import add_store_agent_to_library
|
||||
|
||||
await add_store_agent_to_library(
|
||||
store_listing_version_id=slv_id, user_id=admin_user.id
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user