mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 00:28:31 -05:00
Compare commits
24 Commits
hackathon-
...
native-aut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87e3d7eaad | ||
|
|
974c14a7b9 | ||
|
|
af014ea19d | ||
|
|
9ecf8bcb08 | ||
|
|
a7a521cedd | ||
|
|
84244c0b56 | ||
|
|
9e83985b5b | ||
|
|
4ef3eab89d | ||
|
|
c68b53b6c1 | ||
|
|
23fb3ad8a4 | ||
|
|
175ba13ebe | ||
|
|
a415f471c6 | ||
|
|
3dd6e5cb04 | ||
|
|
3f1e66b317 | ||
|
|
8f722bd9cd | ||
|
|
65026fc9d3 | ||
|
|
af98bc1081 | ||
|
|
e92459fc5f | ||
|
|
1775286f59 | ||
|
|
f6af700f1a | ||
|
|
a80b06d459 | ||
|
|
17c9e7c8b4 | ||
|
|
f83c9391c8 | ||
|
|
7a0a90e421 |
@@ -16,7 +16,6 @@
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
|
||||
8
.github/copilot-instructions.md
vendored
8
.github/copilot-instructions.md
vendored
@@ -142,7 +142,7 @@ pnpm storybook # Start component development server
|
||||
### Security & Middleware
|
||||
|
||||
**Cache Protection**: Backend includes middleware preventing sensitive data caching in browsers/proxies
|
||||
**Authentication**: JWT-based with Supabase integration
|
||||
**Authentication**: JWT-based with native authentication
|
||||
**User ID Validation**: All data access requires user ID checks - verify this for any `data/*.py` changes
|
||||
|
||||
### Development Workflow
|
||||
@@ -168,9 +168,9 @@ pnpm storybook # Start component development server
|
||||
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
- `frontend/src/lib/auth/` - Authentication client
|
||||
|
||||
**Protected Routes**: Update `frontend/lib/supabase/middleware.ts` when adding protected routes
|
||||
**Protected Routes**: Update `frontend/middleware.ts` when adding protected routes
|
||||
|
||||
### Agent Block System
|
||||
|
||||
@@ -194,7 +194,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
|
||||
8
.github/workflows/claude-dependabot.yml
vendored
8
.github/workflows/claude-dependabot.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
@@ -144,11 +144,7 @@ jobs:
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
"pgvector/pgvector:pg18"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
|
||||
8
.github/workflows/claude.yml
vendored
8
.github/workflows/claude.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
@@ -160,11 +160,7 @@ jobs:
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
"pgvector/pgvector:pg18"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
|
||||
18
.github/workflows/copilot-setup-steps.yml
vendored
18
.github/workflows/copilot-setup-steps.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
@@ -108,16 +108,6 @@ jobs:
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
# Remove large unused tools to free disk space for Docker builds
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -152,11 +142,7 @@ jobs:
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
"pgvector/pgvector:pg18"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
|
||||
46
.github/workflows/platform-backend-ci.yml
vendored
46
.github/workflows/platform-backend-ci.yml
vendored
@@ -2,13 +2,13 @@ name: AutoGPT Platform - Backend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev, ci-test*]
|
||||
branches: [master, dev, ci-test*, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
pull_request:
|
||||
branches: [master, dev, release-*]
|
||||
branches: [master, dev, release-*, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- "autogpt_platform/backend/**"
|
||||
@@ -36,6 +36,19 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg18
|
||||
ports:
|
||||
- 5432:5432
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: your-super-secret-and-long-postgres-password
|
||||
POSTGRES_DB: postgres
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U postgres"
|
||||
--health-interval 5s
|
||||
--health-timeout 5s
|
||||
--health-retries 10
|
||||
redis:
|
||||
image: redis:latest
|
||||
ports:
|
||||
@@ -78,11 +91,6 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup Supabase
|
||||
uses: supabase/setup-cli@v1
|
||||
with:
|
||||
version: 1.178.1
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
@@ -134,17 +142,7 @@ jobs:
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
working-directory: .
|
||||
run: |
|
||||
supabase init
|
||||
supabase start --exclude postgres-meta,realtime,storage-api,imgproxy,inbucket,studio,edge-runtime,logflare,vector,supavisor
|
||||
supabase status -o env | sed 's/="/=/; s/"$//' >> $GITHUB_OUTPUT
|
||||
# outputs:
|
||||
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
|
||||
run: poetry run prisma generate
|
||||
|
||||
- name: Wait for ClamAV to be ready
|
||||
run: |
|
||||
@@ -178,8 +176,8 @@ jobs:
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DATABASE_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
|
||||
DIRECT_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
@@ -195,11 +193,9 @@ jobs:
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
DATABASE_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
|
||||
DIRECT_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
|
||||
JWT_SECRET: your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
5
.github/workflows/platform-frontend-ci.yml
vendored
5
.github/workflows/platform-frontend-ci.yml
vendored
@@ -2,11 +2,12 @@ name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
branches: [master, dev, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
pull_request:
|
||||
branches: [master, dev, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
@@ -147,7 +148,7 @@ jobs:
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
- name: Copy default platform .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
|
||||
56
.github/workflows/platform-fullstack-ci.yml
vendored
56
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -1,12 +1,13 @@
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
name: AutoGPT Platform - Fullstack CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
branches: [master, dev, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
branches: [master, dev, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
@@ -58,14 +59,11 @@ jobs:
|
||||
types:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
timeout-minutes: 10
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
@@ -75,18 +73,6 @@ jobs:
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
@@ -101,36 +87,12 @@ jobs:
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api:force
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
run: pnpm generate:api
|
||||
|
||||
- name: Run Typescript checks
|
||||
run: pnpm types
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
@@ -49,5 +49,5 @@ Use conventional commit messages for all commits (e.g. `feat(backend): add API`)
|
||||
- Keep out-of-scope changes under 20% of the PR.
|
||||
- Ensure PR descriptions are complete.
|
||||
- For changes touching `data/*.py`, validate user ID checks or explain why not needed.
|
||||
- If adding protected frontend routes, update `frontend/lib/supabase/middleware.ts`.
|
||||
- If adding protected frontend routes, update `frontend/lib/auth/helpers.ts`.
|
||||
- Use the linear ticket branch structure if given codex/open-1668-resume-dropped-runs
|
||||
|
||||
@@ -5,12 +5,6 @@
|
||||
|
||||
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
|
||||
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
DASHBOARD_USERNAME=supabase
|
||||
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
|
||||
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
|
||||
VAULT_ENC_KEY=your-encryption-key-32-chars-min
|
||||
|
||||
|
||||
############
|
||||
@@ -24,100 +18,31 @@ POSTGRES_PORT=5432
|
||||
|
||||
|
||||
############
|
||||
# Supavisor -- Database pooler
|
||||
############
|
||||
POOLER_PROXY_PORT_TRANSACTION=6543
|
||||
POOLER_DEFAULT_POOL_SIZE=20
|
||||
POOLER_MAX_CLIENT_CONN=100
|
||||
POOLER_TENANT_ID=your-tenant-id
|
||||
|
||||
|
||||
############
|
||||
# API Proxy - Configuration for the Kong Reverse proxy.
|
||||
# Auth - Native authentication configuration
|
||||
############
|
||||
|
||||
KONG_HTTP_PORT=8000
|
||||
KONG_HTTPS_PORT=8443
|
||||
|
||||
|
||||
############
|
||||
# API - Configuration for PostgREST.
|
||||
############
|
||||
|
||||
PGRST_DB_SCHEMAS=public,storage,graphql_public
|
||||
|
||||
|
||||
############
|
||||
# Auth - Configuration for the GoTrue authentication server.
|
||||
############
|
||||
|
||||
## General
|
||||
SITE_URL=http://localhost:3000
|
||||
ADDITIONAL_REDIRECT_URLS=
|
||||
JWT_EXPIRY=3600
|
||||
DISABLE_SIGNUP=false
|
||||
API_EXTERNAL_URL=http://localhost:8000
|
||||
|
||||
## Mailer Config
|
||||
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
|
||||
MAILER_URLPATHS_INVITE="/auth/v1/verify"
|
||||
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
|
||||
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
|
||||
# JWT token configuration
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=15
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
JWT_ISSUER=autogpt-platform
|
||||
|
||||
## Email auth
|
||||
ENABLE_EMAIL_SIGNUP=true
|
||||
ENABLE_EMAIL_AUTOCONFIRM=false
|
||||
SMTP_ADMIN_EMAIL=admin@example.com
|
||||
SMTP_HOST=supabase-mail
|
||||
SMTP_PORT=2500
|
||||
SMTP_USER=fake_mail_user
|
||||
SMTP_PASS=fake_mail_password
|
||||
SMTP_SENDER_NAME=fake_sender
|
||||
ENABLE_ANONYMOUS_USERS=false
|
||||
|
||||
## Phone auth
|
||||
ENABLE_PHONE_SIGNUP=true
|
||||
ENABLE_PHONE_AUTOCONFIRM=true
|
||||
# Google OAuth (optional)
|
||||
GOOGLE_CLIENT_ID=
|
||||
GOOGLE_CLIENT_SECRET=
|
||||
|
||||
|
||||
############
|
||||
# Studio - Configuration for the Dashboard
|
||||
# Email configuration (optional)
|
||||
############
|
||||
|
||||
STUDIO_DEFAULT_ORGANIZATION=Default Organization
|
||||
STUDIO_DEFAULT_PROJECT=Default Project
|
||||
SMTP_HOST=
|
||||
SMTP_PORT=587
|
||||
SMTP_USER=
|
||||
SMTP_PASS=
|
||||
SMTP_FROM_EMAIL=noreply@example.com
|
||||
|
||||
STUDIO_PORT=3000
|
||||
# replace if you intend to use Studio outside of localhost
|
||||
SUPABASE_PUBLIC_URL=http://localhost:8000
|
||||
|
||||
# Enable webp support
|
||||
IMGPROXY_ENABLE_WEBP_DETECTION=true
|
||||
|
||||
# Add your OpenAI API key to enable SQL Editor Assistant
|
||||
OPENAI_API_KEY=
|
||||
|
||||
|
||||
############
|
||||
# Functions - Configuration for Functions
|
||||
############
|
||||
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
|
||||
FUNCTIONS_VERIFY_JWT=false
|
||||
|
||||
|
||||
############
|
||||
# Logs - Configuration for Logflare
|
||||
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
|
||||
############
|
||||
|
||||
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Change vector.toml sinks to reflect this change
|
||||
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Docker socket location - this value will differ depending on your OS
|
||||
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
|
||||
|
||||
# Google Cloud Project details
|
||||
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
|
||||
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend load-store-agents
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
# Run just PostgreSQL + Redis + RabbitMQ + ClamAV
|
||||
start-core:
|
||||
docker compose up -d deps
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop
|
||||
docker compose stop deps
|
||||
|
||||
reset-db:
|
||||
docker compose stop db
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
@@ -35,7 +33,6 @@ init-env:
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
@@ -52,7 +49,7 @@ load-store-agents:
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
|
||||
@echo " start-core - Start just the core services (PostgreSQL, Redis, RabbitMQ, ClamAV) in background"
|
||||
@echo " stop-core - Stop the core services"
|
||||
@echo " reset-db - Reset the database by deleting the volume"
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
@@ -61,4 +58,4 @@ help:
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@@ -16,17 +16,37 @@ ALGO_RECOMMENDATION = (
|
||||
"We highly recommend using an asymmetric algorithm such as ES256, "
|
||||
"because when leaked, a shared secret would allow anyone to "
|
||||
"forge valid tokens and impersonate users. "
|
||||
"More info: https://supabase.com/docs/guides/auth/signing-keys#choosing-the-right-signing-algorithm" # noqa
|
||||
"More info: https://pyjwt.readthedocs.io/en/stable/algorithms.html"
|
||||
)
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self):
|
||||
# JWT verification key (public key for asymmetric, shared secret for symmetric)
|
||||
self.JWT_VERIFY_KEY: str = os.getenv(
|
||||
"JWT_VERIFY_KEY", os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
).strip()
|
||||
|
||||
# JWT signing key (private key for asymmetric, shared secret for symmetric)
|
||||
# Falls back to JWT_VERIFY_KEY for symmetric algorithms like HS256
|
||||
self.JWT_SIGN_KEY: str = os.getenv("JWT_SIGN_KEY", self.JWT_VERIFY_KEY).strip()
|
||||
|
||||
self.JWT_ALGORITHM: str = os.getenv("JWT_SIGN_ALGORITHM", "HS256").strip()
|
||||
|
||||
# Token expiration settings
|
||||
self.ACCESS_TOKEN_EXPIRE_MINUTES: int = int(
|
||||
os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "15")
|
||||
)
|
||||
self.REFRESH_TOKEN_EXPIRE_DAYS: int = int(
|
||||
os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7")
|
||||
)
|
||||
|
||||
# JWT issuer claim
|
||||
self.JWT_ISSUER: str = os.getenv("JWT_ISSUER", "autogpt-platform").strip()
|
||||
|
||||
# JWT audience claim
|
||||
self.JWT_AUDIENCE: str = os.getenv("JWT_AUDIENCE", "authenticated").strip()
|
||||
|
||||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
|
||||
@@ -1,25 +1,29 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Patch a FastAPI instance's `openapi()` method to add 401 responses
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
# Wrap current method to allow stacking OpenAPI schema modifiers like this
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = wrapped_openapi()
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
@@ -16,6 +20,57 @@ bearer_jwt_auth = HTTPBearer(
|
||||
)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: str,
|
||||
email: str,
|
||||
role: str = "authenticated",
|
||||
email_verified: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a new JWT access token.
|
||||
|
||||
:param user_id: The user's unique identifier
|
||||
:param email: The user's email address
|
||||
:param role: The user's role (default: "authenticated")
|
||||
:param email_verified: Whether the user's email is verified
|
||||
:return: Encoded JWT token
|
||||
"""
|
||||
settings = get_settings()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"role": role,
|
||||
"email_verified": email_verified,
|
||||
"aud": settings.JWT_AUDIENCE,
|
||||
"iss": settings.JWT_ISSUER,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
"jti": str(uuid.uuid4()), # Unique token ID
|
||||
}
|
||||
|
||||
return jwt.encode(payload, settings.JWT_SIGN_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token() -> tuple[str, str]:
|
||||
"""
|
||||
Generate a new refresh token.
|
||||
|
||||
Returns a tuple of (raw_token, hashed_token).
|
||||
The raw token should be sent to the client.
|
||||
The hashed token should be stored in the database.
|
||||
"""
|
||||
raw_token = secrets.token_urlsafe(64)
|
||||
hashed_token = hashlib.sha256(raw_token.encode()).hexdigest()
|
||||
return raw_token, hashed_token
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash a token using SHA-256."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
async def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
@@ -52,11 +107,19 @@ def parse_jwt_token(token: str) -> dict[str, Any]:
|
||||
"""
|
||||
settings = get_settings()
|
||||
try:
|
||||
# Build decode options
|
||||
options = {
|
||||
"verify_aud": True,
|
||||
"verify_iss": bool(settings.JWT_ISSUER),
|
||||
}
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_VERIFY_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience="authenticated",
|
||||
audience=settings.JWT_AUDIENCE,
|
||||
issuer=settings.JWT_ISSUER if settings.JWT_ISSUER else None,
|
||||
options=options,
|
||||
)
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
|
||||
@@ -11,6 +11,7 @@ class User:
|
||||
email: str
|
||||
phone_number: str
|
||||
role: str
|
||||
email_verified: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload):
|
||||
@@ -18,5 +19,6 @@ class User:
|
||||
user_id=payload["sub"],
|
||||
email=payload.get("email", ""),
|
||||
phone_number=payload.get("phone", ""),
|
||||
role=payload["role"],
|
||||
role=payload.get("role", "authenticated"),
|
||||
email_verified=payload.get("email_verified", False),
|
||||
)
|
||||
|
||||
414
autogpt_platform/autogpt_libs/poetry.lock
generated
414
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -48,6 +48,21 @@ files = [
|
||||
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "authlib"
|
||||
version = "1.6.6"
|
||||
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd"},
|
||||
{file = "authlib-1.6.6.tar.gz", hash = "sha256:45770e8e056d0f283451d9996fbb59b70d45722b45d854d58f32878d0a40c38e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cryptography = "*"
|
||||
|
||||
[[package]]
|
||||
name = "backports-asyncio-runner"
|
||||
version = "1.2.0"
|
||||
@@ -61,6 +76,71 @@ files = [
|
||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bcrypt"
|
||||
version = "4.3.0"
|
||||
description = "Modern password hashing for your software and your servers"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-macosx_10_12_universal2.whl", hash = "sha256:f01e060f14b6b57bbb72fc5b4a83ac21c443c9a2ee708e04a10e9192f90a6281"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5eeac541cefd0bb887a371ef73c62c3cd78535e4887b310626036a7c0a817bb"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59e1aa0e2cd871b08ca146ed08445038f42ff75968c7ae50d2fdd7860ade2180"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:0042b2e342e9ae3d2ed22727c1262f76cc4f345683b5c1715f0250cf4277294f"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74a8d21a09f5e025a9a23e7c0fd2c7fe8e7503e4d356c0a2c1486ba010619f09"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:0142b2cb84a009f8452c8c5a33ace5e3dfec4159e7735f5afe9a4d50a8ea722d"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_34_aarch64.whl", hash = "sha256:12fa6ce40cde3f0b899729dbd7d5e8811cb892d31b6f7d0334a1f37748b789fd"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_34_x86_64.whl", hash = "sha256:5bd3cca1f2aa5dbcf39e2aa13dd094ea181f48959e1071265de49cc2b82525af"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:335a420cfd63fc5bc27308e929bee231c15c85cc4c496610ffb17923abf7f231"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:0e30e5e67aed0187a1764911af023043b4542e70a7461ad20e837e94d23e1d6c"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:3b8d62290ebefd49ee0b3ce7500f5dbdcf13b81402c05f6dafab9a1e1b27212f"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2ef6630e0ec01376f59a006dc72918b1bf436c3b571b80fa1968d775fa02fe7d"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-win32.whl", hash = "sha256:7a4be4cbf241afee43f1c3969b9103a41b40bcb3a3f467ab19f891d9bc4642e4"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5c1949bf259a388863ced887c7861da1df681cb2388645766c89fdfd9004c669"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-macosx_10_12_universal2.whl", hash = "sha256:f81b0ed2639568bf14749112298f9e4e2b28853dab50a8b357e31798686a036d"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:864f8f19adbe13b7de11ba15d85d4a428c7e2f344bac110f667676a0ff84924b"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e36506d001e93bffe59754397572f21bb5dc7c83f54454c990c74a468cd589e"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:842d08d75d9fe9fb94b18b071090220697f9f184d4547179b60734846461ed59"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7c03296b85cb87db865d91da79bf63d5609284fc0cab9472fdd8367bbd830753"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:62f26585e8b219cdc909b6a0069efc5e4267e25d4a3770a364ac58024f62a761"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:beeefe437218a65322fbd0069eb437e7c98137e08f22c4660ac2dc795c31f8bb"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:97eea7408db3a5bcce4a55d13245ab3fa566e23b4c67cd227062bb49e26c585d"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:191354ebfe305e84f344c5964c7cd5f924a3bfc5d405c75ad07f232b6dffb49f"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:41261d64150858eeb5ff43c753c4b216991e0ae16614a308a15d909503617732"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:33752b1ba962ee793fa2b6321404bf20011fe45b9afd2a842139de3011898fef"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:50e6e80a4bfd23a25f5c05b90167c19030cf9f87930f7cb2eacb99f45d1c3304"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-win32.whl", hash = "sha256:67a561c4d9fb9465ec866177e7aebcad08fe23aaf6fbd692a6fab69088abfc51"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-win_amd64.whl", hash = "sha256:584027857bc2843772114717a7490a37f68da563b3620f78a849bcb54dc11e62"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:0d3efb1157edebfd9128e4e46e2ac1a64e0c1fe46fb023158a407c7892b0f8c3"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08bacc884fd302b611226c01014eca277d48f0a05187666bca23aac0dad6fe24"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6746e6fec103fcd509b96bacdfdaa2fbde9a553245dbada284435173a6f1aef"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:afe327968aaf13fc143a56a3360cb27d4ad0345e34da12c7290f1b00b8fe9a8b"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d9af79d322e735b1fc33404b5765108ae0ff232d4b54666d46730f8ac1a43676"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f1e3ffa1365e8702dc48c8b360fef8d7afeca482809c5e45e653af82ccd088c1"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:3004df1b323d10021fda07a813fd33e0fd57bef0e9a480bb143877f6cba996fe"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:531457e5c839d8caea9b589a1bcfe3756b0547d7814e9ce3d437f17da75c32b0"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:17a854d9a7a476a89dcef6c8bd119ad23e0f82557afbd2c442777a16408e614f"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:6fb1fd3ab08c0cbc6826a2e0447610c6f09e983a281b919ed721ad32236b8b23"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e965a9c1e9a393b8005031ff52583cedc15b7884fce7deb8b0346388837d6cfe"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:79e70b8342a33b52b55d93b3a59223a844962bef479f6a0ea318ebbcadf71505"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-win32.whl", hash = "sha256:b4d4e57f0a63fd0b358eb765063ff661328f69a04494427265950c71b992a39a"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-win_amd64.whl", hash = "sha256:e53e074b120f2877a35cc6c736b8eb161377caae8925c17688bd46ba56daaa5b"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c950d682f0952bafcceaf709761da0a32a942272fad381081b51096ffa46cea1"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:107d53b5c67e0bbc3f03ebf5b030e0403d24dda980f8e244795335ba7b4a027d"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:b693dbb82b3c27a1604a3dff5bfc5418a7e6a781bb795288141e5f80cf3a3492"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:b6354d3760fcd31994a14c89659dee887f1351a06e5dac3c1142307172a79f90"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a839320bf27d474e52ef8cb16449bb2ce0ba03ca9f44daba6d93fa1d8828e48a"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:bdc6a24e754a555d7316fa4774e64c6c3997d27ed2d1964d55920c7c227bc4ce"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:55a935b8e9a1d2def0626c4269db3fcd26728cbff1e84f0341465c31c4ee56d8"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:57967b7a28d855313a963aaea51bf6df89f833db4320da458e5b3c5ab6d4c938"},
|
||||
{file = "bcrypt-4.3.0.tar.gz", hash = "sha256:3a3fd2204178b6d2adcf09cb4f6426ffef54762577a7c9b54c159008cb288c18"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["pytest (>=3.2.1,!=3.3.0)"]
|
||||
typecheck = ["mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "cachetools"
|
||||
version = "5.5.2"
|
||||
@@ -459,21 +539,6 @@ ssh = ["bcrypt (>=3.1.5)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==45.0.6)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test-randomorder = ["pytest-randomly"]
|
||||
|
||||
[[package]]
|
||||
name = "deprecation"
|
||||
version = "2.1.0"
|
||||
description = "A library to handle automated deprecations"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
|
||||
{file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = "*"
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.3.0"
|
||||
@@ -695,23 +760,6 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4
|
||||
[package.extras]
|
||||
grpc = ["grpcio (>=1.44.0,<2.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "gotrue"
|
||||
version = "2.12.3"
|
||||
description = "Python Client Library for Supabase Auth"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "gotrue-2.12.3-py3-none-any.whl", hash = "sha256:b1a3c6a5fe3f92e854a026c4c19de58706a96fd5fbdcc3d620b2802f6a46a26b"},
|
||||
{file = "gotrue-2.12.3.tar.gz", hash = "sha256:f874cf9d0b2f0335bfbd0d6e29e3f7aff79998cd1c14d2ad814db8c06cee3852"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
pydantic = ">=1.10,<3"
|
||||
pyjwt = ">=2.10.1,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "grpc-google-iam-v1"
|
||||
version = "0.14.2"
|
||||
@@ -822,94 +870,6 @@ files = [
|
||||
{file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "4.2.0"
|
||||
description = "Pure-Python HTTP/2 protocol implementation"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "h2-4.2.0-py3-none-any.whl", hash = "sha256:479a53ad425bb29af087f3458a61d30780bc818e4ebcf01f0b536ba916462ed0"},
|
||||
{file = "h2-4.2.0.tar.gz", hash = "sha256:c8a52129695e88b1a0578d8d2cc6842bbd79128ac685463b887ee278126ad01f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
hpack = ">=4.1,<5"
|
||||
hyperframe = ">=6.1,<7"
|
||||
|
||||
[[package]]
|
||||
name = "hpack"
|
||||
version = "4.1.0"
|
||||
description = "Pure-Python HPACK header encoding"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496"},
|
||||
{file = "hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpcore"
|
||||
version = "1.0.9"
|
||||
description = "A minimal low-level HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"},
|
||||
{file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = "*"
|
||||
h11 = ">=0.16"
|
||||
|
||||
[package.extras]
|
||||
asyncio = ["anyio (>=4.0,<5.0)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
trio = ["trio (>=0.22.0,<1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx"
|
||||
version = "0.28.1"
|
||||
description = "The next generation HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"},
|
||||
{file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = "*"
|
||||
certifi = "*"
|
||||
h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""}
|
||||
httpcore = "==1.*"
|
||||
idna = "*"
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
|
||||
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "hyperframe"
|
||||
version = "6.1.0"
|
||||
description = "Pure-Python HTTP/2 framing"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5"},
|
||||
{file = "hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.10"
|
||||
@@ -1036,7 +996,7 @@ version = "25.0"
|
||||
description = "Core utilities for Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main", "dev"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
|
||||
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
|
||||
@@ -1058,24 +1018,6 @@ files = [
|
||||
dev = ["pre-commit", "tox"]
|
||||
testing = ["coverage", "pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "postgrest"
|
||||
version = "1.1.1"
|
||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "postgrest-1.1.1-py3-none-any.whl", hash = "sha256:98a6035ee1d14288484bfe36235942c5fb2d26af6d8120dfe3efbe007859251a"},
|
||||
{file = "postgrest-1.1.1.tar.gz", hash = "sha256:f3bb3e8c4602775c75c844a31f565f5f3dd584df4d36d683f0b67d01a86be322"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2.1.0,<3.0.0"
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
pydantic = ">=1.9,<3.0"
|
||||
strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "proto-plus"
|
||||
version = "1.26.1"
|
||||
@@ -1462,21 +1404,6 @@ pytest = ">=6.2.5"
|
||||
[package.extras]
|
||||
dev = ["pre-commit", "pytest-asyncio", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
description = "Extensions to the standard Python datetime module"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
|
||||
{file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
six = ">=1.5"
|
||||
|
||||
[[package]]
|
||||
name = "python-dotenv"
|
||||
version = "1.1.1"
|
||||
@@ -1492,22 +1419,6 @@ files = [
|
||||
[package.extras]
|
||||
cli = ["click (>=5.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "realtime"
|
||||
version = "2.5.3"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "realtime-2.5.3-py3-none-any.whl", hash = "sha256:eb0994636946eff04c4c7f044f980c8c633c7eb632994f549f61053a474ac970"},
|
||||
{file = "realtime-2.5.3.tar.gz", hash = "sha256:0587594f3bc1c84bf007ff625075b86db6528843e03250dc84f4f2808be3d99a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
typing-extensions = ">=4.14.0,<5.0.0"
|
||||
websockets = ">=11,<16"
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "6.2.0"
|
||||
@@ -1606,18 +1517,6 @@ files = [
|
||||
{file = "semver-3.0.4.tar.gz", hash = "sha256:afc7d8c584a5ed0a11033af086e8af226a9c0b206f313e0301f8dd7b6b589602"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "six"
|
||||
version = "1.17.0"
|
||||
description = "Python 2 and 3 compatibility utilities"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"},
|
||||
{file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.1"
|
||||
@@ -1649,76 +1548,6 @@ typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""
|
||||
[package.extras]
|
||||
full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"]
|
||||
|
||||
[[package]]
|
||||
name = "storage3"
|
||||
version = "0.12.0"
|
||||
description = "Supabase Storage client for Python."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "storage3-0.12.0-py3-none-any.whl", hash = "sha256:1c4585693ca42243ded1512b58e54c697111e91a20916cd14783eebc37e7c87d"},
|
||||
{file = "storage3-0.12.0.tar.gz", hash = "sha256:94243f20922d57738bf42e96b9f5582b4d166e8bf209eccf20b146909f3f71b0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2.1.0,<3.0.0"
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
python-dateutil = ">=2.8.2,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "strenum"
|
||||
version = "0.4.15"
|
||||
description = "An Enum that inherits from str."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"},
|
||||
{file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"]
|
||||
release = ["twine"]
|
||||
test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"]
|
||||
|
||||
[[package]]
|
||||
name = "supabase"
|
||||
version = "2.16.0"
|
||||
description = "Supabase client for Python."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase-2.16.0-py3-none-any.whl", hash = "sha256:99065caab3d90a56650bf39fbd0e49740995da3738ab28706c61bd7f2401db55"},
|
||||
{file = "supabase-2.16.0.tar.gz", hash = "sha256:98f3810158012d4ec0e3083f2e5515f5e10b32bd71e7d458662140e963c1d164"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
gotrue = ">=2.11.0,<3.0.0"
|
||||
httpx = ">=0.26,<0.29"
|
||||
postgrest = ">0.19,<1.2"
|
||||
realtime = ">=2.4.0,<2.6.0"
|
||||
storage3 = ">=0.10,<0.13"
|
||||
supafunc = ">=0.9,<0.11"
|
||||
|
||||
[[package]]
|
||||
name = "supafunc"
|
||||
version = "0.10.1"
|
||||
description = "Library for Supabase Functions"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supafunc-0.10.1-py3-none-any.whl", hash = "sha256:26df9bd25ff2ef56cb5bfb8962de98f43331f7f8ff69572bac3ed9c3a9672040"},
|
||||
{file = "supafunc-0.10.1.tar.gz", hash = "sha256:a5b33c8baecb6b5297d25da29a2503e2ec67ee6986f3d44c137e651b8a59a17d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
strenum = ">=0.4.15,<0.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.2.1"
|
||||
@@ -1827,85 +1656,6 @@ typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "websockets"
|
||||
version = "15.0.1"
|
||||
description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "websockets-15.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d63efaa0cd96cf0c5fe4d581521d9fa87744540d4bc999ae6e08595a1014b45b"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac60e3b188ec7574cb761b08d50fcedf9d77f1530352db4eef1707fe9dee7205"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5756779642579d902eed757b21b0164cd6fe338506a8083eb58af5c372e39d9a"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fdfe3e2a29e4db3659dbd5bbf04560cea53dd9610273917799f1cde46aa725e"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c2529b320eb9e35af0fa3016c187dffb84a3ecc572bcee7c3ce302bfeba52bf"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac1e5c9054fe23226fb11e05a6e630837f074174c4c2f0fe442996112a6de4fb"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5df592cd503496351d6dc14f7cdad49f268d8e618f80dce0cd5a36b93c3fc08d"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0a34631031a8f05657e8e90903e656959234f3a04552259458aac0b0f9ae6fd9"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3d00075aa65772e7ce9e990cab3ff1de702aa09be3940d1dc88d5abf1ab8a09c"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-win32.whl", hash = "sha256:1234d4ef35db82f5446dca8e35a7da7964d02c127b095e172e54397fb6a6c256"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:39c1fec2c11dc8d89bba6b2bf1556af381611a173ac2b511cf7231622058af41"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:823c248b690b2fd9303ba00c4f66cd5e2d8c3ba4aa968b2779be9532a4dad431"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678999709e68425ae2593acf2e3ebcbcf2e69885a5ee78f9eb80e6e371f1bf57"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d50fd1ee42388dcfb2b3676132c78116490976f1300da28eb629272d5d93e905"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d99e5546bf73dbad5bf3547174cd6cb8ba7273062a23808ffea025ecb1cf8562"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66dd88c918e3287efc22409d426c8f729688d89a0c587c88971a0faa2c2f3792"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8dd8327c795b3e3f219760fa603dcae1dcc148172290a8ab15158cf85a953413"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8fdc51055e6ff4adeb88d58a11042ec9a5eae317a0a53d12c062c8a8865909e8"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:693f0192126df6c2327cce3baa7c06f2a117575e32ab2308f7f8216c29d9e2e3"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:54479983bd5fb469c38f2f5c7e3a24f9a4e70594cd68cd1fa6b9340dadaff7cf"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-win32.whl", hash = "sha256:16b6c1b3e57799b9d38427dda63edcbe4926352c47cf88588c0be4ace18dac85"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:27ccee0071a0e75d22cb35849b1db43f2ecd3e161041ac1ee9d2352ddf72f065"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3e90baa811a5d73f3ca0bcbf32064d663ed81318ab225ee4f427ad4e26e5aff3"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:592f1a9fe869c778694f0aa806ba0374e97648ab57936f092fd9d87f8bc03665"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0701bc3cfcb9164d04a14b149fd74be7347a530ad3bbf15ab2c678a2cd3dd9a2"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8b56bdcdb4505c8078cb6c7157d9811a85790f2f2b3632c7d1462ab5783d215"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0af68c55afbd5f07986df82831c7bff04846928ea8d1fd7f30052638788bc9b5"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64dee438fed052b52e4f98f76c5790513235efaa1ef7f3f2192c392cd7c91b65"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d5f6b181bb38171a8ad1d6aa58a67a6aa9d4b38d0f8c5f496b9e42561dfc62fe"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5d54b09eba2bada6011aea5375542a157637b91029687eb4fdb2dab11059c1b4"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3be571a8b5afed347da347bfcf27ba12b069d9d7f42cb8c7028b5e98bbb12597"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-win32.whl", hash = "sha256:c338ffa0520bdb12fbc527265235639fb76e7bc7faafbb93f6ba80d9c06578a9"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcd5cf9e305d7b8338754470cf69cf81f420459dbae8a3b40cee57417f4614a7"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee443ef070bb3b6ed74514f5efaa37a252af57c90eb33b956d35c8e9c10a1931"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5a939de6b7b4e18ca683218320fc67ea886038265fd1ed30173f5ce3f8e85675"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:746ee8dba912cd6fc889a8147168991d50ed70447bf18bcda7039f7d2e3d9151"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:595b6c3969023ecf9041b2936ac3827e4623bfa3ccf007575f04c5a6aa318c22"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c714d2fc58b5ca3e285461a4cc0c9a66bd0e24c5da9911e30158286c9b5be7f"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f3c1e2ab208db911594ae5b4f79addeb3501604a165019dd221c0bdcabe4db8"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:229cf1d3ca6c1804400b0a9790dc66528e08a6a1feec0d5040e8b9eb14422375"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:756c56e867a90fb00177d530dca4b097dd753cde348448a1012ed6c5131f8b7d"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:558d023b3df0bffe50a04e710bc87742de35060580a293c2a984299ed83bc4e4"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-win32.whl", hash = "sha256:ba9e56e8ceeeedb2e080147ba85ffcd5cd0711b89576b83784d8605a7df455fa"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:e09473f095a819042ecb2ab9465aee615bd9c2028e4ef7d933600a8401c79561"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5f4c04ead5aed67c8a1a20491d54cdfba5884507a48dd798ecaf13c74c4489f5"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abdc0c6c8c648b4805c5eacd131910d2a7f6455dfd3becab248ef108e89ab16a"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a625e06551975f4b7ea7102bc43895b90742746797e2e14b70ed61c43a90f09b"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d591f8de75824cbb7acad4e05d2d710484f15f29d4a915092675ad3456f11770"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:47819cea040f31d670cc8d324bb6435c6f133b8c7a19ec3d61634e62f8d8f9eb"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac017dd64572e5c3bd01939121e4d16cf30e5d7e110a119399cf3133b63ad054"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4a9fac8e469d04ce6c25bb2610dc535235bd4aa14996b4e6dbebf5e007eba5ee"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:363c6f671b761efcb30608d24925a382497c12c506b51661883c3e22337265ed"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:2034693ad3097d5355bfdacfffcbd3ef5694f9718ab7f29c29689a9eae841880"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-win32.whl", hash = "sha256:3b1ac0d3e594bf121308112697cf4b32be538fb1444468fb0a6ae4feebc83411"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7643a03db5c95c799b89b31c036d5f27eeb4d259c798e878d6937d71832b1e4"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0c9e74d766f2818bb95f84c25be4dea09841ac0f734d1966f415e4edfc4ef1c3"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1009ee0c7739c08a0cd59de430d6de452a55e42d6b522de7aa15e6f67db0b8e1"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d1f20b1c7a2fa82367e04982e708723ba0e7b8d43aa643d3dcd404d74f1475"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f29d80eb9a9263b8d109135351caf568cc3f80b9928bccde535c235de55c22d9"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b359ed09954d7c18bbc1680f380c7301f92c60bf924171629c5db97febb12f04"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7f493881579c90fc262d9cdbaa05a6b54b3811c2f300766748db79f098db9940"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:47b099e1f4fbc95b701b6e85768e1fcdaf1630f3cbe4765fa216596f12310e2e"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67f2b6de947f8c757db2db9c71527933ad0019737ec374a8a6be9a956786aaf9"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d08eb4c2b7d6c41da6ca0600c077e93f5adcfd979cd777d747e9ee624556da4b"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b826973a4a2ae47ba357e4e82fa44a463b8f168e1ca775ac64521442b19e87f"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:21c1fa28a6a7e3cbdc171c694398b6df4744613ce9b36b1a498e816787e28123"},
|
||||
{file = "websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f"},
|
||||
{file = "websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zipp"
|
||||
version = "3.23.0"
|
||||
@@ -1929,4 +1679,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
|
||||
content-hash = "de209c97aa0feb29d669a20e4422d51bdf3a0872ec37e85ce9b88ce726fcee7a"
|
||||
|
||||
@@ -18,7 +18,8 @@ pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.16.0"
|
||||
bcrypt = "^4.1.0"
|
||||
authlib = "^1.3.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -27,10 +27,15 @@ REDIS_PORT=6379
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
|
||||
# Supabase Authentication
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
# JWT Authentication
|
||||
# Generate a secure random key: python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
JWT_SIGN_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
JWT_SIGN_ALGORITHM=HS256
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=15
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
JWT_ISSUER=autogpt-platform
|
||||
JWT_AUDIENCE=authenticated
|
||||
|
||||
## ===== REQUIRED SECURITY KEYS ===== ##
|
||||
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
|
||||
|
||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -18,3 +18,6 @@ load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
|
||||
# Migration backups (contain user data)
|
||||
migration_backups/
|
||||
|
||||
@@ -48,8 +48,7 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ import fastapi.testclient
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.features.myroute import router
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
@@ -149,7 +149,7 @@ These provide the easiest way to set up authentication mocking in test modules:
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.api.features.myroute import router
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
from .v1.routes import v1_router
|
||||
|
||||
external_api = FastAPI(
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_api.add_middleware(SecurityHeadersMiddleware)
|
||||
external_api.include_router(v1_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_api,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
@@ -1,340 +0,0 @@
|
||||
"""Tests for analytics API endpoints."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from .analytics import router as analytics_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_metric endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_metric_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging."""
|
||||
mock_result = Mock(id="metric-123-uuid")
|
||||
mock_log_metric = mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "page_load_time",
|
||||
"metric_value": 2.5,
|
||||
"data_string": "/dashboard",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "metric-123-uuid"
|
||||
|
||||
mock_log_metric.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_metric_success",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_value,metric_name,data_string,test_id",
|
||||
[
|
||||
(100, "api_calls_count", "external_api", "integer_value"),
|
||||
(0, "error_count", "no_errors", "zero_value"),
|
||||
(-5.2, "temperature_delta", "cooling", "negative_value"),
|
||||
(1.23456789, "precision_test", "float_precision", "float_precision"),
|
||||
(999999999, "large_number", "max_value", "large_number"),
|
||||
(0.0000001, "tiny_number", "min_value", "tiny_number"),
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_various_values(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
metric_value: float,
|
||||
metric_name: str,
|
||||
data_string: str,
|
||||
test_id: str,
|
||||
) -> None:
|
||||
"""Test raw metric logging with various metric values."""
|
||||
mock_result = Mock(id=f"metric-{test_id}-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": metric_name,
|
||||
"metric_value": metric_value,
|
||||
"data_string": data_string,
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"metric_id": response.json(), "test_case": test_id},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
f"analytics_metric_{test_id}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"metric_name": "test"}, "Field required"),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": "not_a_number", "data_string": "x"},
|
||||
"Input should be a valid number",
|
||||
),
|
||||
(
|
||||
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": 1.0, "data_string": ""},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_metric_value_and_data_string",
|
||||
"invalid_metric_value_type",
|
||||
"empty_metric_name",
|
||||
"empty_data_string",
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid metric requests."""
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_metric_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database connection failed"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "test_metric",
|
||||
"metric_value": 1.0,
|
||||
"data_string": "test",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Database connection failed" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_analytics endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_analytics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw analytics logging."""
|
||||
mock_result = Mock(id="analytics-789-uuid")
|
||||
mock_log_analytics = mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "user_action",
|
||||
"data": {
|
||||
"action": "button_click",
|
||||
"button_id": "submit_form",
|
||||
"timestamp": "2023-01-01T00:00:00Z",
|
||||
"metadata": {"form_type": "registration", "fields_filled": 5},
|
||||
},
|
||||
"data_index": "button_click_submit_form",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "analytics-789-uuid"
|
||||
|
||||
mock_log_analytics.assert_called_once_with(
|
||||
test_user_id,
|
||||
"user_action",
|
||||
request_data["data"],
|
||||
"button_click_submit_form",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"analytics_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_analytics_success",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_analytics_complex_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test raw analytics logging with complex nested data structures."""
|
||||
mock_result = Mock(id="analytics-complex-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "agent_execution",
|
||||
"data": {
|
||||
"agent_id": "agent_123",
|
||||
"execution_id": "exec_456",
|
||||
"status": "completed",
|
||||
"duration_ms": 3500,
|
||||
"nodes_executed": 15,
|
||||
"blocks_used": [
|
||||
{"block_id": "llm_block", "count": 3},
|
||||
{"block_id": "http_block", "count": 5},
|
||||
{"block_id": "code_block", "count": 2},
|
||||
],
|
||||
"errors": [],
|
||||
"metadata": {
|
||||
"trigger": "manual",
|
||||
"user_tier": "premium",
|
||||
"environment": "production",
|
||||
},
|
||||
},
|
||||
"data_index": "agent_123_exec_456",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"analytics_id": response.json(), "logged_data": request_data["data"]},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
"analytics_log_analytics_complex_data",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"type": "test"}, "Field required"),
|
||||
(
|
||||
{"type": "test", "data": "not_a_dict", "data_index": "test"},
|
||||
"Input should be a valid dictionary",
|
||||
),
|
||||
({"type": "test", "data": {"key": "value"}}, "Field required"),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_data_and_data_index",
|
||||
"invalid_data_type",
|
||||
"missing_data_index",
|
||||
],
|
||||
)
|
||||
def test_log_raw_analytics_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid analytics requests."""
|
||||
response = client.post("/log_raw_analytics", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_analytics_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Analytics DB unreachable"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "test_event",
|
||||
"data": {"key": "value"},
|
||||
"data_index": "test_index",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Analytics DB unreachable" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
@@ -1,215 +0,0 @@
|
||||
"""Database operations for chat sessions."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
)
|
||||
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort messages by sequence in Python since Prisma doesn't support order_by in include
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> PrismaChatSession:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
id=session_id,
|
||||
userId=user_id,
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
session_id: str,
|
||||
credentials: dict[str, Any] | None = None,
|
||||
successful_agent_runs: dict[str, Any] | None = None,
|
||||
successful_agent_schedules: dict[str, Any] | None = None,
|
||||
total_prompt_tokens: int | None = None,
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> PrismaChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
if credentials is not None:
|
||||
data["credentials"] = SafeJson(credentials)
|
||||
if successful_agent_runs is not None:
|
||||
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
|
||||
if successful_agent_schedules is not None:
|
||||
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
|
||||
if total_prompt_tokens is not None:
|
||||
data["totalPromptTokens"] = total_prompt_tokens
|
||||
if total_completion_tokens is not None:
|
||||
data["totalCompletionTokens"] = total_completion_tokens
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
|
||||
session = await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
sequence: int,
|
||||
content: str | None = None,
|
||||
name: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
refusal: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
# Build the input dict dynamically - only include optional fields when they
|
||||
# have values, as Prisma TypedDict validation fails when optional fields
|
||||
# are explicitly set to None
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if tool_call_id is not None:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = refusal
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if tool_calls is not None:
|
||||
data["toolCalls"] = SafeJson(tool_calls)
|
||||
if function_call is not None:
|
||||
data["functionCall"] = SafeJson(function_call)
|
||||
|
||||
# Update session's updatedAt timestamp
|
||||
await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return await PrismaChatMessage.prisma().create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[PrismaChatMessage]:
|
||||
"""Add multiple messages to a chat session in a batch."""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
created_messages = []
|
||||
for i, msg in enumerate(messages):
|
||||
# Build the input dict dynamically - only include optional JSON fields
|
||||
# when they have values, as Prisma TypedDict validation fails when
|
||||
# optional fields are explicitly set to None
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma().create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp
|
||||
await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return created_messages
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[PrismaChatSession]:
|
||||
"""Get chat sessions for a user, ordered by most recent."""
|
||||
return await PrismaChatSession.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
|
||||
async def get_user_session_count(user_id: str) -> int:
|
||||
"""Get the total number of chat sessions for a user."""
|
||||
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str) -> bool:
|
||||
"""Delete a chat session and all its messages."""
|
||||
try:
|
||||
await PrismaChatSession.prisma().delete(where={"id": session_id})
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
@@ -1,473 +0,0 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionDeveloperMessageParam,
|
||||
ChatCompletionFunctionMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_assistant_message_param import FunctionCall
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str | None = None
|
||||
name: str | None = None
|
||||
tool_call_id: str | None = None
|
||||
refusal: str | None = None
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
title: str | None = None
|
||||
messages: list[ChatMessage]
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
started_at: datetime
|
||||
updated_at: datetime
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def new(user_id: str | None) -> "ChatSession":
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_prisma(
|
||||
prisma_session: PrismaChatSession,
|
||||
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||
) -> "ChatSession":
|
||||
"""Convert Prisma models to Pydantic ChatSession."""
|
||||
messages = []
|
||||
if prisma_messages:
|
||||
for msg in prisma_messages:
|
||||
tool_calls = None
|
||||
if msg.toolCalls:
|
||||
tool_calls = (
|
||||
json.loads(msg.toolCalls)
|
||||
if isinstance(msg.toolCalls, str)
|
||||
else msg.toolCalls
|
||||
)
|
||||
|
||||
function_call = None
|
||||
if msg.functionCall:
|
||||
function_call = (
|
||||
json.loads(msg.functionCall)
|
||||
if isinstance(msg.functionCall, str)
|
||||
else msg.functionCall
|
||||
)
|
||||
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
name=msg.name,
|
||||
tool_call_id=msg.toolCallId,
|
||||
refusal=msg.refusal,
|
||||
tool_calls=tool_calls,
|
||||
function_call=function_call,
|
||||
)
|
||||
)
|
||||
|
||||
# Parse JSON fields from Prisma
|
||||
credentials = (
|
||||
json.loads(prisma_session.credentials)
|
||||
if isinstance(prisma_session.credentials, str)
|
||||
else prisma_session.credentials or {}
|
||||
)
|
||||
successful_agent_runs = (
|
||||
json.loads(prisma_session.successfulAgentRuns)
|
||||
if isinstance(prisma_session.successfulAgentRuns, str)
|
||||
else prisma_session.successfulAgentRuns or {}
|
||||
)
|
||||
successful_agent_schedules = (
|
||||
json.loads(prisma_session.successfulAgentSchedules)
|
||||
if isinstance(prisma_session.successfulAgentSchedules, str)
|
||||
else prisma_session.successfulAgentSchedules or {}
|
||||
)
|
||||
|
||||
# Calculate usage from token counts
|
||||
usage = []
|
||||
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
|
||||
usage.append(
|
||||
Usage(
|
||||
prompt_tokens=prisma_session.totalPromptTokens or 0,
|
||||
completion_tokens=prisma_session.totalCompletionTokens or 0,
|
||||
total_tokens=(prisma_session.totalPromptTokens or 0)
|
||||
+ (prisma_session.totalCompletionTokens or 0),
|
||||
)
|
||||
)
|
||||
|
||||
return ChatSession(
|
||||
session_id=prisma_session.id,
|
||||
user_id=prisma_session.userId,
|
||||
title=prisma_session.title,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
credentials=credentials,
|
||||
started_at=prisma_session.createdAt,
|
||||
updated_at=prisma_session.updatedAt,
|
||||
successful_agent_runs=successful_agent_runs,
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
if message.role == "developer":
|
||||
m = ChatCompletionDeveloperMessageParam(
|
||||
role="developer",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "system":
|
||||
m = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "user":
|
||||
m = ChatCompletionUserMessageParam(
|
||||
role="user",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "assistant":
|
||||
m = ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.function_call:
|
||||
m["function_call"] = FunctionCall(
|
||||
arguments=message.function_call["arguments"],
|
||||
name=message.function_call["name"],
|
||||
)
|
||||
if message.refusal:
|
||||
m["refusal"] = message.refusal
|
||||
if message.tool_calls:
|
||||
t: list[ChatCompletionMessageToolCallParam] = []
|
||||
for tool_call in message.tool_calls:
|
||||
# Tool calls are stored with nested structure: {id, type, function: {name, arguments}}
|
||||
function_data = tool_call.get("function", {})
|
||||
|
||||
# Skip tool calls that are missing required fields
|
||||
if "id" not in tool_call or "name" not in function_data:
|
||||
logger.warning(
|
||||
f"Skipping invalid tool call: missing required fields. "
|
||||
f"Got: {tool_call.keys()}, function keys: {function_data.keys()}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Arguments are stored as a JSON string
|
||||
arguments_str = function_data.get("arguments", "{}")
|
||||
|
||||
t.append(
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call["id"],
|
||||
type="function",
|
||||
function=Function(
|
||||
arguments=arguments_str,
|
||||
name=function_data["name"],
|
||||
),
|
||||
)
|
||||
)
|
||||
m["tool_calls"] = t
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "tool":
|
||||
messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
content=message.content or "",
|
||||
tool_call_id=message.tool_call_id or "",
|
||||
)
|
||||
)
|
||||
elif message.role == "function":
|
||||
messages.append(
|
||||
ChatCompletionFunctionMessageParam(
|
||||
role="function",
|
||||
content=message.content,
|
||||
name=message.name or "",
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = f"chat:session:{session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = f"chat:session:{session.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_prisma(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> ChatSession | None:
|
||||
"""Get a chat session by ID.
|
||||
|
||||
Checks Redis cache first, falls back to database if not found.
|
||||
Caches database results back to Redis.
|
||||
"""
|
||||
# Try cache first
|
||||
try:
|
||||
session = await _get_session_from_cache(session_id)
|
||||
if session:
|
||||
# Verify user ownership
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
return session
|
||||
except RedisError:
|
||||
logger.warning(f"Cache error for session {session_id}, trying database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||
|
||||
# Fall back to database
|
||||
logger.info(f"Session {session_id} not in cache, checking database")
|
||||
session = await _get_session_from_db(session_id)
|
||||
|
||||
if session is None:
|
||||
logger.warning(f"Session {session_id} not found in cache or database")
|
||||
return None
|
||||
|
||||
# Verify user ownership
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
"""Update a chat session in both cache and database."""
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
# Save to database
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save session {session.session_id} to database: {e}")
|
||||
# Continue to cache even if DB fails
|
||||
|
||||
# Save to cache
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
raise RedisError(
|
||||
f"Failed to persist chat session {session.session_id} to Redis: {e}"
|
||||
) from e
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str | None) -> ChatSession:
|
||||
"""Create a new chat session and persist it."""
|
||||
session = ChatSession.new(user_id)
|
||||
|
||||
# Create in database first
|
||||
try:
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session in database: {e}")
|
||||
# Continue even if DB fails - cache will still work
|
||||
|
||||
# Cache the session
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def get_user_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[ChatSession]:
|
||||
"""Get all chat sessions for a user from the database."""
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
# Convert without messages for listing (lighter weight)
|
||||
sessions.append(ChatSession.from_prisma(prisma_session, None))
|
||||
|
||||
return sessions
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str) -> bool:
|
||||
"""Delete a chat session from both cache and database."""
|
||||
# Delete from cache
|
||||
try:
|
||||
redis_key = f"chat:session:{session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
|
||||
|
||||
# Delete from database
|
||||
return await chat_db.delete_chat_session(session_id)
|
||||
@@ -1,117 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
|
||||
messages = [
|
||||
ChatMessage(content="Hello, how are you?", role="user"),
|
||||
ChatMessage(
|
||||
content="I'm fine, thank you!",
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "New York"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
content="I'm using the tool to get the weather",
|
||||
role="tool",
|
||||
tool_call_id="t123",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_serialization_deserialization():
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s.messages = messages
|
||||
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
|
||||
serialized = s.model_dump_json()
|
||||
s2 = ChatSession.model_validate_json(serialized)
|
||||
assert s2.model_dump() == s.model_dump()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage():
|
||||
|
||||
s = ChatSession.new(user_id=None)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 == s
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage_user_id_mismatch():
|
||||
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(s.session_id, None)
|
||||
|
||||
assert s2 is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_db_storage():
|
||||
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=None)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
# Upsert to save to both cache and DB
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
# Clear the Redis cache to force DB load
|
||||
redis_key = f"chat:session:{s.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
|
||||
# Load from DB (cache was cleared)
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 is not None, "Session not found after loading from DB"
|
||||
assert len(s2.messages) == len(
|
||||
s.messages
|
||||
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
|
||||
|
||||
# Verify all roles are present
|
||||
roles = [m.role for m in s2.messages]
|
||||
assert "user" in roles, f"User message missing. Roles found: {roles}"
|
||||
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
|
||||
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
|
||||
|
||||
# Verify message content
|
||||
for orig, loaded in zip(s.messages, s2.messages):
|
||||
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
|
||||
assert (
|
||||
orig.content == loaded.content
|
||||
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
|
||||
if orig.tool_calls:
|
||||
assert (
|
||||
loaded.tool_calls is not None
|
||||
), f"Tool calls missing for {orig.role} message"
|
||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||
@@ -1,192 +0,0 @@
|
||||
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find, create, and set up AutoGPT agents to solve their business problems.
|
||||
|
||||
Here are the functions available to you:
|
||||
|
||||
<functions>
|
||||
**Understanding & Discovery:**
|
||||
1. **add_understanding** - Save information about the user's business context (use this as you learn about them)
|
||||
2. **find_agent** - Search the marketplace for pre-built agents that solve the user's problem
|
||||
3. **find_library_agent** - Search the user's personal library of saved agents
|
||||
4. **find_block** - Search for individual blocks (building components for agents)
|
||||
5. **search_platform_docs** - Search AutoGPT documentation for help
|
||||
|
||||
**Agent Creation & Editing:**
|
||||
6. **create_agent** - Create a new custom agent from scratch based on user requirements
|
||||
7. **edit_agent** - Modify an existing agent (add/remove blocks, change configuration)
|
||||
|
||||
**Execution & Output:**
|
||||
8. **run_agent** - Run or schedule an agent (automatically handles setup)
|
||||
9. **run_block** - Run a single block directly without creating an agent
|
||||
10. **agent_output** - Get the output/results from a running or completed agent execution
|
||||
</functions>
|
||||
|
||||
## ALWAYS GET THE USER'S NAME
|
||||
|
||||
**This is critical:** If you don't know the user's name, ask for it in your first response. Use a friendly, natural approach:
|
||||
- "Hi! I'm Otto. What's your name?"
|
||||
- "Hey there! Before we dive in, what should I call you?"
|
||||
|
||||
Once you have their name, immediately save it with `add_understanding(user_name="...")` and use it throughout the conversation.
|
||||
|
||||
## BUILDING USER UNDERSTANDING
|
||||
|
||||
**If no User Business Context is provided below**, gather information naturally during conversation - don't interrogate them.
|
||||
|
||||
**Key information to gather (in priority order):**
|
||||
1. Their name (ALWAYS first if unknown)
|
||||
2. Their job title and role
|
||||
3. Their business/company and industry
|
||||
4. Pain points and what they want to automate
|
||||
5. Tools they currently use
|
||||
|
||||
**How to gather this information:**
|
||||
- Ask naturally as part of helping them (e.g., "What's your role?" or "What industry are you in?")
|
||||
- When they share information, immediately save it using `add_understanding`
|
||||
- Don't ask all questions at once - spread them across the conversation
|
||||
- Prioritize understanding their immediate problem first
|
||||
|
||||
**Example:**
|
||||
```
|
||||
User: "I need help automating my social media"
|
||||
Otto: I can help with that! I'm Otto - what's your name?
|
||||
User: "I'm Sarah"
|
||||
Otto: [calls add_understanding with user_name="Sarah"]
|
||||
Nice to meet you, Sarah! What's your role - are you a social media manager or business owner?
|
||||
User: "I'm the marketing director at a fintech startup"
|
||||
Otto: [calls add_understanding with job_title="Marketing Director", industry="fintech", business_size="startup"]
|
||||
Great! Let me find social media automation agents for you.
|
||||
[calls find_agent with query="social media automation marketing"]
|
||||
```
|
||||
|
||||
## WHEN TO USE WHICH TOOL
|
||||
|
||||
**Finding existing agents:**
|
||||
- `find_agent` - Search the marketplace for pre-built agents others have created
|
||||
- `find_library_agent` - Search agents the user has already saved to their library
|
||||
|
||||
**Creating/editing agents:**
|
||||
- `create_agent` - When user wants a custom agent that doesn't exist, or has specific requirements
|
||||
- `edit_agent` - When user wants to modify an existing agent (change inputs, add blocks, etc.)
|
||||
|
||||
**Running agents:**
|
||||
- `run_agent` - To execute an agent (handles credentials and inputs automatically)
|
||||
- `agent_output` - To check the results of a running or completed agent execution
|
||||
|
||||
**Direct execution:**
|
||||
- `run_block` - Run a single block directly without needing a full agent
|
||||
|
||||
## HOW run_agent WORKS
|
||||
|
||||
The `run_agent` tool automatically handles the entire setup flow:
|
||||
|
||||
1. **First call** (no inputs) → Returns available inputs so user can decide what values to use
|
||||
2. **Credentials check** → If missing, UI automatically prompts user to add them (you don't need to mention this)
|
||||
3. **Execution** → Runs when you provide `inputs` OR set `use_defaults=true`
|
||||
|
||||
Parameters:
|
||||
- `username_agent_slug` (required): Agent identifier like "creator/agent-name"
|
||||
- `inputs`: Object with input values for the agent
|
||||
- `use_defaults`: Set to `true` to run with default values (only after user confirms)
|
||||
- `schedule_name` + `cron`: For scheduled execution
|
||||
|
||||
## HOW create_agent WORKS
|
||||
|
||||
Use `create_agent` when the user wants to build a custom automation:
|
||||
- Describe what the agent should do
|
||||
- The tool will create the agent structure with appropriate blocks
|
||||
- Returns the agent ID for further editing or running
|
||||
|
||||
## HOW agent_output WORKS
|
||||
|
||||
Use `agent_output` to get results from agent executions:
|
||||
- Pass the execution_id from a run_agent response
|
||||
- Returns the current status and any outputs produced
|
||||
- Useful for checking if an agent has completed and what it produced
|
||||
|
||||
## WORKFLOW
|
||||
|
||||
1. **Get their name** - If unknown, ask for it first
|
||||
2. **Understand context** - Ask 1-2 questions about their problem while helping
|
||||
3. **Find or create** - Use find_agent for existing solutions, create_agent for custom needs
|
||||
4. **Set up and run** - Use run_agent to execute, agent_output to get results
|
||||
|
||||
## YOUR APPROACH
|
||||
|
||||
**Step 1: Greet and Identify**
|
||||
- If you don't know their name, ask for it
|
||||
- Be friendly and conversational
|
||||
|
||||
**Step 2: Understand the Problem**
|
||||
- Ask maximum 1-2 targeted questions
|
||||
- Focus on: What business problem are they solving?
|
||||
- If they want to create/edit an agent, understand what it should do
|
||||
|
||||
**Step 3: Find or Create**
|
||||
- For existing solutions: Use `find_agent` with relevant keywords
|
||||
- For custom needs: Use `create_agent` with their requirements
|
||||
- For modifications: Use `edit_agent` on an existing agent
|
||||
|
||||
**Step 4: Execute**
|
||||
- Call `run_agent` without inputs first to see what's available
|
||||
- Ask user what values they want or if defaults are okay
|
||||
- Call `run_agent` again with inputs or `use_defaults=true`
|
||||
- Use `agent_output` to check results when needed
|
||||
|
||||
## USING add_understanding
|
||||
|
||||
Call `add_understanding` whenever you learn something about the user:
|
||||
|
||||
**User info:** `user_name`, `job_title`
|
||||
**Business:** `business_name`, `industry`, `business_size` (1-10, 11-50, 51-200, 201-1000, 1000+), `user_role` (decision maker, implementer, end user)
|
||||
**Processes:** `key_workflows` (array), `daily_activities` (array)
|
||||
**Pain points:** `pain_points` (array), `bottlenecks` (array), `manual_tasks` (array), `automation_goals` (array)
|
||||
**Tools:** `current_software` (array), `existing_automation` (array)
|
||||
**Other:** `additional_notes`
|
||||
|
||||
Example: `add_understanding(user_name="Sarah", job_title="Marketing Director", industry="fintech")`
|
||||
|
||||
## KEY RULES
|
||||
|
||||
**What You DON'T Do:**
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't mention or explain credentials to the user (frontend handles this automatically)
|
||||
- Don't run agents without first showing available inputs to the user
|
||||
- Don't use `use_defaults=true` without user explicitly confirming
|
||||
- Don't write responses longer than 3 sentences
|
||||
- Don't interrogate users with many questions - gather info naturally
|
||||
|
||||
**What You DO:**
|
||||
- ALWAYS ask for user's name if you don't have it
|
||||
- Save user information with `add_understanding` as you learn it
|
||||
- Use their name when addressing them
|
||||
- Always call run_agent first without inputs to see what's available
|
||||
- Ask user what values they want OR if they want to use defaults
|
||||
- Keep all responses to maximum 3 sentences
|
||||
- Include the agent link in your response after successful execution
|
||||
|
||||
**Error Handling:**
|
||||
- Authentication needed → "Please sign in via the interface"
|
||||
- Credentials missing → The UI handles this automatically. Focus on asking the user about input values instead.
|
||||
|
||||
## RESPONSE STRUCTURE
|
||||
|
||||
Before responding, wrap your analysis in <thinking> tags to systematically plan your approach:
|
||||
- Check if you know the user's name - if not, ask for it
|
||||
- Check if you have user context - if not, plan to gather some naturally
|
||||
- Extract the key business problem or request from the user's message
|
||||
- Determine what function call (if any) you need to make next
|
||||
- Plan your response to stay under the 3-sentence maximum
|
||||
|
||||
Example interaction:
|
||||
```
|
||||
User: "Hi, I want to build an agent that monitors my competitors"
|
||||
Otto: <thinking>I don't know this user's name. I should ask for it while acknowledging their request.</thinking>
|
||||
Hi! I'm Otto and I'd love to help you build a competitor monitoring agent. What's your name?
|
||||
User: "I'm Mike"
|
||||
Otto: [calls add_understanding with user_name="Mike"]
|
||||
<thinking>Now I know Mike wants competitor monitoring. I should search for existing agents first.</thinking>
|
||||
Great to meet you, Mike! Let me search for competitor monitoring agents.
|
||||
[calls find_agent with query="competitor monitoring analysis"]
|
||||
```
|
||||
|
||||
KEEP ANSWERS TO 3 SENTENCES
|
||||
@@ -1,155 +0,0 @@
|
||||
You are Otto, an AI Co-Pilot helping new users get started with AutoGPT, an AI Business Automation platform. Your mission is to welcome them, learn about their needs, and help them run their first successful agent.
|
||||
|
||||
Here are the functions available to you:
|
||||
|
||||
<functions>
|
||||
**Understanding & Discovery:**
|
||||
1. **add_understanding** - Save information about the user's business context (use this as you learn about them)
|
||||
2. **find_agent** - Search the marketplace for pre-built agents that solve the user's problem
|
||||
3. **find_library_agent** - Search the user's personal library of saved agents
|
||||
4. **find_block** - Search for individual blocks (building components for agents)
|
||||
5. **search_platform_docs** - Search AutoGPT documentation for help
|
||||
|
||||
**Agent Creation & Editing:**
|
||||
6. **create_agent** - Create a new custom agent from scratch based on user requirements
|
||||
7. **edit_agent** - Modify an existing agent (add/remove blocks, change configuration)
|
||||
|
||||
**Execution & Output:**
|
||||
8. **run_agent** - Run or schedule an agent (automatically handles setup)
|
||||
9. **run_block** - Run a single block directly without creating an agent
|
||||
10. **agent_output** - Get the output/results from a running or completed agent execution
|
||||
</functions>
|
||||
|
||||
## YOUR ONBOARDING MISSION
|
||||
|
||||
You are guiding a new user through their first experience with AutoGPT. Your goal is to:
|
||||
1. Welcome them warmly and get their name
|
||||
2. Learn about them and their business
|
||||
3. Find or create an agent that solves a real problem for them
|
||||
4. Get that agent running successfully
|
||||
5. Celebrate their success and point them to next steps
|
||||
|
||||
## PHASE 1: WELCOME & INTRODUCTION
|
||||
|
||||
**Start every conversation by:**
|
||||
- Giving a warm, friendly greeting
|
||||
- Introducing yourself as Otto, their AI assistant
|
||||
- Asking for their name immediately
|
||||
|
||||
**Example opening:**
|
||||
```
|
||||
Hi! I'm Otto, your AI assistant. Welcome to AutoGPT! I'm here to help you set up your first automation. What's your name?
|
||||
```
|
||||
|
||||
Once you have their name, save it immediately with `add_understanding(user_name="...")` and use it throughout.
|
||||
|
||||
## PHASE 2: DISCOVERY
|
||||
|
||||
**After getting their name, learn about them:**
|
||||
- What's their role/job title?
|
||||
- What industry/business are they in?
|
||||
- What's one thing they'd love to automate?
|
||||
|
||||
**Keep it conversational - don't interrogate. Example:**
|
||||
```
|
||||
Nice to meet you, Sarah! What do you do for work, and what's one task you wish you could automate?
|
||||
```
|
||||
|
||||
Save everything you learn with `add_understanding`.
|
||||
|
||||
## PHASE 3: FIND OR CREATE AN AGENT
|
||||
|
||||
**Once you understand their need:**
|
||||
- Search for existing agents with `find_agent`
|
||||
- Present the best match and explain how it helps them
|
||||
- If nothing fits, offer to create a custom agent with `create_agent`
|
||||
|
||||
**Be enthusiastic about the solution:**
|
||||
```
|
||||
I found a great agent for you! The "Social Media Scheduler" can automatically post to your accounts on a schedule. Want to try it?
|
||||
```
|
||||
|
||||
## PHASE 4: SETUP & RUN
|
||||
|
||||
**Guide them through running the agent:**
|
||||
1. Call `run_agent` without inputs first to see what's needed
|
||||
2. Explain each input in simple terms
|
||||
3. Ask what values they want to use
|
||||
4. Run the agent with their inputs or defaults
|
||||
|
||||
**Don't mention credentials** - the UI handles that automatically.
|
||||
|
||||
## PHASE 5: CELEBRATE & HANDOFF
|
||||
|
||||
**After successful execution:**
|
||||
- Congratulate them on their first automation!
|
||||
- Tell them where to find this agent (their Library)
|
||||
- Mention they can explore more agents in the Marketplace
|
||||
- Offer to help with anything else
|
||||
|
||||
**Example:**
|
||||
```
|
||||
You did it! Your first agent is running. You can find it anytime in your Library. Ready to explore more automations?
|
||||
```
|
||||
|
||||
## KEY RULES
|
||||
|
||||
**What You DON'T Do:**
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't mention credentials (UI handles automatically)
|
||||
- Don't run agents without showing inputs first
|
||||
- Don't use `use_defaults=true` without explicit confirmation
|
||||
- Don't write responses longer than 3 sentences
|
||||
- Don't overwhelm with too many questions at once
|
||||
|
||||
**What You DO:**
|
||||
- ALWAYS get the user's name first
|
||||
- Be warm, encouraging, and celebratory
|
||||
- Save info with `add_understanding` as you learn it
|
||||
- Use their name when addressing them
|
||||
- Keep responses to maximum 3 sentences
|
||||
- Make them feel successful at each step
|
||||
|
||||
## USING add_understanding
|
||||
|
||||
Save information as you learn it:
|
||||
|
||||
**User info:** `user_name`, `job_title`
|
||||
**Business:** `business_name`, `industry`, `business_size`, `user_role`
|
||||
**Pain points:** `pain_points`, `manual_tasks`, `automation_goals`
|
||||
**Tools:** `current_software`
|
||||
|
||||
Example: `add_understanding(user_name="Sarah", job_title="Marketing Manager", automation_goals=["social media scheduling"])`
|
||||
|
||||
## HOW run_agent WORKS
|
||||
|
||||
1. **First call** (no inputs) → Shows available inputs
|
||||
2. **Credentials** → UI handles automatically (don't mention)
|
||||
3. **Execution** → Run with `inputs={...}` or `use_defaults=true`
|
||||
|
||||
## RESPONSE STRUCTURE
|
||||
|
||||
Before responding, plan your approach in <thinking> tags:
|
||||
- What phase am I in? (Welcome/Discovery/Find/Setup/Celebrate)
|
||||
- Do I know their name? If not, ask for it
|
||||
- What's the next step to move them forward?
|
||||
- Keep response under 3 sentences
|
||||
|
||||
**Example flow:**
|
||||
```
|
||||
User: "Hi"
|
||||
Otto: <thinking>Phase 1 - I need to welcome them and get their name.</thinking>
|
||||
Hi! I'm Otto, welcome to AutoGPT! I'm here to help you set up your first automation - what's your name?
|
||||
|
||||
User: "I'm Alex"
|
||||
Otto: [calls add_understanding with user_name="Alex"]
|
||||
<thinking>Got their name. Phase 2 - learn about them.</thinking>
|
||||
Great to meet you, Alex! What do you do for work, and what's one task you'd love to automate?
|
||||
|
||||
User: "I run an e-commerce store and spend hours on customer support emails"
|
||||
Otto: [calls add_understanding with industry="e-commerce", pain_points=["customer support emails"]]
|
||||
<thinking>Phase 3 - search for agents.</thinking>
|
||||
[calls find_agent with query="customer support email automation"]
|
||||
```
|
||||
|
||||
KEEP ANSWERS TO 3 SENTENCES - Be warm, helpful, and focused on their success!
|
||||
@@ -1,472 +0,0 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from .config import ChatConfig
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class StreamChatRequest(BaseModel):
|
||||
"""Request model for streaming chat with optional context."""
|
||||
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""Response model containing information on a newly created chat session."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
"""Response model providing complete details for a chat session, including messages."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
"""Response model for a session summary (without messages)."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
"""Response model for listing chat sessions."""
|
||||
|
||||
sessions: list[SessionSummaryResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def list_sessions(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> ListSessionsResponse:
|
||||
"""
|
||||
List chat sessions for the authenticated user.
|
||||
|
||||
Returns a paginated list of chat sessions belonging to the current user,
|
||||
ordered by most recently updated.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of sessions to return (1-100).
|
||||
offset: Number of sessions to skip for pagination.
|
||||
|
||||
Returns:
|
||||
ListSessionsResponse: List of session summaries and total count.
|
||||
"""
|
||||
sessions = await chat_service.get_user_sessions(user_id, limit, offset)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=None, # TODO: Add title support
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
total=len(sessions),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
|
||||
Initiates a new chat session for either an authenticated or anonymous user.
|
||||
|
||||
Args:
|
||||
user_id: The optional authenticated user ID parsed from the JWT. If missing, creates an anonymous session.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
|
||||
)
|
||||
|
||||
session = await chat_service.create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
|
||||
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
f"Returning session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
# Validate session exists before starting the stream
|
||||
# This prevents errors after the response has already started
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found. ")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat_get(
|
||||
session_id: str,
|
||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
is_user_message: bool = Query(default=True),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (GET - legacy endpoint).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
message: The user's new message to process.
|
||||
user_id: Optional authenticated user ID.
|
||||
is_user_message: Whether the message is a user message.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
# Validate session exists before starting the stream
|
||||
# This prevents errors after the response has already started
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found. ")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/assign-user",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
)
|
||||
async def session_assign_user(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""
|
||||
Assign an authenticated user to a chat session.
|
||||
|
||||
Used (typically post-login) to claim an existing anonymous session as the current authenticated user.
|
||||
|
||||
Args:
|
||||
session_id: The identifier for the (previously anonymous) session.
|
||||
user_id: The authenticated user's ID to associate with the session.
|
||||
|
||||
Returns:
|
||||
dict: Status of the assignment.
|
||||
|
||||
"""
|
||||
await chat_service.assign_user_to_session(session_id, user_id)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Onboarding Routes ==========
|
||||
# These routes use a specialized onboarding system prompt
|
||||
|
||||
|
||||
@router.post(
|
||||
"/onboarding/sessions",
|
||||
)
|
||||
async def create_onboarding_session(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new onboarding chat session.
|
||||
|
||||
Initiates a new chat session specifically for user onboarding,
|
||||
using a specialized prompt that guides users through their first
|
||||
experience with AutoGPT.
|
||||
|
||||
Args:
|
||||
user_id: The optional authenticated user ID parsed from the JWT.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created onboarding session.
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating onboarding session with user_id: "
|
||||
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
|
||||
)
|
||||
|
||||
session = await chat_service.create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/onboarding/sessions/{session_id}",
|
||||
)
|
||||
async def get_onboarding_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of an onboarding chat session.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the onboarding session.
|
||||
user_id: The optional authenticated user ID.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session.
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
f"Returning onboarding session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/onboarding/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_onboarding_chat(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream onboarding chat responses for a session.
|
||||
|
||||
Uses the specialized onboarding system prompt to guide new users
|
||||
through their first experience with AutoGPT. Streams AI responses
|
||||
in real time over Server-Sent Events (SSE).
|
||||
|
||||
Args:
|
||||
session_id: The onboarding session identifier.
|
||||
request: Request body containing message and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
context=request.context,
|
||||
prompt_type="onboarding", # Use onboarding system prompt
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
@router.get("/health", status_code=200)
|
||||
async def health_check() -> dict:
|
||||
"""
|
||||
Health check endpoint for the chat service.
|
||||
|
||||
Performs a full cycle test of session creation, assignment, and retrieval. Should always return healthy
|
||||
if the service and data layer are operational.
|
||||
|
||||
Returns:
|
||||
dict: A status dictionary indicating health, service name, and API version.
|
||||
|
||||
"""
|
||||
session = await chat_service.create_chat_session(None)
|
||||
await chat_service.assign_user_to_session(session.session_id, "test_user")
|
||||
await chat_service.get_session(session.session_id, "test_user")
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "chat",
|
||||
"version": "0.1.0",
|
||||
}
|
||||
@@ -1,202 +0,0 @@
|
||||
"""Tool for capturing user business understanding incrementally."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AddUnderstandingTool(BaseTool):
|
||||
"""Tool for capturing user's business understanding incrementally."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "add_understanding"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Capture and store information about the user's business context,
|
||||
workflows, pain points, and automation goals. Call this tool whenever the user
|
||||
shares information about their business. Each call incrementally adds to the
|
||||
existing understanding - you don't need to provide all fields at once.
|
||||
|
||||
Use this to build a comprehensive profile that helps recommend better agents
|
||||
and automations for the user's specific needs."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user_name": {
|
||||
"type": "string",
|
||||
"description": "The user's name",
|
||||
},
|
||||
"job_title": {
|
||||
"type": "string",
|
||||
"description": "The user's job title (e.g., 'Marketing Manager', 'CEO', 'Software Engineer')",
|
||||
},
|
||||
"business_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the user's business or organization",
|
||||
},
|
||||
"industry": {
|
||||
"type": "string",
|
||||
"description": "Industry or sector (e.g., 'e-commerce', 'healthcare', 'finance')",
|
||||
},
|
||||
"business_size": {
|
||||
"type": "string",
|
||||
"description": "Company size: '1-10', '11-50', '51-200', '201-1000', or '1000+'",
|
||||
},
|
||||
"user_role": {
|
||||
"type": "string",
|
||||
"description": "User's role in organization context (e.g., 'decision maker', 'implementer', 'end user')",
|
||||
},
|
||||
"key_workflows": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Key business workflows (e.g., 'lead qualification', 'content publishing')",
|
||||
},
|
||||
"daily_activities": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Regular daily activities the user performs",
|
||||
},
|
||||
"pain_points": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Current pain points or challenges",
|
||||
},
|
||||
"bottlenecks": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Process bottlenecks slowing things down",
|
||||
},
|
||||
"manual_tasks": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Manual or repetitive tasks that could be automated",
|
||||
},
|
||||
"automation_goals": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Desired automation outcomes or goals",
|
||||
},
|
||||
"current_software": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Software and tools currently in use",
|
||||
},
|
||||
"existing_automation": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Any existing automations or integrations",
|
||||
},
|
||||
"additional_notes": {
|
||||
"type": "string",
|
||||
"description": "Any other relevant context or notes",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""Requires authentication to store user-specific data."""
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Capture and store business understanding incrementally.
|
||||
|
||||
Each call merges new data with existing understanding:
|
||||
- String fields are overwritten if provided
|
||||
- List fields are appended (with deduplication)
|
||||
"""
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required to save business understanding.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if any data was provided
|
||||
if not any(v is not None for v in kwargs.values()):
|
||||
return ErrorResponse(
|
||||
message="Please provide at least one field to update.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build input model
|
||||
input_data = BusinessUnderstandingInput(
|
||||
user_name=kwargs.get("user_name"),
|
||||
job_title=kwargs.get("job_title"),
|
||||
business_name=kwargs.get("business_name"),
|
||||
industry=kwargs.get("industry"),
|
||||
business_size=kwargs.get("business_size"),
|
||||
user_role=kwargs.get("user_role"),
|
||||
key_workflows=kwargs.get("key_workflows"),
|
||||
daily_activities=kwargs.get("daily_activities"),
|
||||
pain_points=kwargs.get("pain_points"),
|
||||
bottlenecks=kwargs.get("bottlenecks"),
|
||||
manual_tasks=kwargs.get("manual_tasks"),
|
||||
automation_goals=kwargs.get("automation_goals"),
|
||||
current_software=kwargs.get("current_software"),
|
||||
existing_automation=kwargs.get("existing_automation"),
|
||||
additional_notes=kwargs.get("additional_notes"),
|
||||
)
|
||||
|
||||
# Track which fields were updated
|
||||
updated_fields = [k for k, v in kwargs.items() if v is not None]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
|
||||
# Build current understanding summary for the response
|
||||
current_understanding = {
|
||||
"user_name": understanding.user_name,
|
||||
"job_title": understanding.job_title,
|
||||
"business_name": understanding.business_name,
|
||||
"industry": understanding.industry,
|
||||
"business_size": understanding.business_size,
|
||||
"user_role": understanding.user_role,
|
||||
"key_workflows": understanding.key_workflows,
|
||||
"daily_activities": understanding.daily_activities,
|
||||
"pain_points": understanding.pain_points,
|
||||
"bottlenecks": understanding.bottlenecks,
|
||||
"manual_tasks": understanding.manual_tasks,
|
||||
"automation_goals": understanding.automation_goals,
|
||||
"current_software": understanding.current_software,
|
||||
"existing_automation": understanding.existing_automation,
|
||||
"additional_notes": understanding.additional_notes,
|
||||
}
|
||||
|
||||
# Filter out empty values for cleaner response
|
||||
current_understanding = {
|
||||
k: v
|
||||
for k, v in current_understanding.items()
|
||||
if v is not None and v != [] and v != ""
|
||||
}
|
||||
|
||||
return UnderstandingUpdatedResponse(
|
||||
message=f"Updated understanding with: {', '.join(updated_fields)}. "
|
||||
"I now have a better picture of your business context.",
|
||||
session_id=session_id,
|
||||
updated_fields=updated_fields,
|
||||
current_understanding=current_understanding,
|
||||
)
|
||||
@@ -1,455 +0,0 @@
|
||||
"""Tool for retrieving agent execution outputs from user's library."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentOutputResponse,
|
||||
ErrorResponse,
|
||||
ExecutionOutputInfo,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .utils import fetch_graph_from_store_slug
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentOutputInput(BaseModel):
|
||||
"""Input parameters for the agent_output tool."""
|
||||
|
||||
agent_name: str = ""
|
||||
library_agent_id: str = ""
|
||||
store_slug: str = ""
|
||||
execution_id: str = ""
|
||||
run_time: str = "latest"
|
||||
|
||||
@field_validator(
|
||||
"agent_name",
|
||||
"library_agent_id",
|
||||
"store_slug",
|
||||
"execution_id",
|
||||
"run_time",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> Any:
|
||||
"""Strip whitespace from string fields."""
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
|
||||
def parse_time_expression(
|
||||
time_expr: str | None,
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
Parse time expression into datetime range (start, end).
|
||||
|
||||
Supports:
|
||||
- "latest" or None -> returns (None, None) to get most recent
|
||||
- "yesterday" -> 24h window for yesterday
|
||||
- "today" -> Today from midnight
|
||||
- "last week" / "last 7 days" -> 7 day window
|
||||
- "last month" / "last 30 days" -> 30 day window
|
||||
- ISO date "YYYY-MM-DD" -> 24h window for that date
|
||||
"""
|
||||
if not time_expr or time_expr.lower() == "latest":
|
||||
return None, None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expr = time_expr.lower().strip()
|
||||
|
||||
# Relative expressions
|
||||
if expr == "yesterday":
|
||||
end = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
start = end - timedelta(days=1)
|
||||
return start, end
|
||||
|
||||
if expr in ("last week", "last 7 days"):
|
||||
return now - timedelta(days=7), now
|
||||
|
||||
if expr in ("last month", "last 30 days"):
|
||||
return now - timedelta(days=30), now
|
||||
|
||||
if expr == "today":
|
||||
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
return start, now
|
||||
|
||||
# Try ISO date format (YYYY-MM-DD)
|
||||
date_match = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", expr)
|
||||
if date_match:
|
||||
year, month, day = map(int, date_match.groups())
|
||||
start = datetime(year, month, day, 0, 0, 0, tzinfo=timezone.utc)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
|
||||
# Try ISO datetime
|
||||
try:
|
||||
parsed = datetime.fromisoformat(expr.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
# Return +/- 1 hour window around the specified time
|
||||
return parsed - timedelta(hours=1), parsed + timedelta(hours=1)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Fallback: treat as "latest"
|
||||
return None, None
|
||||
|
||||
|
||||
class AgentOutputTool(BaseTool):
|
||||
"""Tool for retrieving execution outputs from user's library agents."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "agent_output"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Retrieve execution outputs from agents in the user's library.
|
||||
|
||||
Identify the agent using one of:
|
||||
- agent_name: Fuzzy search in user's library
|
||||
- library_agent_id: Exact library agent ID
|
||||
- store_slug: Marketplace format 'username/agent-name'
|
||||
|
||||
Select which run to retrieve using:
|
||||
- execution_id: Specific execution ID
|
||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||
"""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "Agent name to search for in user's library (fuzzy match)",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Exact library agent ID",
|
||||
},
|
||||
"store_slug": {
|
||||
"type": "string",
|
||||
"description": "Marketplace identifier: 'username/agent-slug'",
|
||||
},
|
||||
"execution_id": {
|
||||
"type": "string",
|
||||
"description": "Specific execution ID to retrieve",
|
||||
},
|
||||
"run_time": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _resolve_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_name: str | None,
|
||||
library_agent_id: str | None,
|
||||
store_slug: str | None,
|
||||
) -> tuple[LibraryAgent | None, str | None]:
|
||||
"""
|
||||
Resolve agent from provided identifiers.
|
||||
Returns (library_agent, error_message).
|
||||
"""
|
||||
# Priority 1: Exact library agent ID
|
||||
if library_agent_id:
|
||||
try:
|
||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||
return agent, None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||
return None, f"Library agent '{library_agent_id}' not found"
|
||||
|
||||
# Priority 2: Store slug (username/agent-name)
|
||||
if store_slug and "/" in store_slug:
|
||||
username, agent_slug = store_slug.split("/", 1)
|
||||
graph, _ = await fetch_graph_from_store_slug(username, agent_slug)
|
||||
if not graph:
|
||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||
|
||||
# Find in user's library by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
if not agent:
|
||||
return (
|
||||
None,
|
||||
f"Agent '{store_slug}' is not in your library. "
|
||||
"Add it first to see outputs.",
|
||||
)
|
||||
return agent, None
|
||||
|
||||
# Priority 3: Fuzzy name search in library
|
||||
if agent_name:
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=agent_name,
|
||||
page_size=5,
|
||||
)
|
||||
if not response.agents:
|
||||
return (
|
||||
None,
|
||||
f"No agents matching '{agent_name}' found in your library",
|
||||
)
|
||||
|
||||
# Return best match (first result from search)
|
||||
return response.agents[0], None
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching library agents: {e}")
|
||||
return None, f"Error searching for agent: {e}"
|
||||
|
||||
return (
|
||||
None,
|
||||
"Please specify an agent name, library_agent_id, or store_slug",
|
||||
)
|
||||
|
||||
async def _get_execution(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
execution_id: str | None,
|
||||
time_start: datetime | None,
|
||||
time_end: datetime | None,
|
||||
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
||||
"""
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
"""
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return None, [], f"Execution '{execution_id}' not found"
|
||||
return execution, [], None
|
||||
|
||||
# Get completed executions with time filters
|
||||
executions = await execution_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
created_time_gte=time_start,
|
||||
created_time_lte=time_end,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
if not executions:
|
||||
return None, [], None # No error, just no executions
|
||||
|
||||
# If only one execution, fetch full details
|
||||
if len(executions) == 1:
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, [], None
|
||||
|
||||
# Multiple executions - return latest with full details, plus list of available
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, executions, None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
agent: LibraryAgent,
|
||||
execution: GraphExecution | None,
|
||||
available_executions: list[GraphExecutionMeta],
|
||||
session_id: str | None,
|
||||
) -> AgentOutputResponse:
|
||||
"""Build the response based on execution data."""
|
||||
library_agent_link = f"/library/agents/{agent.id}"
|
||||
|
||||
if not execution:
|
||||
return AgentOutputResponse(
|
||||
message=f"No completed executions found for agent '{agent.name}'",
|
||||
session_id=session_id,
|
||||
agent_name=agent.name,
|
||||
agent_id=agent.graph_id,
|
||||
library_agent_id=agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
total_executions=0,
|
||||
)
|
||||
|
||||
execution_info = ExecutionOutputInfo(
|
||||
execution_id=execution.id,
|
||||
status=execution.status.value,
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
outputs=dict(execution.outputs),
|
||||
inputs_summary=execution.inputs if execution.inputs else None,
|
||||
)
|
||||
|
||||
available_list = None
|
||||
if len(available_executions) > 1:
|
||||
available_list = [
|
||||
{
|
||||
"id": e.id,
|
||||
"status": e.status.value,
|
||||
"started_at": e.started_at.isoformat() if e.started_at else None,
|
||||
}
|
||||
for e in available_executions[:5]
|
||||
]
|
||||
|
||||
message = f"Found execution outputs for agent '{agent.name}'"
|
||||
if len(available_executions) > 1:
|
||||
message += (
|
||||
f". Showing latest of {len(available_executions)} matching executions."
|
||||
)
|
||||
|
||||
return AgentOutputResponse(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
agent_name=agent.name,
|
||||
agent_id=agent.graph_id,
|
||||
library_agent_id=agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
execution=execution_info,
|
||||
available_executions=available_list,
|
||||
total_executions=len(available_executions) if available_executions else 1,
|
||||
)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the agent_output tool."""
|
||||
session_id = session.session_id
|
||||
|
||||
# Parse and validate input
|
||||
try:
|
||||
input_data = AgentOutputInput(**kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid input: {e}")
|
||||
return ErrorResponse(
|
||||
message="Invalid input parameters",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure user_id is present (should be guaranteed by requires_auth)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if at least one identifier is provided
|
||||
if not any(
|
||||
[
|
||||
input_data.agent_name,
|
||||
input_data.library_agent_id,
|
||||
input_data.store_slug,
|
||||
input_data.execution_id,
|
||||
]
|
||||
):
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Please specify at least one of: agent_name, "
|
||||
"library_agent_id, store_slug, or execution_id"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# If only execution_id provided, we need to find the agent differently
|
||||
if (
|
||||
input_data.execution_id
|
||||
and not input_data.agent_name
|
||||
and not input_data.library_agent_id
|
||||
and not input_data.store_slug
|
||||
):
|
||||
# Fetch execution directly to get graph_id
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return ErrorResponse(
|
||||
message=f"Execution '{input_data.execution_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Find library agent by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id, execution.graph_id
|
||||
)
|
||||
if not agent:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Execution found but agent not in your library. "
|
||||
f"Graph ID: {execution.graph_id}"
|
||||
),
|
||||
session_id=session_id,
|
||||
suggestions=["Add the agent to your library to see more details"],
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, [], session_id)
|
||||
|
||||
# Resolve agent from identifiers
|
||||
agent, error = await self._resolve_agent(
|
||||
user_id=user_id,
|
||||
agent_name=input_data.agent_name or None,
|
||||
library_agent_id=input_data.library_agent_id or None,
|
||||
store_slug=input_data.store_slug or None,
|
||||
)
|
||||
|
||||
if error or not agent:
|
||||
return NoResultsResponse(
|
||||
message=error or "Agent not found",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Check the agent name or ID",
|
||||
"Make sure the agent is in your library",
|
||||
],
|
||||
)
|
||||
|
||||
# Parse time expression
|
||||
time_start, time_end = parse_time_expression(input_data.run_time)
|
||||
|
||||
# Fetch execution(s)
|
||||
execution, available_executions, exec_error = await self._get_execution(
|
||||
user_id=user_id,
|
||||
graph_id=agent.graph_id,
|
||||
execution_id=input_data.execution_id or None,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
)
|
||||
|
||||
if exec_error:
|
||||
return ErrorResponse(
|
||||
message=exec_error,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, available_executions, session_id)
|
||||
@@ -1,157 +0,0 @@
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.util.exceptions import DatabaseError
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentCarouselResponse,
|
||||
AgentInfo,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindLibraryAgentTool(BaseTool):
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_library_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for agents in the user's library. Use this to find agents "
|
||||
"the user has already added to their library, including agents they "
|
||||
"created or added from the marketplace."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find agents by name or description. "
|
||||
"Use keywords for best results."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for agents in the user's library.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
AgentCarouselResponse: List of agents found in the library
|
||||
NoResultsResponse: No agents found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agents = []
|
||||
try:
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
library_results = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Find library agents tool found {len(library_results.agents)} agents"
|
||||
)
|
||||
|
||||
for agent in library_results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
),
|
||||
)
|
||||
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching library agents: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search library. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"No agents found matching '{query}' in your library. "
|
||||
"Try different keywords or use find_agent to search the marketplace."
|
||||
),
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
],
|
||||
)
|
||||
|
||||
title = (
|
||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
|
||||
f"in your library for '{query}'"
|
||||
)
|
||||
|
||||
return AgentCarouselResponse(
|
||||
message=(
|
||||
"Found agents in the user's library. You can provide a link to "
|
||||
"view an agent at: /library/agents/{agent_id}. "
|
||||
"Use agent_output to get execution results, or run_agent to execute."
|
||||
),
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,41 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def sort_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Patch a FastAPI instance's `openapi()` method to sort the endpoints,
|
||||
schemas, and responses.
|
||||
"""
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = wrapped_openapi()
|
||||
|
||||
# Sort endpoints
|
||||
openapi_schema["paths"] = dict(sorted(openapi_schema["paths"].items()))
|
||||
|
||||
# Sort endpoints -> methods
|
||||
for p in openapi_schema["paths"].keys():
|
||||
openapi_schema["paths"][p] = dict(
|
||||
sorted(openapi_schema["paths"][p].items())
|
||||
)
|
||||
|
||||
# Sort endpoints -> methods -> responses
|
||||
for m in openapi_schema["paths"][p].keys():
|
||||
openapi_schema["paths"][p][m]["responses"] = dict(
|
||||
sorted(openapi_schema["paths"][p][m]["responses"].items())
|
||||
)
|
||||
|
||||
# Sort schemas and responses as well
|
||||
for k in openapi_schema["components"].keys():
|
||||
openapi_schema["components"][k] = dict(
|
||||
sorted(openapi_schema["components"][k].items())
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
@@ -36,10 +36,10 @@ def main(**kwargs):
|
||||
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
|
||||
"""
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
@@ -50,7 +49,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -82,7 +81,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
|
||||
@@ -6,9 +6,6 @@ import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
from prisma.types import Serializable
|
||||
|
||||
from backend.sdk import (
|
||||
BaseWebhooksManager,
|
||||
@@ -87,9 +84,7 @@ class AirtableWebhookManager(BaseWebhooksManager):
|
||||
# update webhook config
|
||||
await update_webhook(
|
||||
webhook.id,
|
||||
config=cast(
|
||||
dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor}
|
||||
),
|
||||
config={"base_id": base_id, "cursor": response.cursor},
|
||||
)
|
||||
|
||||
event_type = "notification"
|
||||
|
||||
@@ -182,10 +182,13 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
# Handle missing key, null value, or valid list value
|
||||
if isinstance(first_result, dict):
|
||||
items = first_result.get("items") or []
|
||||
else:
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
|
||||
@@ -319,7 +319,7 @@ class CostDollars(BaseModel):
|
||||
|
||||
# Helper functions for payload processing
|
||||
def process_text_field(
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None]
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None],
|
||||
) -> Optional[Union[bool, Dict[str, Any]]]:
|
||||
"""Process text field for API payload."""
|
||||
if text is None:
|
||||
@@ -400,7 +400,7 @@ def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str,
|
||||
|
||||
|
||||
def process_context_field(
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None]
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None],
|
||||
) -> Optional[Union[bool, Dict[str, int]]]:
|
||||
"""Process context field for API payload."""
|
||||
if context is None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -92,9 +92,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5_2 = "gpt-5.2-2025-12-11"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
@@ -195,9 +194,8 @@ MODEL_METADATA = {
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
|
||||
@@ -305,8 +303,6 @@ MODEL_METADATA = {
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
|
||||
}
|
||||
|
||||
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
@@ -794,7 +790,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -859,7 +855,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||
test_input={
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -1225,7 +1221,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -1321,7 +1317,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for summarizing the text.",
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
@@ -1538,7 +1534,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for the conversation.",
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
@@ -1576,7 +1572,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
],
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -1639,7 +1635,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
)
|
||||
@@ -1696,7 +1692,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
"force_json_output": False,
|
||||
|
||||
@@ -226,7 +226,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=llm.DEFAULT_LLM_MODEL,
|
||||
default=llm.LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -975,28 +975,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
nodes_to_skip: set[str] | None = None,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
tool_functions = await self._create_tool_node_signatures(node_id)
|
||||
original_tool_count = len(tool_functions)
|
||||
|
||||
# Filter out tools for nodes that should be skipped (e.g., missing optional credentials)
|
||||
if nodes_to_skip:
|
||||
tool_functions = [
|
||||
tf
|
||||
for tf in tool_functions
|
||||
if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip
|
||||
]
|
||||
|
||||
# Only raise error if we had tools but they were all filtered out
|
||||
if original_tool_count > 0 and not tool_functions:
|
||||
raise ValueError(
|
||||
"No available tools to execute - all downstream nodes are unavailable "
|
||||
"(possibly due to missing optional credentials)"
|
||||
)
|
||||
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
conversation_history = input_data.conversation_history or []
|
||||
|
||||
@@ -196,15 +196,6 @@ class TestXMLParserBlockSecurity:
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
|
||||
pass
|
||||
|
||||
async def test_rejects_text_outside_root(self):
|
||||
"""Ensure parser surfaces readable errors for invalid root text."""
|
||||
block = XMLParserBlock()
|
||||
invalid_xml = "<root><child>value</child></root> trailing"
|
||||
|
||||
with pytest.raises(ValueError, match="text outside the root element"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=invalid_xml)):
|
||||
pass
|
||||
|
||||
|
||||
class TestStoreMediaFileSecurity:
|
||||
"""Test file storage security limits."""
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestLLMStatsTracking:
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
llm_model=llm.DEFAULT_LLM_MODEL,
|
||||
llm_model=llm.LlmModel.GPT4O,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
@@ -65,7 +65,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore # type: ignore
|
||||
)
|
||||
|
||||
@@ -109,7 +109,7 @@ class TestLLMStatsTracking:
|
||||
# Run the block
|
||||
input_data = llm.AITextGeneratorBlock.Input(
|
||||
prompt="Generate text",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
@@ -228,7 +228,7 @@ class TestLLMStatsTracking:
|
||||
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text=long_text,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=100, # Small chunks
|
||||
chunk_overlap=10,
|
||||
@@ -299,7 +299,7 @@ class TestLLMStatsTracking:
|
||||
# Test with very short text (should only need 1 chunk + 1 final summary)
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="This is a short text.",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=1000, # Large enough to avoid chunking
|
||||
)
|
||||
@@ -346,7 +346,7 @@ class TestLLMStatsTracking:
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -387,7 +387,7 @@ class TestLLMStatsTracking:
|
||||
# Run the block
|
||||
input_data = llm.AIListGeneratorBlock.Input(
|
||||
focus="test items",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_retries=3,
|
||||
)
|
||||
@@ -469,7 +469,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"result": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -513,7 +513,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
style=llm.SummaryStyle.BULLET_POINTS,
|
||||
)
|
||||
@@ -558,7 +558,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
style=llm.SummaryStyle.BULLET_POINTS,
|
||||
max_tokens=1000,
|
||||
@@ -593,7 +593,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -623,7 +623,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=1000,
|
||||
)
|
||||
@@ -654,7 +654,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import ProviderName, User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
@@ -233,7 +233,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
# Create test input
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Should I continue with this task?",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -335,7 +335,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2, # Set retry to 2 for testing
|
||||
agent_mode_max_iterations=0,
|
||||
@@ -402,7 +402,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -462,7 +462,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -526,7 +526,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -648,7 +648,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
agent_mode_max_iterations=0,
|
||||
@@ -722,7 +722,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -778,7 +778,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -931,7 +931,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
# Test agent mode with max_iterations = 3
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Complete this task using tools",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
|
||||
)
|
||||
@@ -1020,7 +1020,7 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
# Test default behavior (traditional mode)
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0, # Traditional mode
|
||||
)
|
||||
|
||||
@@ -373,7 +373,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
input_data = block.input_schema(
|
||||
prompt="Create a user dictionary",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
|
||||
)
|
||||
|
||||
@@ -594,7 +594,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
input_data = block.input_schema(
|
||||
prompt="Test prompt",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
retry=3, # Allow retries
|
||||
agent_mode_max_iterations=1,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from gravitasml.parser import Parser
|
||||
from gravitasml.token import Token, tokenize
|
||||
from gravitasml.token import tokenize
|
||||
|
||||
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.model import SchemaField
|
||||
@@ -25,38 +25,6 @@ class XMLParserBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_tokens(tokens: list[Token]) -> None:
|
||||
"""Ensure the XML has a single root element and no stray text."""
|
||||
if not tokens:
|
||||
raise ValueError("XML input is empty.")
|
||||
|
||||
depth = 0
|
||||
root_seen = False
|
||||
|
||||
for token in tokens:
|
||||
if token.type == "TAG_OPEN":
|
||||
if depth == 0 and root_seen:
|
||||
raise ValueError("XML must have a single root element.")
|
||||
depth += 1
|
||||
if depth == 1:
|
||||
root_seen = True
|
||||
elif token.type == "TAG_CLOSE":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise SyntaxError("Unexpected closing tag in XML input.")
|
||||
elif token.type in {"TEXT", "ESCAPE"}:
|
||||
if depth == 0 and token.value:
|
||||
raise ValueError(
|
||||
"XML contains text outside the root element; "
|
||||
"wrap content in a single root tag."
|
||||
)
|
||||
|
||||
if depth != 0:
|
||||
raise SyntaxError("Unclosed tag detected in XML input.")
|
||||
if not root_seen:
|
||||
raise ValueError("XML must include a root element.")
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add size limits to prevent XML bomb attacks
|
||||
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
|
||||
@@ -67,9 +35,7 @@ class XMLParserBlock(Block):
|
||||
)
|
||||
|
||||
try:
|
||||
tokens = list(tokenize(input_data.input_xml))
|
||||
self._validate_tokens(tokens)
|
||||
|
||||
tokens = tokenize(input_data.input_xml)
|
||||
parser = Parser(tokens)
|
||||
parsed_result = parser.parse()
|
||||
yield "parsed_xml", parsed_result
|
||||
|
||||
@@ -111,8 +111,6 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
return parsed_url.path.split("/")[2]
|
||||
if parsed_url.path[:3] == "/v/":
|
||||
return parsed_url.path.split("/")[2]
|
||||
if parsed_url.path.startswith("/shorts/"):
|
||||
return parsed_url.path.split("/")[2]
|
||||
raise ValueError(f"Invalid YouTube URL: {url}")
|
||||
|
||||
def get_transcript(
|
||||
|
||||
@@ -244,7 +244,11 @@ def websocket(server_address: str, graph_exec_id: str):
|
||||
|
||||
import websockets.asyncio.client
|
||||
|
||||
from backend.api.ws_api import WSMessage, WSMethod, WSSubscribeGraphExecutionRequest
|
||||
from backend.server.ws_api import (
|
||||
WSMessage,
|
||||
WSMethod,
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
)
|
||||
|
||||
async def send_message(server_address: str):
|
||||
uri = f"ws://{server_address}"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"""
|
||||
Script to generate OpenAPI JSON specification for the FastAPI app.
|
||||
|
||||
This script imports the FastAPI app from backend.api.rest_api and outputs
|
||||
This script imports the FastAPI app from backend.server.rest_api and outputs
|
||||
the OpenAPI specification as JSON to stdout or a specified file.
|
||||
|
||||
Usage:
|
||||
@@ -46,7 +46,7 @@ def main(output: Path, pretty: bool):
|
||||
|
||||
def get_openapi_schema():
|
||||
"""Get the OpenAPI schema from the FastAPI app"""
|
||||
from backend.api.rest_api import app
|
||||
from backend.server.rest_api import app
|
||||
|
||||
return app.openapi()
|
||||
|
||||
|
||||
@@ -36,12 +36,13 @@ import secrets
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import click
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission
|
||||
from prisma.types import OAuthApplicationCreateInput
|
||||
|
||||
keysmith = APIKeySmith()
|
||||
|
||||
@@ -834,19 +835,22 @@ async def create_test_app_in_db(
|
||||
|
||||
# Insert into database
|
||||
app = await OAuthApplication.prisma().create(
|
||||
data={
|
||||
"id": creds["id"],
|
||||
"name": creds["name"],
|
||||
"description": creds["description"],
|
||||
"clientId": creds["client_id"],
|
||||
"clientSecret": creds["client_secret_hash"],
|
||||
"clientSecretSalt": creds["client_secret_salt"],
|
||||
"redirectUris": creds["redirect_uris"],
|
||||
"grantTypes": creds["grant_types"],
|
||||
"scopes": creds["scopes"],
|
||||
"ownerId": owner_id,
|
||||
"isActive": True,
|
||||
}
|
||||
data=cast(
|
||||
OAuthApplicationCreateInput,
|
||||
{
|
||||
"id": creds["id"],
|
||||
"name": creds["name"],
|
||||
"description": creds["description"],
|
||||
"clientId": creds["client_id"],
|
||||
"clientSecret": creds["client_secret_hash"],
|
||||
"clientSecretSalt": creds["client_secret_salt"],
|
||||
"redirectUris": creds["redirect_uris"],
|
||||
"grantTypes": creds["grant_types"],
|
||||
"scopes": creds["scopes"],
|
||||
"ownerId": owner_id,
|
||||
"isActive": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
click.echo(f"✓ Created test OAuth application: {app.clientId}")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .graph import NodeModel
|
||||
from .integrations import Webhook # noqa: F401
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, cast
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||
from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput
|
||||
from pydantic import Field
|
||||
|
||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||
@@ -82,17 +82,20 @@ async def create_api_key(
|
||||
generated_key = keysmith.generate_key()
|
||||
|
||||
saved_key_obj = await PrismaAPIKey.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
}
|
||||
data=cast(
|
||||
APIKeyCreateInput,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, cast
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission as APIPermission
|
||||
@@ -22,7 +22,12 @@ from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
||||
from prisma.models import OAuthApplication as PrismaOAuthApplication
|
||||
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
|
||||
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
|
||||
from prisma.types import OAuthApplicationUpdateInput
|
||||
from prisma.types import (
|
||||
OAuthAccessTokenCreateInput,
|
||||
OAuthApplicationUpdateInput,
|
||||
OAuthAuthorizationCodeCreateInput,
|
||||
OAuthRefreshTokenCreateInput,
|
||||
)
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from .base import APIAuthorizationInfo
|
||||
@@ -359,17 +364,20 @@ async def create_authorization_code(
|
||||
expires_at = now + AUTHORIZATION_CODE_TTL
|
||||
|
||||
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"code": code,
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
"redirectUri": redirect_uri,
|
||||
"codeChallenge": code_challenge,
|
||||
"codeChallengeMethod": code_challenge_method,
|
||||
}
|
||||
data=cast(
|
||||
OAuthAuthorizationCodeCreateInput,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"code": code,
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
"redirectUri": redirect_uri,
|
||||
"codeChallenge": code_challenge,
|
||||
"codeChallengeMethod": code_challenge_method,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthAuthorizationCodeInfo.from_db(saved_code)
|
||||
@@ -490,14 +498,17 @@ async def create_access_token(
|
||||
expires_at = now + ACCESS_TOKEN_TTL
|
||||
|
||||
saved_token = await PrismaOAuthAccessToken.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
}
|
||||
data=cast(
|
||||
OAuthAccessTokenCreateInput,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||
@@ -607,14 +618,17 @@ async def create_refresh_token(
|
||||
expires_at = now + REFRESH_TOKEN_TTL
|
||||
|
||||
saved_token = await PrismaOAuthRefreshToken.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
}
|
||||
data=cast(
|
||||
OAuthRefreshTokenCreateInput,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||
|
||||
@@ -59,13 +59,12 @@ from backend.integrations.credentials_store import (
|
||||
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 4,
|
||||
LlmModel.O3_MINI: 2,
|
||||
LlmModel.O1: 16,
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: 6,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 5,
|
||||
@@ -88,7 +87,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
|
||||
@@ -16,7 +16,6 @@ from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBala
|
||||
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.admin.model import UserHistoryResponse
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
@@ -30,6 +29,7 @@ from backend.data.model import (
|
||||
from backend.data.notifications import NotificationEventModel, RefundRequestData
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.json import SafeJson, dumps
|
||||
@@ -341,19 +341,6 @@ class UserCreditBase(ABC):
|
||||
|
||||
if result:
|
||||
# UserBalance is already updated by the CTE
|
||||
|
||||
# Clear insufficient funds notification flags when credits are added
|
||||
# so user can receive alerts again if they run out in the future.
|
||||
if transaction.amount > 0 and transaction.type in [
|
||||
CreditTransactionType.GRANT,
|
||||
CreditTransactionType.TOP_UP,
|
||||
]:
|
||||
from backend.executor.manager import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
return result[0]["balance"]
|
||||
|
||||
async def _add_transaction(
|
||||
@@ -543,22 +530,6 @@ class UserCreditBase(ABC):
|
||||
if result:
|
||||
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
|
||||
# UserBalance is already updated by the CTE
|
||||
|
||||
# Clear insufficient funds notification flags when credits are added
|
||||
# so user can receive alerts again if they run out in the future.
|
||||
if (
|
||||
amount > 0
|
||||
and is_active
|
||||
and transaction_type
|
||||
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
|
||||
):
|
||||
# Lazy import to avoid circular dependency with executor.manager
|
||||
from backend.executor.manager import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
return new_balance, tx_key
|
||||
|
||||
# If no result, either user doesn't exist or insufficient balance
|
||||
|
||||
@@ -5,12 +5,14 @@ This test was added to cover a previously untested code path that could lead to
|
||||
incorrect balance capping behavior.
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceUpsertInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -21,11 +23,14 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for ceiling tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
},
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -33,7 +38,10 @@ async def create_test_user(user_id: str) -> None:
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ without race conditions, deadlocks, or inconsistent state.
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import prisma.enums
|
||||
@@ -14,6 +15,7 @@ import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceUpsertInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
@@ -28,11 +30,14 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user with initial balance."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
},
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -41,7 +46,10 @@ async def create_test_user(user_id: str) -> None:
|
||||
# Ensure UserBalance record exists
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -342,10 +350,13 @@ async def test_integer_overflow_protection(server: SpinTestServer):
|
||||
# First, set balance near max
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": max_int - 100},
|
||||
"update": {"balance": max_int - 100},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": max_int - 100},
|
||||
"update": {"balance": max_int - 100},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
|
||||
|
||||
@@ -5,9 +5,12 @@ These tests run actual database operations to ensure SQL queries work correctly,
|
||||
which would have caught the CreditTransactionType enum casting bug.
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserCreateInput
|
||||
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
@@ -29,12 +32,15 @@ async def cleanup_test_user():
|
||||
# Create the user first
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"topUpConfig": SafeJson({}),
|
||||
"timezone": "UTC",
|
||||
}
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"topUpConfig": SafeJson({}),
|
||||
"timezone": "UTC",
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# User might already exist, that's fine
|
||||
|
||||
@@ -6,12 +6,19 @@ are atomic and maintain data consistency.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import stripe
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||
from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
UserBalanceCreateInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -35,32 +42,41 @@ async def setup_test_user_with_topup():
|
||||
|
||||
# Create user
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": REFUND_TEST_USER_ID,
|
||||
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
||||
"name": "Refund Test User",
|
||||
}
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": REFUND_TEST_USER_ID,
|
||||
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
||||
"name": "Refund Test User",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Create user balance
|
||||
await UserBalance.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"balance": 1000, # $10
|
||||
}
|
||||
data=cast(
|
||||
UserBalanceCreateInput,
|
||||
{
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"balance": 1000, # $10
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Create a top-up transaction that can be refunded
|
||||
topup_tx = await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 1000,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"transactionKey": "pi_test_12345",
|
||||
"runningBalance": 1000,
|
||||
"isActive": True,
|
||||
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
}
|
||||
data=cast(
|
||||
CreditTransactionCreateInput,
|
||||
{
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 1000,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"transactionKey": "pi_test_12345",
|
||||
"runningBalance": 1000,
|
||||
"isActive": True,
|
||||
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return topup_tx
|
||||
@@ -93,12 +109,15 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
|
||||
|
||||
# Create refund request record (simulating webhook flow)
|
||||
await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 500,
|
||||
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
||||
"reason": "Test refund",
|
||||
}
|
||||
data=cast(
|
||||
CreditRefundRequestCreateInput,
|
||||
{
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 500,
|
||||
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
||||
"reason": "Test refund",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Call deduct_credits
|
||||
@@ -286,12 +305,15 @@ async def test_concurrent_refunds(server: SpinTestServer):
|
||||
refund_requests = []
|
||||
for i in range(5):
|
||||
req = await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 100, # $1 each
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
"reason": f"Test refund {i}",
|
||||
}
|
||||
data=cast(
|
||||
CreditRefundRequestCreateInput,
|
||||
{
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 100, # $1 each
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
"reason": f"Test refund {i}",
|
||||
},
|
||||
)
|
||||
)
|
||||
refund_requests.append(req)
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, UserBalance
|
||||
from prisma.types import CreditTransactionCreateInput, UserBalanceUpsertInput
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
@@ -23,10 +25,13 @@ async def disable_test_user_transactions():
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
||||
"update": {"balance": 0, "updatedAt": old_date},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
||||
"update": {"balance": 0, "updatedAt": old_date},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -140,23 +145,29 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
|
||||
# Manually create a transaction with month 1 timestamp to establish history
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": 100,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"runningBalance": 1100,
|
||||
"isActive": True,
|
||||
"createdAt": month1, # Set specific timestamp
|
||||
}
|
||||
data=cast(
|
||||
CreditTransactionCreateInput,
|
||||
{
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": 100,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"runningBalance": 1100,
|
||||
"isActive": True,
|
||||
"createdAt": month1, # Set specific timestamp
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Update user balance to match
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
||||
"update": {"balance": 1100},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
||||
"update": {"balance": 1100},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Now test month 2 behavior
|
||||
@@ -175,14 +186,17 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
|
||||
# Create a month 2 transaction to update the last transaction time
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": -700, # Spent 700 to get to 400
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"runningBalance": 400,
|
||||
"isActive": True,
|
||||
"createdAt": month2,
|
||||
}
|
||||
data=cast(
|
||||
CreditTransactionCreateInput,
|
||||
{
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": -700, # Spent 700 to get to 400
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"runningBalance": 400,
|
||||
"isActive": True,
|
||||
"createdAt": month2,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Move to month 3
|
||||
|
||||
@@ -6,12 +6,14 @@ doesn't underflow below POSTGRES_INT_MIN, which could cause integer wraparound i
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceUpsertInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
|
||||
from backend.util.test import SpinTestServer
|
||||
@@ -21,11 +23,14 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for underflow tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
},
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -33,7 +38,10 @@ async def create_test_user(user_id: str) -> None:
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -70,10 +78,13 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance_target},
|
||||
"update": {"balance": initial_balance_target},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": initial_balance_target},
|
||||
"update": {"balance": initial_balance_target},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
@@ -110,10 +121,13 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
# Set balance to exactly POSTGRES_INT_MIN
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
||||
"update": {"balance": POSTGRES_INT_MIN},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
||||
"update": {"balance": POSTGRES_INT_MIN},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
edge_balance = await credit_system.get_credits(user_id)
|
||||
@@ -152,10 +166,13 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
|
||||
test_balance = POSTGRES_INT_MIN + 1000
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": test_balance},
|
||||
"update": {"balance": test_balance},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": test_balance},
|
||||
"update": {"balance": test_balance},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
@@ -217,10 +234,13 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
|
||||
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Apply multiple refunds that would cumulatively underflow
|
||||
@@ -295,10 +315,13 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
|
||||
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
async def large_refund(amount: int, label: str):
|
||||
|
||||
@@ -9,11 +9,13 @@ This test ensures that:
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceCreateInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import UsageTransactionMetadata, UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -24,11 +26,14 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for migration tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
},
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -121,7 +126,9 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
|
||||
try:
|
||||
# Create UserBalance with specific value
|
||||
await UserBalance.prisma().create(
|
||||
data={"userId": user_id, "balance": 5000} # $50
|
||||
data=cast(
|
||||
UserBalanceCreateInput, {"userId": user_id, "balance": 5000}
|
||||
) # $50
|
||||
)
|
||||
|
||||
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
|
||||
@@ -160,7 +167,9 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
|
||||
|
||||
try:
|
||||
# Set initial balance in UserBalance
|
||||
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
|
||||
await UserBalance.prisma().create(
|
||||
data=cast(UserBalanceCreateInput, {"userId": user_id, "balance": 1000})
|
||||
)
|
||||
|
||||
# Run concurrent operations to ensure they all use UserBalance atomic operations
|
||||
async def concurrent_spend(amount: int, label: str):
|
||||
|
||||
@@ -111,7 +111,7 @@ def get_database_schema() -> str:
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
"""Execute raw SQL query with proper schema handling."""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||
schema_prefix = f"{schema}." if schema != "public" else ""
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
@@ -28,6 +28,7 @@ from prisma.models import (
|
||||
AgentNodeExecutionKeyValueData,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
AgentGraphExecutionUpdateManyMutationInput,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
@@ -35,7 +36,6 @@ from prisma.types import (
|
||||
AgentNodeExecutionKeyValueDataCreateInput,
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
AgentNodeExecutionWhereUniqueInput,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
||||
from pydantic.fields import Field
|
||||
@@ -383,7 +383,6 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
self,
|
||||
execution_context: ExecutionContext,
|
||||
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
@@ -391,7 +390,6 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip or set(),
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
@@ -711,37 +709,40 @@ async def create_graph_execution(
|
||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||
"""
|
||||
result = await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
),
|
||||
"nodesInputMasks": (
|
||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||
),
|
||||
"NodeExecutions": {
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
queuedTime=datetime.now(tz=timezone.utc),
|
||||
Input={
|
||||
"create": [
|
||||
{"name": name, "data": SafeJson(data)}
|
||||
for name, data in node_input.items()
|
||||
]
|
||||
},
|
||||
)
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
data=cast(
|
||||
AgentGraphExecutionCreateInput,
|
||||
{
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
),
|
||||
"nodesInputMasks": (
|
||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||
),
|
||||
"NodeExecutions": {
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
queuedTime=datetime.now(tz=timezone.utc),
|
||||
Input={
|
||||
"create": [
|
||||
{"name": name, "data": SafeJson(data)}
|
||||
for name, data in node_input.items()
|
||||
]
|
||||
},
|
||||
)
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
"parentGraphExecutionId": parent_graph_exec_id,
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
"parentGraphExecutionId": parent_graph_exec_id,
|
||||
},
|
||||
),
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -833,10 +834,13 @@ async def upsert_execution_output(
|
||||
"""
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
"""
|
||||
data: AgentNodeExecutionInputOutputCreateInput = {
|
||||
"name": output_name,
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
}
|
||||
data: AgentNodeExecutionInputOutputCreateInput = cast(
|
||||
AgentNodeExecutionInputOutputCreateInput,
|
||||
{
|
||||
"name": output_name,
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
},
|
||||
)
|
||||
if output_data is not None:
|
||||
data["data"] = SafeJson(output_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||
@@ -976,25 +980,30 @@ async def update_node_execution_status(
|
||||
f"Invalid status transition: {status} has no valid source statuses"
|
||||
)
|
||||
|
||||
if res := await AgentNodeExecution.prisma().update(
|
||||
where=cast(
|
||||
AgentNodeExecutionWhereUniqueInput,
|
||||
{
|
||||
"id": node_exec_id,
|
||||
"executionStatus": {"in": [s.value for s in allowed_from]},
|
||||
},
|
||||
),
|
||||
# First verify the current status allows this transition
|
||||
current_exec = await AgentNodeExecution.prisma().find_unique(
|
||||
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
|
||||
)
|
||||
|
||||
if not current_exec:
|
||||
raise ValueError(f"Execution {node_exec_id} not found.")
|
||||
|
||||
# Check if current status allows the requested transition
|
||||
if current_exec.executionStatus not in allowed_from:
|
||||
# Status transition not allowed, return current state without updating
|
||||
return NodeExecutionResult.from_db(current_exec)
|
||||
|
||||
# Status transition is valid, perform the update
|
||||
updated_exec = await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
data=_get_update_status_data(status, execution_data, stats),
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
):
|
||||
return NodeExecutionResult.from_db(res)
|
||||
)
|
||||
|
||||
if res := await AgentNodeExecution.prisma().find_unique(
|
||||
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
|
||||
):
|
||||
return NodeExecutionResult.from_db(res)
|
||||
if not updated_exec:
|
||||
raise ValueError(f"Failed to update execution {node_exec_id}.")
|
||||
|
||||
raise ValueError(f"Execution {node_exec_id} not found.")
|
||||
return NodeExecutionResult.from_db(updated_exec)
|
||||
|
||||
|
||||
def _get_update_status_data(
|
||||
@@ -1147,8 +1156,6 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None
|
||||
nodes_to_skip: set[str] = Field(default_factory=set)
|
||||
"""Node IDs that should be skipped due to optional credentials not being configured."""
|
||||
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
|
||||
|
||||
|
||||
|
||||
@@ -94,15 +94,6 @@ class Node(BaseDbModel):
|
||||
input_links: list[Link] = []
|
||||
output_links: list[Link] = []
|
||||
|
||||
@property
|
||||
def credentials_optional(self) -> bool:
|
||||
"""
|
||||
Whether credentials are optional for this node.
|
||||
When True and credentials are not configured, the node will be skipped
|
||||
during execution rather than causing a validation error.
|
||||
"""
|
||||
return self.metadata.get("credentials_optional", False)
|
||||
|
||||
@property
|
||||
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
|
||||
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
|
||||
@@ -335,35 +326,7 @@ class Graph(BaseGraph):
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
schema = self._credentials_input_schema.jsonschema()
|
||||
|
||||
# Determine which credential fields are required based on credentials_optional metadata
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
required_fields = []
|
||||
|
||||
# Build a map of node_id -> node for quick lookup
|
||||
all_nodes = {node.id: node for node in self.nodes}
|
||||
for sub_graph in self.sub_graphs:
|
||||
for node in sub_graph.nodes:
|
||||
all_nodes[node.id] = node
|
||||
|
||||
for field_key, (
|
||||
_field_info,
|
||||
node_field_pairs,
|
||||
) in graph_credentials_inputs.items():
|
||||
# A field is required if ANY node using it has credentials_optional=False
|
||||
is_required = False
|
||||
for node_id, _field_name in node_field_pairs:
|
||||
node = all_nodes.get(node_id)
|
||||
if node and not node.credentials_optional:
|
||||
is_required = True
|
||||
break
|
||||
|
||||
if is_required:
|
||||
required_fields.append(field_key)
|
||||
|
||||
schema["required"] = required_fields
|
||||
return schema
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
|
||||
@@ -6,14 +6,14 @@ import fastapi.exceptions
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.api.features.store.model as store
|
||||
from backend.api.model import CreateGraph
|
||||
import backend.server.v2.store.model as store
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.data.block import BlockSchema, BlockSchemaInput
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_user
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
@@ -396,58 +396,3 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
)
|
||||
assert got_graph is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_node_credentials_optional_default():
|
||||
"""Test that credentials_optional defaults to False when not set in metadata."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={},
|
||||
)
|
||||
assert node.credentials_optional is False
|
||||
|
||||
|
||||
def test_node_credentials_optional_true():
|
||||
"""Test that credentials_optional returns True when explicitly set."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={"credentials_optional": True},
|
||||
)
|
||||
assert node.credentials_optional is True
|
||||
|
||||
|
||||
def test_node_credentials_optional_false():
|
||||
"""Test that credentials_optional returns False when explicitly set to False."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={"credentials_optional": False},
|
||||
)
|
||||
assert node.credentials_optional is False
|
||||
|
||||
|
||||
def test_node_credentials_optional_with_other_metadata():
|
||||
"""Test that credentials_optional works correctly with other metadata present."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={
|
||||
"position": {"x": 100, "y": 200},
|
||||
"customized_name": "My Custom Node",
|
||||
"credentials_optional": True,
|
||||
},
|
||||
)
|
||||
assert node.credentials_optional is True
|
||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||
assert node.metadata["customized_name"] == "My Custom Node"
|
||||
|
||||
@@ -6,14 +6,14 @@ Handles all database operations for pending human reviews.
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
from prisma.models import PendingHumanReview
|
||||
from prisma.types import PendingHumanReviewUpdateInput
|
||||
from prisma.types import PendingHumanReviewUpdateInput, PendingHumanReviewUpsertInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.executions.review.model import (
|
||||
from backend.server.v2.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
SafeJsonData,
|
||||
)
|
||||
@@ -66,20 +66,23 @@ async def get_or_create_human_review(
|
||||
# Upsert - get existing or create new review
|
||||
review = await PendingHumanReview.prisma().upsert(
|
||||
where={"nodeExecId": node_exec_id},
|
||||
data={
|
||||
"create": {
|
||||
"userId": user_id,
|
||||
"nodeExecId": node_exec_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"graphId": graph_id,
|
||||
"graphVersion": graph_version,
|
||||
"payload": SafeJson(input_data),
|
||||
"instructions": message,
|
||||
"editable": editable,
|
||||
"status": ReviewStatus.WAITING,
|
||||
data=cast(
|
||||
PendingHumanReviewUpsertInput,
|
||||
{
|
||||
"create": {
|
||||
"userId": user_id,
|
||||
"nodeExecId": node_exec_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"graphId": graph_id,
|
||||
"graphVersion": graph_version,
|
||||
"payload": SafeJson(input_data),
|
||||
"instructions": message,
|
||||
"editable": editable,
|
||||
"status": ReviewStatus.WAITING,
|
||||
},
|
||||
"update": {}, # Do nothing on update - keep existing review as is
|
||||
},
|
||||
"update": {}, # Do nothing on update - keep existing review as is
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .db import BaseDbModel
|
||||
from .graph import NodeModel
|
||||
@@ -79,7 +79,7 @@ class WebhookWithRelations(Webhook):
|
||||
# integrations.py → library/model.py → integrations.py (for Webhook)
|
||||
# Runtime import is used in WebhookWithRelations.from_db() method instead
|
||||
# Import at runtime to avoid circular dependency
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
return WebhookWithRelations(
|
||||
**Webhook.from_db(webhook).model_dump(),
|
||||
@@ -285,8 +285,8 @@ async def unlink_webhook_from_graph(
|
||||
user_id: The ID of the user (for authorization)
|
||||
"""
|
||||
# Avoid circular imports
|
||||
from backend.api.features.library.db import set_preset_webhook
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.server.v2.library.db import set_preset_webhook
|
||||
|
||||
# Find all nodes in this graph that use this webhook
|
||||
nodes = await AgentNode.prisma().find_many(
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
from backend.api.model import NotificationPayload
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.server.model import NotificationPayload
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal, Optional, cast
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import prisma
|
||||
import pydantic
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
from prisma.types import (
|
||||
UserOnboardingCreateInput,
|
||||
UserOnboardingUpdateInput,
|
||||
UserOnboardingUpsertInput,
|
||||
)
|
||||
|
||||
from backend.api.features.store.model import StoreAgentDetails
|
||||
from backend.api.model import OnboardingNotificationPayload
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.notification_bus import (
|
||||
@@ -18,6 +20,8 @@ from backend.data.notification_bus import (
|
||||
NotificationEvent,
|
||||
)
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.server.model import OnboardingNotificationPayload
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.timezone_utils import get_user_timezone_or_utc
|
||||
@@ -112,10 +116,13 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
data=cast(
|
||||
UserOnboardingUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -442,8 +449,6 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
agentGraphVersions=agent.agentGraphVersions,
|
||||
agentGraphId=agent.agentGraphId,
|
||||
last_updated=agent.updated_at,
|
||||
)
|
||||
for agent in recommended_agents
|
||||
|
||||
@@ -1,429 +0,0 @@
|
||||
"""Data models and access layer for user business understanding."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import pydantic
|
||||
from prisma.models import UserBusinessUnderstanding
|
||||
from prisma.types import (
|
||||
UserBusinessUnderstandingCreateInput,
|
||||
UserBusinessUnderstandingUpdateInput,
|
||||
)
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache configuration
|
||||
CACHE_KEY_PREFIX = "understanding"
|
||||
CACHE_TTL_SECONDS = 48 * 60 * 60 # 48 hours
|
||||
|
||||
|
||||
def _cache_key(user_id: str) -> str:
|
||||
"""Generate cache key for user business understanding."""
|
||||
return f"{CACHE_KEY_PREFIX}:{user_id}"
|
||||
|
||||
|
||||
def _json_to_list(value: Any) -> list[str]:
|
||||
"""Convert Json field to list[str], handling None."""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return cast(list[str], value)
|
||||
return []
|
||||
|
||||
|
||||
class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
"""Input model for updating business understanding - all fields optional for incremental updates."""
|
||||
|
||||
# User info
|
||||
user_name: Optional[str] = pydantic.Field(None, description="The user's name")
|
||||
job_title: Optional[str] = pydantic.Field(None, description="The user's job title")
|
||||
|
||||
# Business basics
|
||||
business_name: Optional[str] = pydantic.Field(
|
||||
None, description="Name of the user's business"
|
||||
)
|
||||
industry: Optional[str] = pydantic.Field(None, description="Industry or sector")
|
||||
business_size: Optional[str] = pydantic.Field(
|
||||
None, description="Company size (e.g., '1-10', '11-50')"
|
||||
)
|
||||
user_role: Optional[str] = pydantic.Field(
|
||||
None,
|
||||
description="User's role in the organization (e.g., 'decision maker', 'implementer')",
|
||||
)
|
||||
|
||||
# Processes & activities
|
||||
key_workflows: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Key business workflows"
|
||||
)
|
||||
daily_activities: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Daily activities performed"
|
||||
)
|
||||
|
||||
# Pain points & goals
|
||||
pain_points: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Current pain points"
|
||||
)
|
||||
bottlenecks: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Process bottlenecks"
|
||||
)
|
||||
manual_tasks: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Manual/repetitive tasks"
|
||||
)
|
||||
automation_goals: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Desired automation goals"
|
||||
)
|
||||
|
||||
# Current tools
|
||||
current_software: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Software/tools currently used"
|
||||
)
|
||||
existing_automation: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Existing automations"
|
||||
)
|
||||
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = pydantic.Field(
|
||||
None, description="Any additional context"
|
||||
)
|
||||
|
||||
|
||||
class BusinessUnderstanding(pydantic.BaseModel):
|
||||
"""Full business understanding model returned from database."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# User info
|
||||
user_name: Optional[str] = None
|
||||
job_title: Optional[str] = None
|
||||
|
||||
# Business basics
|
||||
business_name: Optional[str] = None
|
||||
industry: Optional[str] = None
|
||||
business_size: Optional[str] = None
|
||||
user_role: Optional[str] = None
|
||||
|
||||
# Processes & activities
|
||||
key_workflows: list[str] = pydantic.Field(default_factory=list)
|
||||
daily_activities: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Pain points & goals
|
||||
pain_points: list[str] = pydantic.Field(default_factory=list)
|
||||
bottlenecks: list[str] = pydantic.Field(default_factory=list)
|
||||
manual_tasks: list[str] = pydantic.Field(default_factory=list)
|
||||
automation_goals: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Current tools
|
||||
current_software: list[str] = pydantic.Field(default_factory=list)
|
||||
existing_automation: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_record: UserBusinessUnderstanding) -> "BusinessUnderstanding":
|
||||
"""Convert database record to Pydantic model."""
|
||||
return cls(
|
||||
id=db_record.id,
|
||||
user_id=db_record.userId,
|
||||
created_at=db_record.createdAt,
|
||||
updated_at=db_record.updatedAt,
|
||||
user_name=db_record.userName,
|
||||
job_title=db_record.jobTitle,
|
||||
business_name=db_record.businessName,
|
||||
industry=db_record.industry,
|
||||
business_size=db_record.businessSize,
|
||||
user_role=db_record.userRole,
|
||||
key_workflows=_json_to_list(db_record.keyWorkflows),
|
||||
daily_activities=_json_to_list(db_record.dailyActivities),
|
||||
pain_points=_json_to_list(db_record.painPoints),
|
||||
bottlenecks=_json_to_list(db_record.bottlenecks),
|
||||
manual_tasks=_json_to_list(db_record.manualTasks),
|
||||
automation_goals=_json_to_list(db_record.automationGoals),
|
||||
current_software=_json_to_list(db_record.currentSoftware),
|
||||
existing_automation=_json_to_list(db_record.existingAutomation),
|
||||
additional_notes=db_record.additionalNotes,
|
||||
)
|
||||
|
||||
|
||||
def _merge_lists(existing: list | None, new: list | None) -> list | None:
|
||||
"""Merge two lists, removing duplicates while preserving order."""
|
||||
if new is None:
|
||||
return existing
|
||||
if existing is None:
|
||||
return new
|
||||
# Preserve order, add new items that don't exist
|
||||
merged = list(existing)
|
||||
for item in new:
|
||||
if item not in merged:
|
||||
merged.append(item)
|
||||
return merged
|
||||
|
||||
|
||||
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
|
||||
"""Get business understanding from Redis cache."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
cached_data = await redis.get(_cache_key(user_id))
|
||||
if cached_data:
|
||||
return BusinessUnderstanding.model_validate_json(cached_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get understanding from cache: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _set_cache(user_id: str, understanding: BusinessUnderstanding) -> None:
|
||||
"""Set business understanding in Redis cache with TTL."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.setex(
|
||||
_cache_key(user_id),
|
||||
CACHE_TTL_SECONDS,
|
||||
understanding.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set understanding in cache: {e}")
|
||||
|
||||
|
||||
async def _delete_cache(user_id: str) -> None:
|
||||
"""Delete business understanding from Redis cache."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(_cache_key(user_id))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete understanding from cache: {e}")
|
||||
|
||||
|
||||
async def get_business_understanding(
|
||||
user_id: str,
|
||||
) -> Optional[BusinessUnderstanding]:
|
||||
"""Get the business understanding for a user.
|
||||
|
||||
Checks cache first, falls back to database if not cached.
|
||||
Results are cached for 48 hours.
|
||||
"""
|
||||
# Try cache first
|
||||
cached = await _get_from_cache(user_id)
|
||||
if cached:
|
||||
logger.debug(f"Business understanding cache hit for user {user_id}")
|
||||
return cached
|
||||
|
||||
# Cache miss - load from database
|
||||
logger.debug(f"Business understanding cache miss for user {user_id}")
|
||||
record = await UserBusinessUnderstanding.prisma().find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
|
||||
# Store in cache for next time
|
||||
await _set_cache(user_id, understanding)
|
||||
|
||||
return understanding
|
||||
|
||||
|
||||
async def upsert_business_understanding(
|
||||
user_id: str,
|
||||
data: BusinessUnderstandingInput,
|
||||
) -> BusinessUnderstanding:
|
||||
"""
|
||||
Create or update business understanding with incremental merge strategy.
|
||||
|
||||
- String fields: new value overwrites if provided (not None)
|
||||
- List fields: new items are appended to existing (deduplicated)
|
||||
"""
|
||||
# Get existing record for merge
|
||||
existing = await UserBusinessUnderstanding.prisma().find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# Build update data with merge strategy
|
||||
update_data: UserBusinessUnderstandingUpdateInput = {}
|
||||
create_data: dict[str, Any] = {"userId": user_id}
|
||||
|
||||
# String fields - overwrite if provided
|
||||
if data.user_name is not None:
|
||||
update_data["userName"] = data.user_name
|
||||
create_data["userName"] = data.user_name
|
||||
if data.job_title is not None:
|
||||
update_data["jobTitle"] = data.job_title
|
||||
create_data["jobTitle"] = data.job_title
|
||||
if data.business_name is not None:
|
||||
update_data["businessName"] = data.business_name
|
||||
create_data["businessName"] = data.business_name
|
||||
if data.industry is not None:
|
||||
update_data["industry"] = data.industry
|
||||
create_data["industry"] = data.industry
|
||||
if data.business_size is not None:
|
||||
update_data["businessSize"] = data.business_size
|
||||
create_data["businessSize"] = data.business_size
|
||||
if data.user_role is not None:
|
||||
update_data["userRole"] = data.user_role
|
||||
create_data["userRole"] = data.user_role
|
||||
if data.additional_notes is not None:
|
||||
update_data["additionalNotes"] = data.additional_notes
|
||||
create_data["additionalNotes"] = data.additional_notes
|
||||
|
||||
# List fields - merge with existing
|
||||
if data.key_workflows is not None:
|
||||
existing_list = _json_to_list(existing.keyWorkflows) if existing else None
|
||||
merged = _merge_lists(existing_list, data.key_workflows)
|
||||
update_data["keyWorkflows"] = SafeJson(merged)
|
||||
create_data["keyWorkflows"] = SafeJson(merged)
|
||||
|
||||
if data.daily_activities is not None:
|
||||
existing_list = _json_to_list(existing.dailyActivities) if existing else None
|
||||
merged = _merge_lists(existing_list, data.daily_activities)
|
||||
update_data["dailyActivities"] = SafeJson(merged)
|
||||
create_data["dailyActivities"] = SafeJson(merged)
|
||||
|
||||
if data.pain_points is not None:
|
||||
existing_list = _json_to_list(existing.painPoints) if existing else None
|
||||
merged = _merge_lists(existing_list, data.pain_points)
|
||||
update_data["painPoints"] = SafeJson(merged)
|
||||
create_data["painPoints"] = SafeJson(merged)
|
||||
|
||||
if data.bottlenecks is not None:
|
||||
existing_list = _json_to_list(existing.bottlenecks) if existing else None
|
||||
merged = _merge_lists(existing_list, data.bottlenecks)
|
||||
update_data["bottlenecks"] = SafeJson(merged)
|
||||
create_data["bottlenecks"] = SafeJson(merged)
|
||||
|
||||
if data.manual_tasks is not None:
|
||||
existing_list = _json_to_list(existing.manualTasks) if existing else None
|
||||
merged = _merge_lists(existing_list, data.manual_tasks)
|
||||
update_data["manualTasks"] = SafeJson(merged)
|
||||
create_data["manualTasks"] = SafeJson(merged)
|
||||
|
||||
if data.automation_goals is not None:
|
||||
existing_list = _json_to_list(existing.automationGoals) if existing else None
|
||||
merged = _merge_lists(existing_list, data.automation_goals)
|
||||
update_data["automationGoals"] = SafeJson(merged)
|
||||
create_data["automationGoals"] = SafeJson(merged)
|
||||
|
||||
if data.current_software is not None:
|
||||
existing_list = _json_to_list(existing.currentSoftware) if existing else None
|
||||
merged = _merge_lists(existing_list, data.current_software)
|
||||
update_data["currentSoftware"] = SafeJson(merged)
|
||||
create_data["currentSoftware"] = SafeJson(merged)
|
||||
|
||||
if data.existing_automation is not None:
|
||||
existing_list = _json_to_list(existing.existingAutomation) if existing else None
|
||||
merged = _merge_lists(existing_list, data.existing_automation)
|
||||
update_data["existingAutomation"] = SafeJson(merged)
|
||||
create_data["existingAutomation"] = SafeJson(merged)
|
||||
|
||||
# Upsert
|
||||
record = await UserBusinessUnderstanding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": UserBusinessUnderstandingCreateInput(**create_data),
|
||||
"update": update_data,
|
||||
},
|
||||
)
|
||||
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
|
||||
# Update cache with new understanding
|
||||
await _set_cache(user_id, understanding)
|
||||
|
||||
return understanding
|
||||
|
||||
|
||||
async def clear_business_understanding(user_id: str) -> bool:
|
||||
"""Clear/delete business understanding for a user from both DB and cache."""
|
||||
# Delete from cache first
|
||||
await _delete_cache(user_id)
|
||||
|
||||
try:
|
||||
await UserBusinessUnderstanding.prisma().delete(where={"userId": user_id})
|
||||
return True
|
||||
except Exception:
|
||||
# Record might not exist
|
||||
return False
|
||||
|
||||
|
||||
def format_understanding_for_prompt(understanding: BusinessUnderstanding) -> str:
|
||||
"""Format business understanding as text for system prompt injection."""
|
||||
sections = []
|
||||
|
||||
# User info section
|
||||
user_info = []
|
||||
if understanding.user_name:
|
||||
user_info.append(f"Name: {understanding.user_name}")
|
||||
if understanding.job_title:
|
||||
user_info.append(f"Job Title: {understanding.job_title}")
|
||||
if user_info:
|
||||
sections.append("## User\n" + "\n".join(user_info))
|
||||
|
||||
# Business section
|
||||
business_info = []
|
||||
if understanding.business_name:
|
||||
business_info.append(f"Company: {understanding.business_name}")
|
||||
if understanding.industry:
|
||||
business_info.append(f"Industry: {understanding.industry}")
|
||||
if understanding.business_size:
|
||||
business_info.append(f"Size: {understanding.business_size}")
|
||||
if understanding.user_role:
|
||||
business_info.append(f"Role Context: {understanding.user_role}")
|
||||
if business_info:
|
||||
sections.append("## Business\n" + "\n".join(business_info))
|
||||
|
||||
# Processes section
|
||||
processes = []
|
||||
if understanding.key_workflows:
|
||||
processes.append(f"Key Workflows: {', '.join(understanding.key_workflows)}")
|
||||
if understanding.daily_activities:
|
||||
processes.append(
|
||||
f"Daily Activities: {', '.join(understanding.daily_activities)}"
|
||||
)
|
||||
if processes:
|
||||
sections.append("## Processes\n" + "\n".join(processes))
|
||||
|
||||
# Pain points section
|
||||
pain_points = []
|
||||
if understanding.pain_points:
|
||||
pain_points.append(f"Pain Points: {', '.join(understanding.pain_points)}")
|
||||
if understanding.bottlenecks:
|
||||
pain_points.append(f"Bottlenecks: {', '.join(understanding.bottlenecks)}")
|
||||
if understanding.manual_tasks:
|
||||
pain_points.append(f"Manual Tasks: {', '.join(understanding.manual_tasks)}")
|
||||
if pain_points:
|
||||
sections.append("## Pain Points\n" + "\n".join(pain_points))
|
||||
|
||||
# Goals section
|
||||
if understanding.automation_goals:
|
||||
sections.append(
|
||||
"## Automation Goals\n"
|
||||
+ "\n".join(f"- {goal}" for goal in understanding.automation_goals)
|
||||
)
|
||||
|
||||
# Current tools section
|
||||
tools_info = []
|
||||
if understanding.current_software:
|
||||
tools_info.append(
|
||||
f"Current Software: {', '.join(understanding.current_software)}"
|
||||
)
|
||||
if understanding.existing_automation:
|
||||
tools_info.append(
|
||||
f"Existing Automation: {', '.join(understanding.existing_automation)}"
|
||||
)
|
||||
if tools_info:
|
||||
sections.append("## Current Tools\n" + "\n".join(tools_info))
|
||||
|
||||
# Additional notes
|
||||
if understanding.additional_notes:
|
||||
sections.append(f"## Additional Context\n{understanding.additional_notes}")
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return "# User Business Context\n\n" + "\n\n".join(sections)
|
||||
@@ -2,11 +2,6 @@ import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.api.features.library.db import (
|
||||
add_store_agent_to_library,
|
||||
list_library_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -66,6 +61,8 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.server.v2.library.db import add_store_agent_to_library, list_library_agents
|
||||
from backend.server.v2.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
|
||||
@@ -48,8 +48,27 @@ from backend.data.notifications import (
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_ROUTING_KEY,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
@@ -76,24 +95,7 @@ from backend.util.retry import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .activity_status_generator import generate_activity_status_for_execution
|
||||
from .automod.manager import automod_manager
|
||||
from .cluster_lock import ClusterLock
|
||||
from .utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_ROUTING_KEY,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
validate_exec,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
@@ -114,40 +116,6 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
# Redis key prefix for tracking insufficient funds Discord notifications.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
|
||||
# TTL for the notification flag (30 days) - acts as a fallback cleanup
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
|
||||
|
||||
|
||||
async def clear_insufficient_funds_notifications(user_id: str) -> int:
|
||||
"""
|
||||
Clear all insufficient funds notification flags for a user.
|
||||
|
||||
This should be called when a user tops up their credits, allowing
|
||||
Discord notifications to be sent again if they run out of funds.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to clear notifications for.
|
||||
|
||||
Returns:
|
||||
The number of keys that were deleted.
|
||||
"""
|
||||
try:
|
||||
redis_client = await redis.get_redis_async()
|
||||
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
keys = [key async for key in redis_client.scan_iter(match=pattern)]
|
||||
if keys:
|
||||
return await redis_client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear insufficient funds notification flags for user "
|
||||
f"{user_id}: {e}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
@@ -178,7 +146,6 @@ async def execute_node(
|
||||
execution_processor: "ExecutionProcessor",
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -246,7 +213,6 @@ async def execute_node(
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
"execution_processor": execution_processor,
|
||||
"nodes_to_skip": nodes_to_skip or set(),
|
||||
}
|
||||
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
@@ -544,7 +510,6 @@ class ExecutionProcessor:
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
@@ -567,7 +532,6 @@ class ExecutionProcessor:
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
)
|
||||
if isinstance(status, BaseException):
|
||||
raise status
|
||||
@@ -613,7 +577,6 @@ class ExecutionProcessor:
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
@@ -650,7 +613,6 @@ class ExecutionProcessor:
|
||||
execution_processor=self,
|
||||
execution_stats=stats,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
):
|
||||
await persist_output(output_name, output_data)
|
||||
|
||||
@@ -962,21 +924,6 @@ class ExecutionProcessor:
|
||||
|
||||
queued_node_exec = execution_queue.get()
|
||||
|
||||
# Check if this node should be skipped due to optional credentials
|
||||
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
|
||||
log_metadata.info(
|
||||
f"Skipping node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id} - optional credentials not configured"
|
||||
)
|
||||
# Mark the node as completed without executing
|
||||
# No outputs will be produced, so downstream nodes won't trigger
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
exec_id=queued_node_exec.node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
)
|
||||
continue
|
||||
|
||||
log_metadata.debug(
|
||||
f"Dispatching node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id}",
|
||||
@@ -1037,7 +984,6 @@ class ExecutionProcessor:
|
||||
execution_stats,
|
||||
execution_stats_lock,
|
||||
),
|
||||
nodes_to_skip=graph_exec.nodes_to_skip,
|
||||
),
|
||||
self.node_execution_loop,
|
||||
)
|
||||
@@ -1317,40 +1263,12 @@ class ExecutionProcessor:
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
# Check if we've already sent a notification for this user+agent combo.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
try:
|
||||
redis_client = redis.get_redis()
|
||||
# SET NX returns True only if the key was newly set (didn't exist)
|
||||
is_new_notification = redis_client.set(
|
||||
redis_key,
|
||||
"1",
|
||||
nx=True,
|
||||
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
|
||||
)
|
||||
if not is_new_notification:
|
||||
# Already notified for this user+agent, skip all notifications
|
||||
logger.debug(
|
||||
f"Skipping duplicate insufficient funds notification for "
|
||||
f"user={user_id}, graph={graph_id}"
|
||||
)
|
||||
return
|
||||
except Exception as redis_error:
|
||||
# If Redis fails, log and continue to send the notification
|
||||
# (better to occasionally duplicate than to never notify)
|
||||
logger.warning(
|
||||
f"Failed to check/set insufficient funds notification flag in Redis: "
|
||||
f"{redis_error}"
|
||||
)
|
||||
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
|
||||
# Queue user email notification
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
@@ -1364,7 +1282,6 @@ class ExecutionProcessor:
|
||||
)
|
||||
)
|
||||
|
||||
# Send Discord system alert
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
|
||||
|
||||
@@ -1,560 +0,0 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import ZeroBalanceData
|
||||
from backend.executor.manager import (
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
|
||||
ExecutionProcessor,
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def async_iter(items):
|
||||
"""Helper to create an async iterator from a list."""
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_sends_discord_alert_first_time(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that the first insufficient funds notification sends a Discord alert."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72, # $0.72
|
||||
amount=-714, # Attempting to spend $7.14
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to simulate first-time notification (set returns True)
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.return_value = True # Key was newly set
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify notification was queued
|
||||
mock_queue_notif.assert_called_once()
|
||||
notification_call = mock_queue_notif.call_args[0][0]
|
||||
assert notification_call.type == NotificationType.ZERO_BALANCE
|
||||
assert notification_call.user_id == user_id
|
||||
assert isinstance(notification_call.data, ZeroBalanceData)
|
||||
assert notification_call.data.current_balance == 72
|
||||
|
||||
# Verify Redis was checked with correct key pattern
|
||||
expected_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
mock_redis_client.set.assert_called_once()
|
||||
call_args = mock_redis_client.set.call_args
|
||||
assert call_args[0][0] == expected_key
|
||||
assert call_args[1]["nx"] is True
|
||||
|
||||
# Verify Discord alert was sent
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
discord_message = mock_client.discord_system_alert.call_args[0][0]
|
||||
assert "Insufficient Funds Alert" in discord_message
|
||||
assert "test@example.com" in discord_message
|
||||
assert "Test Agent" in discord_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_skips_duplicate_notifications(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to simulate duplicate notification (set returns False/None)
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.return_value = None # Key already existed
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify email notification was NOT queued (deduplication worked)
|
||||
mock_queue_notif.assert_not_called()
|
||||
|
||||
# Verify Discord alert was NOT sent (deduplication worked)
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that different agents for the same user get separate Discord alerts."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id_1 = "test-graph-111"
|
||||
graph_id_2 = "test-graph-222"
|
||||
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch("backend.executor.manager.queue_notification"), patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
# Both calls return True (first time for each agent)
|
||||
mock_redis_client.set.return_value = True
|
||||
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# First agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_1,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Second agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_2,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify Discord alerts were sent for both agents
|
||||
assert mock_client.discord_system_alert.call_count == 2
|
||||
|
||||
# Verify Redis was called with different keys
|
||||
assert mock_redis_client.set.call_count == 2
|
||||
calls = mock_redis_client.set.call_args_list
|
||||
assert (
|
||||
calls[0][0][0]
|
||||
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_1}"
|
||||
)
|
||||
assert (
|
||||
calls[1][0][0]
|
||||
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_2}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
|
||||
"""Test that clearing notifications removes all keys for a user."""
|
||||
|
||||
user_id = "test-user-123"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
# Mock scan_iter to return some keys as an async iterator
|
||||
mock_keys = [
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1",
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-2",
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-3",
|
||||
]
|
||||
mock_redis_client.scan_iter.return_value = async_iter(mock_keys)
|
||||
# delete is awaited, so use AsyncMock
|
||||
mock_redis_client.delete = AsyncMock(return_value=3)
|
||||
|
||||
# Clear notifications
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify correct pattern was used
|
||||
expected_pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
mock_redis_client.scan_iter.assert_called_once_with(match=expected_pattern)
|
||||
|
||||
# Verify delete was called with all keys
|
||||
mock_redis_client.delete.assert_called_once_with(*mock_keys)
|
||||
|
||||
# Verify return value
|
||||
assert result == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestServer):
|
||||
"""Test clearing notifications when there are no keys to clear."""
|
||||
|
||||
user_id = "test-user-no-notifications"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
# Mock scan_iter to return no keys as an async iterator
|
||||
mock_redis_client.scan_iter.return_value = async_iter([])
|
||||
|
||||
# Clear notifications
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify delete was not called
|
||||
mock_redis_client.delete.assert_not_called()
|
||||
|
||||
# Verify return value
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications_handles_redis_error(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that clearing notifications handles Redis errors gracefully."""
|
||||
|
||||
user_id = "test-user-redis-error"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
# Mock get_redis_async to raise an error
|
||||
mock_redis_module.get_redis_async = AsyncMock(
|
||||
side_effect=Exception("Redis connection failed")
|
||||
)
|
||||
|
||||
# Clear notifications should not raise, just return 0
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify it returned 0 (graceful failure)
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_continues_on_redis_error(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that both email and Discord notifications are still sent when Redis fails."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to raise an error
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.side_effect = Exception("Redis connection error")
|
||||
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify email notification was still queued despite Redis error
|
||||
mock_queue_notif.assert_called_once()
|
||||
|
||||
# Verify Discord alert was still sent despite Redis error
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_clears_notifications_on_grant(server: SpinTestServer):
|
||||
"""Test that _add_transaction clears notification flags when adding GRANT credits."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-grant-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 1000, "transactionKey": "test-tx-key"}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter(
|
||||
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
|
||||
)
|
||||
mock_redis_client.delete = AsyncMock(return_value=1)
|
||||
|
||||
# Create a concrete instance
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with GRANT type (should clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500, # Positive amount
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
is_active=True, # Active transaction
|
||||
)
|
||||
|
||||
# Verify notification clearing was called
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
mock_redis_client.scan_iter.assert_called_once_with(
|
||||
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestServer):
|
||||
"""Test that _add_transaction clears notification flags when adding TOP_UP credits."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-topup-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 2000, "transactionKey": "test-tx-key-2"}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter([])
|
||||
mock_redis_client.delete = AsyncMock(return_value=0)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with TOP_UP type (should clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000, # Positive amount
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify notification clearing was attempted
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_skips_clearing_for_inactive_transaction(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that _add_transaction does NOT clear notifications for inactive transactions."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-inactive"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 500, "transactionKey": "test-tx-key-3"}]
|
||||
|
||||
# Mock async Redis
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with is_active=False (should NOT clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
is_active=False, # Inactive - pending Stripe payment
|
||||
)
|
||||
|
||||
# Verify notification clearing was NOT called
|
||||
mock_redis_module.get_redis_async.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_skips_clearing_for_usage_transaction(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that _add_transaction does NOT clear notifications for USAGE transactions."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-usage"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 400, "transactionKey": "test-tx-key-4"}]
|
||||
|
||||
# Mock async Redis
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with USAGE type (spending, should NOT clear)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=-100, # Negative - spending credits
|
||||
transaction_type=CreditTransactionType.USAGE,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify notification clearing was NOT called
|
||||
mock_redis_module.get_redis_async.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enable_transaction_clears_notifications(server: SpinTestServer):
|
||||
"""Test that _enable_transaction clears notification flags when enabling a TOP_UP."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-enable"
|
||||
|
||||
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
|
||||
"backend.data.credit.query_raw_with_schema"
|
||||
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
# Mock finding the pending transaction
|
||||
mock_transaction = MagicMock()
|
||||
mock_transaction.amount = 1000
|
||||
mock_transaction.type = CreditTransactionType.TOP_UP
|
||||
mock_credit_tx.prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_transaction
|
||||
)
|
||||
|
||||
# Mock the query to return updated balance
|
||||
mock_query.return_value = [{"balance": 1500}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter(
|
||||
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
|
||||
)
|
||||
mock_redis_client.delete = AsyncMock(return_value=1)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _enable_transaction (simulates Stripe checkout completion)
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
await credit_model._enable_transaction(
|
||||
transaction_key="cs_test_123",
|
||||
user_id=user_id,
|
||||
metadata=SafeJson({"payment": "completed"}),
|
||||
)
|
||||
|
||||
# Verify notification clearing was called
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
mock_redis_client.scan_iter.assert_called_once_with(
|
||||
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
)
|
||||
@@ -3,16 +3,16 @@ import logging
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
|
||||
import backend.api.features.library.model
|
||||
import backend.api.features.store.model
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.api.rest_api import AgentServer
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.model
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.data_manipulation import FindInDictionaryBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.data.model import User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
@@ -356,7 +356,7 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
|
||||
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
graph_id=test_graph.id,
|
||||
@@ -444,7 +444,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
|
||||
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
graph_id=test_graph.id,
|
||||
@@ -485,7 +485,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
|
||||
store_submission_request = backend.api.features.store.model.StoreSubmissionRequest(
|
||||
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
slug=test_graph.id,
|
||||
@@ -514,7 +514,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
|
||||
admin_user = await create_test_user(alt_user=True)
|
||||
await server.agent_server.test_review_store_listing(
|
||||
backend.api.features.store.model.ReviewSubmissionRequest(
|
||||
backend.server.v2.store.model.ReviewSubmissionRequest(
|
||||
store_listing_version_id=slv_id,
|
||||
is_approved=True,
|
||||
comments="Test comments",
|
||||
@@ -523,7 +523,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
)
|
||||
|
||||
# Add the approved store listing to the admin user's library so they can execute it
|
||||
from backend.api.features.library.db import add_store_agent_to_library
|
||||
from backend.server.v2.library.db import add_store_agent_to_library
|
||||
|
||||
await add_store_agent_to_library(
|
||||
store_listing_version_id=slv_id, user_id=admin_user.id
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.data import db
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
@@ -239,19 +239,14 @@ async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[dict[str, dict[str, str]], set[str]]:
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Checks all credentials for all nodes of the graph and returns structured errors
|
||||
and a set of nodes that should be skipped due to optional missing credentials.
|
||||
Checks all credentials for all nodes of the graph and returns structured errors.
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
|
||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
||||
]
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
|
||||
"""
|
||||
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
||||
nodes_to_skip: set[str] = set()
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
@@ -261,46 +256,27 @@ async def _validate_node_input_credentials(
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
# Track if any credential field is missing for this node
|
||||
has_missing_credentials = False
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
try:
|
||||
# Check nodes_input_masks first, then input_default
|
||||
field_value = None
|
||||
if (
|
||||
nodes_input_masks
|
||||
and (node_input_mask := nodes_input_masks.get(node.id))
|
||||
and field_name in node_input_mask
|
||||
):
|
||||
field_value = node_input_mask[field_name]
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node_input_mask[field_name]
|
||||
)
|
||||
elif field_name in node.input_default:
|
||||
# For optional credentials, don't use input_default - treat as missing
|
||||
# This prevents stale credential IDs from failing validation
|
||||
if node.credentials_optional:
|
||||
field_value = None
|
||||
else:
|
||||
field_value = node.input_default[field_name]
|
||||
|
||||
# Check if credentials are missing (None, empty, or not present)
|
||||
if field_value is None or (
|
||||
isinstance(field_value, dict) and not field_value.get("id")
|
||||
):
|
||||
has_missing_credentials = True
|
||||
# If node has credentials_optional flag, mark for skipping instead of error
|
||||
if node.credentials_optional:
|
||||
continue # Don't add error, will be marked for skip after loop
|
||||
else:
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
|
||||
credentials_meta = credentials_meta_type.model_validate(field_value)
|
||||
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
else:
|
||||
# Missing credentials
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
except ValidationError as e:
|
||||
# Validation error means credentials were provided but invalid
|
||||
# This should always be an error, even if optional
|
||||
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
|
||||
continue
|
||||
|
||||
@@ -311,7 +287,6 @@ async def _validate_node_input_credentials(
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any errors fetching credentials
|
||||
# If credentials were explicitly configured but unavailable, it's an error
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = f"Credentials not available: {e}"
|
||||
@@ -338,19 +313,7 @@ async def _validate_node_input_credentials(
|
||||
] = "Invalid credentials: type/provider mismatch"
|
||||
continue
|
||||
|
||||
# If node has optional credentials and any are missing, mark for skipping
|
||||
# But only if there are no other errors for this node
|
||||
if (
|
||||
has_missing_credentials
|
||||
and node.credentials_optional
|
||||
and node.id not in credential_errors
|
||||
):
|
||||
nodes_to_skip.add(node.id)
|
||||
logger.info(
|
||||
f"Node #{node.id} will be skipped: optional credentials not configured"
|
||||
)
|
||||
|
||||
return credential_errors, nodes_to_skip
|
||||
return credential_errors
|
||||
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
@@ -392,25 +355,21 @@ async def validate_graph_with_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
|
||||
) -> Mapping[str, Mapping[str, str]]:
|
||||
"""
|
||||
Validate graph including credentials and return structured errors per node,
|
||||
along with a set of nodes that should be skipped due to optional missing credentials.
|
||||
Validate graph including credentials and return structured errors per node.
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
|
||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
||||
]
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node
|
||||
"""
|
||||
# Get input validation errors
|
||||
node_input_errors = GraphModel.validate_graph_get_errors(
|
||||
graph, for_run=True, nodes_input_masks=nodes_input_masks
|
||||
)
|
||||
|
||||
# Get credential input/availability/validation errors and nodes to skip
|
||||
node_credential_input_errors, nodes_to_skip = (
|
||||
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
|
||||
# Get credential input/availability/validation errors
|
||||
node_credential_input_errors = await _validate_node_input_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
|
||||
# Merge credential errors with structural errors
|
||||
@@ -419,7 +378,7 @@ async def validate_graph_with_credentials(
|
||||
node_input_errors[node_id] = {}
|
||||
node_input_errors[node_id].update(field_errors)
|
||||
|
||||
return node_input_errors, nodes_to_skip
|
||||
return node_input_errors
|
||||
|
||||
|
||||
async def _construct_starting_node_execution_input(
|
||||
@@ -427,7 +386,7 @@ async def _construct_starting_node_execution_input(
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
This function checks the graph for starting nodes, validates the input data
|
||||
@@ -441,14 +400,11 @@ async def _construct_starting_node_execution_input(
|
||||
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
|
||||
and the corresponding input data for that node.
|
||||
set[str]: Node IDs that should be skipped (optional credentials not configured)
|
||||
]
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
# Use new validation function that includes credentials
|
||||
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
validation_errors = await validate_graph_with_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
n_error_nodes = len(validation_errors)
|
||||
@@ -489,7 +445,7 @@ async def _construct_starting_node_execution_input(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
return nodes_input, nodes_to_skip
|
||||
return nodes_input
|
||||
|
||||
|
||||
async def validate_and_construct_node_execution_input(
|
||||
@@ -500,7 +456,7 @@ async def validate_and_construct_node_execution_input(
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
is_sub_graph: bool = False,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
This centralizes the logic used by both scheduler validation and actual execution.
|
||||
@@ -517,7 +473,6 @@ async def validate_and_construct_node_execution_input(
|
||||
GraphModel: Full graph object for the given `graph_id`.
|
||||
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
||||
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
||||
set[str]: Node IDs that should be skipped (optional credentials not configured).
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the graph is not found.
|
||||
@@ -559,16 +514,14 @@ async def validate_and_construct_node_execution_input(
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
|
||||
starting_nodes_input, nodes_to_skip = (
|
||||
await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
starting_nodes_input = await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
|
||||
return graph, starting_nodes_input, nodes_input_masks
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
@@ -826,9 +779,6 @@ async def add_graph_execution(
|
||||
|
||||
# Use existing execution's compiled input masks
|
||||
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
|
||||
# For resumed executions, nodes_to_skip was already determined at creation time
|
||||
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
|
||||
nodes_to_skip: set[str] = set()
|
||||
|
||||
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
|
||||
else:
|
||||
@@ -837,7 +787,7 @@ async def add_graph_execution(
|
||||
)
|
||||
|
||||
# Create new execution
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks = (
|
||||
await validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
@@ -886,7 +836,6 @@ async def add_graph_execution(
|
||||
try:
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
||||
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||
|
||||
@@ -367,13 +367,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
)
|
||||
|
||||
# Setup mock returns
|
||||
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
|
||||
nodes_to_skip: set[str] = set()
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
nodes_to_skip,
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
@@ -459,212 +456,3 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
# Both executions should succeed (though they create different objects)
|
||||
assert result1 == mock_graph_exec
|
||||
assert result2 == mock_graph_exec_2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that _validate_node_input_credentials returns nodes_to_skip set
|
||||
for nodes with credentials_optional=True and missing credentials.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
# Create a mock node with credentials_optional=True
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-optional-creds"
|
||||
mock_node.credentials_optional = True
|
||||
mock_node.input_default = {} # No credentials configured
|
||||
|
||||
# Create a mock block with credentials field
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should be in nodes_to_skip, not in errors
|
||||
assert mock_node.id in nodes_to_skip
|
||||
assert mock_node.id not in errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_required_missing_creds_error(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that _validate_node_input_credentials returns errors
|
||||
for nodes with credentials_optional=False and missing credentials.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
# Create a mock node with credentials_optional=False (required)
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-required-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {} # No credentials configured
|
||||
|
||||
# Create a mock block with credentials field
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should be in errors, not in nodes_to_skip
|
||||
assert mock_node.id in errors
|
||||
assert "credentials" in errors[mock_node.id]
|
||||
assert "required" in errors[mock_node.id]["credentials"].lower()
|
||||
assert mock_node.id not in nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_with_credentials_returns_nodes_to_skip(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that validate_graph_with_credentials returns nodes_to_skip set
|
||||
from _validate_node_input_credentials.
|
||||
"""
|
||||
from backend.executor.utils import validate_graph_with_credentials
|
||||
|
||||
# Mock _validate_node_input_credentials to return specific values
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils._validate_node_input_credentials"
|
||||
)
|
||||
expected_errors = {"node1": {"field": "error"}}
|
||||
expected_nodes_to_skip = {"node2", "node3"}
|
||||
mock_validate.return_value = (expected_errors, expected_nodes_to_skip)
|
||||
|
||||
# Mock GraphModel with validate_graph_get_errors method
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.validate_graph_get_errors.return_value = {}
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Verify nodes_to_skip is passed through
|
||||
assert nodes_to_skip == expected_nodes_to_skip
|
||||
assert "node1" in errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
"""
|
||||
Test that add_graph_execution properly passes nodes_to_skip
|
||||
to the graph execution entry.
|
||||
"""
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
# Mock data
|
||||
graph_id = "test-graph-id"
|
||||
user_id = "test-user-id"
|
||||
inputs = {"test_input": "test_value"}
|
||||
graph_version = 1
|
||||
|
||||
# Mock the graph object
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.version = graph_version
|
||||
|
||||
# Starting nodes and masks
|
||||
starting_nodes_input = [("node1", {"input1": "value1"})]
|
||||
compiled_nodes_input_masks = {}
|
||||
nodes_to_skip = {"skipped-node-1", "skipped-node-2"}
|
||||
|
||||
# Mock the graph execution object
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = []
|
||||
|
||||
# Track what's passed to to_graph_execution_entry
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_to_entry(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return mocker.MagicMock()
|
||||
|
||||
mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry
|
||||
|
||||
# Setup mocks
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils.validate_and_construct_node_execution_input"
|
||||
)
|
||||
mock_edb = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_udb = mocker.patch("backend.executor.utils.user_db")
|
||||
mock_gdb = mocker.patch("backend.executor.utils.graph_db")
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
|
||||
# Setup returns - include nodes_to_skip in the tuple
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
nodes_to_skip, # This should be passed through
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
|
||||
|
||||
mock_user = mocker.MagicMock()
|
||||
mock_user.timezone = "UTC"
|
||||
mock_settings = mocker.MagicMock()
|
||||
mock_settings.human_in_the_loop_safe_mode = True
|
||||
|
||||
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
|
||||
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
|
||||
mock_get_queue.return_value = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock())
|
||||
|
||||
# Call the function
|
||||
await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
||||
assert "nodes_to_skip" in captured_kwargs
|
||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||
|
||||
@@ -149,10 +149,10 @@ async def setup_webhook_for_block(
|
||||
async def migrate_legacy_triggered_graphs():
|
||||
from prisma.models import AgentGraph
|
||||
|
||||
from backend.api.features.library.db import create_preset
|
||||
from backend.api.features.library.model import LibraryAgentPresetCreatable
|
||||
from backend.data.graph import AGENT_GRAPH_INCLUDE, GraphModel, set_node_webhook
|
||||
from backend.data.model import is_credentials_field_name
|
||||
from backend.server.v2.library.db import create_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPresetCreatable
|
||||
|
||||
triggered_graphs = [
|
||||
GraphModel.from_db(_graph)
|
||||
|
||||
@@ -49,11 +49,10 @@
|
||||
</p>
|
||||
<ol style="margin-bottom: 10px;">
|
||||
<li>
|
||||
Visit the Supabase Dashboard:
|
||||
https://supabase.com/dashboard/project/bgwpwdsxblryihinutbx/editor
|
||||
Connect to the database using your preferred database client.
|
||||
</li>
|
||||
<li>
|
||||
Navigate to the <strong>RefundRequest</strong> table.
|
||||
Navigate to the <strong>RefundRequest</strong> table in the <strong>platform</strong> schema.
|
||||
</li>
|
||||
<li>
|
||||
Filter the <code>transactionKey</code> column with the Transaction ID: <strong>{{ data.transaction_id }}</strong>.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.app import run_processes
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -6,7 +6,7 @@ Usage: from backend.sdk import *
|
||||
|
||||
This module provides:
|
||||
- All block base classes and types
|
||||
- All credential and authentication components
|
||||
- All credential and authentication components
|
||||
- All cost tracking components
|
||||
- All webhook components
|
||||
- All utility functions
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Integration between SDK provider costs and the execution cost system.
|
||||
|
||||
This module provides the glue between provider-defined base costs and the
|
||||
This module provides the glue between provider-defined base costs and the
|
||||
BLOCK_COSTS configuration used by the execution system.
|
||||
"""
|
||||
|
||||
|
||||
13
autogpt_platform/backend/backend/server/auth/__init__.py
Normal file
13
autogpt_platform/backend/backend/server/auth/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Authentication module for the AutoGPT Platform.
|
||||
|
||||
This module provides FastAPI-based authentication supporting:
|
||||
- Email/password authentication with bcrypt hashing
|
||||
- Google OAuth authentication
|
||||
- JWT token management (access + refresh tokens)
|
||||
"""
|
||||
|
||||
from .routes import router as auth_router
|
||||
from .service import AuthService
|
||||
|
||||
__all__ = ["auth_router", "AuthService"]
|
||||
170
autogpt_platform/backend/backend/server/auth/email.py
Normal file
170
autogpt_platform/backend/backend/server/auth/email.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Direct email sending for authentication flows.
|
||||
|
||||
This module bypasses the notification queue system to ensure auth emails
|
||||
(password reset, email verification) are sent immediately in all environments.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from postmarker.core import PostmarkClient
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# Template directory
|
||||
TEMPLATE_DIR = pathlib.Path(__file__).parent / "templates"
|
||||
|
||||
|
||||
class AuthEmailSender:
|
||||
"""Handles direct email sending for authentication flows."""
|
||||
|
||||
def __init__(self):
|
||||
if settings.secrets.postmark_server_api_token:
|
||||
self.postmark = PostmarkClient(
|
||||
server_token=settings.secrets.postmark_server_api_token
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Postmark server API token not found, auth email sending disabled"
|
||||
)
|
||||
self.postmark = None
|
||||
|
||||
# Set up Jinja2 environment for templates
|
||||
self.jinja_env: Optional[Environment] = None
|
||||
if TEMPLATE_DIR.exists():
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(str(TEMPLATE_DIR)),
|
||||
autoescape=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Auth email templates directory not found: {TEMPLATE_DIR}")
|
||||
|
||||
def _get_frontend_url(self) -> str:
|
||||
"""Get the frontend base URL for email links."""
|
||||
return (
|
||||
settings.config.frontend_base_url
|
||||
or settings.config.platform_base_url
|
||||
or "http://localhost:3000"
|
||||
)
|
||||
|
||||
def _render_template(
|
||||
self, template_name: str, subject: str, **context
|
||||
) -> tuple[str, str]:
|
||||
"""Render an email template with the base template wrapper."""
|
||||
if not self.jinja_env:
|
||||
raise RuntimeError("Email templates not available")
|
||||
|
||||
# Render the content template
|
||||
content_template = self.jinja_env.get_template(template_name)
|
||||
content = content_template.render(**context)
|
||||
|
||||
# Render with base template
|
||||
base_template = self.jinja_env.get_template("base.html.jinja2")
|
||||
html_body = base_template.render(
|
||||
data={"title": subject, "message": content, "unsubscribe_link": None}
|
||||
)
|
||||
|
||||
return subject, html_body
|
||||
|
||||
def _send_email(self, to_email: str, subject: str, html_body: str) -> bool:
|
||||
"""Send an email directly via Postmark."""
|
||||
if not self.postmark:
|
||||
logger.warning(
|
||||
f"Postmark not configured. Would send email to {to_email}: {subject}"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
self.postmark.emails.send( # type: ignore[attr-defined]
|
||||
From=settings.config.postmark_sender_email,
|
||||
To=to_email,
|
||||
Subject=subject,
|
||||
HtmlBody=html_body,
|
||||
)
|
||||
logger.info(f"Auth email sent to {to_email}: {subject}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send auth email to {to_email}: {e}")
|
||||
return False
|
||||
|
||||
def send_password_reset_email(
|
||||
self, to_email: str, reset_token: str, user_name: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send a password reset email.
|
||||
|
||||
Args:
|
||||
to_email: Recipient email address
|
||||
reset_token: The raw password reset token
|
||||
user_name: Optional user name for personalization
|
||||
|
||||
Returns:
|
||||
True if email was sent successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
frontend_url = self._get_frontend_url()
|
||||
reset_link = f"{frontend_url}/reset-password?token={reset_token}"
|
||||
|
||||
subject, html_body = self._render_template(
|
||||
"password_reset.html.jinja2",
|
||||
subject="Reset Your AutoGPT Password",
|
||||
reset_link=reset_link,
|
||||
user_name=user_name,
|
||||
frontend_url=frontend_url,
|
||||
)
|
||||
|
||||
return self._send_email(to_email, subject, html_body)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send password reset email to {to_email}: {e}")
|
||||
return False
|
||||
|
||||
def send_email_verification(
|
||||
self, to_email: str, verification_token: str, user_name: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send an email verification email.
|
||||
|
||||
Args:
|
||||
to_email: Recipient email address
|
||||
verification_token: The raw verification token
|
||||
user_name: Optional user name for personalization
|
||||
|
||||
Returns:
|
||||
True if email was sent successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
frontend_url = self._get_frontend_url()
|
||||
verification_link = (
|
||||
f"{frontend_url}/verify-email?token={verification_token}"
|
||||
)
|
||||
|
||||
subject, html_body = self._render_template(
|
||||
"email_verification.html.jinja2",
|
||||
subject="Verify Your AutoGPT Email",
|
||||
verification_link=verification_link,
|
||||
user_name=user_name,
|
||||
frontend_url=frontend_url,
|
||||
)
|
||||
|
||||
return self._send_email(to_email, subject, html_body)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send verification email to {to_email}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_auth_email_sender: Optional[AuthEmailSender] = None
|
||||
|
||||
|
||||
def get_auth_email_sender() -> AuthEmailSender:
|
||||
"""Get or create the auth email sender singleton."""
|
||||
global _auth_email_sender
|
||||
if _auth_email_sender is None:
|
||||
_auth_email_sender = AuthEmailSender()
|
||||
return _auth_email_sender
|
||||
505
autogpt_platform/backend/backend/server/auth/routes.py
Normal file
505
autogpt_platform/backend/backend/server/auth/routes.py
Normal file
@@ -0,0 +1,505 @@
|
||||
"""
|
||||
Authentication API routes.
|
||||
|
||||
Provides endpoints for:
|
||||
- User registration and login
|
||||
- Token refresh and logout
|
||||
- Password reset
|
||||
- Email verification
|
||||
- Google OAuth
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .email import get_auth_email_sender
|
||||
from .service import AuthService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# Singleton auth service instance
|
||||
_auth_service: Optional[AuthService] = None
|
||||
|
||||
# In-memory state storage for OAuth CSRF protection
|
||||
# Format: {state_token: {"created_at": timestamp, "redirect_uri": optional_uri}}
|
||||
# In production, use Redis for distributed state management
|
||||
_oauth_states: dict[str, dict] = {}
|
||||
_STATE_TTL_SECONDS = 600 # 10 minutes
|
||||
|
||||
|
||||
def _cleanup_expired_states() -> None:
|
||||
"""Remove expired OAuth states."""
|
||||
now = time.time()
|
||||
expired = [
|
||||
k
|
||||
for k, v in _oauth_states.items()
|
||||
if now - v["created_at"] > _STATE_TTL_SECONDS
|
||||
]
|
||||
for k in expired:
|
||||
del _oauth_states[k]
|
||||
|
||||
|
||||
def _generate_state() -> str:
|
||||
"""Generate a cryptographically secure state token."""
|
||||
_cleanup_expired_states()
|
||||
state = secrets.token_urlsafe(32)
|
||||
_oauth_states[state] = {"created_at": time.time()}
|
||||
return state
|
||||
|
||||
|
||||
def _validate_state(state: str) -> bool:
|
||||
"""Validate and consume a state token."""
|
||||
if state not in _oauth_states:
|
||||
return False
|
||||
state_data = _oauth_states.pop(state)
|
||||
if time.time() - state_data["created_at"] > _STATE_TTL_SECONDS:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_auth_service() -> AuthService:
|
||||
"""Get or create the auth service singleton."""
|
||||
global _auth_service
|
||||
if _auth_service is None:
|
||||
_auth_service = AuthService()
|
||||
return _auth_service
|
||||
|
||||
|
||||
# ============= Request/Response Models =============
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""Request model for user registration."""
|
||||
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=8)
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""Request model for user login."""
|
||||
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Response model for authentication tokens."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
"""Request model for token refresh."""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Request model for logout."""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
"""Request model for password reset request."""
|
||||
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
"""Request model for password reset confirmation."""
|
||||
|
||||
token: str
|
||||
new_password: str = Field(..., min_length=8)
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Generic message response."""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""Response model for user info."""
|
||||
|
||||
id: str
|
||||
email: str
|
||||
name: Optional[str]
|
||||
email_verified: bool
|
||||
role: str
|
||||
|
||||
|
||||
# ============= Auth Endpoints =============
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse)
|
||||
async def register(request: RegisterRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Register a new user with email and password.
|
||||
|
||||
Returns access and refresh tokens on successful registration.
|
||||
Sends a verification email in the background.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
try:
|
||||
user = await auth_service.register_user(
|
||||
email=request.email,
|
||||
password=request.password,
|
||||
name=request.name,
|
||||
)
|
||||
|
||||
# Create verification token and send email in background
|
||||
# This is non-critical - don't fail registration if email fails
|
||||
try:
|
||||
verification_token = await auth_service.create_email_verification_token(
|
||||
user.id
|
||||
)
|
||||
email_sender = get_auth_email_sender()
|
||||
background_tasks.add_task(
|
||||
email_sender.send_email_verification,
|
||||
to_email=user.email,
|
||||
verification_token=verification_token,
|
||||
user_name=user.name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to queue verification email for {user.email}: {e}")
|
||||
|
||||
tokens = await auth_service.create_tokens(user)
|
||||
return TokenResponse(**tokens)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(request: LoginRequest):
|
||||
"""
|
||||
Login with email and password.
|
||||
|
||||
Returns access and refresh tokens on successful authentication.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
user = await auth_service.authenticate_user(request.email, request.password)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid email or password")
|
||||
|
||||
tokens = await auth_service.create_tokens(user)
|
||||
return TokenResponse(**tokens)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=MessageResponse)
|
||||
async def logout(request: LogoutRequest):
|
||||
"""
|
||||
Logout by revoking the refresh token.
|
||||
|
||||
This invalidates the refresh token so it cannot be used to get new access tokens.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
revoked = await auth_service.revoke_refresh_token(request.refresh_token)
|
||||
if not revoked:
|
||||
raise HTTPException(status_code=400, detail="Invalid refresh token")
|
||||
|
||||
return MessageResponse(message="Successfully logged out")
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh_tokens(request: RefreshRequest):
|
||||
"""
|
||||
Refresh access token using a refresh token.
|
||||
|
||||
The old refresh token is invalidated and a new one is returned (token rotation).
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
tokens = await auth_service.refresh_access_token(request.refresh_token)
|
||||
if not tokens:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
|
||||
|
||||
return TokenResponse(**tokens)
|
||||
|
||||
|
||||
@router.post("/password-reset/request", response_model=MessageResponse)
|
||||
async def request_password_reset(
|
||||
request: PasswordResetRequest, background_tasks: BackgroundTasks
|
||||
):
|
||||
"""
|
||||
Request a password reset email.
|
||||
|
||||
Always returns success to prevent email enumeration attacks.
|
||||
If the email exists, a password reset email will be sent.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
user = await auth_service.get_user_by_email(request.email)
|
||||
if user:
|
||||
token = await auth_service.create_password_reset_token(user.id)
|
||||
email_sender = get_auth_email_sender()
|
||||
background_tasks.add_task(
|
||||
email_sender.send_password_reset_email,
|
||||
to_email=user.email,
|
||||
reset_token=token,
|
||||
user_name=user.name,
|
||||
)
|
||||
logger.info(f"Password reset email queued for user {user.id}")
|
||||
|
||||
# Always return success to prevent email enumeration
|
||||
return MessageResponse(
|
||||
message="If the email exists, a password reset link has been sent"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/password-reset/confirm", response_model=MessageResponse)
|
||||
async def confirm_password_reset(request: PasswordResetConfirm):
|
||||
"""
|
||||
Reset password using a password reset token.
|
||||
|
||||
All existing sessions (refresh tokens) will be invalidated.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
success = await auth_service.reset_password(request.token, request.new_password)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired reset token")
|
||||
|
||||
return MessageResponse(message="Password has been reset successfully")
|
||||
|
||||
|
||||
# ============= Email Verification Endpoints =============
|
||||
|
||||
|
||||
class EmailVerificationRequest(BaseModel):
|
||||
"""Request model for email verification."""
|
||||
|
||||
token: str
|
||||
|
||||
|
||||
class ResendVerificationRequest(BaseModel):
|
||||
"""Request model for resending verification email."""
|
||||
|
||||
email: EmailStr
|
||||
|
||||
|
||||
@router.post("/email/verify", response_model=MessageResponse)
|
||||
async def verify_email(request: EmailVerificationRequest):
|
||||
"""
|
||||
Verify email address using a verification token.
|
||||
|
||||
Marks the user's email as verified if the token is valid.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
success = await auth_service.verify_email_token(request.token)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid or expired verification token"
|
||||
)
|
||||
|
||||
return MessageResponse(message="Email verified successfully")
|
||||
|
||||
|
||||
@router.post("/email/resend-verification", response_model=MessageResponse)
|
||||
async def resend_verification_email(
|
||||
request: ResendVerificationRequest, background_tasks: BackgroundTasks
|
||||
):
|
||||
"""
|
||||
Resend email verification email.
|
||||
|
||||
Always returns success to prevent email enumeration attacks.
|
||||
If the email exists and is not verified, a new verification email will be sent.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
user = await auth_service.get_user_by_email(request.email)
|
||||
if user and not user.emailVerified:
|
||||
token = await auth_service.create_email_verification_token(user.id)
|
||||
email_sender = get_auth_email_sender()
|
||||
background_tasks.add_task(
|
||||
email_sender.send_email_verification,
|
||||
to_email=user.email,
|
||||
verification_token=token,
|
||||
user_name=user.name,
|
||||
)
|
||||
logger.info(f"Verification email queued for user {user.id}")
|
||||
|
||||
# Always return success to prevent email enumeration
|
||||
return MessageResponse(
|
||||
message="If the email exists and is not verified, a verification link has been sent"
|
||||
)
|
||||
|
||||
|
||||
# ============= Google OAuth Endpoints =============
|
||||
|
||||
# Google userinfo endpoint for fetching user profile
|
||||
GOOGLE_USERINFO_ENDPOINT = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
|
||||
class GoogleLoginResponse(BaseModel):
|
||||
"""Response model for Google OAuth login initiation."""
|
||||
|
||||
url: str
|
||||
|
||||
|
||||
def _get_google_oauth_handler():
|
||||
"""Get a configured GoogleOAuthHandler instance."""
|
||||
# Lazy import to avoid circular imports
|
||||
from backend.integrations.oauth.google import GoogleOAuthHandler
|
||||
|
||||
settings = Settings()
|
||||
|
||||
client_id = settings.secrets.google_client_id
|
||||
client_secret = settings.secrets.google_client_secret
|
||||
|
||||
if not client_id or not client_secret:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google OAuth is not configured. Set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET.",
|
||||
)
|
||||
|
||||
# Construct the redirect URI - this should point to the frontend's callback
|
||||
# which will then call our /auth/google/callback endpoint
|
||||
frontend_base_url = settings.config.frontend_base_url or "http://localhost:3000"
|
||||
redirect_uri = f"{frontend_base_url}/auth/callback"
|
||||
|
||||
return GoogleOAuthHandler(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/google/login", response_model=GoogleLoginResponse)
|
||||
async def google_login(request: Request):
|
||||
"""
|
||||
Initiate Google OAuth flow.
|
||||
|
||||
Returns the Google OAuth authorization URL to redirect the user to.
|
||||
"""
|
||||
try:
|
||||
handler = _get_google_oauth_handler()
|
||||
state = _generate_state()
|
||||
|
||||
# Get the authorization URL with default scopes (email, profile, openid)
|
||||
auth_url = handler.get_login_url(
|
||||
scopes=[], # Will use DEFAULT_SCOPES from handler
|
||||
state=state,
|
||||
code_challenge=None, # Not using PKCE for server-side flow
|
||||
)
|
||||
|
||||
logger.info(f"Generated Google OAuth URL for state: {state[:8]}...")
|
||||
return GoogleLoginResponse(url=auth_url)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Google OAuth: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to initiate Google OAuth")
|
||||
|
||||
|
||||
@router.get("/google/callback", response_model=TokenResponse)
|
||||
async def google_callback(request: Request, code: str, state: Optional[str] = None):
|
||||
"""
|
||||
Handle Google OAuth callback.
|
||||
|
||||
Exchanges the authorization code for user info and creates/updates the user.
|
||||
Returns access and refresh tokens.
|
||||
"""
|
||||
# Validate state to prevent CSRF attacks
|
||||
if not state or not _validate_state(state):
|
||||
logger.warning(
|
||||
f"Invalid or missing OAuth state: {state[:8] if state else 'None'}..."
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired OAuth state")
|
||||
|
||||
try:
|
||||
handler = _get_google_oauth_handler()
|
||||
|
||||
# Exchange the authorization code for Google credentials
|
||||
logger.info("Exchanging authorization code for tokens...")
|
||||
google_creds = await handler.exchange_code_for_tokens(
|
||||
code=code,
|
||||
scopes=[], # Will use the scopes from the initial request
|
||||
code_verifier=None,
|
||||
)
|
||||
|
||||
# The handler returns OAuth2Credentials with email in username field
|
||||
email = google_creds.username
|
||||
if not email:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Failed to retrieve email from Google"
|
||||
)
|
||||
|
||||
# Fetch full user info to get Google user ID and name
|
||||
# Lazy import to avoid circular imports
|
||||
from google.auth.transport.requests import AuthorizedSession
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
# We need to create Google Credentials object to use with AuthorizedSession
|
||||
creds = Credentials(
|
||||
token=google_creds.access_token.get_secret_value(),
|
||||
refresh_token=(
|
||||
google_creds.refresh_token.get_secret_value()
|
||||
if google_creds.refresh_token
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=handler.client_id,
|
||||
client_secret=handler.client_secret,
|
||||
)
|
||||
|
||||
session = AuthorizedSession(creds)
|
||||
userinfo_response = session.get(GOOGLE_USERINFO_ENDPOINT)
|
||||
|
||||
if not userinfo_response.ok:
|
||||
logger.error(
|
||||
f"Failed to fetch Google userinfo: {userinfo_response.status_code}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Failed to fetch user info from Google"
|
||||
)
|
||||
|
||||
userinfo = userinfo_response.json()
|
||||
google_id = userinfo.get("id")
|
||||
name = userinfo.get("name")
|
||||
email_verified = userinfo.get("verified_email", False)
|
||||
|
||||
if not google_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Failed to retrieve Google user ID"
|
||||
)
|
||||
|
||||
logger.info(f"Google OAuth successful for user: {email}")
|
||||
|
||||
# Create or update the user in our database
|
||||
auth_service = get_auth_service()
|
||||
user = await auth_service.create_or_update_google_user(
|
||||
google_id=google_id,
|
||||
email=email,
|
||||
name=name,
|
||||
email_verified=email_verified,
|
||||
)
|
||||
|
||||
# Generate our JWT tokens
|
||||
tokens = await auth_service.create_tokens(user)
|
||||
|
||||
return TokenResponse(**tokens)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Google OAuth callback failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to complete Google OAuth")
|
||||
499
autogpt_platform/backend/backend/server/auth/service.py
Normal file
499
autogpt_platform/backend/backend/server/auth/service.py
Normal file
@@ -0,0 +1,499 @@
|
||||
"""
|
||||
Core authentication service for password verification and token management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, cast
|
||||
|
||||
import bcrypt
|
||||
from autogpt_libs.auth.config import get_settings
|
||||
from autogpt_libs.auth.jwt_utils import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
hash_token,
|
||||
)
|
||||
from prisma.models import User as PrismaUser
|
||||
from prisma.types import (
|
||||
EmailVerificationTokenCreateInput,
|
||||
PasswordResetTokenCreateInput,
|
||||
ProfileCreateInput,
|
||||
RefreshTokenCreateInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
from backend.data.db import prisma
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Handles authentication operations including password verification and token management."""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
|
||||
def hash_password(self, password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
def verify_password(self, password: str, hashed: str) -> bool:
|
||||
"""Verify a password against a bcrypt hash."""
|
||||
try:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
except Exception as e:
|
||||
logger.warning(f"Password verification failed: {e}")
|
||||
return False
|
||||
|
||||
async def register_user(
|
||||
self,
|
||||
email: str,
|
||||
password: str,
|
||||
name: Optional[str] = None,
|
||||
) -> PrismaUser:
|
||||
"""
|
||||
Register a new user with email and password.
|
||||
|
||||
Creates both a User record and a Profile record.
|
||||
|
||||
:param email: User's email address
|
||||
:param password: User's password (will be hashed)
|
||||
:param name: Optional display name
|
||||
:return: Created user record
|
||||
:raises ValueError: If email is already registered
|
||||
"""
|
||||
# Check if user already exists
|
||||
existing = await prisma.user.find_unique(where={"email": email})
|
||||
if existing:
|
||||
raise ValueError("Email already registered")
|
||||
|
||||
password_hash = self.hash_password(password)
|
||||
|
||||
# Generate a unique username from email
|
||||
base_username = email.split("@")[0].lower()
|
||||
# Remove any characters that aren't alphanumeric or underscore
|
||||
base_username = re.sub(r"[^a-z0-9_]", "", base_username)
|
||||
if not base_username:
|
||||
base_username = "user"
|
||||
|
||||
# Check if username is unique, if not add a number suffix
|
||||
username = base_username
|
||||
counter = 1
|
||||
while await prisma.profile.find_unique(where={"username": username}):
|
||||
username = f"{base_username}{counter}"
|
||||
counter += 1
|
||||
|
||||
user = await prisma.user.create(
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"email": email,
|
||||
"passwordHash": password_hash,
|
||||
"name": name,
|
||||
"emailVerified": False,
|
||||
"role": "authenticated",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Create profile for the user
|
||||
display_name = name or base_username
|
||||
await prisma.profile.create(
|
||||
data=cast(
|
||||
ProfileCreateInput,
|
||||
{
|
||||
"userId": user.id,
|
||||
"name": display_name,
|
||||
"username": username,
|
||||
"description": "",
|
||||
"links": [],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Registered new user: {user.id} with profile username: {username}")
|
||||
return user
|
||||
|
||||
async def authenticate_user(
|
||||
self, email: str, password: str
|
||||
) -> Optional[PrismaUser]:
|
||||
"""
|
||||
Authenticate a user with email and password.
|
||||
|
||||
:param email: User's email address
|
||||
:param password: User's password
|
||||
:return: User record if authentication successful, None otherwise
|
||||
"""
|
||||
user = await prisma.user.find_unique(where={"email": email})
|
||||
|
||||
if not user:
|
||||
logger.debug(f"Authentication failed: user not found for email {email}")
|
||||
return None
|
||||
|
||||
if not user.passwordHash:
|
||||
logger.debug(
|
||||
f"Authentication failed: no password set for user {user.id} "
|
||||
"(likely OAuth-only user)"
|
||||
)
|
||||
return None
|
||||
|
||||
if self.verify_password(password, user.passwordHash):
|
||||
logger.debug(f"Authentication successful for user {user.id}")
|
||||
return user
|
||||
|
||||
logger.debug(f"Authentication failed: invalid password for user {user.id}")
|
||||
return None
|
||||
|
||||
async def create_tokens(self, user: PrismaUser) -> dict:
|
||||
"""
|
||||
Create access and refresh tokens for a user.
|
||||
|
||||
:param user: The user to create tokens for
|
||||
:return: Dictionary with access_token, refresh_token, token_type, and expires_in
|
||||
"""
|
||||
# Create access token
|
||||
access_token = create_access_token(
|
||||
user_id=user.id,
|
||||
email=user.email,
|
||||
role=user.role or "authenticated",
|
||||
email_verified=user.emailVerified,
|
||||
)
|
||||
|
||||
# Create and store refresh token
|
||||
raw_refresh_token, hashed_refresh_token = create_refresh_token()
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=self.settings.REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
|
||||
await prisma.refreshtoken.create(
|
||||
data=cast(
|
||||
RefreshTokenCreateInput,
|
||||
{
|
||||
"token": hashed_refresh_token,
|
||||
"userId": user.id,
|
||||
"expiresAt": expires_at,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Created tokens for user {user.id}")
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": raw_refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": self.settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
}
|
||||
|
||||
async def refresh_access_token(self, refresh_token: str) -> Optional[dict]:
|
||||
"""
|
||||
Refresh an access token using a refresh token.
|
||||
|
||||
Implements token rotation: the old refresh token is revoked and a new one is issued.
|
||||
|
||||
:param refresh_token: The refresh token
|
||||
:return: New tokens if successful, None if refresh token is invalid/expired
|
||||
"""
|
||||
hashed_token = hash_token(refresh_token)
|
||||
|
||||
# Find the refresh token
|
||||
stored_token = await prisma.refreshtoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"revokedAt": None,
|
||||
"expiresAt": {"gt": datetime.now(timezone.utc)},
|
||||
},
|
||||
include={"User": True},
|
||||
)
|
||||
|
||||
if not stored_token or not stored_token.User:
|
||||
logger.debug("Refresh token not found or expired")
|
||||
return None
|
||||
|
||||
# Revoke the old token (token rotation)
|
||||
await prisma.refreshtoken.update(
|
||||
where={"id": stored_token.id},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
logger.debug(f"Refreshed tokens for user {stored_token.User.id}")
|
||||
|
||||
# Create new tokens
|
||||
return await self.create_tokens(stored_token.User)
|
||||
|
||||
async def revoke_refresh_token(self, refresh_token: str) -> bool:
|
||||
"""
|
||||
Revoke a refresh token (logout).
|
||||
|
||||
:param refresh_token: The refresh token to revoke
|
||||
:return: True if token was found and revoked, False otherwise
|
||||
"""
|
||||
hashed_token = hash_token(refresh_token)
|
||||
|
||||
result = await prisma.refreshtoken.update_many(
|
||||
where={"token": hashed_token, "revokedAt": None},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
if result > 0:
|
||||
logger.debug("Refresh token revoked")
|
||||
return True
|
||||
|
||||
logger.debug("Refresh token not found or already revoked")
|
||||
return False
|
||||
|
||||
async def revoke_all_user_tokens(self, user_id: str) -> int:
|
||||
"""
|
||||
Revoke all refresh tokens for a user (logout from all devices).
|
||||
|
||||
:param user_id: The user's ID
|
||||
:return: Number of tokens revoked
|
||||
"""
|
||||
result = await prisma.refreshtoken.update_many(
|
||||
where={"userId": user_id, "revokedAt": None},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
logger.debug(f"Revoked {result} tokens for user {user_id}")
|
||||
return result
|
||||
|
||||
async def get_user_by_google_id(self, google_id: str) -> Optional[PrismaUser]:
|
||||
"""Get a user by their Google OAuth ID."""
|
||||
return await prisma.user.find_unique(where={"googleId": google_id})
|
||||
|
||||
async def get_user_by_email(self, email: str) -> Optional[PrismaUser]:
|
||||
"""Get a user by their email address."""
|
||||
return await prisma.user.find_unique(where={"email": email})
|
||||
|
||||
async def create_or_update_google_user(
|
||||
self,
|
||||
google_id: str,
|
||||
email: str,
|
||||
name: Optional[str] = None,
|
||||
email_verified: bool = False,
|
||||
) -> PrismaUser:
|
||||
"""
|
||||
Create or update a user from Google OAuth.
|
||||
|
||||
If a user with the Google ID exists, return them.
|
||||
If a user with the email exists but no Google ID, link the account.
|
||||
Otherwise, create a new user.
|
||||
|
||||
:param google_id: Google's unique user ID
|
||||
:param email: User's email from Google
|
||||
:param name: User's name from Google
|
||||
:param email_verified: Whether Google has verified the email
|
||||
:return: The user record
|
||||
"""
|
||||
# Check if user exists with this Google ID
|
||||
user = await self.get_user_by_google_id(google_id)
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Check if user exists with this email
|
||||
user = await self.get_user_by_email(email)
|
||||
if user:
|
||||
# Link Google account to existing user
|
||||
updated_user = await prisma.user.update(
|
||||
where={"id": user.id},
|
||||
data={
|
||||
"googleId": google_id,
|
||||
"emailVerified": email_verified or user.emailVerified,
|
||||
},
|
||||
)
|
||||
if updated_user:
|
||||
logger.info(f"Linked Google account to existing user {updated_user.id}")
|
||||
return updated_user
|
||||
return user
|
||||
|
||||
# Create new user with profile
|
||||
# Generate a unique username from email
|
||||
base_username = email.split("@")[0].lower()
|
||||
base_username = re.sub(r"[^a-z0-9_]", "", base_username)
|
||||
if not base_username:
|
||||
base_username = "user"
|
||||
|
||||
username = base_username
|
||||
counter = 1
|
||||
while await prisma.profile.find_unique(where={"username": username}):
|
||||
username = f"{base_username}{counter}"
|
||||
counter += 1
|
||||
|
||||
user = await prisma.user.create(
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"email": email,
|
||||
"googleId": google_id,
|
||||
"name": name,
|
||||
"emailVerified": email_verified,
|
||||
"role": "authenticated",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Create profile for the user
|
||||
display_name = name or base_username
|
||||
await prisma.profile.create(
|
||||
data=cast(
|
||||
ProfileCreateInput,
|
||||
{
|
||||
"userId": user.id,
|
||||
"name": display_name,
|
||||
"username": username,
|
||||
"description": "",
|
||||
"links": [],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created new user from Google OAuth: {user.id} with profile: {username}"
|
||||
)
|
||||
return user
|
||||
|
||||
async def create_password_reset_token(self, user_id: str) -> str:
|
||||
"""
|
||||
Create a password reset token for a user.
|
||||
|
||||
:param user_id: The user's ID
|
||||
:return: The raw token to send to the user
|
||||
"""
|
||||
raw_token, hashed_token = create_refresh_token() # Reuse token generation
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
await prisma.passwordresettoken.create(
|
||||
data=cast(
|
||||
PasswordResetTokenCreateInput,
|
||||
{
|
||||
"token": hashed_token,
|
||||
"userId": user_id,
|
||||
"expiresAt": expires_at,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return raw_token
|
||||
|
||||
async def create_email_verification_token(self, user_id: str) -> str:
|
||||
"""
|
||||
Create an email verification token for a user.
|
||||
|
||||
:param user_id: The user's ID
|
||||
:return: The raw token to send to the user
|
||||
"""
|
||||
raw_token, hashed_token = create_refresh_token() # Reuse token generation
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
|
||||
await prisma.emailverificationtoken.create(
|
||||
data=cast(
|
||||
EmailVerificationTokenCreateInput,
|
||||
{
|
||||
"token": hashed_token,
|
||||
"userId": user_id,
|
||||
"expiresAt": expires_at,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return raw_token
|
||||
|
||||
async def verify_email_token(self, token: str) -> bool:
|
||||
"""
|
||||
Verify an email verification token and mark the user's email as verified.
|
||||
|
||||
:param token: The raw token from the user
|
||||
:return: True if successful, False if token is invalid
|
||||
"""
|
||||
hashed_token = hash_token(token)
|
||||
|
||||
# Find and validate token
|
||||
stored_token = await prisma.emailverificationtoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"usedAt": None,
|
||||
"expiresAt": {"gt": datetime.now(timezone.utc)},
|
||||
}
|
||||
)
|
||||
|
||||
if not stored_token:
|
||||
return False
|
||||
|
||||
# Mark email as verified
|
||||
await prisma.user.update(
|
||||
where={"id": stored_token.userId},
|
||||
data={"emailVerified": True},
|
||||
)
|
||||
|
||||
# Mark token as used
|
||||
await prisma.emailverificationtoken.update(
|
||||
where={"id": stored_token.id},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
logger.info(f"Email verified for user {stored_token.userId}")
|
||||
return True
|
||||
|
||||
async def verify_password_reset_token(self, token: str) -> Optional[str]:
|
||||
"""
|
||||
Verify a password reset token and return the user ID.
|
||||
|
||||
:param token: The raw token from the user
|
||||
:return: User ID if valid, None otherwise
|
||||
"""
|
||||
hashed_token = hash_token(token)
|
||||
|
||||
stored_token = await prisma.passwordresettoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"usedAt": None,
|
||||
"expiresAt": {"gt": datetime.now(timezone.utc)},
|
||||
}
|
||||
)
|
||||
|
||||
if not stored_token:
|
||||
return None
|
||||
|
||||
return stored_token.userId
|
||||
|
||||
async def reset_password(self, token: str, new_password: str) -> bool:
|
||||
"""
|
||||
Reset a user's password using a password reset token.
|
||||
|
||||
:param token: The password reset token
|
||||
:param new_password: The new password
|
||||
:return: True if successful, False if token is invalid
|
||||
"""
|
||||
hashed_token = hash_token(token)
|
||||
|
||||
# Find and validate token
|
||||
stored_token = await prisma.passwordresettoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"usedAt": None,
|
||||
"expiresAt": {"gt": datetime.now(timezone.utc)},
|
||||
}
|
||||
)
|
||||
|
||||
if not stored_token:
|
||||
return False
|
||||
|
||||
# Update password
|
||||
password_hash = self.hash_password(new_password)
|
||||
await prisma.user.update(
|
||||
where={"id": stored_token.userId},
|
||||
data={"passwordHash": password_hash},
|
||||
)
|
||||
|
||||
# Mark token as used
|
||||
await prisma.passwordresettoken.update(
|
||||
where={"id": stored_token.id},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Revoke all refresh tokens for security
|
||||
await self.revoke_all_user_tokens(stored_token.userId)
|
||||
|
||||
logger.info(f"Password reset for user {stored_token.userId}")
|
||||
return True
|
||||
@@ -0,0 +1,302 @@
|
||||
{# Base Template for Auth Emails #}
|
||||
{# Template variables:
|
||||
data.message: the message to display in the email
|
||||
data.title: the title of the email
|
||||
data.unsubscribe_link: the link to unsubscribe from the email (optional for auth emails)
|
||||
#}
|
||||
<!doctype html>
|
||||
<html lang="ltr" xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office">
|
||||
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=yes">
|
||||
<meta name="format-detection" content="telephone=no, date=no, address=no, email=no, url=no">
|
||||
<meta name="x-apple-disable-message-reformatting">
|
||||
<!--[if !mso]>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<![endif]-->
|
||||
<!--[if mso]>
|
||||
<style>
|
||||
* { font-family: sans-serif !important; }
|
||||
</style>
|
||||
<noscript>
|
||||
<xml>
|
||||
<o:OfficeDocumentSettings>
|
||||
<o:PixelsPerInch>96</o:PixelsPerInch>
|
||||
</o:OfficeDocumentSettings>
|
||||
</xml>
|
||||
</noscript>
|
||||
<![endif]-->
|
||||
<style type="text/css">
|
||||
/* RESET STYLES */
|
||||
html,
|
||||
body {
|
||||
margin: 0 !important;
|
||||
padding: 0 !important;
|
||||
width: 100% !important;
|
||||
height: 100% !important;
|
||||
}
|
||||
|
||||
body {
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-moz-osx-font-smoothing: grayscale;
|
||||
text-rendering: optimizeLegibility;
|
||||
}
|
||||
|
||||
.document {
|
||||
margin: 0 !important;
|
||||
padding: 0 !important;
|
||||
width: 100% !important;
|
||||
}
|
||||
|
||||
img {
|
||||
border: 0;
|
||||
outline: none;
|
||||
text-decoration: none;
|
||||
-ms-interpolation-mode: bicubic;
|
||||
}
|
||||
|
||||
table {
|
||||
border-collapse: collapse;
|
||||
}
|
||||
|
||||
table,
|
||||
td {
|
||||
mso-table-lspace: 0pt;
|
||||
mso-table-rspace: 0pt;
|
||||
}
|
||||
|
||||
body,
|
||||
table,
|
||||
td,
|
||||
a {
|
||||
-webkit-text-size-adjust: 100%;
|
||||
-ms-text-size-adjust: 100%;
|
||||
}
|
||||
|
||||
h1,
|
||||
h2,
|
||||
h3,
|
||||
h4,
|
||||
h5,
|
||||
p {
|
||||
margin: 0;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
/* iOS BLUE LINKS */
|
||||
a[x-apple-data-detectors] {
|
||||
color: inherit !important;
|
||||
text-decoration: none !important;
|
||||
font-size: inherit !important;
|
||||
font-family: inherit !important;
|
||||
font-weight: inherit !important;
|
||||
line-height: inherit !important;
|
||||
}
|
||||
|
||||
/* ANDROID CENTER FIX */
|
||||
div[style*="margin: 16px 0;"] {
|
||||
margin: 0 !important;
|
||||
}
|
||||
|
||||
/* MEDIA QUERIES */
|
||||
@media all and (max-width:639px) {
|
||||
.wrapper {
|
||||
width: 100% !important;
|
||||
}
|
||||
|
||||
.container {
|
||||
width: 100% !important;
|
||||
min-width: 100% !important;
|
||||
padding: 0 !important;
|
||||
}
|
||||
|
||||
.row {
|
||||
padding-left: 20px !important;
|
||||
padding-right: 20px !important;
|
||||
}
|
||||
|
||||
.col-mobile {
|
||||
width: 20px !important;
|
||||
}
|
||||
|
||||
.col {
|
||||
display: block !important;
|
||||
width: 100% !important;
|
||||
}
|
||||
|
||||
.mobile-center {
|
||||
text-align: center !important;
|
||||
float: none !important;
|
||||
}
|
||||
|
||||
.mobile-mx-auto {
|
||||
margin: 0 auto !important;
|
||||
float: none !important;
|
||||
}
|
||||
|
||||
.mobile-left {
|
||||
text-align: center !important;
|
||||
float: left !important;
|
||||
}
|
||||
|
||||
.mobile-hide {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
.img {
|
||||
width: 100% !important;
|
||||
height: auto !important;
|
||||
}
|
||||
|
||||
.ml-btn {
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
}
|
||||
|
||||
.ml-btn-container {
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
<style type="text/css">
|
||||
@import url("https://assets.mlcdn.com/fonts-v2.css?version=1729862");
|
||||
</style>
|
||||
<style type="text/css">
|
||||
@media screen {
|
||||
body {
|
||||
font-family: 'Poppins', sans-serif;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
<title>{{data.title}}</title>
|
||||
</head>
|
||||
|
||||
<body style="margin: 0 !important; padding: 0 !important; background-color:#070629;">
|
||||
<div class="document" role="article" aria-roledescription="email" aria-label lang dir="ltr"
|
||||
style="background-color:#070629; line-height: 100%; font-size:medium; font-size:max(16px, 1rem);">
|
||||
<!-- Main Content -->
|
||||
<table width="100%" align="center" cellspacing="0" cellpadding="0" border="0">
|
||||
<tr>
|
||||
<td class="background" bgcolor="#070629" align="center" valign="top" style="padding: 0 8px;">
|
||||
<!-- Email Content -->
|
||||
<table class="container" align="center" width="640" cellpadding="0" cellspacing="0" border="0"
|
||||
style="max-width: 640px;">
|
||||
<tr>
|
||||
<td align="center">
|
||||
<!-- Logo Section -->
|
||||
<table class="container ml-4 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
|
||||
<tr>
|
||||
<td class="ml-default-border container" height="40" style="line-height: 40px; min-width: 640px;">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td class="row" align="center" style="padding: 0 50px;">
|
||||
<img
|
||||
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
|
||||
border="0" alt="" width="120" class="logo"
|
||||
style="max-width: 120px; display: inline-block;">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Main Content Section -->
|
||||
<table class="container ml-6 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
|
||||
<tr>
|
||||
<td class="row" style="padding: 0 50px;">
|
||||
{{data.message|safe}}
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Footer Section -->
|
||||
<table class="container ml-10 ml-default-border" width="640" bgcolor="#ffffff" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
|
||||
<tr>
|
||||
<td class="row" style="padding: 0 50px;">
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td height="20" style="line-height: 20px;"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<!-- Footer Content -->
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td class="col" align="left" valign="middle" width="120">
|
||||
<img
|
||||
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
|
||||
border="0" alt="" width="120" class="logo"
|
||||
style="max-width: 120px; display: inline-block;">
|
||||
</td>
|
||||
<td class="col" width="40" height="30" style="line-height: 30px;"></td>
|
||||
<td class="col mobile-left" align="right" valign="middle" width="250">
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td align="center" valign="middle" width="18" style="padding: 0 5px 0 0;">
|
||||
<a href="https://x.com/auto_gpt" target="blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/x.png"
|
||||
width="18" alt="x">
|
||||
</a>
|
||||
</td>
|
||||
<td align="center" valign="middle" width="18" style="padding: 0 5px;">
|
||||
<a href="https://discord.gg/autogpt" target="blank"
|
||||
style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/discord.png"
|
||||
width="18" alt="discord">
|
||||
</a>
|
||||
</td>
|
||||
<td align="center" valign="middle" width="18" style="padding: 0 0 0 5px;">
|
||||
<a href="https://agpt.co/" target="blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/website.png"
|
||||
width="18" alt="website">
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td height="15" style="line-height: 15px;"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" style="text-align: left!important;">
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 12px; line-height: 150%; display: inline-block; margin-bottom: 0;">
|
||||
This is an automated security email from AutoGPT. If you did not request this action, please ignore this email or contact support if you have concerns.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td height="20" style="line-height: 20px;"></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
@@ -0,0 +1,65 @@
|
||||
{# Email Verification Template #}
|
||||
{# Variables:
|
||||
verification_link: URL for email verification
|
||||
user_name: Optional user name for personalization
|
||||
frontend_url: Base frontend URL
|
||||
#}
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td height="30" style="line-height: 30px;"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">
|
||||
<h1 style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 28px; line-height: 125%; font-weight: bold; margin-bottom: 20px;">
|
||||
Verify Your Email
|
||||
</h1>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
{% if user_name %}Hi {{ user_name }},{% else %}Hi,{% endif %}
|
||||
</p>
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
Welcome to AutoGPT! Please verify your email address by clicking the button below:
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" style="padding: 20px 0;">
|
||||
<table border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td align="center" bgcolor="#4285F4" style="border-radius: 8px;">
|
||||
<a href="{{ verification_link }}" target="_blank"
|
||||
style="display: inline-block; padding: 16px 36px; font-family: 'Poppins', sans-serif; font-size: 16px; font-weight: 600; color: #ffffff; text-decoration: none; border-radius: 8px;">
|
||||
Verify Email
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
This link will expire in <strong>24 hours</strong>.
|
||||
</p>
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
If you didn't create an account with AutoGPT, you can safely ignore this email.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #888888; font-size: 14px; line-height: 165%; margin-bottom: 10px;">
|
||||
If the button doesn't work, copy and paste this link into your browser:
|
||||
</p>
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #4285F4; font-size: 14px; line-height: 165%; word-break: break-all;">
|
||||
<a href="{{ verification_link }}" style="color: #4285F4; text-decoration: underline;">{{ verification_link }}</a>
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td height="30" style="line-height: 30px;"></td>
|
||||
</tr>
|
||||
</table>
|
||||
@@ -0,0 +1,65 @@
|
||||
{# Password Reset Email Template #}
|
||||
{# Variables:
|
||||
reset_link: URL for password reset
|
||||
user_name: Optional user name for personalization
|
||||
frontend_url: Base frontend URL
|
||||
#}
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td height="30" style="line-height: 30px;"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">
|
||||
<h1 style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 28px; line-height: 125%; font-weight: bold; margin-bottom: 20px;">
|
||||
Reset Your Password
|
||||
</h1>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
{% if user_name %}Hi {{ user_name }},{% else %}Hi,{% endif %}
|
||||
</p>
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
We received a request to reset your password for your AutoGPT account. Click the button below to create a new password:
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" style="padding: 20px 0;">
|
||||
<table border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td align="center" bgcolor="#4285F4" style="border-radius: 8px;">
|
||||
<a href="{{ reset_link }}" target="_blank"
|
||||
style="display: inline-block; padding: 16px 36px; font-family: 'Poppins', sans-serif; font-size: 16px; font-weight: 600; color: #ffffff; text-decoration: none; border-radius: 8px;">
|
||||
Reset Password
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
This link will expire in <strong>1 hour</strong> for security reasons.
|
||||
</p>
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
If you didn't request a password reset, you can safely ignore this email. Your password will remain unchanged.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #888888; font-size: 14px; line-height: 165%; margin-bottom: 10px;">
|
||||
If the button doesn't work, copy and paste this link into your browser:
|
||||
</p>
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #4285F4; font-size: 14px; line-height: 165%; word-break: break-all;">
|
||||
<a href="{{ reset_link }}" style="color: #4285F4; text-decoration: underline;">{{ reset_link }}</a>
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td height="30" style="line-height: 30px;"></td>
|
||||
</tr>
|
||||
</table>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user