mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 08:38:09 -05:00
Compare commits
2 Commits
native-aut
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cfdccf966b | ||
|
|
8eadfb8f3a |
8
.github/copilot-instructions.md
vendored
8
.github/copilot-instructions.md
vendored
@@ -142,7 +142,7 @@ pnpm storybook # Start component development server
|
||||
### Security & Middleware
|
||||
|
||||
**Cache Protection**: Backend includes middleware preventing sensitive data caching in browsers/proxies
|
||||
**Authentication**: JWT-based with native authentication
|
||||
**Authentication**: JWT-based with Supabase integration
|
||||
**User ID Validation**: All data access requires user ID checks - verify this for any `data/*.py` changes
|
||||
|
||||
### Development Workflow
|
||||
@@ -168,9 +168,9 @@ pnpm storybook # Start component development server
|
||||
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/auth/` - Authentication client
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
|
||||
**Protected Routes**: Update `frontend/middleware.ts` when adding protected routes
|
||||
**Protected Routes**: Update `frontend/lib/supabase/middleware.ts` when adding protected routes
|
||||
|
||||
### Agent Block System
|
||||
|
||||
@@ -194,7 +194,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (shared) → `/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ jobs:
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const run = await github.rest.actions.getWorkflowRun({
|
||||
|
||||
6
.github/workflows/claude-dependabot.yml
vendored
6
.github/workflows/claude-dependabot.yml
vendored
@@ -144,7 +144,11 @@ jobs:
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"pgvector/pgvector:pg18"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
|
||||
6
.github/workflows/claude.yml
vendored
6
.github/workflows/claude.yml
vendored
@@ -160,7 +160,11 @@ jobs:
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"pgvector/pgvector:pg18"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
|
||||
6
.github/workflows/copilot-setup-steps.yml
vendored
6
.github/workflows/copilot-setup-steps.yml
vendored
@@ -142,7 +142,11 @@ jobs:
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"pgvector/pgvector:pg18"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
|
||||
44
.github/workflows/platform-backend-ci.yml
vendored
44
.github/workflows/platform-backend-ci.yml
vendored
@@ -2,13 +2,13 @@ name: AutoGPT Platform - Backend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev, ci-test*, native-auth]
|
||||
branches: [master, dev, ci-test*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
pull_request:
|
||||
branches: [master, dev, release-*, native-auth]
|
||||
branches: [master, dev, release-*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- "autogpt_platform/backend/**"
|
||||
@@ -36,19 +36,6 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg18
|
||||
ports:
|
||||
- 5432:5432
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: your-super-secret-and-long-postgres-password
|
||||
POSTGRES_DB: postgres
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U postgres"
|
||||
--health-interval 5s
|
||||
--health-timeout 5s
|
||||
--health-retries 10
|
||||
redis:
|
||||
image: redis:latest
|
||||
ports:
|
||||
@@ -91,6 +78,11 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup Supabase
|
||||
uses: supabase/setup-cli@v1
|
||||
with:
|
||||
version: 1.178.1
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
@@ -144,6 +136,16 @@ jobs:
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
working-directory: .
|
||||
run: |
|
||||
supabase init
|
||||
supabase start --exclude postgres-meta,realtime,storage-api,imgproxy,inbucket,studio,edge-runtime,logflare,vector,supavisor
|
||||
supabase status -o env | sed 's/="/=/; s/"$//' >> $GITHUB_OUTPUT
|
||||
# outputs:
|
||||
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
|
||||
|
||||
- name: Wait for ClamAV to be ready
|
||||
run: |
|
||||
echo "Waiting for ClamAV daemon to start..."
|
||||
@@ -176,8 +178,8 @@ jobs:
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
|
||||
DIRECT_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
@@ -193,9 +195,11 @@ jobs:
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
|
||||
DIRECT_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
|
||||
JWT_SECRET: your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Check comment permissions and deployment status
|
||||
id: check_status
|
||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const commentBody = context.payload.comment.body.trim();
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
|
||||
- name: Post permission denied comment
|
||||
if: steps.check_status.outputs.permission_denied == 'true'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
- name: Get PR details for deployment
|
||||
id: pr_details
|
||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const pr = await github.rest.pulls.get({
|
||||
@@ -98,7 +98,7 @@ jobs:
|
||||
|
||||
- name: Post deploy success comment
|
||||
if: steps.check_status.outputs.should_deploy == 'true'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -126,7 +126,7 @@ jobs:
|
||||
|
||||
- name: Post undeploy success comment
|
||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
- name: Check deployment status on PR close
|
||||
id: check_pr_close
|
||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const comments = await github.rest.issues.listComments({
|
||||
@@ -187,7 +187,7 @@ jobs:
|
||||
github.event_name == 'pull_request' &&
|
||||
github.event.action == 'closed' &&
|
||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
|
||||
5
.github/workflows/platform-frontend-ci.yml
vendored
5
.github/workflows/platform-frontend-ci.yml
vendored
@@ -2,12 +2,11 @@ name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev, native-auth]
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
pull_request:
|
||||
branches: [master, dev, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
@@ -148,7 +147,7 @@ jobs:
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default platform .env
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
|
||||
56
.github/workflows/platform-fullstack-ci.yml
vendored
56
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -1,13 +1,12 @@
|
||||
name: AutoGPT Platform - Fullstack CI
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev, native-auth]
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
branches: [master, dev, native-auth]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
@@ -59,11 +58,14 @@ jobs:
|
||||
types:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
@@ -73,6 +75,18 @@ jobs:
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
@@ -87,12 +101,36 @@ jobs:
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api
|
||||
run: pnpm generate:api:force
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
|
||||
- name: Run Typescript checks
|
||||
run: pnpm types
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
@@ -11,7 +11,7 @@ jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
# operations-per-run: 5000
|
||||
stale-issue-message: >
|
||||
|
||||
2
.github/workflows/repo-pr-label.yml
vendored
2
.github/workflows/repo-pr-label.yml
vendored
@@ -61,6 +61,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v6
|
||||
- uses: actions/labeler@v5
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
@@ -49,5 +49,5 @@ Use conventional commit messages for all commits (e.g. `feat(backend): add API`)
|
||||
- Keep out-of-scope changes under 20% of the PR.
|
||||
- Ensure PR descriptions are complete.
|
||||
- For changes touching `data/*.py`, validate user ID checks or explain why not needed.
|
||||
- If adding protected frontend routes, update `frontend/lib/auth/helpers.ts`.
|
||||
- If adding protected frontend routes, update `frontend/lib/supabase/middleware.ts`.
|
||||
- Use the linear ticket branch structure if given codex/open-1668-resume-dropped-runs
|
||||
|
||||
@@ -5,6 +5,12 @@
|
||||
|
||||
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
|
||||
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
DASHBOARD_USERNAME=supabase
|
||||
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
|
||||
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
|
||||
VAULT_ENC_KEY=your-encryption-key-32-chars-min
|
||||
|
||||
|
||||
############
|
||||
@@ -18,31 +24,100 @@ POSTGRES_PORT=5432
|
||||
|
||||
|
||||
############
|
||||
# Auth - Native authentication configuration
|
||||
# Supavisor -- Database pooler
|
||||
############
|
||||
POOLER_PROXY_PORT_TRANSACTION=6543
|
||||
POOLER_DEFAULT_POOL_SIZE=20
|
||||
POOLER_MAX_CLIENT_CONN=100
|
||||
POOLER_TENANT_ID=your-tenant-id
|
||||
|
||||
|
||||
############
|
||||
# API Proxy - Configuration for the Kong Reverse proxy.
|
||||
############
|
||||
|
||||
KONG_HTTP_PORT=8000
|
||||
KONG_HTTPS_PORT=8443
|
||||
|
||||
|
||||
############
|
||||
# API - Configuration for PostgREST.
|
||||
############
|
||||
|
||||
PGRST_DB_SCHEMAS=public,storage,graphql_public
|
||||
|
||||
|
||||
############
|
||||
# Auth - Configuration for the GoTrue authentication server.
|
||||
############
|
||||
|
||||
## General
|
||||
SITE_URL=http://localhost:3000
|
||||
ADDITIONAL_REDIRECT_URLS=
|
||||
JWT_EXPIRY=3600
|
||||
DISABLE_SIGNUP=false
|
||||
API_EXTERNAL_URL=http://localhost:8000
|
||||
|
||||
# JWT token configuration
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=15
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
JWT_ISSUER=autogpt-platform
|
||||
## Mailer Config
|
||||
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
|
||||
MAILER_URLPATHS_INVITE="/auth/v1/verify"
|
||||
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
|
||||
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
|
||||
|
||||
# Google OAuth (optional)
|
||||
GOOGLE_CLIENT_ID=
|
||||
GOOGLE_CLIENT_SECRET=
|
||||
## Email auth
|
||||
ENABLE_EMAIL_SIGNUP=true
|
||||
ENABLE_EMAIL_AUTOCONFIRM=false
|
||||
SMTP_ADMIN_EMAIL=admin@example.com
|
||||
SMTP_HOST=supabase-mail
|
||||
SMTP_PORT=2500
|
||||
SMTP_USER=fake_mail_user
|
||||
SMTP_PASS=fake_mail_password
|
||||
SMTP_SENDER_NAME=fake_sender
|
||||
ENABLE_ANONYMOUS_USERS=false
|
||||
|
||||
## Phone auth
|
||||
ENABLE_PHONE_SIGNUP=true
|
||||
ENABLE_PHONE_AUTOCONFIRM=true
|
||||
|
||||
|
||||
############
|
||||
# Email configuration (optional)
|
||||
# Studio - Configuration for the Dashboard
|
||||
############
|
||||
|
||||
SMTP_HOST=
|
||||
SMTP_PORT=587
|
||||
SMTP_USER=
|
||||
SMTP_PASS=
|
||||
SMTP_FROM_EMAIL=noreply@example.com
|
||||
STUDIO_DEFAULT_ORGANIZATION=Default Organization
|
||||
STUDIO_DEFAULT_PROJECT=Default Project
|
||||
|
||||
STUDIO_PORT=3000
|
||||
# replace if you intend to use Studio outside of localhost
|
||||
SUPABASE_PUBLIC_URL=http://localhost:8000
|
||||
|
||||
# Enable webp support
|
||||
IMGPROXY_ENABLE_WEBP_DETECTION=true
|
||||
|
||||
# Add your OpenAI API key to enable SQL Editor Assistant
|
||||
OPENAI_API_KEY=
|
||||
|
||||
|
||||
############
|
||||
# Functions - Configuration for Functions
|
||||
############
|
||||
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
|
||||
FUNCTIONS_VERIFY_JWT=false
|
||||
|
||||
|
||||
############
|
||||
# Logs - Configuration for Logflare
|
||||
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
|
||||
############
|
||||
|
||||
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Change vector.toml sinks to reflect this change
|
||||
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Docker socket location - this value will differ depending on your OS
|
||||
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
|
||||
|
||||
# Google Cloud Project details
|
||||
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
|
||||
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend load-store-agents
|
||||
|
||||
# Run just PostgreSQL + Redis + RabbitMQ + ClamAV
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
start-core:
|
||||
docker compose up -d deps
|
||||
|
||||
@@ -49,7 +49,7 @@ load-store-agents:
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@echo " start-core - Start just the core services (PostgreSQL, Redis, RabbitMQ, ClamAV) in background"
|
||||
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
|
||||
@echo " stop-core - Stop the core services"
|
||||
@echo " reset-db - Reset the database by deleting the volume"
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
|
||||
@@ -57,9 +57,6 @@ class APIKeySmith:
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
if not raw_key.startswith(self.PREFIX):
|
||||
raise ValueError("Key without 'agpt_' prefix would fail validation")
|
||||
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
@@ -16,37 +16,17 @@ ALGO_RECOMMENDATION = (
|
||||
"We highly recommend using an asymmetric algorithm such as ES256, "
|
||||
"because when leaked, a shared secret would allow anyone to "
|
||||
"forge valid tokens and impersonate users. "
|
||||
"More info: https://pyjwt.readthedocs.io/en/stable/algorithms.html"
|
||||
"More info: https://supabase.com/docs/guides/auth/signing-keys#choosing-the-right-signing-algorithm" # noqa
|
||||
)
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self):
|
||||
# JWT verification key (public key for asymmetric, shared secret for symmetric)
|
||||
self.JWT_VERIFY_KEY: str = os.getenv(
|
||||
"JWT_VERIFY_KEY", os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
).strip()
|
||||
|
||||
# JWT signing key (private key for asymmetric, shared secret for symmetric)
|
||||
# Falls back to JWT_VERIFY_KEY for symmetric algorithms like HS256
|
||||
self.JWT_SIGN_KEY: str = os.getenv("JWT_SIGN_KEY", self.JWT_VERIFY_KEY).strip()
|
||||
|
||||
self.JWT_ALGORITHM: str = os.getenv("JWT_SIGN_ALGORITHM", "HS256").strip()
|
||||
|
||||
# Token expiration settings
|
||||
self.ACCESS_TOKEN_EXPIRE_MINUTES: int = int(
|
||||
os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "15")
|
||||
)
|
||||
self.REFRESH_TOKEN_EXPIRE_DAYS: int = int(
|
||||
os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7")
|
||||
)
|
||||
|
||||
# JWT issuer claim
|
||||
self.JWT_ISSUER: str = os.getenv("JWT_ISSUER", "autogpt-platform").strip()
|
||||
|
||||
# JWT audience claim
|
||||
self.JWT_AUDIENCE: str = os.getenv("JWT_AUDIENCE", "authenticated").strip()
|
||||
|
||||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
@@ -20,57 +16,6 @@ bearer_jwt_auth = HTTPBearer(
|
||||
)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: str,
|
||||
email: str,
|
||||
role: str = "authenticated",
|
||||
email_verified: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a new JWT access token.
|
||||
|
||||
:param user_id: The user's unique identifier
|
||||
:param email: The user's email address
|
||||
:param role: The user's role (default: "authenticated")
|
||||
:param email_verified: Whether the user's email is verified
|
||||
:return: Encoded JWT token
|
||||
"""
|
||||
settings = get_settings()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"role": role,
|
||||
"email_verified": email_verified,
|
||||
"aud": settings.JWT_AUDIENCE,
|
||||
"iss": settings.JWT_ISSUER,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
"jti": str(uuid.uuid4()), # Unique token ID
|
||||
}
|
||||
|
||||
return jwt.encode(payload, settings.JWT_SIGN_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token() -> tuple[str, str]:
|
||||
"""
|
||||
Generate a new refresh token.
|
||||
|
||||
Returns a tuple of (raw_token, hashed_token).
|
||||
The raw token should be sent to the client.
|
||||
The hashed token should be stored in the database.
|
||||
"""
|
||||
raw_token = secrets.token_urlsafe(64)
|
||||
hashed_token = hashlib.sha256(raw_token.encode()).hexdigest()
|
||||
return raw_token, hashed_token
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash a token using SHA-256."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
async def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
@@ -107,19 +52,11 @@ def parse_jwt_token(token: str) -> dict[str, Any]:
|
||||
"""
|
||||
settings = get_settings()
|
||||
try:
|
||||
# Build decode options
|
||||
options = {
|
||||
"verify_aud": True,
|
||||
"verify_iss": bool(settings.JWT_ISSUER),
|
||||
}
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_VERIFY_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience=settings.JWT_AUDIENCE,
|
||||
issuer=settings.JWT_ISSUER if settings.JWT_ISSUER else None,
|
||||
options=options,
|
||||
audience="authenticated",
|
||||
)
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
|
||||
@@ -11,7 +11,6 @@ class User:
|
||||
email: str
|
||||
phone_number: str
|
||||
role: str
|
||||
email_verified: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload):
|
||||
@@ -19,6 +18,5 @@ class User:
|
||||
user_id=payload["sub"],
|
||||
email=payload.get("email", ""),
|
||||
phone_number=payload.get("phone", ""),
|
||||
role=payload.get("role", "authenticated"),
|
||||
email_verified=payload.get("email_verified", False),
|
||||
role=payload["role"],
|
||||
)
|
||||
|
||||
414
autogpt_platform/autogpt_libs/poetry.lock
generated
414
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -48,21 +48,6 @@ files = [
|
||||
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "authlib"
|
||||
version = "1.6.6"
|
||||
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd"},
|
||||
{file = "authlib-1.6.6.tar.gz", hash = "sha256:45770e8e056d0f283451d9996fbb59b70d45722b45d854d58f32878d0a40c38e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cryptography = "*"
|
||||
|
||||
[[package]]
|
||||
name = "backports-asyncio-runner"
|
||||
version = "1.2.0"
|
||||
@@ -76,71 +61,6 @@ files = [
|
||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bcrypt"
|
||||
version = "4.3.0"
|
||||
description = "Modern password hashing for your software and your servers"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-macosx_10_12_universal2.whl", hash = "sha256:f01e060f14b6b57bbb72fc5b4a83ac21c443c9a2ee708e04a10e9192f90a6281"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5eeac541cefd0bb887a371ef73c62c3cd78535e4887b310626036a7c0a817bb"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59e1aa0e2cd871b08ca146ed08445038f42ff75968c7ae50d2fdd7860ade2180"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:0042b2e342e9ae3d2ed22727c1262f76cc4f345683b5c1715f0250cf4277294f"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74a8d21a09f5e025a9a23e7c0fd2c7fe8e7503e4d356c0a2c1486ba010619f09"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:0142b2cb84a009f8452c8c5a33ace5e3dfec4159e7735f5afe9a4d50a8ea722d"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_34_aarch64.whl", hash = "sha256:12fa6ce40cde3f0b899729dbd7d5e8811cb892d31b6f7d0334a1f37748b789fd"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_34_x86_64.whl", hash = "sha256:5bd3cca1f2aa5dbcf39e2aa13dd094ea181f48959e1071265de49cc2b82525af"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:335a420cfd63fc5bc27308e929bee231c15c85cc4c496610ffb17923abf7f231"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:0e30e5e67aed0187a1764911af023043b4542e70a7461ad20e837e94d23e1d6c"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:3b8d62290ebefd49ee0b3ce7500f5dbdcf13b81402c05f6dafab9a1e1b27212f"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2ef6630e0ec01376f59a006dc72918b1bf436c3b571b80fa1968d775fa02fe7d"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-win32.whl", hash = "sha256:7a4be4cbf241afee43f1c3969b9103a41b40bcb3a3f467ab19f891d9bc4642e4"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5c1949bf259a388863ced887c7861da1df681cb2388645766c89fdfd9004c669"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-macosx_10_12_universal2.whl", hash = "sha256:f81b0ed2639568bf14749112298f9e4e2b28853dab50a8b357e31798686a036d"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:864f8f19adbe13b7de11ba15d85d4a428c7e2f344bac110f667676a0ff84924b"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e36506d001e93bffe59754397572f21bb5dc7c83f54454c990c74a468cd589e"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:842d08d75d9fe9fb94b18b071090220697f9f184d4547179b60734846461ed59"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7c03296b85cb87db865d91da79bf63d5609284fc0cab9472fdd8367bbd830753"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:62f26585e8b219cdc909b6a0069efc5e4267e25d4a3770a364ac58024f62a761"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:beeefe437218a65322fbd0069eb437e7c98137e08f22c4660ac2dc795c31f8bb"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:97eea7408db3a5bcce4a55d13245ab3fa566e23b4c67cd227062bb49e26c585d"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:191354ebfe305e84f344c5964c7cd5f924a3bfc5d405c75ad07f232b6dffb49f"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:41261d64150858eeb5ff43c753c4b216991e0ae16614a308a15d909503617732"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:33752b1ba962ee793fa2b6321404bf20011fe45b9afd2a842139de3011898fef"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:50e6e80a4bfd23a25f5c05b90167c19030cf9f87930f7cb2eacb99f45d1c3304"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-win32.whl", hash = "sha256:67a561c4d9fb9465ec866177e7aebcad08fe23aaf6fbd692a6fab69088abfc51"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-win_amd64.whl", hash = "sha256:584027857bc2843772114717a7490a37f68da563b3620f78a849bcb54dc11e62"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:0d3efb1157edebfd9128e4e46e2ac1a64e0c1fe46fb023158a407c7892b0f8c3"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08bacc884fd302b611226c01014eca277d48f0a05187666bca23aac0dad6fe24"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6746e6fec103fcd509b96bacdfdaa2fbde9a553245dbada284435173a6f1aef"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:afe327968aaf13fc143a56a3360cb27d4ad0345e34da12c7290f1b00b8fe9a8b"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d9af79d322e735b1fc33404b5765108ae0ff232d4b54666d46730f8ac1a43676"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f1e3ffa1365e8702dc48c8b360fef8d7afeca482809c5e45e653af82ccd088c1"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:3004df1b323d10021fda07a813fd33e0fd57bef0e9a480bb143877f6cba996fe"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:531457e5c839d8caea9b589a1bcfe3756b0547d7814e9ce3d437f17da75c32b0"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:17a854d9a7a476a89dcef6c8bd119ad23e0f82557afbd2c442777a16408e614f"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:6fb1fd3ab08c0cbc6826a2e0447610c6f09e983a281b919ed721ad32236b8b23"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e965a9c1e9a393b8005031ff52583cedc15b7884fce7deb8b0346388837d6cfe"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:79e70b8342a33b52b55d93b3a59223a844962bef479f6a0ea318ebbcadf71505"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-win32.whl", hash = "sha256:b4d4e57f0a63fd0b358eb765063ff661328f69a04494427265950c71b992a39a"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-win_amd64.whl", hash = "sha256:e53e074b120f2877a35cc6c736b8eb161377caae8925c17688bd46ba56daaa5b"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c950d682f0952bafcceaf709761da0a32a942272fad381081b51096ffa46cea1"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:107d53b5c67e0bbc3f03ebf5b030e0403d24dda980f8e244795335ba7b4a027d"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:b693dbb82b3c27a1604a3dff5bfc5418a7e6a781bb795288141e5f80cf3a3492"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:b6354d3760fcd31994a14c89659dee887f1351a06e5dac3c1142307172a79f90"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a839320bf27d474e52ef8cb16449bb2ce0ba03ca9f44daba6d93fa1d8828e48a"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:bdc6a24e754a555d7316fa4774e64c6c3997d27ed2d1964d55920c7c227bc4ce"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:55a935b8e9a1d2def0626c4269db3fcd26728cbff1e84f0341465c31c4ee56d8"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:57967b7a28d855313a963aaea51bf6df89f833db4320da458e5b3c5ab6d4c938"},
|
||||
{file = "bcrypt-4.3.0.tar.gz", hash = "sha256:3a3fd2204178b6d2adcf09cb4f6426ffef54762577a7c9b54c159008cb288c18"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["pytest (>=3.2.1,!=3.3.0)"]
|
||||
typecheck = ["mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "cachetools"
|
||||
version = "5.5.2"
|
||||
@@ -539,6 +459,21 @@ ssh = ["bcrypt (>=3.1.5)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==45.0.6)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test-randomorder = ["pytest-randomly"]
|
||||
|
||||
[[package]]
|
||||
name = "deprecation"
|
||||
version = "2.1.0"
|
||||
description = "A library to handle automated deprecations"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
|
||||
{file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = "*"
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.3.0"
|
||||
@@ -760,6 +695,23 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4
|
||||
[package.extras]
|
||||
grpc = ["grpcio (>=1.44.0,<2.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "gotrue"
|
||||
version = "2.12.3"
|
||||
description = "Python Client Library for Supabase Auth"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "gotrue-2.12.3-py3-none-any.whl", hash = "sha256:b1a3c6a5fe3f92e854a026c4c19de58706a96fd5fbdcc3d620b2802f6a46a26b"},
|
||||
{file = "gotrue-2.12.3.tar.gz", hash = "sha256:f874cf9d0b2f0335bfbd0d6e29e3f7aff79998cd1c14d2ad814db8c06cee3852"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
pydantic = ">=1.10,<3"
|
||||
pyjwt = ">=2.10.1,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "grpc-google-iam-v1"
|
||||
version = "0.14.2"
|
||||
@@ -870,6 +822,94 @@ files = [
|
||||
{file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "4.2.0"
|
||||
description = "Pure-Python HTTP/2 protocol implementation"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "h2-4.2.0-py3-none-any.whl", hash = "sha256:479a53ad425bb29af087f3458a61d30780bc818e4ebcf01f0b536ba916462ed0"},
|
||||
{file = "h2-4.2.0.tar.gz", hash = "sha256:c8a52129695e88b1a0578d8d2cc6842bbd79128ac685463b887ee278126ad01f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
hpack = ">=4.1,<5"
|
||||
hyperframe = ">=6.1,<7"
|
||||
|
||||
[[package]]
|
||||
name = "hpack"
|
||||
version = "4.1.0"
|
||||
description = "Pure-Python HPACK header encoding"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496"},
|
||||
{file = "hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpcore"
|
||||
version = "1.0.9"
|
||||
description = "A minimal low-level HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"},
|
||||
{file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = "*"
|
||||
h11 = ">=0.16"
|
||||
|
||||
[package.extras]
|
||||
asyncio = ["anyio (>=4.0,<5.0)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
trio = ["trio (>=0.22.0,<1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx"
|
||||
version = "0.28.1"
|
||||
description = "The next generation HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"},
|
||||
{file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = "*"
|
||||
certifi = "*"
|
||||
h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""}
|
||||
httpcore = "==1.*"
|
||||
idna = "*"
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
|
||||
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "hyperframe"
|
||||
version = "6.1.0"
|
||||
description = "Pure-Python HTTP/2 framing"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5"},
|
||||
{file = "hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.10"
|
||||
@@ -996,7 +1036,7 @@ version = "25.0"
|
||||
description = "Core utilities for Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
|
||||
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
|
||||
@@ -1018,6 +1058,24 @@ files = [
|
||||
dev = ["pre-commit", "tox"]
|
||||
testing = ["coverage", "pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "postgrest"
|
||||
version = "1.1.1"
|
||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "postgrest-1.1.1-py3-none-any.whl", hash = "sha256:98a6035ee1d14288484bfe36235942c5fb2d26af6d8120dfe3efbe007859251a"},
|
||||
{file = "postgrest-1.1.1.tar.gz", hash = "sha256:f3bb3e8c4602775c75c844a31f565f5f3dd584df4d36d683f0b67d01a86be322"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2.1.0,<3.0.0"
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
pydantic = ">=1.9,<3.0"
|
||||
strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "proto-plus"
|
||||
version = "1.26.1"
|
||||
@@ -1404,6 +1462,21 @@ pytest = ">=6.2.5"
|
||||
[package.extras]
|
||||
dev = ["pre-commit", "pytest-asyncio", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
description = "Extensions to the standard Python datetime module"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
|
||||
{file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
six = ">=1.5"
|
||||
|
||||
[[package]]
|
||||
name = "python-dotenv"
|
||||
version = "1.1.1"
|
||||
@@ -1419,6 +1492,22 @@ files = [
|
||||
[package.extras]
|
||||
cli = ["click (>=5.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "realtime"
|
||||
version = "2.5.3"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "realtime-2.5.3-py3-none-any.whl", hash = "sha256:eb0994636946eff04c4c7f044f980c8c633c7eb632994f549f61053a474ac970"},
|
||||
{file = "realtime-2.5.3.tar.gz", hash = "sha256:0587594f3bc1c84bf007ff625075b86db6528843e03250dc84f4f2808be3d99a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
typing-extensions = ">=4.14.0,<5.0.0"
|
||||
websockets = ">=11,<16"
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "6.2.0"
|
||||
@@ -1517,6 +1606,18 @@ files = [
|
||||
{file = "semver-3.0.4.tar.gz", hash = "sha256:afc7d8c584a5ed0a11033af086e8af226a9c0b206f313e0301f8dd7b6b589602"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "six"
|
||||
version = "1.17.0"
|
||||
description = "Python 2 and 3 compatibility utilities"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"},
|
||||
{file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.1"
|
||||
@@ -1548,6 +1649,76 @@ typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""
|
||||
[package.extras]
|
||||
full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"]
|
||||
|
||||
[[package]]
|
||||
name = "storage3"
|
||||
version = "0.12.0"
|
||||
description = "Supabase Storage client for Python."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "storage3-0.12.0-py3-none-any.whl", hash = "sha256:1c4585693ca42243ded1512b58e54c697111e91a20916cd14783eebc37e7c87d"},
|
||||
{file = "storage3-0.12.0.tar.gz", hash = "sha256:94243f20922d57738bf42e96b9f5582b4d166e8bf209eccf20b146909f3f71b0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2.1.0,<3.0.0"
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
python-dateutil = ">=2.8.2,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "strenum"
|
||||
version = "0.4.15"
|
||||
description = "An Enum that inherits from str."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"},
|
||||
{file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"]
|
||||
release = ["twine"]
|
||||
test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"]
|
||||
|
||||
[[package]]
|
||||
name = "supabase"
|
||||
version = "2.16.0"
|
||||
description = "Supabase client for Python."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase-2.16.0-py3-none-any.whl", hash = "sha256:99065caab3d90a56650bf39fbd0e49740995da3738ab28706c61bd7f2401db55"},
|
||||
{file = "supabase-2.16.0.tar.gz", hash = "sha256:98f3810158012d4ec0e3083f2e5515f5e10b32bd71e7d458662140e963c1d164"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
gotrue = ">=2.11.0,<3.0.0"
|
||||
httpx = ">=0.26,<0.29"
|
||||
postgrest = ">0.19,<1.2"
|
||||
realtime = ">=2.4.0,<2.6.0"
|
||||
storage3 = ">=0.10,<0.13"
|
||||
supafunc = ">=0.9,<0.11"
|
||||
|
||||
[[package]]
|
||||
name = "supafunc"
|
||||
version = "0.10.1"
|
||||
description = "Library for Supabase Functions"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supafunc-0.10.1-py3-none-any.whl", hash = "sha256:26df9bd25ff2ef56cb5bfb8962de98f43331f7f8ff69572bac3ed9c3a9672040"},
|
||||
{file = "supafunc-0.10.1.tar.gz", hash = "sha256:a5b33c8baecb6b5297d25da29a2503e2ec67ee6986f3d44c137e651b8a59a17d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
strenum = ">=0.4.15,<0.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.2.1"
|
||||
@@ -1656,6 +1827,85 @@ typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "websockets"
|
||||
version = "15.0.1"
|
||||
description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "websockets-15.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d63efaa0cd96cf0c5fe4d581521d9fa87744540d4bc999ae6e08595a1014b45b"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac60e3b188ec7574cb761b08d50fcedf9d77f1530352db4eef1707fe9dee7205"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5756779642579d902eed757b21b0164cd6fe338506a8083eb58af5c372e39d9a"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fdfe3e2a29e4db3659dbd5bbf04560cea53dd9610273917799f1cde46aa725e"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c2529b320eb9e35af0fa3016c187dffb84a3ecc572bcee7c3ce302bfeba52bf"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac1e5c9054fe23226fb11e05a6e630837f074174c4c2f0fe442996112a6de4fb"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5df592cd503496351d6dc14f7cdad49f268d8e618f80dce0cd5a36b93c3fc08d"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0a34631031a8f05657e8e90903e656959234f3a04552259458aac0b0f9ae6fd9"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3d00075aa65772e7ce9e990cab3ff1de702aa09be3940d1dc88d5abf1ab8a09c"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-win32.whl", hash = "sha256:1234d4ef35db82f5446dca8e35a7da7964d02c127b095e172e54397fb6a6c256"},
|
||||
{file = "websockets-15.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:39c1fec2c11dc8d89bba6b2bf1556af381611a173ac2b511cf7231622058af41"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:823c248b690b2fd9303ba00c4f66cd5e2d8c3ba4aa968b2779be9532a4dad431"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678999709e68425ae2593acf2e3ebcbcf2e69885a5ee78f9eb80e6e371f1bf57"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d50fd1ee42388dcfb2b3676132c78116490976f1300da28eb629272d5d93e905"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d99e5546bf73dbad5bf3547174cd6cb8ba7273062a23808ffea025ecb1cf8562"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66dd88c918e3287efc22409d426c8f729688d89a0c587c88971a0faa2c2f3792"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8dd8327c795b3e3f219760fa603dcae1dcc148172290a8ab15158cf85a953413"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8fdc51055e6ff4adeb88d58a11042ec9a5eae317a0a53d12c062c8a8865909e8"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:693f0192126df6c2327cce3baa7c06f2a117575e32ab2308f7f8216c29d9e2e3"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:54479983bd5fb469c38f2f5c7e3a24f9a4e70594cd68cd1fa6b9340dadaff7cf"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-win32.whl", hash = "sha256:16b6c1b3e57799b9d38427dda63edcbe4926352c47cf88588c0be4ace18dac85"},
|
||||
{file = "websockets-15.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:27ccee0071a0e75d22cb35849b1db43f2ecd3e161041ac1ee9d2352ddf72f065"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3e90baa811a5d73f3ca0bcbf32064d663ed81318ab225ee4f427ad4e26e5aff3"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:592f1a9fe869c778694f0aa806ba0374e97648ab57936f092fd9d87f8bc03665"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0701bc3cfcb9164d04a14b149fd74be7347a530ad3bbf15ab2c678a2cd3dd9a2"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8b56bdcdb4505c8078cb6c7157d9811a85790f2f2b3632c7d1462ab5783d215"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0af68c55afbd5f07986df82831c7bff04846928ea8d1fd7f30052638788bc9b5"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64dee438fed052b52e4f98f76c5790513235efaa1ef7f3f2192c392cd7c91b65"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d5f6b181bb38171a8ad1d6aa58a67a6aa9d4b38d0f8c5f496b9e42561dfc62fe"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5d54b09eba2bada6011aea5375542a157637b91029687eb4fdb2dab11059c1b4"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3be571a8b5afed347da347bfcf27ba12b069d9d7f42cb8c7028b5e98bbb12597"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-win32.whl", hash = "sha256:c338ffa0520bdb12fbc527265235639fb76e7bc7faafbb93f6ba80d9c06578a9"},
|
||||
{file = "websockets-15.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcd5cf9e305d7b8338754470cf69cf81f420459dbae8a3b40cee57417f4614a7"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee443ef070bb3b6ed74514f5efaa37a252af57c90eb33b956d35c8e9c10a1931"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5a939de6b7b4e18ca683218320fc67ea886038265fd1ed30173f5ce3f8e85675"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:746ee8dba912cd6fc889a8147168991d50ed70447bf18bcda7039f7d2e3d9151"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:595b6c3969023ecf9041b2936ac3827e4623bfa3ccf007575f04c5a6aa318c22"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c714d2fc58b5ca3e285461a4cc0c9a66bd0e24c5da9911e30158286c9b5be7f"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f3c1e2ab208db911594ae5b4f79addeb3501604a165019dd221c0bdcabe4db8"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:229cf1d3ca6c1804400b0a9790dc66528e08a6a1feec0d5040e8b9eb14422375"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:756c56e867a90fb00177d530dca4b097dd753cde348448a1012ed6c5131f8b7d"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:558d023b3df0bffe50a04e710bc87742de35060580a293c2a984299ed83bc4e4"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-win32.whl", hash = "sha256:ba9e56e8ceeeedb2e080147ba85ffcd5cd0711b89576b83784d8605a7df455fa"},
|
||||
{file = "websockets-15.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:e09473f095a819042ecb2ab9465aee615bd9c2028e4ef7d933600a8401c79561"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5f4c04ead5aed67c8a1a20491d54cdfba5884507a48dd798ecaf13c74c4489f5"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abdc0c6c8c648b4805c5eacd131910d2a7f6455dfd3becab248ef108e89ab16a"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a625e06551975f4b7ea7102bc43895b90742746797e2e14b70ed61c43a90f09b"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d591f8de75824cbb7acad4e05d2d710484f15f29d4a915092675ad3456f11770"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:47819cea040f31d670cc8d324bb6435c6f133b8c7a19ec3d61634e62f8d8f9eb"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac017dd64572e5c3bd01939121e4d16cf30e5d7e110a119399cf3133b63ad054"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4a9fac8e469d04ce6c25bb2610dc535235bd4aa14996b4e6dbebf5e007eba5ee"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:363c6f671b761efcb30608d24925a382497c12c506b51661883c3e22337265ed"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:2034693ad3097d5355bfdacfffcbd3ef5694f9718ab7f29c29689a9eae841880"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-win32.whl", hash = "sha256:3b1ac0d3e594bf121308112697cf4b32be538fb1444468fb0a6ae4feebc83411"},
|
||||
{file = "websockets-15.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7643a03db5c95c799b89b31c036d5f27eeb4d259c798e878d6937d71832b1e4"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0c9e74d766f2818bb95f84c25be4dea09841ac0f734d1966f415e4edfc4ef1c3"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1009ee0c7739c08a0cd59de430d6de452a55e42d6b522de7aa15e6f67db0b8e1"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d1f20b1c7a2fa82367e04982e708723ba0e7b8d43aa643d3dcd404d74f1475"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f29d80eb9a9263b8d109135351caf568cc3f80b9928bccde535c235de55c22d9"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b359ed09954d7c18bbc1680f380c7301f92c60bf924171629c5db97febb12f04"},
|
||||
{file = "websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7f493881579c90fc262d9cdbaa05a6b54b3811c2f300766748db79f098db9940"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:47b099e1f4fbc95b701b6e85768e1fcdaf1630f3cbe4765fa216596f12310e2e"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67f2b6de947f8c757db2db9c71527933ad0019737ec374a8a6be9a956786aaf9"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d08eb4c2b7d6c41da6ca0600c077e93f5adcfd979cd777d747e9ee624556da4b"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b826973a4a2ae47ba357e4e82fa44a463b8f168e1ca775ac64521442b19e87f"},
|
||||
{file = "websockets-15.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:21c1fa28a6a7e3cbdc171c694398b6df4744613ce9b36b1a498e816787e28123"},
|
||||
{file = "websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f"},
|
||||
{file = "websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zipp"
|
||||
version = "3.23.0"
|
||||
@@ -1679,4 +1929,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "de209c97aa0feb29d669a20e4422d51bdf3a0872ec37e85ce9b88ce726fcee7a"
|
||||
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
|
||||
|
||||
@@ -18,8 +18,7 @@ pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
bcrypt = "^4.1.0"
|
||||
authlib = "^1.3.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -27,15 +27,10 @@ REDIS_PORT=6379
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
|
||||
# JWT Authentication
|
||||
# Generate a secure random key: python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
JWT_SIGN_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
# Supabase Authentication
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
JWT_SIGN_ALGORITHM=HS256
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=15
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
JWT_ISSUER=autogpt-platform
|
||||
JWT_AUDIENCE=authenticated
|
||||
|
||||
## ===== REQUIRED SECURITY KEYS ===== ##
|
||||
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
|
||||
|
||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -18,6 +18,3 @@ load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
|
||||
# Migration backups (contain user data)
|
||||
migration_backups/
|
||||
|
||||
@@ -20,7 +20,6 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.request import Requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -247,11 +246,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise BlockExecutionError(
|
||||
message="Video creation timed out",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -427,11 +422,7 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise BlockExecutionError(
|
||||
message="Video creation timed out",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -608,11 +599,7 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise BlockExecutionError(
|
||||
message="Video creation timed out",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -106,10 +106,7 @@ class ConditionBlock(Block):
|
||||
ComparisonOperator.LESS_THAN_OR_EQUAL: lambda a, b: a <= b,
|
||||
}
|
||||
|
||||
try:
|
||||
result = comparison_funcs[operator](value1, value2)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Comparison failed: {e}") from e
|
||||
result = comparison_funcs[operator](value1, value2)
|
||||
|
||||
yield "result", result
|
||||
|
||||
|
||||
@@ -319,7 +319,7 @@ class CostDollars(BaseModel):
|
||||
|
||||
# Helper functions for payload processing
|
||||
def process_text_field(
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None],
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None]
|
||||
) -> Optional[Union[bool, Dict[str, Any]]]:
|
||||
"""Process text field for API payload."""
|
||||
if text is None:
|
||||
@@ -400,7 +400,7 @@ def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str,
|
||||
|
||||
|
||||
def process_context_field(
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None],
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None]
|
||||
) -> Optional[Union[bool, Dict[str, int]]]:
|
||||
"""Process context field for API payload."""
|
||||
if context is None:
|
||||
|
||||
@@ -15,7 +15,6 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
@@ -60,18 +59,11 @@ class FirecrawlExtractBlock(Block):
|
||||
) -> BlockOutput:
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
extract_result = app.extract(
|
||||
urls=input_data.urls,
|
||||
prompt=input_data.prompt,
|
||||
schema=input_data.output_schema,
|
||||
enable_web_search=input_data.enable_web_search,
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Extract failed: {e}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
extract_result = app.extract(
|
||||
urls=input_data.urls,
|
||||
prompt=input_data.prompt,
|
||||
schema=input_data.output_schema,
|
||||
enable_web_search=input_data.enable_web_search,
|
||||
)
|
||||
|
||||
yield "data", extract_result.data
|
||||
|
||||
@@ -19,7 +19,6 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import ModerationError
|
||||
from backend.util.file import MediaFileType, store_media_file
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -154,8 +153,6 @@ class AIImageEditorBlock(Block):
|
||||
),
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
seed=input_data.seed,
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
yield "output_image", result
|
||||
|
||||
@@ -167,8 +164,6 @@ class AIImageEditorBlock(Block):
|
||||
input_image_b64: Optional[str],
|
||||
aspect_ratio: str,
|
||||
seed: Optional[int],
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
input_params = {
|
||||
@@ -178,21 +173,11 @@ class AIImageEditorBlock(Block):
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
try:
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
except Exception as e:
|
||||
if "flagged as sensitive" in str(e).lower():
|
||||
raise ModerationError(
|
||||
message="Content was flagged as sensitive by the model provider",
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
moderation_type="model_provider",
|
||||
)
|
||||
raise ValueError(f"Model execution failed: {e}") from e
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
if isinstance(output, list) and output:
|
||||
output = output[0]
|
||||
|
||||
@@ -2,6 +2,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -331,8 +332,8 @@ class IdeogramModelBlock(Block):
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch image with V3 endpoint: {e}") from e
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image with V3 endpoint: {str(e)}")
|
||||
|
||||
async def _run_model_legacy(
|
||||
self,
|
||||
@@ -384,8 +385,8 @@ class IdeogramModelBlock(Block):
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch image with legacy endpoint: {e}") from e
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image with legacy endpoint: {str(e)}")
|
||||
|
||||
async def upscale_image(self, api_key: SecretStr, image_url: str):
|
||||
url = "https://api.ideogram.ai/upscale"
|
||||
@@ -412,5 +413,5 @@ class IdeogramModelBlock(Block):
|
||||
|
||||
return (response.json())["data"][0]["url"]
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to upscale image: {e}") from e
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to upscale image: {str(e)}")
|
||||
|
||||
@@ -16,7 +16,6 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
|
||||
class SearchTheWebBlock(Block, GetRequest):
|
||||
@@ -57,17 +56,7 @@ class SearchTheWebBlock(Block, GetRequest):
|
||||
|
||||
# Prepend the Jina Search URL to the encoded query
|
||||
jina_search_url = f"https://s.jina.ai/{encoded_query}"
|
||||
|
||||
try:
|
||||
results = await self.get_request(
|
||||
jina_search_url, headers=headers, json=False
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Search failed: {e}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
results = await self.get_request(jina_search_url, headers=headers, json=False)
|
||||
|
||||
# Output the search results
|
||||
yield "results", results
|
||||
|
||||
@@ -18,7 +18,6 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError, BlockInputError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -112,27 +111,9 @@ class ReplicateModelBlock(Block):
|
||||
yield "status", "succeeded"
|
||||
yield "model_name", input_data.model_name
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"Error running Replicate model: {error_msg}")
|
||||
|
||||
# Input validation errors (422, 400) → BlockInputError
|
||||
if (
|
||||
"422" in error_msg
|
||||
or "Input validation failed" in error_msg
|
||||
or "400" in error_msg
|
||||
):
|
||||
raise BlockInputError(
|
||||
message=f"Invalid model inputs: {error_msg}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
# Everything else → BlockExecutionError
|
||||
else:
|
||||
raise BlockExecutionError(
|
||||
message=f"Replicate model error: {error_msg}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
error_msg = f"Unexpected error running Replicate model: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
async def run_model(self, model_ref: str, model_inputs: dict, api_key: SecretStr):
|
||||
"""
|
||||
|
||||
@@ -45,16 +45,10 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
topic = input_data.topic
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||
|
||||
# Note: User-Agent is now automatically set by the request library
|
||||
# to comply with Wikimedia's robot policy (https://w.wiki/4wJS)
|
||||
try:
|
||||
response = await self.get_request(url, json=True)
|
||||
if "extract" not in response:
|
||||
raise ValueError(f"Unable to parse Wikipedia response: {response}")
|
||||
yield "summary", response["extract"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch Wikipedia summary: {e}") from e
|
||||
response = await self.get_request(url, json=True)
|
||||
if "extract" not in response:
|
||||
raise RuntimeError(f"Unable to parse Wikipedia response: {response}")
|
||||
yield "summary", response["extract"]
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""CLI utilities for backend development & administration"""
|
||||
@@ -1,57 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to generate OpenAPI JSON specification for the FastAPI app.
|
||||
|
||||
This script imports the FastAPI app from backend.server.rest_api and outputs
|
||||
the OpenAPI specification as JSON to stdout or a specified file.
|
||||
|
||||
Usage:
|
||||
`poetry run python generate_openapi_json.py`
|
||||
`poetry run python generate_openapi_json.py --output openapi.json`
|
||||
`poetry run python generate_openapi_json.py --indent 4 --output openapi.json`
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--output",
|
||||
type=click.Path(dir_okay=False, path_type=Path),
|
||||
help="Output file path (default: stdout)",
|
||||
)
|
||||
@click.option(
|
||||
"--pretty",
|
||||
type=click.BOOL,
|
||||
default=False,
|
||||
help="Pretty-print JSON output (indented 2 spaces)",
|
||||
)
|
||||
def main(output: Path, pretty: bool):
|
||||
"""Generate and output the OpenAPI JSON specification."""
|
||||
openapi_schema = get_openapi_schema()
|
||||
|
||||
json_output = json.dumps(openapi_schema, indent=2 if pretty else None)
|
||||
|
||||
if output:
|
||||
output.write_text(json_output)
|
||||
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
|
||||
click.echo(f"\n{json_output[:500]} ...")
|
||||
else:
|
||||
print(json_output)
|
||||
|
||||
|
||||
def get_openapi_schema():
|
||||
"""Get the OpenAPI schema from the FastAPI app"""
|
||||
from backend.server.rest_api import app
|
||||
|
||||
return app.openapi()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["LOG_LEVEL"] = "ERROR" # disable stdout log output
|
||||
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,24 +1,22 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||
from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput
|
||||
from pydantic import Field
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
|
||||
from .base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
keysmith = APIKeySmith()
|
||||
|
||||
|
||||
class APIKeyInfo(APIAuthorizationInfo):
|
||||
class APIKeyInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
head: str = Field(
|
||||
@@ -28,9 +26,12 @@ class APIKeyInfo(APIAuthorizationInfo):
|
||||
description=f"The last {APIKeySmith.TAIL_LENGTH} characters of the key"
|
||||
)
|
||||
status: APIKeyStatus
|
||||
permissions: list[APIKeyPermission]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
type: Literal["api_key"] = "api_key" # type: ignore
|
||||
user_id: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
@@ -40,7 +41,7 @@ class APIKeyInfo(APIAuthorizationInfo):
|
||||
head=api_key.head,
|
||||
tail=api_key.tail,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
scopes=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
@@ -82,20 +83,17 @@ async def create_api_key(
|
||||
generated_key = keysmith.generate_key()
|
||||
|
||||
saved_key_obj = await PrismaAPIKey.prisma().create(
|
||||
data=cast(
|
||||
APIKeyCreateInput,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
},
|
||||
)
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
}
|
||||
)
|
||||
|
||||
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
||||
@@ -213,7 +211,7 @@ async def suspend_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
|
||||
|
||||
def has_permission(api_key: APIKeyInfo, required_permission: APIKeyPermission) -> bool:
|
||||
return required_permission in api_key.scopes
|
||||
return required_permission in api_key.permissions
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyInfo]:
|
||||
@@ -1,15 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class APIAuthorizationInfo(BaseModel):
|
||||
user_id: str
|
||||
scopes: list[APIKeyPermission]
|
||||
type: Literal["oauth", "api_key"]
|
||||
created_at: datetime
|
||||
expires_at: Optional[datetime] = None
|
||||
last_used_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
@@ -1,886 +0,0 @@
|
||||
"""
|
||||
OAuth 2.0 Provider Data Layer
|
||||
|
||||
Handles management of OAuth applications, authorization codes,
|
||||
access tokens, and refresh tokens.
|
||||
|
||||
Hashing strategy:
|
||||
- Access tokens & Refresh tokens: SHA256 (deterministic, allows direct lookup by hash)
|
||||
- Client secrets: Scrypt with salt (lookup by client_id, then verify with salt)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal, Optional, cast
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission as APIPermission
|
||||
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
||||
from prisma.models import OAuthApplication as PrismaOAuthApplication
|
||||
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
|
||||
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
|
||||
from prisma.types import (
|
||||
OAuthAccessTokenCreateInput,
|
||||
OAuthApplicationUpdateInput,
|
||||
OAuthAuthorizationCodeCreateInput,
|
||||
OAuthRefreshTokenCreateInput,
|
||||
)
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from .base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
keysmith = APIKeySmith() # Only used for client secret hashing (Scrypt)
|
||||
|
||||
|
||||
def _generate_token() -> str:
|
||||
"""Generate a cryptographically secure random token."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def _hash_token(token: str) -> str:
|
||||
"""Hash a token using SHA256 (deterministic, for direct lookup)."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
# Token TTLs
|
||||
AUTHORIZATION_CODE_TTL = timedelta(minutes=10)
|
||||
ACCESS_TOKEN_TTL = timedelta(hours=1)
|
||||
REFRESH_TOKEN_TTL = timedelta(days=30)
|
||||
|
||||
ACCESS_TOKEN_PREFIX = "agpt_xt_"
|
||||
REFRESH_TOKEN_PREFIX = "agpt_rt_"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Exception Classes
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthError(Exception):
|
||||
"""Base OAuth error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidClientError(OAuthError):
|
||||
"""Invalid client_id or client_secret"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidGrantError(OAuthError):
|
||||
"""Invalid or expired authorization code/refresh token"""
|
||||
|
||||
def __init__(self, reason: str):
|
||||
self.reason = reason
|
||||
super().__init__(f"Invalid grant: {reason}")
|
||||
|
||||
|
||||
class InvalidTokenError(OAuthError):
|
||||
"""Invalid, expired, or revoked token"""
|
||||
|
||||
def __init__(self, reason: str):
|
||||
self.reason = reason
|
||||
super().__init__(f"Invalid token: {reason}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Data Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthApplicationInfo(BaseModel):
|
||||
"""OAuth application information (without client secret hash)"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
client_id: str
|
||||
redirect_uris: list[str]
|
||||
grant_types: list[str]
|
||||
scopes: list[APIPermission]
|
||||
owner_id: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@staticmethod
|
||||
def from_db(app: PrismaOAuthApplication):
|
||||
return OAuthApplicationInfo(
|
||||
id=app.id,
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
logo_url=app.logoUrl,
|
||||
client_id=app.clientId,
|
||||
redirect_uris=app.redirectUris,
|
||||
grant_types=app.grantTypes,
|
||||
scopes=[APIPermission(s) for s in app.scopes],
|
||||
owner_id=app.ownerId,
|
||||
is_active=app.isActive,
|
||||
created_at=app.createdAt,
|
||||
updated_at=app.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
class OAuthApplicationInfoWithSecret(OAuthApplicationInfo):
|
||||
"""OAuth application with client secret hash (for validation)"""
|
||||
|
||||
client_secret_hash: str
|
||||
client_secret_salt: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(app: PrismaOAuthApplication):
|
||||
return OAuthApplicationInfoWithSecret(
|
||||
**OAuthApplicationInfo.from_db(app).model_dump(),
|
||||
client_secret_hash=app.clientSecret,
|
||||
client_secret_salt=app.clientSecretSalt,
|
||||
)
|
||||
|
||||
def verify_secret(self, plaintext_secret: str) -> bool:
|
||||
"""Verify a plaintext client secret against the stored hash"""
|
||||
# Use keysmith.verify_key() with stored salt
|
||||
return keysmith.verify_key(
|
||||
plaintext_secret, self.client_secret_hash, self.client_secret_salt
|
||||
)
|
||||
|
||||
|
||||
class OAuthAuthorizationCodeInfo(BaseModel):
|
||||
"""Authorization code information"""
|
||||
|
||||
id: str
|
||||
code: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
application_id: str
|
||||
user_id: str
|
||||
scopes: list[APIPermission]
|
||||
redirect_uri: str
|
||||
code_challenge: Optional[str] = None
|
||||
code_challenge_method: Optional[str] = None
|
||||
used_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def is_used(self) -> bool:
|
||||
return self.used_at is not None
|
||||
|
||||
@staticmethod
|
||||
def from_db(code: PrismaOAuthAuthorizationCode):
|
||||
return OAuthAuthorizationCodeInfo(
|
||||
id=code.id,
|
||||
code=code.code,
|
||||
created_at=code.createdAt,
|
||||
expires_at=code.expiresAt,
|
||||
application_id=code.applicationId,
|
||||
user_id=code.userId,
|
||||
scopes=[APIPermission(s) for s in code.scopes],
|
||||
redirect_uri=code.redirectUri,
|
||||
code_challenge=code.codeChallenge,
|
||||
code_challenge_method=code.codeChallengeMethod,
|
||||
used_at=code.usedAt,
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccessTokenInfo(APIAuthorizationInfo):
|
||||
"""Access token information"""
|
||||
|
||||
id: str
|
||||
expires_at: datetime # type: ignore
|
||||
application_id: str
|
||||
|
||||
type: Literal["oauth"] = "oauth" # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def from_db(token: PrismaOAuthAccessToken):
|
||||
return OAuthAccessTokenInfo(
|
||||
id=token.id,
|
||||
user_id=token.userId,
|
||||
scopes=[APIPermission(s) for s in token.scopes],
|
||||
created_at=token.createdAt,
|
||||
expires_at=token.expiresAt,
|
||||
last_used_at=None,
|
||||
revoked_at=token.revokedAt,
|
||||
application_id=token.applicationId,
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccessToken(OAuthAccessTokenInfo):
|
||||
"""Access token with plaintext token included (sensitive)"""
|
||||
|
||||
token: SecretStr = Field(description="Plaintext token (sensitive)")
|
||||
|
||||
@staticmethod
|
||||
def from_db(token: PrismaOAuthAccessToken, plaintext_token: str): # type: ignore
|
||||
return OAuthAccessToken(
|
||||
**OAuthAccessTokenInfo.from_db(token).model_dump(),
|
||||
token=SecretStr(plaintext_token),
|
||||
)
|
||||
|
||||
|
||||
class OAuthRefreshTokenInfo(BaseModel):
|
||||
"""Refresh token information"""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
scopes: list[APIPermission]
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
application_id: str
|
||||
revoked_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def is_revoked(self) -> bool:
|
||||
return self.revoked_at is not None
|
||||
|
||||
@staticmethod
|
||||
def from_db(token: PrismaOAuthRefreshToken):
|
||||
return OAuthRefreshTokenInfo(
|
||||
id=token.id,
|
||||
user_id=token.userId,
|
||||
scopes=[APIPermission(s) for s in token.scopes],
|
||||
created_at=token.createdAt,
|
||||
expires_at=token.expiresAt,
|
||||
application_id=token.applicationId,
|
||||
revoked_at=token.revokedAt,
|
||||
)
|
||||
|
||||
|
||||
class OAuthRefreshToken(OAuthRefreshTokenInfo):
|
||||
"""Refresh token with plaintext token included (sensitive)"""
|
||||
|
||||
token: SecretStr = Field(description="Plaintext token (sensitive)")
|
||||
|
||||
@staticmethod
|
||||
def from_db(token: PrismaOAuthRefreshToken, plaintext_token: str): # type: ignore
|
||||
return OAuthRefreshToken(
|
||||
**OAuthRefreshTokenInfo.from_db(token).model_dump(),
|
||||
token=SecretStr(plaintext_token),
|
||||
)
|
||||
|
||||
|
||||
class TokenIntrospectionResult(BaseModel):
|
||||
"""Result of token introspection (RFC 7662)"""
|
||||
|
||||
active: bool
|
||||
scopes: Optional[list[str]] = None
|
||||
client_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
exp: Optional[int] = None # Unix timestamp
|
||||
token_type: Optional[Literal["access_token", "refresh_token"]] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Application Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def get_oauth_application(client_id: str) -> Optional[OAuthApplicationInfo]:
|
||||
"""Get OAuth application by client ID (without secret)"""
|
||||
app = await PrismaOAuthApplication.prisma().find_unique(
|
||||
where={"clientId": client_id}
|
||||
)
|
||||
if not app:
|
||||
return None
|
||||
return OAuthApplicationInfo.from_db(app)
|
||||
|
||||
|
||||
async def get_oauth_application_with_secret(
|
||||
client_id: str,
|
||||
) -> Optional[OAuthApplicationInfoWithSecret]:
|
||||
"""Get OAuth application by client ID (with secret hash for validation)"""
|
||||
app = await PrismaOAuthApplication.prisma().find_unique(
|
||||
where={"clientId": client_id}
|
||||
)
|
||||
if not app:
|
||||
return None
|
||||
return OAuthApplicationInfoWithSecret.from_db(app)
|
||||
|
||||
|
||||
async def validate_client_credentials(
|
||||
client_id: str, client_secret: str
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Validate client credentials and return application info.
|
||||
|
||||
Raises:
|
||||
InvalidClientError: If client_id or client_secret is invalid, or app is inactive
|
||||
"""
|
||||
app = await get_oauth_application_with_secret(client_id)
|
||||
if not app:
|
||||
raise InvalidClientError("Invalid client_id")
|
||||
|
||||
if not app.is_active:
|
||||
raise InvalidClientError("Application is not active")
|
||||
|
||||
# Verify client secret
|
||||
if not app.verify_secret(client_secret):
|
||||
raise InvalidClientError("Invalid client_secret")
|
||||
|
||||
# Return without secret hash
|
||||
return OAuthApplicationInfo(**app.model_dump(exclude={"client_secret_hash"}))
|
||||
|
||||
|
||||
def validate_redirect_uri(app: OAuthApplicationInfo, redirect_uri: str) -> bool:
|
||||
"""Validate that redirect URI is registered for the application"""
|
||||
return redirect_uri in app.redirect_uris
|
||||
|
||||
|
||||
def validate_scopes(
|
||||
app: OAuthApplicationInfo, requested_scopes: list[APIPermission]
|
||||
) -> bool:
|
||||
"""Validate that all requested scopes are allowed for the application"""
|
||||
return all(scope in app.scopes for scope in requested_scopes)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Code Flow
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _generate_authorization_code() -> str:
|
||||
"""Generate a cryptographically secure authorization code"""
|
||||
# 32 bytes = 256 bits of entropy
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
async def create_authorization_code(
|
||||
application_id: str,
|
||||
user_id: str,
|
||||
scopes: list[APIPermission],
|
||||
redirect_uri: str,
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[Literal["S256", "plain"]] = None,
|
||||
) -> OAuthAuthorizationCodeInfo:
|
||||
"""
|
||||
Create a new authorization code.
|
||||
Expires in 10 minutes and can only be used once.
|
||||
"""
|
||||
code = _generate_authorization_code()
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + AUTHORIZATION_CODE_TTL
|
||||
|
||||
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
|
||||
data=cast(
|
||||
OAuthAuthorizationCodeCreateInput,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"code": code,
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
"redirectUri": redirect_uri,
|
||||
"codeChallenge": code_challenge,
|
||||
"codeChallengeMethod": code_challenge_method,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthAuthorizationCodeInfo.from_db(saved_code)
|
||||
|
||||
|
||||
async def consume_authorization_code(
|
||||
code: str,
|
||||
application_id: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: Optional[str] = None,
|
||||
) -> tuple[str, list[APIPermission]]:
|
||||
"""
|
||||
Consume an authorization code and return (user_id, scopes).
|
||||
|
||||
This marks the code as used and validates:
|
||||
- Code exists and matches application
|
||||
- Code is not expired
|
||||
- Code has not been used
|
||||
- Redirect URI matches
|
||||
- PKCE code verifier matches (if code challenge was provided)
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If code is invalid, expired, used, or PKCE fails
|
||||
"""
|
||||
auth_code = await PrismaOAuthAuthorizationCode.prisma().find_unique(
|
||||
where={"code": code}
|
||||
)
|
||||
|
||||
if not auth_code:
|
||||
raise InvalidGrantError("authorization code not found")
|
||||
|
||||
# Validate application
|
||||
if auth_code.applicationId != application_id:
|
||||
raise InvalidGrantError(
|
||||
"authorization code does not belong to this application"
|
||||
)
|
||||
|
||||
# Check if already used
|
||||
if auth_code.usedAt is not None:
|
||||
raise InvalidGrantError(
|
||||
f"authorization code already used at {auth_code.usedAt}"
|
||||
)
|
||||
|
||||
# Check expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
if auth_code.expiresAt < now:
|
||||
raise InvalidGrantError("authorization code expired")
|
||||
|
||||
# Validate redirect URI
|
||||
if auth_code.redirectUri != redirect_uri:
|
||||
raise InvalidGrantError("redirect_uri mismatch")
|
||||
|
||||
# Validate PKCE if code challenge was provided
|
||||
if auth_code.codeChallenge:
|
||||
if not code_verifier:
|
||||
raise InvalidGrantError("code_verifier required but not provided")
|
||||
|
||||
if not _verify_pkce(
|
||||
code_verifier, auth_code.codeChallenge, auth_code.codeChallengeMethod
|
||||
):
|
||||
raise InvalidGrantError("PKCE verification failed")
|
||||
|
||||
# Mark code as used
|
||||
await PrismaOAuthAuthorizationCode.prisma().update(
|
||||
where={"code": code},
|
||||
data={"usedAt": now},
|
||||
)
|
||||
|
||||
return auth_code.userId, [APIPermission(s) for s in auth_code.scopes]
|
||||
|
||||
|
||||
def _verify_pkce(
|
||||
code_verifier: str, code_challenge: str, code_challenge_method: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Verify PKCE code verifier against code challenge.
|
||||
|
||||
Supports:
|
||||
- S256: SHA256(code_verifier) == code_challenge
|
||||
- plain: code_verifier == code_challenge
|
||||
"""
|
||||
if code_challenge_method == "S256":
|
||||
# Hash the verifier with SHA256 and base64url encode
|
||||
hashed = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||
computed_challenge = (
|
||||
secrets.token_urlsafe(len(hashed)).encode("ascii").decode("ascii")
|
||||
)
|
||||
# For proper base64url encoding
|
||||
import base64
|
||||
|
||||
computed_challenge = (
|
||||
base64.urlsafe_b64encode(hashed).decode("ascii").rstrip("=")
|
||||
)
|
||||
return secrets.compare_digest(computed_challenge, code_challenge)
|
||||
elif code_challenge_method == "plain" or code_challenge_method is None:
|
||||
# Plain comparison
|
||||
return secrets.compare_digest(code_verifier, code_challenge)
|
||||
else:
|
||||
logger.warning(f"Unsupported code challenge method: {code_challenge_method}")
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Access Token Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def create_access_token(
|
||||
application_id: str, user_id: str, scopes: list[APIPermission]
|
||||
) -> OAuthAccessToken:
|
||||
"""
|
||||
Create a new access token.
|
||||
Returns OAuthAccessToken (with plaintext token).
|
||||
"""
|
||||
plaintext_token = ACCESS_TOKEN_PREFIX + _generate_token()
|
||||
token_hash = _hash_token(plaintext_token)
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + ACCESS_TOKEN_TTL
|
||||
|
||||
saved_token = await PrismaOAuthAccessToken.prisma().create(
|
||||
data=cast(
|
||||
OAuthAccessTokenCreateInput,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||
|
||||
|
||||
async def validate_access_token(
|
||||
token: str,
|
||||
) -> tuple[OAuthAccessTokenInfo, OAuthApplicationInfo]:
|
||||
"""
|
||||
Validate an access token and return token info.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If token is invalid, expired, or revoked
|
||||
InvalidClientError: If the client application is not marked as active
|
||||
"""
|
||||
token_hash = _hash_token(token)
|
||||
|
||||
# Direct lookup by hash
|
||||
access_token = await PrismaOAuthAccessToken.prisma().find_unique(
|
||||
where={"token": token_hash}, include={"Application": True}
|
||||
)
|
||||
|
||||
if not access_token:
|
||||
raise InvalidTokenError("access token not found")
|
||||
|
||||
if not access_token.Application: # should be impossible
|
||||
raise InvalidClientError("Client application not found")
|
||||
|
||||
if not access_token.Application.isActive:
|
||||
raise InvalidClientError("Client application is disabled")
|
||||
|
||||
if access_token.revokedAt is not None:
|
||||
raise InvalidTokenError("access token has been revoked")
|
||||
|
||||
# Check expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
if access_token.expiresAt < now:
|
||||
raise InvalidTokenError("access token expired")
|
||||
|
||||
return (
|
||||
OAuthAccessTokenInfo.from_db(access_token),
|
||||
OAuthApplicationInfo.from_db(access_token.Application),
|
||||
)
|
||||
|
||||
|
||||
async def revoke_access_token(
|
||||
token: str, application_id: str
|
||||
) -> OAuthAccessTokenInfo | None:
|
||||
"""
|
||||
Revoke an access token.
|
||||
|
||||
Args:
|
||||
token: The plaintext access token to revoke
|
||||
application_id: The application ID making the revocation request.
|
||||
Only tokens belonging to this application will be revoked.
|
||||
|
||||
Returns:
|
||||
OAuthAccessTokenInfo if token was found and revoked, None otherwise.
|
||||
|
||||
Note:
|
||||
Always performs exactly 2 DB queries regardless of outcome to prevent
|
||||
timing side-channel attacks that could reveal token existence.
|
||||
"""
|
||||
try:
|
||||
token_hash = _hash_token(token)
|
||||
|
||||
# Use update_many to filter by both token and applicationId
|
||||
updated_count = await PrismaOAuthAccessToken.prisma().update_many(
|
||||
where={
|
||||
"token": token_hash,
|
||||
"applicationId": application_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Always perform second query to ensure constant time
|
||||
result = await PrismaOAuthAccessToken.prisma().find_unique(
|
||||
where={"token": token_hash}
|
||||
)
|
||||
|
||||
# Only return result if we actually revoked something
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
return OAuthAccessTokenInfo.from_db(result) if result else None
|
||||
except Exception as e:
|
||||
logger.exception(f"Error revoking access token: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Refresh Token Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def create_refresh_token(
|
||||
application_id: str, user_id: str, scopes: list[APIPermission]
|
||||
) -> OAuthRefreshToken:
|
||||
"""
|
||||
Create a new refresh token.
|
||||
Returns OAuthRefreshToken (with plaintext token).
|
||||
"""
|
||||
plaintext_token = REFRESH_TOKEN_PREFIX + _generate_token()
|
||||
token_hash = _hash_token(plaintext_token)
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + REFRESH_TOKEN_TTL
|
||||
|
||||
saved_token = await PrismaOAuthRefreshToken.prisma().create(
|
||||
data=cast(
|
||||
OAuthRefreshTokenCreateInput,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||
|
||||
|
||||
async def refresh_tokens(
|
||||
refresh_token: str, application_id: str
|
||||
) -> tuple[OAuthAccessToken, OAuthRefreshToken]:
|
||||
"""
|
||||
Use a refresh token to create new access and refresh tokens.
|
||||
Returns (new_access_token, new_refresh_token) both with plaintext tokens included.
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If refresh token is invalid, expired, or revoked
|
||||
"""
|
||||
token_hash = _hash_token(refresh_token)
|
||||
|
||||
# Direct lookup by hash
|
||||
rt = await PrismaOAuthRefreshToken.prisma().find_unique(where={"token": token_hash})
|
||||
|
||||
if not rt:
|
||||
raise InvalidGrantError("refresh token not found")
|
||||
|
||||
# NOTE: no need to check Application.isActive, this is checked by the token endpoint
|
||||
|
||||
if rt.revokedAt is not None:
|
||||
raise InvalidGrantError("refresh token has been revoked")
|
||||
|
||||
# Validate application
|
||||
if rt.applicationId != application_id:
|
||||
raise InvalidGrantError("refresh token does not belong to this application")
|
||||
|
||||
# Check expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
if rt.expiresAt < now:
|
||||
raise InvalidGrantError("refresh token expired")
|
||||
|
||||
# Revoke old refresh token
|
||||
await PrismaOAuthRefreshToken.prisma().update(
|
||||
where={"token": token_hash},
|
||||
data={"revokedAt": now},
|
||||
)
|
||||
|
||||
# Create new access and refresh tokens with same scopes
|
||||
scopes = [APIPermission(s) for s in rt.scopes]
|
||||
new_access_token = await create_access_token(
|
||||
rt.applicationId,
|
||||
rt.userId,
|
||||
scopes,
|
||||
)
|
||||
new_refresh_token = await create_refresh_token(
|
||||
rt.applicationId,
|
||||
rt.userId,
|
||||
scopes,
|
||||
)
|
||||
|
||||
return new_access_token, new_refresh_token
|
||||
|
||||
|
||||
async def revoke_refresh_token(
|
||||
token: str, application_id: str
|
||||
) -> OAuthRefreshTokenInfo | None:
|
||||
"""
|
||||
Revoke a refresh token.
|
||||
|
||||
Args:
|
||||
token: The plaintext refresh token to revoke
|
||||
application_id: The application ID making the revocation request.
|
||||
Only tokens belonging to this application will be revoked.
|
||||
|
||||
Returns:
|
||||
OAuthRefreshTokenInfo if token was found and revoked, None otherwise.
|
||||
|
||||
Note:
|
||||
Always performs exactly 2 DB queries regardless of outcome to prevent
|
||||
timing side-channel attacks that could reveal token existence.
|
||||
"""
|
||||
try:
|
||||
token_hash = _hash_token(token)
|
||||
|
||||
# Use update_many to filter by both token and applicationId
|
||||
updated_count = await PrismaOAuthRefreshToken.prisma().update_many(
|
||||
where={
|
||||
"token": token_hash,
|
||||
"applicationId": application_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Always perform second query to ensure constant time
|
||||
result = await PrismaOAuthRefreshToken.prisma().find_unique(
|
||||
where={"token": token_hash}
|
||||
)
|
||||
|
||||
# Only return result if we actually revoked something
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
return OAuthRefreshTokenInfo.from_db(result) if result else None
|
||||
except Exception as e:
|
||||
logger.exception(f"Error revoking refresh token: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def introspect_token(
|
||||
token: str,
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = None,
|
||||
) -> TokenIntrospectionResult:
|
||||
"""
|
||||
Introspect a token and return its metadata (RFC 7662).
|
||||
|
||||
Returns TokenIntrospectionResult with active=True and metadata if valid,
|
||||
or active=False if the token is invalid/expired/revoked.
|
||||
"""
|
||||
# Try as access token first (or if hint says "access_token")
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
token_info, app = await validate_access_token(token)
|
||||
return TokenIntrospectionResult(
|
||||
active=True,
|
||||
scopes=list(s.value for s in token_info.scopes),
|
||||
client_id=app.client_id if app else None,
|
||||
user_id=token_info.user_id,
|
||||
exp=int(token_info.expires_at.timestamp()),
|
||||
token_type="access_token",
|
||||
)
|
||||
except InvalidTokenError:
|
||||
pass # Try as refresh token
|
||||
|
||||
# Try as refresh token
|
||||
token_hash = _hash_token(token)
|
||||
refresh_token = await PrismaOAuthRefreshToken.prisma().find_unique(
|
||||
where={"token": token_hash}
|
||||
)
|
||||
|
||||
if refresh_token and refresh_token.revokedAt is None:
|
||||
# Check if valid (not expired)
|
||||
now = datetime.now(timezone.utc)
|
||||
if refresh_token.expiresAt > now:
|
||||
app = await get_oauth_application_by_id(refresh_token.applicationId)
|
||||
return TokenIntrospectionResult(
|
||||
active=True,
|
||||
scopes=list(s for s in refresh_token.scopes),
|
||||
client_id=app.client_id if app else None,
|
||||
user_id=refresh_token.userId,
|
||||
exp=int(refresh_token.expiresAt.timestamp()),
|
||||
token_type="refresh_token",
|
||||
)
|
||||
|
||||
# Token not found or inactive
|
||||
return TokenIntrospectionResult(active=False)
|
||||
|
||||
|
||||
async def get_oauth_application_by_id(app_id: str) -> Optional[OAuthApplicationInfo]:
|
||||
"""Get OAuth application by ID"""
|
||||
app = await PrismaOAuthApplication.prisma().find_unique(where={"id": app_id})
|
||||
if not app:
|
||||
return None
|
||||
return OAuthApplicationInfo.from_db(app)
|
||||
|
||||
|
||||
async def list_user_oauth_applications(user_id: str) -> list[OAuthApplicationInfo]:
|
||||
"""Get all OAuth applications owned by a user"""
|
||||
apps = await PrismaOAuthApplication.prisma().find_many(
|
||||
where={"ownerId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [OAuthApplicationInfo.from_db(app) for app in apps]
|
||||
|
||||
|
||||
async def update_oauth_application(
|
||||
app_id: str,
|
||||
*,
|
||||
owner_id: str,
|
||||
is_active: Optional[bool] = None,
|
||||
logo_url: Optional[str] = None,
|
||||
) -> Optional[OAuthApplicationInfo]:
|
||||
"""
|
||||
Update OAuth application active status.
|
||||
Only the owner can update their app's status.
|
||||
|
||||
Returns the updated app info, or None if app not found or not owned by user.
|
||||
"""
|
||||
# First verify ownership
|
||||
app = await PrismaOAuthApplication.prisma().find_first(
|
||||
where={"id": app_id, "ownerId": owner_id}
|
||||
)
|
||||
if not app:
|
||||
return None
|
||||
|
||||
patch: OAuthApplicationUpdateInput = {}
|
||||
if is_active is not None:
|
||||
patch["isActive"] = is_active
|
||||
if logo_url:
|
||||
patch["logoUrl"] = logo_url
|
||||
if not patch:
|
||||
return OAuthApplicationInfo.from_db(app) # return unchanged
|
||||
|
||||
updated_app = await PrismaOAuthApplication.prisma().update(
|
||||
where={"id": app_id},
|
||||
data=patch,
|
||||
)
|
||||
return OAuthApplicationInfo.from_db(updated_app) if updated_app else None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Cleanup
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def cleanup_expired_oauth_tokens() -> dict[str, int]:
|
||||
"""
|
||||
Delete expired OAuth tokens from the database.
|
||||
|
||||
This removes:
|
||||
- Expired authorization codes (10 min TTL)
|
||||
- Expired access tokens (1 hour TTL)
|
||||
- Expired refresh tokens (30 day TTL)
|
||||
|
||||
Returns a dict with counts of deleted tokens by type.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Delete expired authorization codes
|
||||
codes_result = await PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"expiresAt": {"lt": now}}
|
||||
)
|
||||
|
||||
# Delete expired access tokens
|
||||
access_result = await PrismaOAuthAccessToken.prisma().delete_many(
|
||||
where={"expiresAt": {"lt": now}}
|
||||
)
|
||||
|
||||
# Delete expired refresh tokens
|
||||
refresh_result = await PrismaOAuthRefreshToken.prisma().delete_many(
|
||||
where={"expiresAt": {"lt": now}}
|
||||
)
|
||||
|
||||
deleted = {
|
||||
"authorization_codes": codes_result,
|
||||
"access_tokens": access_result,
|
||||
"refresh_tokens": refresh_result,
|
||||
}
|
||||
|
||||
total = sum(deleted.values())
|
||||
if total > 0:
|
||||
logger.info(f"Cleaned up {total} expired OAuth tokens: {deleted}")
|
||||
|
||||
return deleted
|
||||
@@ -5,14 +5,12 @@ This test was added to cover a previously untested code path that could lead to
|
||||
incorrect balance capping behavior.
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceUpsertInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -23,14 +21,11 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for ceiling tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
},
|
||||
)
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -38,10 +33,7 @@ async def create_test_user(user_id: str) -> None:
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
),
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ without race conditions, deadlocks, or inconsistent state.
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import prisma.enums
|
||||
@@ -15,7 +14,6 @@ import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceUpsertInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
@@ -30,14 +28,11 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user with initial balance."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
},
|
||||
)
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -46,10 +41,7 @@ async def create_test_user(user_id: str) -> None:
|
||||
# Ensure UserBalance record exists
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
),
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
)
|
||||
|
||||
|
||||
@@ -350,13 +342,10 @@ async def test_integer_overflow_protection(server: SpinTestServer):
|
||||
# First, set balance near max
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": max_int - 100},
|
||||
"update": {"balance": max_int - 100},
|
||||
},
|
||||
),
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": max_int - 100},
|
||||
"update": {"balance": max_int - 100},
|
||||
},
|
||||
)
|
||||
|
||||
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
|
||||
|
||||
@@ -5,12 +5,9 @@ These tests run actual database operations to ensure SQL queries work correctly,
|
||||
which would have caught the CreditTransactionType enum casting bug.
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserCreateInput
|
||||
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
@@ -32,15 +29,12 @@ async def cleanup_test_user():
|
||||
# Create the user first
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"topUpConfig": SafeJson({}),
|
||||
"timezone": "UTC",
|
||||
},
|
||||
)
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"topUpConfig": SafeJson({}),
|
||||
"timezone": "UTC",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# User might already exist, that's fine
|
||||
|
||||
@@ -6,19 +6,12 @@ are atomic and maintain data consistency.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import stripe
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||
from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
UserBalanceCreateInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -42,41 +35,32 @@ async def setup_test_user_with_topup():
|
||||
|
||||
# Create user
|
||||
await User.prisma().create(
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": REFUND_TEST_USER_ID,
|
||||
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
||||
"name": "Refund Test User",
|
||||
},
|
||||
)
|
||||
data={
|
||||
"id": REFUND_TEST_USER_ID,
|
||||
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
||||
"name": "Refund Test User",
|
||||
}
|
||||
)
|
||||
|
||||
# Create user balance
|
||||
await UserBalance.prisma().create(
|
||||
data=cast(
|
||||
UserBalanceCreateInput,
|
||||
{
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"balance": 1000, # $10
|
||||
},
|
||||
)
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"balance": 1000, # $10
|
||||
}
|
||||
)
|
||||
|
||||
# Create a top-up transaction that can be refunded
|
||||
topup_tx = await CreditTransaction.prisma().create(
|
||||
data=cast(
|
||||
CreditTransactionCreateInput,
|
||||
{
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 1000,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"transactionKey": "pi_test_12345",
|
||||
"runningBalance": 1000,
|
||||
"isActive": True,
|
||||
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
},
|
||||
)
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 1000,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"transactionKey": "pi_test_12345",
|
||||
"runningBalance": 1000,
|
||||
"isActive": True,
|
||||
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
}
|
||||
)
|
||||
|
||||
return topup_tx
|
||||
@@ -109,15 +93,12 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
|
||||
|
||||
# Create refund request record (simulating webhook flow)
|
||||
await CreditRefundRequest.prisma().create(
|
||||
data=cast(
|
||||
CreditRefundRequestCreateInput,
|
||||
{
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 500,
|
||||
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
||||
"reason": "Test refund",
|
||||
},
|
||||
)
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 500,
|
||||
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
||||
"reason": "Test refund",
|
||||
}
|
||||
)
|
||||
|
||||
# Call deduct_credits
|
||||
@@ -305,15 +286,12 @@ async def test_concurrent_refunds(server: SpinTestServer):
|
||||
refund_requests = []
|
||||
for i in range(5):
|
||||
req = await CreditRefundRequest.prisma().create(
|
||||
data=cast(
|
||||
CreditRefundRequestCreateInput,
|
||||
{
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 100, # $1 each
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
"reason": f"Test refund {i}",
|
||||
},
|
||||
)
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 100, # $1 each
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
"reason": f"Test refund {i}",
|
||||
}
|
||||
)
|
||||
refund_requests.append(req)
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, UserBalance
|
||||
from prisma.types import CreditTransactionCreateInput, UserBalanceUpsertInput
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
@@ -25,13 +23,10 @@ async def disable_test_user_transactions():
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
||||
"update": {"balance": 0, "updatedAt": old_date},
|
||||
},
|
||||
),
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
||||
"update": {"balance": 0, "updatedAt": old_date},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -145,29 +140,23 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
|
||||
# Manually create a transaction with month 1 timestamp to establish history
|
||||
await CreditTransaction.prisma().create(
|
||||
data=cast(
|
||||
CreditTransactionCreateInput,
|
||||
{
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": 100,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"runningBalance": 1100,
|
||||
"isActive": True,
|
||||
"createdAt": month1, # Set specific timestamp
|
||||
},
|
||||
)
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": 100,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"runningBalance": 1100,
|
||||
"isActive": True,
|
||||
"createdAt": month1, # Set specific timestamp
|
||||
}
|
||||
)
|
||||
|
||||
# Update user balance to match
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
||||
"update": {"balance": 1100},
|
||||
},
|
||||
),
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
||||
"update": {"balance": 1100},
|
||||
},
|
||||
)
|
||||
|
||||
# Now test month 2 behavior
|
||||
@@ -186,17 +175,14 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
|
||||
# Create a month 2 transaction to update the last transaction time
|
||||
await CreditTransaction.prisma().create(
|
||||
data=cast(
|
||||
CreditTransactionCreateInput,
|
||||
{
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": -700, # Spent 700 to get to 400
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"runningBalance": 400,
|
||||
"isActive": True,
|
||||
"createdAt": month2,
|
||||
},
|
||||
)
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": -700, # Spent 700 to get to 400
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"runningBalance": 400,
|
||||
"isActive": True,
|
||||
"createdAt": month2,
|
||||
}
|
||||
)
|
||||
|
||||
# Move to month 3
|
||||
|
||||
@@ -6,14 +6,12 @@ doesn't underflow below POSTGRES_INT_MIN, which could cause integer wraparound i
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceUpsertInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
|
||||
from backend.util.test import SpinTestServer
|
||||
@@ -23,14 +21,11 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for underflow tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
},
|
||||
)
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -38,10 +33,7 @@ async def create_test_user(user_id: str) -> None:
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
),
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
)
|
||||
|
||||
|
||||
@@ -78,13 +70,10 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": initial_balance_target},
|
||||
"update": {"balance": initial_balance_target},
|
||||
},
|
||||
),
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance_target},
|
||||
"update": {"balance": initial_balance_target},
|
||||
},
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
@@ -121,13 +110,10 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
# Set balance to exactly POSTGRES_INT_MIN
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
||||
"update": {"balance": POSTGRES_INT_MIN},
|
||||
},
|
||||
),
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
||||
"update": {"balance": POSTGRES_INT_MIN},
|
||||
},
|
||||
)
|
||||
|
||||
edge_balance = await credit_system.get_credits(user_id)
|
||||
@@ -166,13 +152,10 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
|
||||
test_balance = POSTGRES_INT_MIN + 1000
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": test_balance},
|
||||
"update": {"balance": test_balance},
|
||||
},
|
||||
),
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": test_balance},
|
||||
"update": {"balance": test_balance},
|
||||
},
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
@@ -234,13 +217,10 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
|
||||
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
),
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
)
|
||||
|
||||
# Apply multiple refunds that would cumulatively underflow
|
||||
@@ -315,13 +295,10 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
|
||||
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data=cast(
|
||||
UserBalanceUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
),
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
)
|
||||
|
||||
async def large_refund(amount: int, label: str):
|
||||
|
||||
@@ -9,13 +9,11 @@ This test ensures that:
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceCreateInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import UsageTransactionMetadata, UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -26,14 +24,11 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for migration tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
},
|
||||
)
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -126,9 +121,7 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
|
||||
try:
|
||||
# Create UserBalance with specific value
|
||||
await UserBalance.prisma().create(
|
||||
data=cast(
|
||||
UserBalanceCreateInput, {"userId": user_id, "balance": 5000}
|
||||
) # $50
|
||||
data={"userId": user_id, "balance": 5000} # $50
|
||||
)
|
||||
|
||||
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
|
||||
@@ -167,9 +160,7 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
|
||||
|
||||
try:
|
||||
# Set initial balance in UserBalance
|
||||
await UserBalance.prisma().create(
|
||||
data=cast(UserBalanceCreateInput, {"userId": user_id, "balance": 1000})
|
||||
)
|
||||
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
|
||||
|
||||
# Run concurrent operations to ensure they all use UserBalance atomic operations
|
||||
async def concurrent_spend(amount: int, label: str):
|
||||
|
||||
@@ -28,7 +28,6 @@ from prisma.models import (
|
||||
AgentNodeExecutionKeyValueData,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
AgentGraphExecutionUpdateManyMutationInput,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
@@ -36,6 +35,7 @@ from prisma.types import (
|
||||
AgentNodeExecutionKeyValueDataCreateInput,
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
AgentNodeExecutionWhereUniqueInput,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
||||
from pydantic.fields import Field
|
||||
@@ -709,40 +709,37 @@ async def create_graph_execution(
|
||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||
"""
|
||||
result = await AgentGraphExecution.prisma().create(
|
||||
data=cast(
|
||||
AgentGraphExecutionCreateInput,
|
||||
{
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
),
|
||||
"nodesInputMasks": (
|
||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||
),
|
||||
"NodeExecutions": {
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
queuedTime=datetime.now(tz=timezone.utc),
|
||||
Input={
|
||||
"create": [
|
||||
{"name": name, "data": SafeJson(data)}
|
||||
for name, data in node_input.items()
|
||||
]
|
||||
},
|
||||
)
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
"parentGraphExecutionId": parent_graph_exec_id,
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
),
|
||||
"nodesInputMasks": (
|
||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||
),
|
||||
"NodeExecutions": {
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
queuedTime=datetime.now(tz=timezone.utc),
|
||||
Input={
|
||||
"create": [
|
||||
{"name": name, "data": SafeJson(data)}
|
||||
for name, data in node_input.items()
|
||||
]
|
||||
},
|
||||
)
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
),
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
"parentGraphExecutionId": parent_graph_exec_id,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -834,13 +831,10 @@ async def upsert_execution_output(
|
||||
"""
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
"""
|
||||
data: AgentNodeExecutionInputOutputCreateInput = cast(
|
||||
AgentNodeExecutionInputOutputCreateInput,
|
||||
{
|
||||
"name": output_name,
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
},
|
||||
)
|
||||
data: AgentNodeExecutionInputOutputCreateInput = {
|
||||
"name": output_name,
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
}
|
||||
if output_data is not None:
|
||||
data["data"] = SafeJson(output_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||
@@ -980,30 +974,25 @@ async def update_node_execution_status(
|
||||
f"Invalid status transition: {status} has no valid source statuses"
|
||||
)
|
||||
|
||||
# First verify the current status allows this transition
|
||||
current_exec = await AgentNodeExecution.prisma().find_unique(
|
||||
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
|
||||
)
|
||||
|
||||
if not current_exec:
|
||||
raise ValueError(f"Execution {node_exec_id} not found.")
|
||||
|
||||
# Check if current status allows the requested transition
|
||||
if current_exec.executionStatus not in allowed_from:
|
||||
# Status transition not allowed, return current state without updating
|
||||
return NodeExecutionResult.from_db(current_exec)
|
||||
|
||||
# Status transition is valid, perform the update
|
||||
updated_exec = await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
if res := await AgentNodeExecution.prisma().update(
|
||||
where=cast(
|
||||
AgentNodeExecutionWhereUniqueInput,
|
||||
{
|
||||
"id": node_exec_id,
|
||||
"executionStatus": {"in": [s.value for s in allowed_from]},
|
||||
},
|
||||
),
|
||||
data=_get_update_status_data(status, execution_data, stats),
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
):
|
||||
return NodeExecutionResult.from_db(res)
|
||||
|
||||
if not updated_exec:
|
||||
raise ValueError(f"Failed to update execution {node_exec_id}.")
|
||||
if res := await AgentNodeExecution.prisma().find_unique(
|
||||
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
|
||||
):
|
||||
return NodeExecutionResult.from_db(res)
|
||||
|
||||
return NodeExecutionResult.from_db(updated_exec)
|
||||
raise ValueError(f"Execution {node_exec_id} not found.")
|
||||
|
||||
|
||||
def _get_update_status_data(
|
||||
|
||||
@@ -6,11 +6,11 @@ Handles all database operations for pending human reviews.
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
from prisma.models import PendingHumanReview
|
||||
from prisma.types import PendingHumanReviewUpdateInput, PendingHumanReviewUpsertInput
|
||||
from prisma.types import PendingHumanReviewUpdateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.server.v2.executions.review.model import (
|
||||
@@ -66,23 +66,20 @@ async def get_or_create_human_review(
|
||||
# Upsert - get existing or create new review
|
||||
review = await PendingHumanReview.prisma().upsert(
|
||||
where={"nodeExecId": node_exec_id},
|
||||
data=cast(
|
||||
PendingHumanReviewUpsertInput,
|
||||
{
|
||||
"create": {
|
||||
"userId": user_id,
|
||||
"nodeExecId": node_exec_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"graphId": graph_id,
|
||||
"graphVersion": graph_version,
|
||||
"payload": SafeJson(input_data),
|
||||
"instructions": message,
|
||||
"editable": editable,
|
||||
"status": ReviewStatus.WAITING,
|
||||
},
|
||||
"update": {}, # Do nothing on update - keep existing review as is
|
||||
data={
|
||||
"create": {
|
||||
"userId": user_id,
|
||||
"nodeExecId": node_exec_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"graphId": graph_id,
|
||||
"graphVersion": graph_version,
|
||||
"payload": SafeJson(input_data),
|
||||
"instructions": message,
|
||||
"editable": editable,
|
||||
"status": ReviewStatus.WAITING,
|
||||
},
|
||||
),
|
||||
"update": {}, # Do nothing on update - keep existing review as is
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Literal, Optional, cast
|
||||
from typing import Any, Literal, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import prisma
|
||||
import pydantic
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import (
|
||||
UserOnboardingCreateInput,
|
||||
UserOnboardingUpdateInput,
|
||||
UserOnboardingUpsertInput,
|
||||
)
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.credit import get_user_credit_model
|
||||
@@ -116,13 +112,10 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data=cast(
|
||||
UserOnboardingUpsertInput,
|
||||
{
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
),
|
||||
data={
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
@@ -243,12 +242,6 @@ def cleanup_expired_files():
|
||||
run_async(cleanup_expired_files_async())
|
||||
|
||||
|
||||
def cleanup_oauth_tokens():
|
||||
"""Clean up expired OAuth tokens from the database."""
|
||||
# Wait for completion
|
||||
run_async(cleanup_expired_oauth_tokens())
|
||||
|
||||
|
||||
def execution_accuracy_alerts():
|
||||
"""Check execution accuracy and send alerts if drops are detected."""
|
||||
return report_execution_accuracy_alerts()
|
||||
@@ -453,17 +446,6 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# OAuth Token Cleanup - configurable interval
|
||||
self.scheduler.add_job(
|
||||
cleanup_oauth_tokens,
|
||||
id="cleanup_oauth_tokens",
|
||||
trigger="interval",
|
||||
replace_existing=True,
|
||||
seconds=config.oauth_token_cleanup_interval_hours
|
||||
* 3600, # Convert hours to seconds
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# Execution Accuracy Monitoring - configurable interval
|
||||
self.scheduler.add_job(
|
||||
execution_accuracy_alerts,
|
||||
@@ -622,11 +604,6 @@ class Scheduler(AppService):
|
||||
"""Manually trigger cleanup of expired cloud storage files."""
|
||||
return cleanup_expired_files()
|
||||
|
||||
@expose
|
||||
def execute_cleanup_oauth_tokens(self):
|
||||
"""Manually trigger cleanup of expired OAuth tokens."""
|
||||
return cleanup_oauth_tokens()
|
||||
|
||||
@expose
|
||||
def execute_report_execution_accuracy_alerts(self):
|
||||
"""Manually trigger execution accuracy alert checking."""
|
||||
|
||||
@@ -49,10 +49,11 @@
|
||||
</p>
|
||||
<ol style="margin-bottom: 10px;">
|
||||
<li>
|
||||
Connect to the database using your preferred database client.
|
||||
Visit the Supabase Dashboard:
|
||||
https://supabase.com/dashboard/project/bgwpwdsxblryihinutbx/editor
|
||||
</li>
|
||||
<li>
|
||||
Navigate to the <strong>RefundRequest</strong> table in the <strong>platform</strong> schema.
|
||||
Navigate to the <strong>RefundRequest</strong> table.
|
||||
</li>
|
||||
<li>
|
||||
Filter the <code>transactionKey</code> column with the Transaction ID: <strong>{{ data.transaction_id }}</strong>.
|
||||
|
||||
@@ -6,7 +6,7 @@ Usage: from backend.sdk import *
|
||||
|
||||
This module provides:
|
||||
- All block base classes and types
|
||||
- All credential and authentication components
|
||||
- All credential and authentication components
|
||||
- All cost tracking components
|
||||
- All webhook components
|
||||
- All utility functions
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Integration between SDK provider costs and the execution cost system.
|
||||
|
||||
This module provides the glue between provider-defined base costs and the
|
||||
This module provides the glue between provider-defined base costs and the
|
||||
BLOCK_COSTS configuration used by the execution system.
|
||||
"""
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
"""
|
||||
Authentication module for the AutoGPT Platform.
|
||||
|
||||
This module provides FastAPI-based authentication supporting:
|
||||
- Email/password authentication with bcrypt hashing
|
||||
- Google OAuth authentication
|
||||
- JWT token management (access + refresh tokens)
|
||||
"""
|
||||
|
||||
from .routes import router as auth_router
|
||||
from .service import AuthService
|
||||
|
||||
__all__ = ["auth_router", "AuthService"]
|
||||
@@ -1,170 +0,0 @@
|
||||
"""
|
||||
Direct email sending for authentication flows.
|
||||
|
||||
This module bypasses the notification queue system to ensure auth emails
|
||||
(password reset, email verification) are sent immediately in all environments.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from postmarker.core import PostmarkClient
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# Template directory
|
||||
TEMPLATE_DIR = pathlib.Path(__file__).parent / "templates"
|
||||
|
||||
|
||||
class AuthEmailSender:
|
||||
"""Handles direct email sending for authentication flows."""
|
||||
|
||||
def __init__(self):
|
||||
if settings.secrets.postmark_server_api_token:
|
||||
self.postmark = PostmarkClient(
|
||||
server_token=settings.secrets.postmark_server_api_token
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Postmark server API token not found, auth email sending disabled"
|
||||
)
|
||||
self.postmark = None
|
||||
|
||||
# Set up Jinja2 environment for templates
|
||||
self.jinja_env: Optional[Environment] = None
|
||||
if TEMPLATE_DIR.exists():
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(str(TEMPLATE_DIR)),
|
||||
autoescape=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Auth email templates directory not found: {TEMPLATE_DIR}")
|
||||
|
||||
def _get_frontend_url(self) -> str:
|
||||
"""Get the frontend base URL for email links."""
|
||||
return (
|
||||
settings.config.frontend_base_url
|
||||
or settings.config.platform_base_url
|
||||
or "http://localhost:3000"
|
||||
)
|
||||
|
||||
def _render_template(
|
||||
self, template_name: str, subject: str, **context
|
||||
) -> tuple[str, str]:
|
||||
"""Render an email template with the base template wrapper."""
|
||||
if not self.jinja_env:
|
||||
raise RuntimeError("Email templates not available")
|
||||
|
||||
# Render the content template
|
||||
content_template = self.jinja_env.get_template(template_name)
|
||||
content = content_template.render(**context)
|
||||
|
||||
# Render with base template
|
||||
base_template = self.jinja_env.get_template("base.html.jinja2")
|
||||
html_body = base_template.render(
|
||||
data={"title": subject, "message": content, "unsubscribe_link": None}
|
||||
)
|
||||
|
||||
return subject, html_body
|
||||
|
||||
def _send_email(self, to_email: str, subject: str, html_body: str) -> bool:
|
||||
"""Send an email directly via Postmark."""
|
||||
if not self.postmark:
|
||||
logger.warning(
|
||||
f"Postmark not configured. Would send email to {to_email}: {subject}"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
self.postmark.emails.send( # type: ignore[attr-defined]
|
||||
From=settings.config.postmark_sender_email,
|
||||
To=to_email,
|
||||
Subject=subject,
|
||||
HtmlBody=html_body,
|
||||
)
|
||||
logger.info(f"Auth email sent to {to_email}: {subject}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send auth email to {to_email}: {e}")
|
||||
return False
|
||||
|
||||
def send_password_reset_email(
|
||||
self, to_email: str, reset_token: str, user_name: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send a password reset email.
|
||||
|
||||
Args:
|
||||
to_email: Recipient email address
|
||||
reset_token: The raw password reset token
|
||||
user_name: Optional user name for personalization
|
||||
|
||||
Returns:
|
||||
True if email was sent successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
frontend_url = self._get_frontend_url()
|
||||
reset_link = f"{frontend_url}/reset-password?token={reset_token}"
|
||||
|
||||
subject, html_body = self._render_template(
|
||||
"password_reset.html.jinja2",
|
||||
subject="Reset Your AutoGPT Password",
|
||||
reset_link=reset_link,
|
||||
user_name=user_name,
|
||||
frontend_url=frontend_url,
|
||||
)
|
||||
|
||||
return self._send_email(to_email, subject, html_body)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send password reset email to {to_email}: {e}")
|
||||
return False
|
||||
|
||||
def send_email_verification(
|
||||
self, to_email: str, verification_token: str, user_name: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send an email verification email.
|
||||
|
||||
Args:
|
||||
to_email: Recipient email address
|
||||
verification_token: The raw verification token
|
||||
user_name: Optional user name for personalization
|
||||
|
||||
Returns:
|
||||
True if email was sent successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
frontend_url = self._get_frontend_url()
|
||||
verification_link = (
|
||||
f"{frontend_url}/verify-email?token={verification_token}"
|
||||
)
|
||||
|
||||
subject, html_body = self._render_template(
|
||||
"email_verification.html.jinja2",
|
||||
subject="Verify Your AutoGPT Email",
|
||||
verification_link=verification_link,
|
||||
user_name=user_name,
|
||||
frontend_url=frontend_url,
|
||||
)
|
||||
|
||||
return self._send_email(to_email, subject, html_body)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send verification email to {to_email}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_auth_email_sender: Optional[AuthEmailSender] = None
|
||||
|
||||
|
||||
def get_auth_email_sender() -> AuthEmailSender:
|
||||
"""Get or create the auth email sender singleton."""
|
||||
global _auth_email_sender
|
||||
if _auth_email_sender is None:
|
||||
_auth_email_sender = AuthEmailSender()
|
||||
return _auth_email_sender
|
||||
@@ -1,505 +0,0 @@
|
||||
"""
|
||||
Authentication API routes.
|
||||
|
||||
Provides endpoints for:
|
||||
- User registration and login
|
||||
- Token refresh and logout
|
||||
- Password reset
|
||||
- Email verification
|
||||
- Google OAuth
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .email import get_auth_email_sender
|
||||
from .service import AuthService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# Singleton auth service instance
|
||||
_auth_service: Optional[AuthService] = None
|
||||
|
||||
# In-memory state storage for OAuth CSRF protection
|
||||
# Format: {state_token: {"created_at": timestamp, "redirect_uri": optional_uri}}
|
||||
# In production, use Redis for distributed state management
|
||||
_oauth_states: dict[str, dict] = {}
|
||||
_STATE_TTL_SECONDS = 600 # 10 minutes
|
||||
|
||||
|
||||
def _cleanup_expired_states() -> None:
|
||||
"""Remove expired OAuth states."""
|
||||
now = time.time()
|
||||
expired = [
|
||||
k
|
||||
for k, v in _oauth_states.items()
|
||||
if now - v["created_at"] > _STATE_TTL_SECONDS
|
||||
]
|
||||
for k in expired:
|
||||
del _oauth_states[k]
|
||||
|
||||
|
||||
def _generate_state() -> str:
|
||||
"""Generate a cryptographically secure state token."""
|
||||
_cleanup_expired_states()
|
||||
state = secrets.token_urlsafe(32)
|
||||
_oauth_states[state] = {"created_at": time.time()}
|
||||
return state
|
||||
|
||||
|
||||
def _validate_state(state: str) -> bool:
|
||||
"""Validate and consume a state token."""
|
||||
if state not in _oauth_states:
|
||||
return False
|
||||
state_data = _oauth_states.pop(state)
|
||||
if time.time() - state_data["created_at"] > _STATE_TTL_SECONDS:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_auth_service() -> AuthService:
|
||||
"""Get or create the auth service singleton."""
|
||||
global _auth_service
|
||||
if _auth_service is None:
|
||||
_auth_service = AuthService()
|
||||
return _auth_service
|
||||
|
||||
|
||||
# ============= Request/Response Models =============
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""Request model for user registration."""
|
||||
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=8)
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""Request model for user login."""
|
||||
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Response model for authentication tokens."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
"""Request model for token refresh."""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Request model for logout."""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
"""Request model for password reset request."""
|
||||
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
"""Request model for password reset confirmation."""
|
||||
|
||||
token: str
|
||||
new_password: str = Field(..., min_length=8)
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Generic message response."""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""Response model for user info."""
|
||||
|
||||
id: str
|
||||
email: str
|
||||
name: Optional[str]
|
||||
email_verified: bool
|
||||
role: str
|
||||
|
||||
|
||||
# ============= Auth Endpoints =============
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse)
|
||||
async def register(request: RegisterRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Register a new user with email and password.
|
||||
|
||||
Returns access and refresh tokens on successful registration.
|
||||
Sends a verification email in the background.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
try:
|
||||
user = await auth_service.register_user(
|
||||
email=request.email,
|
||||
password=request.password,
|
||||
name=request.name,
|
||||
)
|
||||
|
||||
# Create verification token and send email in background
|
||||
# This is non-critical - don't fail registration if email fails
|
||||
try:
|
||||
verification_token = await auth_service.create_email_verification_token(
|
||||
user.id
|
||||
)
|
||||
email_sender = get_auth_email_sender()
|
||||
background_tasks.add_task(
|
||||
email_sender.send_email_verification,
|
||||
to_email=user.email,
|
||||
verification_token=verification_token,
|
||||
user_name=user.name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to queue verification email for {user.email}: {e}")
|
||||
|
||||
tokens = await auth_service.create_tokens(user)
|
||||
return TokenResponse(**tokens)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(request: LoginRequest):
|
||||
"""
|
||||
Login with email and password.
|
||||
|
||||
Returns access and refresh tokens on successful authentication.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
user = await auth_service.authenticate_user(request.email, request.password)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid email or password")
|
||||
|
||||
tokens = await auth_service.create_tokens(user)
|
||||
return TokenResponse(**tokens)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=MessageResponse)
|
||||
async def logout(request: LogoutRequest):
|
||||
"""
|
||||
Logout by revoking the refresh token.
|
||||
|
||||
This invalidates the refresh token so it cannot be used to get new access tokens.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
revoked = await auth_service.revoke_refresh_token(request.refresh_token)
|
||||
if not revoked:
|
||||
raise HTTPException(status_code=400, detail="Invalid refresh token")
|
||||
|
||||
return MessageResponse(message="Successfully logged out")
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh_tokens(request: RefreshRequest):
|
||||
"""
|
||||
Refresh access token using a refresh token.
|
||||
|
||||
The old refresh token is invalidated and a new one is returned (token rotation).
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
tokens = await auth_service.refresh_access_token(request.refresh_token)
|
||||
if not tokens:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
|
||||
|
||||
return TokenResponse(**tokens)
|
||||
|
||||
|
||||
@router.post("/password-reset/request", response_model=MessageResponse)
|
||||
async def request_password_reset(
|
||||
request: PasswordResetRequest, background_tasks: BackgroundTasks
|
||||
):
|
||||
"""
|
||||
Request a password reset email.
|
||||
|
||||
Always returns success to prevent email enumeration attacks.
|
||||
If the email exists, a password reset email will be sent.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
user = await auth_service.get_user_by_email(request.email)
|
||||
if user:
|
||||
token = await auth_service.create_password_reset_token(user.id)
|
||||
email_sender = get_auth_email_sender()
|
||||
background_tasks.add_task(
|
||||
email_sender.send_password_reset_email,
|
||||
to_email=user.email,
|
||||
reset_token=token,
|
||||
user_name=user.name,
|
||||
)
|
||||
logger.info(f"Password reset email queued for user {user.id}")
|
||||
|
||||
# Always return success to prevent email enumeration
|
||||
return MessageResponse(
|
||||
message="If the email exists, a password reset link has been sent"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/password-reset/confirm", response_model=MessageResponse)
|
||||
async def confirm_password_reset(request: PasswordResetConfirm):
|
||||
"""
|
||||
Reset password using a password reset token.
|
||||
|
||||
All existing sessions (refresh tokens) will be invalidated.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
success = await auth_service.reset_password(request.token, request.new_password)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired reset token")
|
||||
|
||||
return MessageResponse(message="Password has been reset successfully")
|
||||
|
||||
|
||||
# ============= Email Verification Endpoints =============
|
||||
|
||||
|
||||
class EmailVerificationRequest(BaseModel):
|
||||
"""Request model for email verification."""
|
||||
|
||||
token: str
|
||||
|
||||
|
||||
class ResendVerificationRequest(BaseModel):
|
||||
"""Request model for resending verification email."""
|
||||
|
||||
email: EmailStr
|
||||
|
||||
|
||||
@router.post("/email/verify", response_model=MessageResponse)
|
||||
async def verify_email(request: EmailVerificationRequest):
|
||||
"""
|
||||
Verify email address using a verification token.
|
||||
|
||||
Marks the user's email as verified if the token is valid.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
success = await auth_service.verify_email_token(request.token)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid or expired verification token"
|
||||
)
|
||||
|
||||
return MessageResponse(message="Email verified successfully")
|
||||
|
||||
|
||||
@router.post("/email/resend-verification", response_model=MessageResponse)
|
||||
async def resend_verification_email(
|
||||
request: ResendVerificationRequest, background_tasks: BackgroundTasks
|
||||
):
|
||||
"""
|
||||
Resend email verification email.
|
||||
|
||||
Always returns success to prevent email enumeration attacks.
|
||||
If the email exists and is not verified, a new verification email will be sent.
|
||||
"""
|
||||
auth_service = get_auth_service()
|
||||
|
||||
user = await auth_service.get_user_by_email(request.email)
|
||||
if user and not user.emailVerified:
|
||||
token = await auth_service.create_email_verification_token(user.id)
|
||||
email_sender = get_auth_email_sender()
|
||||
background_tasks.add_task(
|
||||
email_sender.send_email_verification,
|
||||
to_email=user.email,
|
||||
verification_token=token,
|
||||
user_name=user.name,
|
||||
)
|
||||
logger.info(f"Verification email queued for user {user.id}")
|
||||
|
||||
# Always return success to prevent email enumeration
|
||||
return MessageResponse(
|
||||
message="If the email exists and is not verified, a verification link has been sent"
|
||||
)
|
||||
|
||||
|
||||
# ============= Google OAuth Endpoints =============
|
||||
|
||||
# Google userinfo endpoint for fetching user profile
|
||||
GOOGLE_USERINFO_ENDPOINT = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
|
||||
class GoogleLoginResponse(BaseModel):
|
||||
"""Response model for Google OAuth login initiation."""
|
||||
|
||||
url: str
|
||||
|
||||
|
||||
def _get_google_oauth_handler():
|
||||
"""Get a configured GoogleOAuthHandler instance."""
|
||||
# Lazy import to avoid circular imports
|
||||
from backend.integrations.oauth.google import GoogleOAuthHandler
|
||||
|
||||
settings = Settings()
|
||||
|
||||
client_id = settings.secrets.google_client_id
|
||||
client_secret = settings.secrets.google_client_secret
|
||||
|
||||
if not client_id or not client_secret:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google OAuth is not configured. Set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET.",
|
||||
)
|
||||
|
||||
# Construct the redirect URI - this should point to the frontend's callback
|
||||
# which will then call our /auth/google/callback endpoint
|
||||
frontend_base_url = settings.config.frontend_base_url or "http://localhost:3000"
|
||||
redirect_uri = f"{frontend_base_url}/auth/callback"
|
||||
|
||||
return GoogleOAuthHandler(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/google/login", response_model=GoogleLoginResponse)
|
||||
async def google_login(request: Request):
|
||||
"""
|
||||
Initiate Google OAuth flow.
|
||||
|
||||
Returns the Google OAuth authorization URL to redirect the user to.
|
||||
"""
|
||||
try:
|
||||
handler = _get_google_oauth_handler()
|
||||
state = _generate_state()
|
||||
|
||||
# Get the authorization URL with default scopes (email, profile, openid)
|
||||
auth_url = handler.get_login_url(
|
||||
scopes=[], # Will use DEFAULT_SCOPES from handler
|
||||
state=state,
|
||||
code_challenge=None, # Not using PKCE for server-side flow
|
||||
)
|
||||
|
||||
logger.info(f"Generated Google OAuth URL for state: {state[:8]}...")
|
||||
return GoogleLoginResponse(url=auth_url)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Google OAuth: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to initiate Google OAuth")
|
||||
|
||||
|
||||
@router.get("/google/callback", response_model=TokenResponse)
|
||||
async def google_callback(request: Request, code: str, state: Optional[str] = None):
|
||||
"""
|
||||
Handle Google OAuth callback.
|
||||
|
||||
Exchanges the authorization code for user info and creates/updates the user.
|
||||
Returns access and refresh tokens.
|
||||
"""
|
||||
# Validate state to prevent CSRF attacks
|
||||
if not state or not _validate_state(state):
|
||||
logger.warning(
|
||||
f"Invalid or missing OAuth state: {state[:8] if state else 'None'}..."
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired OAuth state")
|
||||
|
||||
try:
|
||||
handler = _get_google_oauth_handler()
|
||||
|
||||
# Exchange the authorization code for Google credentials
|
||||
logger.info("Exchanging authorization code for tokens...")
|
||||
google_creds = await handler.exchange_code_for_tokens(
|
||||
code=code,
|
||||
scopes=[], # Will use the scopes from the initial request
|
||||
code_verifier=None,
|
||||
)
|
||||
|
||||
# The handler returns OAuth2Credentials with email in username field
|
||||
email = google_creds.username
|
||||
if not email:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Failed to retrieve email from Google"
|
||||
)
|
||||
|
||||
# Fetch full user info to get Google user ID and name
|
||||
# Lazy import to avoid circular imports
|
||||
from google.auth.transport.requests import AuthorizedSession
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
# We need to create Google Credentials object to use with AuthorizedSession
|
||||
creds = Credentials(
|
||||
token=google_creds.access_token.get_secret_value(),
|
||||
refresh_token=(
|
||||
google_creds.refresh_token.get_secret_value()
|
||||
if google_creds.refresh_token
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=handler.client_id,
|
||||
client_secret=handler.client_secret,
|
||||
)
|
||||
|
||||
session = AuthorizedSession(creds)
|
||||
userinfo_response = session.get(GOOGLE_USERINFO_ENDPOINT)
|
||||
|
||||
if not userinfo_response.ok:
|
||||
logger.error(
|
||||
f"Failed to fetch Google userinfo: {userinfo_response.status_code}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Failed to fetch user info from Google"
|
||||
)
|
||||
|
||||
userinfo = userinfo_response.json()
|
||||
google_id = userinfo.get("id")
|
||||
name = userinfo.get("name")
|
||||
email_verified = userinfo.get("verified_email", False)
|
||||
|
||||
if not google_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Failed to retrieve Google user ID"
|
||||
)
|
||||
|
||||
logger.info(f"Google OAuth successful for user: {email}")
|
||||
|
||||
# Create or update the user in our database
|
||||
auth_service = get_auth_service()
|
||||
user = await auth_service.create_or_update_google_user(
|
||||
google_id=google_id,
|
||||
email=email,
|
||||
name=name,
|
||||
email_verified=email_verified,
|
||||
)
|
||||
|
||||
# Generate our JWT tokens
|
||||
tokens = await auth_service.create_tokens(user)
|
||||
|
||||
return TokenResponse(**tokens)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Google OAuth callback failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to complete Google OAuth")
|
||||
@@ -1,499 +0,0 @@
|
||||
"""
|
||||
Core authentication service for password verification and token management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, cast
|
||||
|
||||
import bcrypt
|
||||
from autogpt_libs.auth.config import get_settings
|
||||
from autogpt_libs.auth.jwt_utils import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
hash_token,
|
||||
)
|
||||
from prisma.models import User as PrismaUser
|
||||
from prisma.types import (
|
||||
EmailVerificationTokenCreateInput,
|
||||
PasswordResetTokenCreateInput,
|
||||
ProfileCreateInput,
|
||||
RefreshTokenCreateInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
from backend.data.db import prisma
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Handles authentication operations including password verification and token management."""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
|
||||
def hash_password(self, password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
def verify_password(self, password: str, hashed: str) -> bool:
|
||||
"""Verify a password against a bcrypt hash."""
|
||||
try:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
except Exception as e:
|
||||
logger.warning(f"Password verification failed: {e}")
|
||||
return False
|
||||
|
||||
async def register_user(
|
||||
self,
|
||||
email: str,
|
||||
password: str,
|
||||
name: Optional[str] = None,
|
||||
) -> PrismaUser:
|
||||
"""
|
||||
Register a new user with email and password.
|
||||
|
||||
Creates both a User record and a Profile record.
|
||||
|
||||
:param email: User's email address
|
||||
:param password: User's password (will be hashed)
|
||||
:param name: Optional display name
|
||||
:return: Created user record
|
||||
:raises ValueError: If email is already registered
|
||||
"""
|
||||
# Check if user already exists
|
||||
existing = await prisma.user.find_unique(where={"email": email})
|
||||
if existing:
|
||||
raise ValueError("Email already registered")
|
||||
|
||||
password_hash = self.hash_password(password)
|
||||
|
||||
# Generate a unique username from email
|
||||
base_username = email.split("@")[0].lower()
|
||||
# Remove any characters that aren't alphanumeric or underscore
|
||||
base_username = re.sub(r"[^a-z0-9_]", "", base_username)
|
||||
if not base_username:
|
||||
base_username = "user"
|
||||
|
||||
# Check if username is unique, if not add a number suffix
|
||||
username = base_username
|
||||
counter = 1
|
||||
while await prisma.profile.find_unique(where={"username": username}):
|
||||
username = f"{base_username}{counter}"
|
||||
counter += 1
|
||||
|
||||
user = await prisma.user.create(
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"email": email,
|
||||
"passwordHash": password_hash,
|
||||
"name": name,
|
||||
"emailVerified": False,
|
||||
"role": "authenticated",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Create profile for the user
|
||||
display_name = name or base_username
|
||||
await prisma.profile.create(
|
||||
data=cast(
|
||||
ProfileCreateInput,
|
||||
{
|
||||
"userId": user.id,
|
||||
"name": display_name,
|
||||
"username": username,
|
||||
"description": "",
|
||||
"links": [],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Registered new user: {user.id} with profile username: {username}")
|
||||
return user
|
||||
|
||||
async def authenticate_user(
|
||||
self, email: str, password: str
|
||||
) -> Optional[PrismaUser]:
|
||||
"""
|
||||
Authenticate a user with email and password.
|
||||
|
||||
:param email: User's email address
|
||||
:param password: User's password
|
||||
:return: User record if authentication successful, None otherwise
|
||||
"""
|
||||
user = await prisma.user.find_unique(where={"email": email})
|
||||
|
||||
if not user:
|
||||
logger.debug(f"Authentication failed: user not found for email {email}")
|
||||
return None
|
||||
|
||||
if not user.passwordHash:
|
||||
logger.debug(
|
||||
f"Authentication failed: no password set for user {user.id} "
|
||||
"(likely OAuth-only user)"
|
||||
)
|
||||
return None
|
||||
|
||||
if self.verify_password(password, user.passwordHash):
|
||||
logger.debug(f"Authentication successful for user {user.id}")
|
||||
return user
|
||||
|
||||
logger.debug(f"Authentication failed: invalid password for user {user.id}")
|
||||
return None
|
||||
|
||||
async def create_tokens(self, user: PrismaUser) -> dict:
|
||||
"""
|
||||
Create access and refresh tokens for a user.
|
||||
|
||||
:param user: The user to create tokens for
|
||||
:return: Dictionary with access_token, refresh_token, token_type, and expires_in
|
||||
"""
|
||||
# Create access token
|
||||
access_token = create_access_token(
|
||||
user_id=user.id,
|
||||
email=user.email,
|
||||
role=user.role or "authenticated",
|
||||
email_verified=user.emailVerified,
|
||||
)
|
||||
|
||||
# Create and store refresh token
|
||||
raw_refresh_token, hashed_refresh_token = create_refresh_token()
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=self.settings.REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
|
||||
await prisma.refreshtoken.create(
|
||||
data=cast(
|
||||
RefreshTokenCreateInput,
|
||||
{
|
||||
"token": hashed_refresh_token,
|
||||
"userId": user.id,
|
||||
"expiresAt": expires_at,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Created tokens for user {user.id}")
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": raw_refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": self.settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
}
|
||||
|
||||
async def refresh_access_token(self, refresh_token: str) -> Optional[dict]:
|
||||
"""
|
||||
Refresh an access token using a refresh token.
|
||||
|
||||
Implements token rotation: the old refresh token is revoked and a new one is issued.
|
||||
|
||||
:param refresh_token: The refresh token
|
||||
:return: New tokens if successful, None if refresh token is invalid/expired
|
||||
"""
|
||||
hashed_token = hash_token(refresh_token)
|
||||
|
||||
# Find the refresh token
|
||||
stored_token = await prisma.refreshtoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"revokedAt": None,
|
||||
"expiresAt": {"gt": datetime.now(timezone.utc)},
|
||||
},
|
||||
include={"User": True},
|
||||
)
|
||||
|
||||
if not stored_token or not stored_token.User:
|
||||
logger.debug("Refresh token not found or expired")
|
||||
return None
|
||||
|
||||
# Revoke the old token (token rotation)
|
||||
await prisma.refreshtoken.update(
|
||||
where={"id": stored_token.id},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
logger.debug(f"Refreshed tokens for user {stored_token.User.id}")
|
||||
|
||||
# Create new tokens
|
||||
return await self.create_tokens(stored_token.User)
|
||||
|
||||
async def revoke_refresh_token(self, refresh_token: str) -> bool:
|
||||
"""
|
||||
Revoke a refresh token (logout).
|
||||
|
||||
:param refresh_token: The refresh token to revoke
|
||||
:return: True if token was found and revoked, False otherwise
|
||||
"""
|
||||
hashed_token = hash_token(refresh_token)
|
||||
|
||||
result = await prisma.refreshtoken.update_many(
|
||||
where={"token": hashed_token, "revokedAt": None},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
if result > 0:
|
||||
logger.debug("Refresh token revoked")
|
||||
return True
|
||||
|
||||
logger.debug("Refresh token not found or already revoked")
|
||||
return False
|
||||
|
||||
async def revoke_all_user_tokens(self, user_id: str) -> int:
|
||||
"""
|
||||
Revoke all refresh tokens for a user (logout from all devices).
|
||||
|
||||
:param user_id: The user's ID
|
||||
:return: Number of tokens revoked
|
||||
"""
|
||||
result = await prisma.refreshtoken.update_many(
|
||||
where={"userId": user_id, "revokedAt": None},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
logger.debug(f"Revoked {result} tokens for user {user_id}")
|
||||
return result
|
||||
|
||||
async def get_user_by_google_id(self, google_id: str) -> Optional[PrismaUser]:
|
||||
"""Get a user by their Google OAuth ID."""
|
||||
return await prisma.user.find_unique(where={"googleId": google_id})
|
||||
|
||||
async def get_user_by_email(self, email: str) -> Optional[PrismaUser]:
|
||||
"""Get a user by their email address."""
|
||||
return await prisma.user.find_unique(where={"email": email})
|
||||
|
||||
async def create_or_update_google_user(
|
||||
self,
|
||||
google_id: str,
|
||||
email: str,
|
||||
name: Optional[str] = None,
|
||||
email_verified: bool = False,
|
||||
) -> PrismaUser:
|
||||
"""
|
||||
Create or update a user from Google OAuth.
|
||||
|
||||
If a user with the Google ID exists, return them.
|
||||
If a user with the email exists but no Google ID, link the account.
|
||||
Otherwise, create a new user.
|
||||
|
||||
:param google_id: Google's unique user ID
|
||||
:param email: User's email from Google
|
||||
:param name: User's name from Google
|
||||
:param email_verified: Whether Google has verified the email
|
||||
:return: The user record
|
||||
"""
|
||||
# Check if user exists with this Google ID
|
||||
user = await self.get_user_by_google_id(google_id)
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Check if user exists with this email
|
||||
user = await self.get_user_by_email(email)
|
||||
if user:
|
||||
# Link Google account to existing user
|
||||
updated_user = await prisma.user.update(
|
||||
where={"id": user.id},
|
||||
data={
|
||||
"googleId": google_id,
|
||||
"emailVerified": email_verified or user.emailVerified,
|
||||
},
|
||||
)
|
||||
if updated_user:
|
||||
logger.info(f"Linked Google account to existing user {updated_user.id}")
|
||||
return updated_user
|
||||
return user
|
||||
|
||||
# Create new user with profile
|
||||
# Generate a unique username from email
|
||||
base_username = email.split("@")[0].lower()
|
||||
base_username = re.sub(r"[^a-z0-9_]", "", base_username)
|
||||
if not base_username:
|
||||
base_username = "user"
|
||||
|
||||
username = base_username
|
||||
counter = 1
|
||||
while await prisma.profile.find_unique(where={"username": username}):
|
||||
username = f"{base_username}{counter}"
|
||||
counter += 1
|
||||
|
||||
user = await prisma.user.create(
|
||||
data=cast(
|
||||
UserCreateInput,
|
||||
{
|
||||
"email": email,
|
||||
"googleId": google_id,
|
||||
"name": name,
|
||||
"emailVerified": email_verified,
|
||||
"role": "authenticated",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Create profile for the user
|
||||
display_name = name or base_username
|
||||
await prisma.profile.create(
|
||||
data=cast(
|
||||
ProfileCreateInput,
|
||||
{
|
||||
"userId": user.id,
|
||||
"name": display_name,
|
||||
"username": username,
|
||||
"description": "",
|
||||
"links": [],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created new user from Google OAuth: {user.id} with profile: {username}"
|
||||
)
|
||||
return user
|
||||
|
||||
async def create_password_reset_token(self, user_id: str) -> str:
|
||||
"""
|
||||
Create a password reset token for a user.
|
||||
|
||||
:param user_id: The user's ID
|
||||
:return: The raw token to send to the user
|
||||
"""
|
||||
raw_token, hashed_token = create_refresh_token() # Reuse token generation
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
await prisma.passwordresettoken.create(
|
||||
data=cast(
|
||||
PasswordResetTokenCreateInput,
|
||||
{
|
||||
"token": hashed_token,
|
||||
"userId": user_id,
|
||||
"expiresAt": expires_at,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return raw_token
|
||||
|
||||
async def create_email_verification_token(self, user_id: str) -> str:
|
||||
"""
|
||||
Create an email verification token for a user.
|
||||
|
||||
:param user_id: The user's ID
|
||||
:return: The raw token to send to the user
|
||||
"""
|
||||
raw_token, hashed_token = create_refresh_token() # Reuse token generation
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
|
||||
await prisma.emailverificationtoken.create(
|
||||
data=cast(
|
||||
EmailVerificationTokenCreateInput,
|
||||
{
|
||||
"token": hashed_token,
|
||||
"userId": user_id,
|
||||
"expiresAt": expires_at,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return raw_token
|
||||
|
||||
async def verify_email_token(self, token: str) -> bool:
|
||||
"""
|
||||
Verify an email verification token and mark the user's email as verified.
|
||||
|
||||
:param token: The raw token from the user
|
||||
:return: True if successful, False if token is invalid
|
||||
"""
|
||||
hashed_token = hash_token(token)
|
||||
|
||||
# Find and validate token
|
||||
stored_token = await prisma.emailverificationtoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"usedAt": None,
|
||||
"expiresAt": {"gt": datetime.now(timezone.utc)},
|
||||
}
|
||||
)
|
||||
|
||||
if not stored_token:
|
||||
return False
|
||||
|
||||
# Mark email as verified
|
||||
await prisma.user.update(
|
||||
where={"id": stored_token.userId},
|
||||
data={"emailVerified": True},
|
||||
)
|
||||
|
||||
# Mark token as used
|
||||
await prisma.emailverificationtoken.update(
|
||||
where={"id": stored_token.id},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
logger.info(f"Email verified for user {stored_token.userId}")
|
||||
return True
|
||||
|
||||
async def verify_password_reset_token(self, token: str) -> Optional[str]:
|
||||
"""
|
||||
Verify a password reset token and return the user ID.
|
||||
|
||||
:param token: The raw token from the user
|
||||
:return: User ID if valid, None otherwise
|
||||
"""
|
||||
hashed_token = hash_token(token)
|
||||
|
||||
stored_token = await prisma.passwordresettoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"usedAt": None,
|
||||
"expiresAt": {"gt": datetime.now(timezone.utc)},
|
||||
}
|
||||
)
|
||||
|
||||
if not stored_token:
|
||||
return None
|
||||
|
||||
return stored_token.userId
|
||||
|
||||
async def reset_password(self, token: str, new_password: str) -> bool:
|
||||
"""
|
||||
Reset a user's password using a password reset token.
|
||||
|
||||
:param token: The password reset token
|
||||
:param new_password: The new password
|
||||
:return: True if successful, False if token is invalid
|
||||
"""
|
||||
hashed_token = hash_token(token)
|
||||
|
||||
# Find and validate token
|
||||
stored_token = await prisma.passwordresettoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"usedAt": None,
|
||||
"expiresAt": {"gt": datetime.now(timezone.utc)},
|
||||
}
|
||||
)
|
||||
|
||||
if not stored_token:
|
||||
return False
|
||||
|
||||
# Update password
|
||||
password_hash = self.hash_password(new_password)
|
||||
await prisma.user.update(
|
||||
where={"id": stored_token.userId},
|
||||
data={"passwordHash": password_hash},
|
||||
)
|
||||
|
||||
# Mark token as used
|
||||
await prisma.passwordresettoken.update(
|
||||
where={"id": stored_token.id},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Revoke all refresh tokens for security
|
||||
await self.revoke_all_user_tokens(stored_token.userId)
|
||||
|
||||
logger.info(f"Password reset for user {stored_token.userId}")
|
||||
return True
|
||||
@@ -1,302 +0,0 @@
|
||||
{# Base Template for Auth Emails #}
|
||||
{# Template variables:
|
||||
data.message: the message to display in the email
|
||||
data.title: the title of the email
|
||||
data.unsubscribe_link: the link to unsubscribe from the email (optional for auth emails)
|
||||
#}
|
||||
<!doctype html>
|
||||
<html lang="ltr" xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office">
|
||||
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=yes">
|
||||
<meta name="format-detection" content="telephone=no, date=no, address=no, email=no, url=no">
|
||||
<meta name="x-apple-disable-message-reformatting">
|
||||
<!--[if !mso]>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<![endif]-->
|
||||
<!--[if mso]>
|
||||
<style>
|
||||
* { font-family: sans-serif !important; }
|
||||
</style>
|
||||
<noscript>
|
||||
<xml>
|
||||
<o:OfficeDocumentSettings>
|
||||
<o:PixelsPerInch>96</o:PixelsPerInch>
|
||||
</o:OfficeDocumentSettings>
|
||||
</xml>
|
||||
</noscript>
|
||||
<![endif]-->
|
||||
<style type="text/css">
|
||||
/* RESET STYLES */
|
||||
html,
|
||||
body {
|
||||
margin: 0 !important;
|
||||
padding: 0 !important;
|
||||
width: 100% !important;
|
||||
height: 100% !important;
|
||||
}
|
||||
|
||||
body {
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-moz-osx-font-smoothing: grayscale;
|
||||
text-rendering: optimizeLegibility;
|
||||
}
|
||||
|
||||
.document {
|
||||
margin: 0 !important;
|
||||
padding: 0 !important;
|
||||
width: 100% !important;
|
||||
}
|
||||
|
||||
img {
|
||||
border: 0;
|
||||
outline: none;
|
||||
text-decoration: none;
|
||||
-ms-interpolation-mode: bicubic;
|
||||
}
|
||||
|
||||
table {
|
||||
border-collapse: collapse;
|
||||
}
|
||||
|
||||
table,
|
||||
td {
|
||||
mso-table-lspace: 0pt;
|
||||
mso-table-rspace: 0pt;
|
||||
}
|
||||
|
||||
body,
|
||||
table,
|
||||
td,
|
||||
a {
|
||||
-webkit-text-size-adjust: 100%;
|
||||
-ms-text-size-adjust: 100%;
|
||||
}
|
||||
|
||||
h1,
|
||||
h2,
|
||||
h3,
|
||||
h4,
|
||||
h5,
|
||||
p {
|
||||
margin: 0;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
/* iOS BLUE LINKS */
|
||||
a[x-apple-data-detectors] {
|
||||
color: inherit !important;
|
||||
text-decoration: none !important;
|
||||
font-size: inherit !important;
|
||||
font-family: inherit !important;
|
||||
font-weight: inherit !important;
|
||||
line-height: inherit !important;
|
||||
}
|
||||
|
||||
/* ANDROID CENTER FIX */
|
||||
div[style*="margin: 16px 0;"] {
|
||||
margin: 0 !important;
|
||||
}
|
||||
|
||||
/* MEDIA QUERIES */
|
||||
@media all and (max-width:639px) {
|
||||
.wrapper {
|
||||
width: 100% !important;
|
||||
}
|
||||
|
||||
.container {
|
||||
width: 100% !important;
|
||||
min-width: 100% !important;
|
||||
padding: 0 !important;
|
||||
}
|
||||
|
||||
.row {
|
||||
padding-left: 20px !important;
|
||||
padding-right: 20px !important;
|
||||
}
|
||||
|
||||
.col-mobile {
|
||||
width: 20px !important;
|
||||
}
|
||||
|
||||
.col {
|
||||
display: block !important;
|
||||
width: 100% !important;
|
||||
}
|
||||
|
||||
.mobile-center {
|
||||
text-align: center !important;
|
||||
float: none !important;
|
||||
}
|
||||
|
||||
.mobile-mx-auto {
|
||||
margin: 0 auto !important;
|
||||
float: none !important;
|
||||
}
|
||||
|
||||
.mobile-left {
|
||||
text-align: center !important;
|
||||
float: left !important;
|
||||
}
|
||||
|
||||
.mobile-hide {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
.img {
|
||||
width: 100% !important;
|
||||
height: auto !important;
|
||||
}
|
||||
|
||||
.ml-btn {
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
}
|
||||
|
||||
.ml-btn-container {
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
<style type="text/css">
|
||||
@import url("https://assets.mlcdn.com/fonts-v2.css?version=1729862");
|
||||
</style>
|
||||
<style type="text/css">
|
||||
@media screen {
|
||||
body {
|
||||
font-family: 'Poppins', sans-serif;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
<title>{{data.title}}</title>
|
||||
</head>
|
||||
|
||||
<body style="margin: 0 !important; padding: 0 !important; background-color:#070629;">
|
||||
<div class="document" role="article" aria-roledescription="email" aria-label lang dir="ltr"
|
||||
style="background-color:#070629; line-height: 100%; font-size:medium; font-size:max(16px, 1rem);">
|
||||
<!-- Main Content -->
|
||||
<table width="100%" align="center" cellspacing="0" cellpadding="0" border="0">
|
||||
<tr>
|
||||
<td class="background" bgcolor="#070629" align="center" valign="top" style="padding: 0 8px;">
|
||||
<!-- Email Content -->
|
||||
<table class="container" align="center" width="640" cellpadding="0" cellspacing="0" border="0"
|
||||
style="max-width: 640px;">
|
||||
<tr>
|
||||
<td align="center">
|
||||
<!-- Logo Section -->
|
||||
<table class="container ml-4 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
|
||||
<tr>
|
||||
<td class="ml-default-border container" height="40" style="line-height: 40px; min-width: 640px;">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td class="row" align="center" style="padding: 0 50px;">
|
||||
<img
|
||||
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
|
||||
border="0" alt="" width="120" class="logo"
|
||||
style="max-width: 120px; display: inline-block;">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Main Content Section -->
|
||||
<table class="container ml-6 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
|
||||
<tr>
|
||||
<td class="row" style="padding: 0 50px;">
|
||||
{{data.message|safe}}
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Footer Section -->
|
||||
<table class="container ml-10 ml-default-border" width="640" bgcolor="#ffffff" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
|
||||
<tr>
|
||||
<td class="row" style="padding: 0 50px;">
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td height="20" style="line-height: 20px;"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<!-- Footer Content -->
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td class="col" align="left" valign="middle" width="120">
|
||||
<img
|
||||
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
|
||||
border="0" alt="" width="120" class="logo"
|
||||
style="max-width: 120px; display: inline-block;">
|
||||
</td>
|
||||
<td class="col" width="40" height="30" style="line-height: 30px;"></td>
|
||||
<td class="col mobile-left" align="right" valign="middle" width="250">
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td align="center" valign="middle" width="18" style="padding: 0 5px 0 0;">
|
||||
<a href="https://x.com/auto_gpt" target="blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/x.png"
|
||||
width="18" alt="x">
|
||||
</a>
|
||||
</td>
|
||||
<td align="center" valign="middle" width="18" style="padding: 0 5px;">
|
||||
<a href="https://discord.gg/autogpt" target="blank"
|
||||
style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/discord.png"
|
||||
width="18" alt="discord">
|
||||
</a>
|
||||
</td>
|
||||
<td align="center" valign="middle" width="18" style="padding: 0 0 0 5px;">
|
||||
<a href="https://agpt.co/" target="blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/website.png"
|
||||
width="18" alt="website">
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td height="15" style="line-height: 15px;"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" style="text-align: left!important;">
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 12px; line-height: 150%; display: inline-block; margin-bottom: 0;">
|
||||
This is an automated security email from AutoGPT. If you did not request this action, please ignore this email or contact support if you have concerns.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td height="20" style="line-height: 20px;"></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
@@ -1,65 +0,0 @@
|
||||
{# Email Verification Template #}
|
||||
{# Variables:
|
||||
verification_link: URL for email verification
|
||||
user_name: Optional user name for personalization
|
||||
frontend_url: Base frontend URL
|
||||
#}
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td height="30" style="line-height: 30px;"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">
|
||||
<h1 style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 28px; line-height: 125%; font-weight: bold; margin-bottom: 20px;">
|
||||
Verify Your Email
|
||||
</h1>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
{% if user_name %}Hi {{ user_name }},{% else %}Hi,{% endif %}
|
||||
</p>
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
Welcome to AutoGPT! Please verify your email address by clicking the button below:
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" style="padding: 20px 0;">
|
||||
<table border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td align="center" bgcolor="#4285F4" style="border-radius: 8px;">
|
||||
<a href="{{ verification_link }}" target="_blank"
|
||||
style="display: inline-block; padding: 16px 36px; font-family: 'Poppins', sans-serif; font-size: 16px; font-weight: 600; color: #ffffff; text-decoration: none; border-radius: 8px;">
|
||||
Verify Email
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
This link will expire in <strong>24 hours</strong>.
|
||||
</p>
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
If you didn't create an account with AutoGPT, you can safely ignore this email.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #888888; font-size: 14px; line-height: 165%; margin-bottom: 10px;">
|
||||
If the button doesn't work, copy and paste this link into your browser:
|
||||
</p>
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #4285F4; font-size: 14px; line-height: 165%; word-break: break-all;">
|
||||
<a href="{{ verification_link }}" style="color: #4285F4; text-decoration: underline;">{{ verification_link }}</a>
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td height="30" style="line-height: 30px;"></td>
|
||||
</tr>
|
||||
</table>
|
||||
@@ -1,65 +0,0 @@
|
||||
{# Password Reset Email Template #}
|
||||
{# Variables:
|
||||
reset_link: URL for password reset
|
||||
user_name: Optional user name for personalization
|
||||
frontend_url: Base frontend URL
|
||||
#}
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td height="30" style="line-height: 30px;"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">
|
||||
<h1 style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 28px; line-height: 125%; font-weight: bold; margin-bottom: 20px;">
|
||||
Reset Your Password
|
||||
</h1>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
{% if user_name %}Hi {{ user_name }},{% else %}Hi,{% endif %}
|
||||
</p>
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
We received a request to reset your password for your AutoGPT account. Click the button below to create a new password:
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" style="padding: 20px 0;">
|
||||
<table border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td align="center" bgcolor="#4285F4" style="border-radius: 8px;">
|
||||
<a href="{{ reset_link }}" target="_blank"
|
||||
style="display: inline-block; padding: 16px 36px; font-family: 'Poppins', sans-serif; font-size: 16px; font-weight: 600; color: #ffffff; text-decoration: none; border-radius: 8px;">
|
||||
Reset Password
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
This link will expire in <strong>1 hour</strong> for security reasons.
|
||||
</p>
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
|
||||
If you didn't request a password reset, you can safely ignore this email. Your password will remain unchanged.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #888888; font-size: 14px; line-height: 165%; margin-bottom: 10px;">
|
||||
If the button doesn't work, copy and paste this link into your browser:
|
||||
</p>
|
||||
<p style="font-family: 'Poppins', sans-serif; color: #4285F4; font-size: 14px; line-height: 165%; word-break: break-all;">
|
||||
<a href="{{ reset_link }}" style="color: #4285F4; text-decoration: underline;">{{ reset_link }}</a>
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td height="30" style="line-height: 30px;"></td>
|
||||
</tr>
|
||||
</table>
|
||||
@@ -1,107 +1,36 @@
|
||||
from fastapi import HTTPException, Security, status
|
||||
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.auth.api_key import APIKeyInfo, validate_api_key
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidTokenError,
|
||||
OAuthAccessTokenInfo,
|
||||
validate_access_token,
|
||||
)
|
||||
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
bearer_auth = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
"""Middleware for API key authentication only"""
|
||||
"""Base middleware for API key authentication"""
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API key"
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
api_key_obj = await validate_api_key(api_key)
|
||||
|
||||
if not api_key_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return api_key_obj
|
||||
|
||||
|
||||
async def require_access_token(
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> OAuthAccessTokenInfo:
|
||||
"""Middleware for OAuth access token authentication only"""
|
||||
if bearer is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Authorization header",
|
||||
)
|
||||
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
return token_info
|
||||
|
||||
|
||||
async def require_auth(
|
||||
api_key: str | None = Security(api_key_header),
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
"""
|
||||
Unified authentication middleware supporting both API keys and OAuth tokens.
|
||||
|
||||
Supports two authentication methods, which are checked in order:
|
||||
1. X-API-Key header (existing API key authentication)
|
||||
2. Authorization: Bearer <token> header (OAuth access token)
|
||||
|
||||
Returns:
|
||||
APIAuthorizationInfo: base class of both APIKeyInfo and OAuthAccessTokenInfo.
|
||||
"""
|
||||
# Try API key first
|
||||
if api_key is not None:
|
||||
api_key_info = await validate_api_key(api_key)
|
||||
if api_key_info:
|
||||
return api_key_info
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
# Try OAuth bearer token
|
||||
if bearer is not None:
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
return token_info
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
# No credentials provided
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication. Provide API key or access token.",
|
||||
)
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""
|
||||
Dependency function for checking specific permissions
|
||||
(works with API keys and OAuth tokens)
|
||||
"""
|
||||
"""Dependency function for checking specific permissions"""
|
||||
|
||||
async def check_permission(
|
||||
auth: APIAuthorizationInfo = Security(require_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
if permission not in auth.scopes:
|
||||
api_key: APIKeyInfo = Security(require_api_key),
|
||||
) -> APIKeyInfo:
|
||||
if not has_permission(api_key, permission):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permission: {permission.value}",
|
||||
status_code=403,
|
||||
detail=f"API key lacks the required permission '{permission}'",
|
||||
)
|
||||
return auth
|
||||
return api_key
|
||||
|
||||
return check_permission
|
||||
|
||||
@@ -16,7 +16,7 @@ from fastapi import APIRouter, Body, HTTPException, Path, Security, status
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
@@ -255,7 +255,7 @@ def _get_oauth_handler_for_external(
|
||||
|
||||
@integrations_router.get("/providers", response_model=list[ProviderInfo])
|
||||
async def list_providers(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[ProviderInfo]:
|
||||
@@ -319,7 +319,7 @@ async def list_providers(
|
||||
async def initiate_oauth(
|
||||
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||
request: OAuthInitiateRequest,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> OAuthInitiateResponse:
|
||||
@@ -337,10 +337,7 @@ async def initiate_oauth(
|
||||
if not validate_callback_url(request.callback_url):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
f"Callback URL origin is not allowed. "
|
||||
f"Allowed origins: {settings.config.external_oauth_callback_origins}",
|
||||
),
|
||||
detail=f"Callback URL origin is not allowed. Allowed origins: {settings.config.external_oauth_callback_origins}",
|
||||
)
|
||||
|
||||
# Validate provider
|
||||
@@ -362,15 +359,13 @@ async def initiate_oauth(
|
||||
)
|
||||
|
||||
# Store state token with external flow metadata
|
||||
# Note: initiated_by_api_key_id is only available for API key auth, not OAuth
|
||||
api_key_id = getattr(auth, "id", None) if auth.type == "api_key" else None
|
||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||
user_id=auth.user_id,
|
||||
user_id=api_key.user_id,
|
||||
provider=provider if isinstance(provider_name, str) else provider_name.value,
|
||||
scopes=request.scopes,
|
||||
callback_url=request.callback_url,
|
||||
state_metadata=request.state_metadata,
|
||||
initiated_by_api_key_id=api_key_id,
|
||||
initiated_by_api_key_id=api_key.id,
|
||||
)
|
||||
|
||||
# Build login URL
|
||||
@@ -398,7 +393,7 @@ async def initiate_oauth(
|
||||
async def complete_oauth(
|
||||
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||
request: OAuthCompleteRequest,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> OAuthCompleteResponse:
|
||||
@@ -411,7 +406,7 @@ async def complete_oauth(
|
||||
"""
|
||||
# Verify state token
|
||||
valid_state = await creds_manager.store.verify_state_token(
|
||||
auth.user_id, request.state_token, provider
|
||||
api_key.user_id, request.state_token, provider
|
||||
)
|
||||
|
||||
if not valid_state:
|
||||
@@ -458,7 +453,7 @@ async def complete_oauth(
|
||||
)
|
||||
|
||||
# Store credentials
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
await creds_manager.create(api_key.user_id, credentials)
|
||||
|
||||
logger.info(f"Successfully completed external OAuth for provider {provider}")
|
||||
|
||||
@@ -475,7 +470,7 @@ async def complete_oauth(
|
||||
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||
async def list_credentials(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
@@ -484,7 +479,7 @@ async def list_credentials(
|
||||
|
||||
Returns metadata about each credential without exposing sensitive tokens.
|
||||
"""
|
||||
credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
credentials = await creds_manager.store.get_all_creds(api_key.user_id)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
@@ -504,7 +499,7 @@ async def list_credentials(
|
||||
)
|
||||
async def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
@@ -512,7 +507,7 @@ async def list_credentials_by_provider(
|
||||
List credentials for a specific provider.
|
||||
"""
|
||||
credentials = await creds_manager.store.get_creds_by_provider(
|
||||
auth.user_id, provider
|
||||
api_key.user_id, provider
|
||||
)
|
||||
return [
|
||||
CredentialSummary(
|
||||
@@ -541,7 +536,7 @@ async def create_credential(
|
||||
CreateUserPasswordCredentialRequest,
|
||||
CreateHostScopedCredentialRequest,
|
||||
] = Body(..., discriminator="type"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> CreateCredentialResponse:
|
||||
@@ -596,7 +591,7 @@ async def create_credential(
|
||||
|
||||
# Store credentials
|
||||
try:
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
await creds_manager.create(api_key.user_id, credentials)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials: {e}")
|
||||
raise HTTPException(
|
||||
@@ -628,7 +623,7 @@ class DeleteCredentialResponse(BaseModel):
|
||||
async def delete_credential(
|
||||
provider: Annotated[str, Path(title="The provider")],
|
||||
cred_id: Annotated[str, Path(title="The credential ID to delete")],
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.DELETE_INTEGRATIONS)
|
||||
),
|
||||
) -> DeleteCredentialResponse:
|
||||
@@ -639,7 +634,7 @@ async def delete_credential(
|
||||
use the main API's delete endpoint which handles webhook cleanup and
|
||||
token revocation.
|
||||
"""
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
creds = await creds_manager.store.get_creds_by_id(api_key.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
@@ -650,6 +645,6 @@ async def delete_credential(
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
|
||||
await creds_manager.delete(auth.user_id, cred_id)
|
||||
await creds_manager.delete(api_key.user_id, cred_id)
|
||||
|
||||
return DeleteCredentialResponse(deleted=True, credentials_id=cred_id)
|
||||
|
||||
@@ -14,7 +14,7 @@ from fastapi import APIRouter, Security
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools import find_agent_tool, run_agent_tool
|
||||
@@ -24,9 +24,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
tools_router = APIRouter(prefix="/tools", tags=["tools"])
|
||||
|
||||
# Note: We use Security() as a function parameter dependency (auth: APIAuthorizationInfo = Security(...))
|
||||
# Note: We use Security() as a function parameter dependency (api_key: APIKeyInfo = Security(...))
|
||||
# rather than in the decorator's dependencies= list. This avoids duplicate permission checks
|
||||
# while still enforcing auth AND giving us access to auth for extracting user_id.
|
||||
# while still enforcing auth AND giving us access to the api_key for extracting user_id.
|
||||
|
||||
|
||||
# Request models
|
||||
@@ -80,9 +80,7 @@ def _create_ephemeral_session(user_id: str | None) -> ChatSession:
|
||||
)
|
||||
async def find_agent(
|
||||
request: FindAgentRequest,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.USE_TOOLS)
|
||||
),
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Search for agents in the marketplace based on capabilities and user needs.
|
||||
@@ -93,9 +91,9 @@ async def find_agent(
|
||||
Returns:
|
||||
List of matching agents or no results response
|
||||
"""
|
||||
session = _create_ephemeral_session(auth.user_id)
|
||||
session = _create_ephemeral_session(api_key.user_id)
|
||||
result = await find_agent_tool._execute(
|
||||
user_id=auth.user_id,
|
||||
user_id=api_key.user_id,
|
||||
session=session,
|
||||
query=request.query,
|
||||
)
|
||||
@@ -107,9 +105,7 @@ async def find_agent(
|
||||
)
|
||||
async def run_agent(
|
||||
request: RunAgentRequest,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.USE_TOOLS)
|
||||
),
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run or schedule an agent from the marketplace.
|
||||
@@ -133,9 +129,9 @@ async def run_agent(
|
||||
- execution_started: If agent was run or scheduled successfully
|
||||
- error: If something went wrong
|
||||
"""
|
||||
session = _create_ephemeral_session(auth.user_id)
|
||||
session = _create_ephemeral_session(api_key.user_id)
|
||||
result = await run_agent_tool._execute(
|
||||
user_id=auth.user_id,
|
||||
user_id=api_key.user_id,
|
||||
session=session,
|
||||
username_agent_slug=request.username_agent_slug,
|
||||
inputs=request.inputs,
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Annotated, Any, Literal, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.data.block
|
||||
@@ -13,8 +12,7 @@ import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.external.middleware import require_permission
|
||||
@@ -26,33 +24,27 @@ logger = logging.getLogger(__name__)
|
||||
v1_router = APIRouter()
|
||||
|
||||
|
||||
class UserInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: Optional[str]
|
||||
email: str
|
||||
timezone: str = Field(
|
||||
description="The user's last known timezone (e.g. 'Europe/Amsterdam'), "
|
||||
"or 'not-set' if not set"
|
||||
)
|
||||
class NodeOutput(TypedDict):
|
||||
key: str
|
||||
value: Any
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/me",
|
||||
tags=["user", "meta"],
|
||||
)
|
||||
async def get_user_info(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.IDENTITY)
|
||||
),
|
||||
) -> UserInfoResponse:
|
||||
user = await user_db.get_user_by_id(auth.user_id)
|
||||
class ExecutionNode(TypedDict):
|
||||
node_id: str
|
||||
input: Any
|
||||
output: dict[str, Any]
|
||||
|
||||
return UserInfoResponse(
|
||||
id=user.id,
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
timezone=user.timezone,
|
||||
)
|
||||
|
||||
class ExecutionNodeOutput(TypedDict):
|
||||
node_id: str
|
||||
outputs: list[NodeOutput]
|
||||
|
||||
|
||||
class GraphExecutionResult(TypedDict):
|
||||
execution_id: str
|
||||
status: str
|
||||
nodes: list[ExecutionNode]
|
||||
output: Optional[list[dict[str, str]]]
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -73,9 +65,7 @@ async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
async def execute_graph_block(
|
||||
block_id: str,
|
||||
data: BlockInput,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
||||
),
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||
) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
if not obj:
|
||||
@@ -95,14 +85,12 @@ async def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.EXECUTE_GRAPH)
|
||||
),
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=auth.user_id,
|
||||
user_id=api_key.user_id,
|
||||
inputs=node_input,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
@@ -112,19 +100,6 @@ async def execute_graph(
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
|
||||
class ExecutionNode(TypedDict):
|
||||
node_id: str
|
||||
input: Any
|
||||
output: dict[str, Any]
|
||||
|
||||
|
||||
class GraphExecutionResult(TypedDict):
|
||||
execution_id: str
|
||||
status: str
|
||||
nodes: list[ExecutionNode]
|
||||
output: Optional[list[dict[str, str]]]
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/results",
|
||||
tags=["graphs"],
|
||||
@@ -132,12 +107,10 @@ class GraphExecutionResult(TypedDict):
|
||||
async def get_graph_execution_results(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_GRAPH)
|
||||
),
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||
) -> GraphExecutionResult:
|
||||
graph_exec = await execution_db.get_graph_execution(
|
||||
user_id=auth.user_id,
|
||||
user_id=api_key.user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=True,
|
||||
)
|
||||
@@ -149,7 +122,7 @@ async def get_graph_execution_results(
|
||||
if not await graph_db.get_graph(
|
||||
graph_id=graph_exec.graph_id,
|
||||
version=graph_exec.graph_version,
|
||||
user_id=auth.user_id,
|
||||
user_id=api_key.user_id,
|
||||
):
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Literal, Optional
|
||||
import pydantic
|
||||
from prisma.enums import OnboardingStep
|
||||
|
||||
from backend.data.auth.api_key import APIKeyInfo, APIKeyPermission
|
||||
from backend.data.api_key import APIKeyInfo, APIKeyPermission
|
||||
from backend.data.graph import Graph
|
||||
from backend.util.timezone_name import TimeZoneName
|
||||
|
||||
|
||||
@@ -21,8 +21,6 @@ import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.auth
|
||||
import backend.server.routers.oauth
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
@@ -256,7 +254,6 @@ app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||
|
||||
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
app.include_router(backend.server.auth.auth_router, tags=["auth"], prefix="/api")
|
||||
app.include_router(
|
||||
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
|
||||
)
|
||||
@@ -300,11 +297,6 @@ app.include_router(
|
||||
tags=["v2", "chat"],
|
||||
prefix="/api/chat",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.routers.oauth.router,
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_app)
|
||||
|
||||
|
||||
@@ -1,833 +0,0 @@
|
||||
"""
|
||||
OAuth 2.0 Provider Endpoints
|
||||
|
||||
Implements OAuth 2.0 Authorization Code flow with PKCE support.
|
||||
|
||||
Flow:
|
||||
1. User clicks "Login with AutoGPT" in 3rd party app
|
||||
2. App redirects user to /oauth/authorize with client_id, redirect_uri, scope, state
|
||||
3. User sees consent screen (if not already logged in, redirects to login first)
|
||||
4. User approves → backend creates authorization code
|
||||
5. User redirected back to app with code
|
||||
6. App exchanges code for access/refresh tokens at /oauth/token
|
||||
7. App uses access token to call external API endpoints
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import APIRouter, Body, HTTPException, Security, UploadFile, status
|
||||
from gcloud.aio import storage as async_storage
|
||||
from PIL import Image
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidGrantError,
|
||||
OAuthApplicationInfo,
|
||||
TokenIntrospectionResult,
|
||||
consume_authorization_code,
|
||||
create_access_token,
|
||||
create_authorization_code,
|
||||
create_refresh_token,
|
||||
get_oauth_application,
|
||||
get_oauth_application_by_id,
|
||||
introspect_token,
|
||||
list_user_oauth_applications,
|
||||
refresh_tokens,
|
||||
revoke_access_token,
|
||||
revoke_refresh_token,
|
||||
update_oauth_application,
|
||||
validate_client_credentials,
|
||||
validate_redirect_uri,
|
||||
validate_scopes,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""OAuth 2.0 token response"""
|
||||
|
||||
token_type: Literal["Bearer"] = "Bearer"
|
||||
access_token: str
|
||||
access_token_expires_at: datetime
|
||||
refresh_token: str
|
||||
refresh_token_expires_at: datetime
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""OAuth 2.0 error response"""
|
||||
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
|
||||
|
||||
class OAuthApplicationPublicInfo(BaseModel):
|
||||
"""Public information about an OAuth application (for consent screen)"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Info Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/app/{client_id}",
|
||||
responses={
|
||||
404: {"description": "Application not found or disabled"},
|
||||
},
|
||||
)
|
||||
async def get_oauth_app_info(
|
||||
client_id: str, user_id: str = Security(get_user_id)
|
||||
) -> OAuthApplicationPublicInfo:
|
||||
"""
|
||||
Get public information about an OAuth application.
|
||||
|
||||
This endpoint is used by the consent screen to display application details
|
||||
to the user before they authorize access.
|
||||
|
||||
Returns:
|
||||
- name: Application name
|
||||
- description: Application description (if provided)
|
||||
- scopes: List of scopes the application is allowed to request
|
||||
"""
|
||||
app = await get_oauth_application(client_id)
|
||||
if not app or not app.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found",
|
||||
)
|
||||
|
||||
return OAuthApplicationPublicInfo(
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
logo_url=app.logo_url,
|
||||
scopes=[s.value for s in app.scopes],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AuthorizeRequest(BaseModel):
|
||||
"""OAuth 2.0 authorization request"""
|
||||
|
||||
client_id: str = Field(description="Client identifier")
|
||||
redirect_uri: str = Field(description="Redirect URI")
|
||||
scopes: list[str] = Field(description="List of scopes")
|
||||
state: str = Field(description="Anti-CSRF token from client")
|
||||
response_type: str = Field(
|
||||
default="code", description="Must be 'code' for authorization code flow"
|
||||
)
|
||||
code_challenge: str = Field(description="PKCE code challenge (required)")
|
||||
code_challenge_method: Literal["S256", "plain"] = Field(
|
||||
default="S256", description="PKCE code challenge method (S256 recommended)"
|
||||
)
|
||||
|
||||
|
||||
class AuthorizeResponse(BaseModel):
|
||||
"""OAuth 2.0 authorization response with redirect URL"""
|
||||
|
||||
redirect_url: str = Field(description="URL to redirect the user to")
|
||||
|
||||
|
||||
@router.post("/authorize")
|
||||
async def authorize(
|
||||
request: AuthorizeRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> AuthorizeResponse:
|
||||
"""
|
||||
OAuth 2.0 Authorization Endpoint
|
||||
|
||||
User must be logged in (authenticated with Supabase JWT).
|
||||
This endpoint creates an authorization code and returns a redirect URL.
|
||||
|
||||
PKCE (Proof Key for Code Exchange) is REQUIRED for all authorization requests.
|
||||
|
||||
The frontend consent screen should call this endpoint after the user approves,
|
||||
then redirect the user to the returned `redirect_url`.
|
||||
|
||||
Request Body:
|
||||
- client_id: The OAuth application's client ID
|
||||
- redirect_uri: Where to redirect after authorization (must match registered URI)
|
||||
- scopes: List of permissions (e.g., "EXECUTE_GRAPH READ_GRAPH")
|
||||
- state: Anti-CSRF token provided by client (will be returned in redirect)
|
||||
- response_type: Must be "code" (for authorization code flow)
|
||||
- code_challenge: PKCE code challenge (required)
|
||||
- code_challenge_method: "S256" (recommended) or "plain"
|
||||
|
||||
Returns:
|
||||
- redirect_url: The URL to redirect the user to (includes authorization code)
|
||||
|
||||
Error cases return a redirect_url with error parameters, or raise HTTPException
|
||||
for critical errors (like invalid redirect_uri).
|
||||
"""
|
||||
try:
|
||||
# Validate response_type
|
||||
if request.response_type != "code":
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"unsupported_response_type",
|
||||
"Only 'code' response type is supported",
|
||||
)
|
||||
|
||||
# Get application
|
||||
app = await get_oauth_application(request.client_id)
|
||||
if not app:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Unknown client_id",
|
||||
)
|
||||
|
||||
if not app.is_active:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Application is not active",
|
||||
)
|
||||
|
||||
# Validate redirect URI
|
||||
if not validate_redirect_uri(app, request.redirect_uri):
|
||||
# For invalid redirect_uri, we can't redirect safely
|
||||
# Must return error instead
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"Invalid redirect_uri. "
|
||||
f"Must be one of: {', '.join(app.redirect_uris)}"
|
||||
),
|
||||
)
|
||||
|
||||
# Parse and validate scopes
|
||||
try:
|
||||
requested_scopes = [APIKeyPermission(s.strip()) for s in request.scopes]
|
||||
except ValueError as e:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
f"Invalid scope: {e}",
|
||||
)
|
||||
|
||||
if not requested_scopes:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"At least one scope is required",
|
||||
)
|
||||
|
||||
if not validate_scopes(app, requested_scopes):
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"Application is not authorized for all requested scopes. "
|
||||
f"Allowed: {', '.join(s.value for s in app.scopes)}",
|
||||
)
|
||||
|
||||
# Create authorization code
|
||||
auth_code = await create_authorization_code(
|
||||
application_id=app.id,
|
||||
user_id=user_id,
|
||||
scopes=requested_scopes,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_challenge=request.code_challenge,
|
||||
code_challenge_method=request.code_challenge_method,
|
||||
)
|
||||
|
||||
# Build redirect URL with authorization code
|
||||
params = {
|
||||
"code": auth_code.code,
|
||||
"state": request.state,
|
||||
}
|
||||
redirect_url = f"{request.redirect_uri}?{urlencode(params)}"
|
||||
|
||||
logger.info(
|
||||
f"Authorization code issued for user #{user_id} "
|
||||
f"and app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in authorization endpoint: {e}", exc_info=True)
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"server_error",
|
||||
"An unexpected error occurred",
|
||||
)
|
||||
|
||||
|
||||
def _error_redirect_url(
|
||||
redirect_uri: str,
|
||||
state: str,
|
||||
error: str,
|
||||
error_description: Optional[str] = None,
|
||||
) -> AuthorizeResponse:
|
||||
"""Helper to build redirect URL with OAuth error parameters"""
|
||||
params = {
|
||||
"error": error,
|
||||
"state": state,
|
||||
}
|
||||
if error_description:
|
||||
params["error_description"] = error_description
|
||||
|
||||
redirect_url = f"{redirect_uri}?{urlencode(params)}"
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenRequestByCode(BaseModel):
|
||||
grant_type: Literal["authorization_code"]
|
||||
code: str = Field(description="Authorization code")
|
||||
redirect_uri: str = Field(
|
||||
description="Redirect URI (must match authorization request)"
|
||||
)
|
||||
client_id: str
|
||||
client_secret: str
|
||||
code_verifier: str = Field(description="PKCE code verifier")
|
||||
|
||||
|
||||
class TokenRequestByRefreshToken(BaseModel):
|
||||
grant_type: Literal["refresh_token"]
|
||||
refresh_token: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
|
||||
|
||||
@router.post("/token")
|
||||
async def token(
|
||||
request: TokenRequestByCode | TokenRequestByRefreshToken = Body(),
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
OAuth 2.0 Token Endpoint
|
||||
|
||||
Exchanges authorization code or refresh token for access token.
|
||||
|
||||
Grant Types:
|
||||
1. authorization_code: Exchange authorization code for tokens
|
||||
- Required: grant_type, code, redirect_uri, client_id, client_secret
|
||||
- Optional: code_verifier (required if PKCE was used)
|
||||
|
||||
2. refresh_token: Exchange refresh token for new access token
|
||||
- Required: grant_type, refresh_token, client_id, client_secret
|
||||
|
||||
Returns:
|
||||
- access_token: Bearer token for API access (1 hour TTL)
|
||||
- token_type: "Bearer"
|
||||
- expires_in: Seconds until access token expires
|
||||
- refresh_token: Token for refreshing access (30 days TTL)
|
||||
- scopes: List of scopes
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(
|
||||
request.client_id, request.client_secret
|
||||
)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Handle authorization_code grant
|
||||
if request.grant_type == "authorization_code":
|
||||
# Consume authorization code
|
||||
try:
|
||||
user_id, scopes = await consume_authorization_code(
|
||||
code=request.code,
|
||||
application_id=app.id,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_verifier=request.code_verifier,
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Create access and refresh tokens
|
||||
access_token = await create_access_token(app.id, user_id, scopes)
|
||||
refresh_token = await create_refresh_token(app.id, user_id, scopes)
|
||||
|
||||
logger.info(
|
||||
f"Access token issued for user #{user_id} and app {app.name} (#{app.id})"
|
||||
"via authorization code"
|
||||
)
|
||||
|
||||
if not access_token.token or not refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=access_token.token.get_secret_value(),
|
||||
access_token_expires_at=access_token.expires_at,
|
||||
refresh_token=refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=refresh_token.expires_at,
|
||||
scopes=list(s.value for s in scopes),
|
||||
)
|
||||
|
||||
# Handle refresh_token grant
|
||||
elif request.grant_type == "refresh_token":
|
||||
# Refresh access token
|
||||
try:
|
||||
new_access_token, new_refresh_token = await refresh_tokens(
|
||||
request.refresh_token, app.id
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Tokens refreshed for user #{new_access_token.user_id} "
|
||||
f"by app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
if not new_access_token.token or not new_refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=new_access_token.token.get_secret_value(),
|
||||
access_token_expires_at=new_access_token.expires_at,
|
||||
refresh_token=new_refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=new_refresh_token.expires_at,
|
||||
scopes=list(s.value for s in new_access_token.scopes),
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported grant_type: {request.grant_type}. "
|
||||
"Must be 'authorization_code' or 'refresh_token'",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/introspect")
|
||||
async def introspect(
|
||||
token: str = Body(description="Token to introspect"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
) -> TokenIntrospectionResult:
|
||||
"""
|
||||
OAuth 2.0 Token Introspection Endpoint (RFC 7662)
|
||||
|
||||
Allows clients to check if a token is valid and get its metadata.
|
||||
|
||||
Returns:
|
||||
- active: Whether the token is currently active
|
||||
- scopes: List of authorized scopes (if active)
|
||||
- client_id: The client the token was issued to (if active)
|
||||
- user_id: The user the token represents (if active)
|
||||
- exp: Expiration timestamp (if active)
|
||||
- token_type: "access_token" or "refresh_token" (if active)
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Introspect the token
|
||||
return await introspect_token(token, token_type_hint)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/revoke")
|
||||
async def revoke(
|
||||
token: str = Body(description="Token to revoke"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
):
|
||||
"""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009)
|
||||
|
||||
Allows clients to revoke an access or refresh token.
|
||||
|
||||
Note: Revoking a refresh token does NOT revoke associated access tokens.
|
||||
Revoking an access token does NOT revoke the associated refresh token.
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Try to revoke as access token first
|
||||
# Note: We pass app.id to ensure the token belongs to the authenticated app
|
||||
if token_type_hint != "refresh_token":
|
||||
revoked = await revoke_access_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Access token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Try to revoke as refresh token
|
||||
revoked = await revoke_refresh_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Refresh token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Per RFC 7009, revocation endpoint returns 200 even if token not found
|
||||
# or if token belongs to a different application.
|
||||
# This prevents token scanning attacks.
|
||||
logger.warning(f"Unsuccessful token revocation attempt by app {app.name} #{app.id}")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Management Endpoints (for app owners)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get("/apps/mine")
|
||||
async def list_my_oauth_apps(
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> list[OAuthApplicationInfo]:
|
||||
"""
|
||||
List all OAuth applications owned by the current user.
|
||||
|
||||
Returns a list of OAuth applications with their details including:
|
||||
- id, name, description, logo_url
|
||||
- client_id (public identifier)
|
||||
- redirect_uris, grant_types, scopes
|
||||
- is_active status
|
||||
- created_at, updated_at timestamps
|
||||
|
||||
Note: client_secret is never returned for security reasons.
|
||||
"""
|
||||
return await list_user_oauth_applications(user_id)
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/status")
|
||||
async def update_app_status(
|
||||
app_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
is_active: bool = Body(description="Whether the app should be active", embed=True),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Enable or disable an OAuth application.
|
||||
|
||||
Only the application owner can update the status.
|
||||
When disabled, the application cannot be used for new authorizations
|
||||
and existing access tokens will fail validation.
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
action = "enabled" if is_active else "disabled"
|
||||
logger.info(f"OAuth app {updated_app.name} (#{app_id}) {action} by user #{user_id}")
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
class UpdateAppLogoRequest(BaseModel):
|
||||
logo_url: str = Field(description="URL of the uploaded logo image")
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/logo")
|
||||
async def update_app_logo(
|
||||
app_id: str,
|
||||
request: UpdateAppLogoRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Update the logo URL for an OAuth application.
|
||||
|
||||
Only the application owner can update the logo.
|
||||
The logo should be uploaded first using the media upload endpoint,
|
||||
then this endpoint is called with the resulting URL.
|
||||
|
||||
Logo requirements:
|
||||
- Must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=request.logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo updated by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
# Logo upload constraints
|
||||
LOGO_MIN_SIZE = 512
|
||||
LOGO_MAX_SIZE = 2048
|
||||
LOGO_ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||
LOGO_MAX_FILE_SIZE = 3 * 1024 * 1024 # 3MB
|
||||
|
||||
|
||||
@router.post("/apps/{app_id}/logo/upload")
|
||||
async def upload_app_logo(
|
||||
app_id: str,
|
||||
file: UploadFile,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Upload a logo image for an OAuth application.
|
||||
|
||||
Requirements:
|
||||
- Image must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
- Allowed formats: JPEG, PNG, WebP
|
||||
- Maximum file size: 3MB
|
||||
|
||||
The image is uploaded to cloud storage and the app's logoUrl is updated.
|
||||
Returns the updated application info.
|
||||
"""
|
||||
# Verify ownership to reduce vulnerability to DoS(torage) or DoM(oney) attacks
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Check GCS configuration
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Media storage is not configured",
|
||||
)
|
||||
|
||||
# Validate content type
|
||||
content_type = file.content_type
|
||||
if content_type not in LOGO_ALLOWED_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid file type. Allowed: JPEG, PNG, WebP. Got: {content_type}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
file_bytes = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading logo file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to read uploaded file",
|
||||
)
|
||||
|
||||
# Check file size
|
||||
if len(file_bytes) > LOGO_MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"File too large. "
|
||||
f"Maximum size is {LOGO_MAX_FILE_SIZE // 1024 // 1024}MB"
|
||||
),
|
||||
)
|
||||
|
||||
# Validate image dimensions
|
||||
try:
|
||||
image = Image.open(io.BytesIO(file_bytes))
|
||||
width, height = image.size
|
||||
|
||||
if width != height:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo must be square. Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width < LOGO_MIN_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too small. Minimum {LOGO_MIN_SIZE}x{LOGO_MIN_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width > LOGO_MAX_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too large. Maximum {LOGO_MAX_SIZE}x{LOGO_MAX_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating logo image: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid image file",
|
||||
)
|
||||
|
||||
# Scan for viruses
|
||||
filename = file.filename or "logo"
|
||||
await scan_content_safe(file_bytes, filename=filename)
|
||||
|
||||
# Generate unique filename
|
||||
file_ext = os.path.splitext(filename)[1].lower() or ".png"
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
storage_path = f"oauth-apps/{app_id}/logo/{unique_filename}"
|
||||
|
||||
# Upload to GCS
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
|
||||
await async_client.upload(
|
||||
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||
)
|
||||
|
||||
logo_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading logo to GCS: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to upload logo",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
# Update the app with the new logo URL
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo uploaded by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
async def _delete_app_current_logo_file(app: OAuthApplicationInfo):
|
||||
"""
|
||||
Delete the current logo file for the given app, if there is one in our cloud storage
|
||||
"""
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
storage_base_url = f"https://storage.googleapis.com/{bucket_name}/"
|
||||
|
||||
if app.logo_url and app.logo_url.startswith(storage_base_url):
|
||||
# Parse blob path from URL: https://storage.googleapis.com/{bucket}/{path}
|
||||
old_path = app.logo_url.replace(storage_base_url, "")
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
await async_client.delete(bucket_name, old_path)
|
||||
logger.info(f"Deleted old logo for OAuth app #{app.id}: {old_path}")
|
||||
except Exception as e:
|
||||
# Log but don't fail - the new logo was uploaded successfully
|
||||
logger.warning(
|
||||
f"Failed to delete old logo for OAuth app #{app.id}: {e}", exc_info=e
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -56,7 +56,7 @@ async def postmark_webhook_handler(
|
||||
webhook: Annotated[
|
||||
PostmarkWebhook,
|
||||
Body(discriminator="RecordType"),
|
||||
],
|
||||
]
|
||||
):
|
||||
logger.info(f"Received webhook from Postmark: {webhook}")
|
||||
match webhook:
|
||||
|
||||
@@ -31,9 +31,9 @@ from typing_extensions import Optional, TypedDict
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.v2.library.db as library_db
|
||||
from backend.data import api_key as api_key_db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.auth import api_key as api_key_db
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
|
||||
@@ -522,8 +522,8 @@ async def test_api_keys_with_newline_variations(mock_request):
|
||||
"valid\r\ntoken", # Windows newline
|
||||
"valid\rtoken", # Mac newline
|
||||
"valid\x85token", # NEL (Next Line)
|
||||
"valid\x0btoken", # Vertical Tab
|
||||
"valid\x0ctoken", # Form Feed
|
||||
"valid\x0Btoken", # Vertical Tab
|
||||
"valid\x0Ctoken", # Form Feed
|
||||
]
|
||||
|
||||
for api_key in newline_variations:
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from os import getenv
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
@@ -51,16 +49,13 @@ async def setup_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data=cast(
|
||||
ProfileCreateInput,
|
||||
{
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
},
|
||||
)
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Create a test graph with agent input -> agent output
|
||||
@@ -177,16 +172,13 @@ async def setup_llm_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data=cast(
|
||||
ProfileCreateInput,
|
||||
{
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for LLM tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
},
|
||||
)
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for LLM tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Create test OpenAI credentials for the user
|
||||
@@ -340,16 +332,13 @@ async def setup_firecrawl_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data=cast(
|
||||
ProfileCreateInput,
|
||||
{
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for Firecrawl tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
},
|
||||
)
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for Firecrawl tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
)
|
||||
|
||||
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Literal, Optional, cast
|
||||
from typing import Literal, Optional
|
||||
|
||||
import fastapi
|
||||
import prisma.errors
|
||||
import prisma.fields
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
from prisma.types import LibraryAgentCreateInput
|
||||
|
||||
import backend.data.graph as graph_db
|
||||
import backend.data.integrations as integrations_db
|
||||
@@ -803,21 +802,18 @@ async def add_store_agent_to_library(
|
||||
|
||||
# Create LibraryAgent entry
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data=cast(
|
||||
LibraryAgentCreateInput,
|
||||
{
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"settings": SafeJson(
|
||||
_initialize_graph_settings(graph_model).model_dump()
|
||||
),
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
),
|
||||
"isCreatedByUser": False,
|
||||
"settings": SafeJson(
|
||||
_initialize_graph_settings(graph_model).model_dump()
|
||||
),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
|
||||
@@ -2,14 +2,13 @@ import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, cast
|
||||
from typing import Literal
|
||||
|
||||
import fastapi
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
from prisma.types import SearchTermsCreateInput, StoreListingVersionCreateInput
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
@@ -249,10 +248,7 @@ async def log_search_term(search_query: str):
|
||||
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
try:
|
||||
await prisma.models.SearchTerms.prisma().create(
|
||||
data=cast(
|
||||
SearchTermsCreateInput,
|
||||
{"searchTerm": search_query, "createdDate": date},
|
||||
)
|
||||
data={"searchTerm": search_query, "createdDate": date}
|
||||
)
|
||||
except Exception as e:
|
||||
# Fail silently here so that logging search terms doesn't break the app
|
||||
@@ -1435,14 +1431,11 @@ async def _approve_sub_agent(
|
||||
# Create new version if no matching version found
|
||||
next_version = max((v.version for v in listing.Versions or []), default=0) + 1
|
||||
await prisma.models.StoreListingVersion.prisma(tx).create(
|
||||
data=cast(
|
||||
StoreListingVersionCreateInput,
|
||||
{
|
||||
**_create_sub_agent_version_data(sub_graph, heading, main_agent_name),
|
||||
"version": next_version,
|
||||
"storeListingId": listing.id,
|
||||
},
|
||||
)
|
||||
data={
|
||||
**_create_sub_agent_version_data(sub_graph, heading, main_agent_name),
|
||||
"version": next_version,
|
||||
"storeListingId": listing.id,
|
||||
}
|
||||
)
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": listing.id}, data={"hasApprovedVersion": True}
|
||||
|
||||
@@ -45,7 +45,7 @@ def mock_storage_client(mocker):
|
||||
|
||||
async def test_upload_media_success(mock_settings, mock_storage_client):
|
||||
# Create test JPEG data with valid signature
|
||||
test_data = b"\xff\xd8\xff" + b"test data"
|
||||
test_data = b"\xFF\xD8\xFF" + b"test data"
|
||||
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="laptop.jpeg",
|
||||
@@ -83,7 +83,7 @@ async def test_upload_media_missing_credentials(monkeypatch):
|
||||
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="laptop.jpeg",
|
||||
file=io.BytesIO(b"\xff\xd8\xff" + b"test data"), # Valid JPEG signature
|
||||
file=io.BytesIO(b"\xFF\xD8\xFF" + b"test data"), # Valid JPEG signature
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
@@ -108,7 +108,7 @@ async def test_upload_media_video_type(mock_settings, mock_storage_client):
|
||||
|
||||
|
||||
async def test_upload_media_file_too_large(mock_settings, mock_storage_client):
|
||||
large_data = b"\xff\xd8\xff" + b"x" * (
|
||||
large_data = b"\xFF\xD8\xFF" + b"x" * (
|
||||
50 * 1024 * 1024 + 1
|
||||
) # 50MB + 1 byte with valid JPEG signature
|
||||
test_file = fastapi.UploadFile(
|
||||
|
||||
@@ -4,12 +4,14 @@ Centralized service client helpers with thread caching.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.util.cache import thread_cached
|
||||
from backend.util.cache import cached, thread_cached
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from supabase import AClient, Client
|
||||
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
RedisExecutionEventBus,
|
||||
@@ -114,6 +116,29 @@ def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
return IntegrationCredentialsStore()
|
||||
|
||||
|
||||
# ============ Supabase Clients ============ #
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_supabase() -> "Client":
|
||||
"""Get a process-cached synchronous Supabase client instance."""
|
||||
from supabase import create_client
|
||||
|
||||
return create_client(
|
||||
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
|
||||
)
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
async def get_async_supabase() -> "AClient":
|
||||
"""Get a process-cached asynchronous Supabase client instance."""
|
||||
from supabase import create_async_client
|
||||
|
||||
return await create_async_client(
|
||||
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
|
||||
)
|
||||
|
||||
|
||||
# ============ Notification Queue Helpers ============ #
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ Utilities for handling dynamic field names and delimiters in the AutoGPT Platfor
|
||||
|
||||
Dynamic fields allow graphs to connect complex data structures using special delimiters:
|
||||
- _#_ for dictionary keys (e.g., "values_#_name" → values["name"])
|
||||
- _$_ for list indices (e.g., "items_$_0" → items[0])
|
||||
- _$_ for list indices (e.g., "items_$_0" → items[0])
|
||||
- _@_ for object attributes (e.g., "obj_@_attr" → obj.attr)
|
||||
|
||||
This module provides utilities for:
|
||||
|
||||
@@ -83,7 +83,7 @@ def shutdown_launchdarkly() -> None:
|
||||
@cached(maxsize=1000, ttl_seconds=86400) # 1000 entries, 24 hours TTL
|
||||
async def _fetch_user_context_data(user_id: str) -> Context:
|
||||
"""
|
||||
Fetch user context for LaunchDarkly from the database.
|
||||
Fetch user context for LaunchDarkly from Supabase.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to fetch data for
|
||||
@@ -94,11 +94,12 @@ async def _fetch_user_context_data(user_id: str) -> Context:
|
||||
builder = Context.builder(user_id).kind("user").anonymous(True)
|
||||
|
||||
try:
|
||||
from backend.data.db import prisma
|
||||
from backend.util.clients import get_supabase
|
||||
|
||||
# If we have user data, update context
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
if user:
|
||||
response = get_supabase().auth.admin.get_user_by_id(user_id)
|
||||
if response and response.user:
|
||||
user = response.user
|
||||
builder.anonymous(False)
|
||||
if user.role:
|
||||
builder.set("role", user.role)
|
||||
|
||||
@@ -14,47 +14,12 @@ from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
||||
|
||||
# Maximum filename length (conservative limit for most filesystems)
|
||||
MAX_FILENAME_LENGTH = 200
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""
|
||||
Sanitize and truncate filename to prevent filesystem errors.
|
||||
"""
|
||||
# Remove or replace invalid characters
|
||||
sanitized = re.sub(r'[<>:"/\\|?*\n\r\t]', "_", filename)
|
||||
|
||||
# Truncate if too long
|
||||
if len(sanitized) > MAX_FILENAME_LENGTH:
|
||||
# Keep the extension if possible
|
||||
if "." in sanitized:
|
||||
name, ext = sanitized.rsplit(".", 1)
|
||||
max_name_length = MAX_FILENAME_LENGTH - len(ext) - 1
|
||||
sanitized = name[:max_name_length] + "." + ext
|
||||
else:
|
||||
sanitized = sanitized[:MAX_FILENAME_LENGTH]
|
||||
|
||||
# Ensure it's not empty or just dots
|
||||
if not sanitized or sanitized.strip(".") == "":
|
||||
sanitized = f"file_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def get_exec_file_path(graph_exec_id: str, path: str) -> str:
|
||||
"""
|
||||
Utility to build an absolute path in the {temp}/exec_file/{exec_id}/... folder.
|
||||
"""
|
||||
try:
|
||||
full_path = TEMP_DIR / "exec_file" / graph_exec_id / path
|
||||
return str(full_path)
|
||||
except OSError as e:
|
||||
if "File name too long" in str(e):
|
||||
raise ValueError(
|
||||
f"File path too long: {len(path)} characters. Maximum path length exceeded."
|
||||
) from e
|
||||
raise ValueError(f"Invalid file path: {e}") from e
|
||||
return str(TEMP_DIR / "exec_file" / graph_exec_id / path)
|
||||
|
||||
|
||||
def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
|
||||
@@ -152,11 +117,8 @@ async def store_media_file(
|
||||
|
||||
# Generate filename from cloud path
|
||||
_, path_part = cloud_storage.parse_cloud_path(file)
|
||||
filename = sanitize_filename(Path(path_part).name or f"{uuid.uuid4()}.bin")
|
||||
try:
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
except OSError as e:
|
||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||
filename = Path(path_part).name or f"{uuid.uuid4()}.bin"
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
|
||||
# Check file size limit
|
||||
if len(cloud_content) > MAX_FILE_SIZE:
|
||||
@@ -182,10 +144,7 @@ async def store_media_file(
|
||||
# Generate filename and decode
|
||||
extension = _extension_from_mime(mime_type)
|
||||
filename = f"{uuid.uuid4()}{extension}"
|
||||
try:
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
except OSError as e:
|
||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
content = base64.b64decode(b64_content)
|
||||
|
||||
# Check file size limit
|
||||
@@ -201,11 +160,8 @@ async def store_media_file(
|
||||
elif file.startswith(("http://", "https://")):
|
||||
# URL
|
||||
parsed_url = urlparse(file)
|
||||
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
|
||||
try:
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
except OSError as e:
|
||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||
filename = Path(parsed_url.path).name or f"{uuid.uuid4()}"
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
|
||||
# Download and save
|
||||
resp = await Requests().get(file)
|
||||
@@ -221,12 +177,8 @@ async def store_media_file(
|
||||
target_path.write_bytes(resp.content)
|
||||
|
||||
else:
|
||||
# Local path - sanitize the filename part to prevent long filename errors
|
||||
sanitized_file = sanitize_filename(file)
|
||||
try:
|
||||
target_path = _ensure_inside_base(base_path / sanitized_file, base_path)
|
||||
except OSError as e:
|
||||
raise ValueError(f"Invalid file path '{sanitized_file}': {e}") from e
|
||||
# Local path
|
||||
target_path = _ensure_inside_base(base_path / file, base_path)
|
||||
if not target_path.is_file():
|
||||
raise ValueError(f"Local file does not exist: {target_path}")
|
||||
|
||||
|
||||
@@ -21,26 +21,6 @@ from tenacity import (
|
||||
|
||||
from backend.util.json import loads
|
||||
|
||||
|
||||
class HTTPClientError(Exception):
|
||||
"""4xx client errors (400-499)"""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class HTTPServerError(Exception):
|
||||
"""5xx server errors (500-599)"""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
# Default User-Agent for all requests
|
||||
DEFAULT_USER_AGENT = "AutoGPT-Platform/1.0 (https://github.com/Significant-Gravitas/AutoGPT; info@agpt.co) aiohttp"
|
||||
|
||||
# Retry status codes for which we will automatically retry the request
|
||||
THROTTLE_RETRY_STATUS_CODES: set[int] = {429, 500, 502, 503, 504, 408}
|
||||
|
||||
@@ -470,10 +450,6 @@ class Requests:
|
||||
if self.extra_headers is not None:
|
||||
req_headers.update(self.extra_headers)
|
||||
|
||||
# Set default User-Agent if not provided
|
||||
if "User-Agent" not in req_headers and "user-agent" not in req_headers:
|
||||
req_headers["User-Agent"] = DEFAULT_USER_AGENT
|
||||
|
||||
# Override Host header if using IP connection
|
||||
if connector:
|
||||
req_headers["Host"] = hostname
|
||||
@@ -500,16 +476,9 @@ class Requests:
|
||||
response.raise_for_status()
|
||||
except ClientResponseError as e:
|
||||
body = await response.read()
|
||||
error_message = f"HTTP {response.status} Error: {response.reason}, Body: {body.decode(errors='replace')}"
|
||||
|
||||
# Raise specific exceptions based on status code range
|
||||
if 400 <= response.status <= 499:
|
||||
raise HTTPClientError(error_message, response.status) from e
|
||||
elif 500 <= response.status <= 599:
|
||||
raise HTTPServerError(error_message, response.status) from e
|
||||
else:
|
||||
# Generic fallback for other HTTP errors
|
||||
raise Exception(error_message) from e
|
||||
raise Exception(
|
||||
f"HTTP {response.status} Error: {response.reason}, Body: {body.decode(errors='replace')}"
|
||||
) from e
|
||||
|
||||
# If allowed and a redirect is received, follow the redirect manually
|
||||
if allow_redirects and response.status in (301, 302, 303, 307, 308):
|
||||
|
||||
@@ -362,13 +362,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Hours between cloud storage cleanup runs (1-24 hours)",
|
||||
)
|
||||
|
||||
oauth_token_cleanup_interval_hours: int = Field(
|
||||
default=6,
|
||||
ge=1,
|
||||
le=24,
|
||||
description="Hours between OAuth token cleanup runs (1-24 hours)",
|
||||
)
|
||||
|
||||
upload_file_size_limit_mb: int = Field(
|
||||
default=256,
|
||||
ge=1,
|
||||
@@ -530,6 +523,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
"""Secrets for the server."""
|
||||
|
||||
supabase_url: str = Field(default="", description="Supabase URL")
|
||||
supabase_service_role_key: str = Field(
|
||||
default="", description="Supabase service role key"
|
||||
)
|
||||
|
||||
encryption_key: str = Field(default="", description="Encryption key")
|
||||
|
||||
rabbitmq_default_user: str = Field(default="", description="RabbitMQ default user")
|
||||
|
||||
@@ -222,9 +222,9 @@ class TestSafeJson:
|
||||
problematic_data = {
|
||||
"null_byte": "data with \x00 null",
|
||||
"bell_char": "data with \x07 bell",
|
||||
"form_feed": "data with \x0c feed",
|
||||
"escape_char": "data with \x1b escape",
|
||||
"delete_char": "data with \x7f delete",
|
||||
"form_feed": "data with \x0C feed",
|
||||
"escape_char": "data with \x1B escape",
|
||||
"delete_char": "data with \x7F delete",
|
||||
}
|
||||
|
||||
# SafeJson should successfully process data with control characters
|
||||
@@ -235,9 +235,9 @@ class TestSafeJson:
|
||||
result_data = result.data
|
||||
assert "\x00" not in str(result_data) # null byte removed
|
||||
assert "\x07" not in str(result_data) # bell removed
|
||||
assert "\x0c" not in str(result_data) # form feed removed
|
||||
assert "\x1b" not in str(result_data) # escape removed
|
||||
assert "\x7f" not in str(result_data) # delete removed
|
||||
assert "\x0C" not in str(result_data) # form feed removed
|
||||
assert "\x1B" not in str(result_data) # escape removed
|
||||
assert "\x7F" not in str(result_data) # delete removed
|
||||
|
||||
# Test that safe whitespace characters are preserved
|
||||
safe_data = {
|
||||
@@ -263,7 +263,7 @@ class TestSafeJson:
|
||||
def test_web_scraping_content_sanitization(self):
|
||||
"""Test sanitization of typical web scraping content with null characters."""
|
||||
# Simulate web content that might contain null bytes from SearchTheWebBlock
|
||||
web_content = "Article title\x00Hidden null\x01Start of heading\x08Backspace\x0cForm feed content\x1fUnit separator\x7fDelete char"
|
||||
web_content = "Article title\x00Hidden null\x01Start of heading\x08Backspace\x0CForm feed content\x1FUnit separator\x7FDelete char"
|
||||
|
||||
result = SafeJson(web_content)
|
||||
assert isinstance(result, Json)
|
||||
@@ -273,9 +273,9 @@ class TestSafeJson:
|
||||
assert "\x00" not in sanitized_content
|
||||
assert "\x01" not in sanitized_content
|
||||
assert "\x08" not in sanitized_content
|
||||
assert "\x0c" not in sanitized_content
|
||||
assert "\x1f" not in sanitized_content
|
||||
assert "\x7f" not in sanitized_content
|
||||
assert "\x0C" not in sanitized_content
|
||||
assert "\x1F" not in sanitized_content
|
||||
assert "\x7F" not in sanitized_content
|
||||
|
||||
# Verify the content is still readable
|
||||
assert "Article title" in sanitized_content
|
||||
@@ -391,7 +391,7 @@ class TestSafeJson:
|
||||
mixed_content = {
|
||||
"safe_and_unsafe": "Good text\twith tab\x00NULL BYTE\nand newline\x08BACKSPACE",
|
||||
"file_path_with_null": "C:\\temp\\file\x00.txt",
|
||||
"json_with_controls": '{"text": "data\x01\x0c\x1f"}',
|
||||
"json_with_controls": '{"text": "data\x01\x0C\x1F"}',
|
||||
}
|
||||
|
||||
result = SafeJson(mixed_content)
|
||||
@@ -419,13 +419,13 @@ class TestSafeJson:
|
||||
|
||||
# Create data with various problematic escape sequences that could cause JSON parsing errors
|
||||
problematic_output_data = {
|
||||
"web_content": "Article text\x00with null\x01and control\x08chars\x0c\x1f\x7f",
|
||||
"web_content": "Article text\x00with null\x01and control\x08chars\x0C\x1F\x7F",
|
||||
"file_path": "C:\\Users\\test\\file\x00.txt",
|
||||
"json_like_string": '{"text": "data\x00\x08\x1f"}',
|
||||
"json_like_string": '{"text": "data\x00\x08\x1F"}',
|
||||
"escaped_sequences": "Text with \\u0000 and \\u0008 sequences",
|
||||
"mixed_content": "Normal text\tproperly\nformatted\rwith\x00invalid\x08chars\x1fmixed",
|
||||
"mixed_content": "Normal text\tproperly\nformatted\rwith\x00invalid\x08chars\x1Fmixed",
|
||||
"large_text": "A" * 35000
|
||||
+ "\x00\x08\x1f"
|
||||
+ "\x00\x08\x1F"
|
||||
+ "B" * 5000, # Large text like in the error
|
||||
}
|
||||
|
||||
@@ -446,9 +446,9 @@ class TestSafeJson:
|
||||
assert "\x00" not in str(web_content)
|
||||
assert "\x01" not in str(web_content)
|
||||
assert "\x08" not in str(web_content)
|
||||
assert "\x0c" not in str(web_content)
|
||||
assert "\x1f" not in str(web_content)
|
||||
assert "\x7f" not in str(web_content)
|
||||
assert "\x0C" not in str(web_content)
|
||||
assert "\x1F" not in str(web_content)
|
||||
assert "\x7F" not in str(web_content)
|
||||
|
||||
# Check that legitimate content is preserved
|
||||
assert "Article text" in str(web_content)
|
||||
@@ -467,7 +467,7 @@ class TestSafeJson:
|
||||
assert "B" * 1000 in str(large_text) # B's preserved
|
||||
assert "\x00" not in str(large_text) # Control chars removed
|
||||
assert "\x08" not in str(large_text)
|
||||
assert "\x1f" not in str(large_text)
|
||||
assert "\x1F" not in str(large_text)
|
||||
|
||||
# Most importantly: ensure the result can be JSON-serialized without errors
|
||||
# This would have failed with the old approach
|
||||
@@ -602,7 +602,7 @@ class TestSafeJson:
|
||||
model = SamplePydanticModel(
|
||||
name="Test\x00User", # Has null byte
|
||||
age=30,
|
||||
metadata={"info": "data\x08with\x0ccontrols"},
|
||||
metadata={"info": "data\x08with\x0Ccontrols"},
|
||||
)
|
||||
|
||||
data = {"credential": model}
|
||||
@@ -616,7 +616,7 @@ class TestSafeJson:
|
||||
json_string = json.dumps(result.data)
|
||||
assert "\x00" not in json_string
|
||||
assert "\x08" not in json_string
|
||||
assert "\x0c" not in json_string
|
||||
assert "\x0C" not in json_string
|
||||
assert "TestUser" in json_string # Name preserved minus null byte
|
||||
|
||||
def test_deeply_nested_pydantic_models_control_char_sanitization(self):
|
||||
@@ -639,16 +639,16 @@ class TestSafeJson:
|
||||
|
||||
# Create test data with control characters at every nesting level
|
||||
inner = InnerModel(
|
||||
deep_string="Deepest\x00Level\x08Control\x0cChars", # Multiple control chars at deepest level
|
||||
deep_string="Deepest\x00Level\x08Control\x0CChars", # Multiple control chars at deepest level
|
||||
metadata={
|
||||
"nested_key": "Nested\x1fValue\x7fDelete"
|
||||
"nested_key": "Nested\x1FValue\x7FDelete"
|
||||
}, # Control chars in nested dict
|
||||
)
|
||||
|
||||
middle = MiddleModel(
|
||||
middle_string="Middle\x01StartOfHeading\x1fUnitSeparator",
|
||||
middle_string="Middle\x01StartOfHeading\x1FUnitSeparator",
|
||||
inner=inner,
|
||||
data="Some\x0bVerticalTab\x0eShiftOut",
|
||||
data="Some\x0BVerticalTab\x0EShiftOut",
|
||||
)
|
||||
|
||||
outer = OuterModel(outer_string="Outer\x00Null\x07Bell", middle=middle)
|
||||
@@ -659,7 +659,7 @@ class TestSafeJson:
|
||||
"nested_model": outer,
|
||||
"list_with_strings": [
|
||||
"List\x00Item1",
|
||||
"List\x0cItem2\x1f",
|
||||
"List\x0CItem2\x1F",
|
||||
{"dict_in_list": "Dict\x08Value"},
|
||||
],
|
||||
}
|
||||
@@ -684,10 +684,10 @@ class TestSafeJson:
|
||||
"\x06",
|
||||
"\x07",
|
||||
"\x08",
|
||||
"\x0b",
|
||||
"\x0c",
|
||||
"\x0e",
|
||||
"\x0f",
|
||||
"\x0B",
|
||||
"\x0C",
|
||||
"\x0E",
|
||||
"\x0F",
|
||||
"\x10",
|
||||
"\x11",
|
||||
"\x12",
|
||||
@@ -698,13 +698,13 @@ class TestSafeJson:
|
||||
"\x17",
|
||||
"\x18",
|
||||
"\x19",
|
||||
"\x1a",
|
||||
"\x1b",
|
||||
"\x1c",
|
||||
"\x1d",
|
||||
"\x1e",
|
||||
"\x1f",
|
||||
"\x7f",
|
||||
"\x1A",
|
||||
"\x1B",
|
||||
"\x1C",
|
||||
"\x1D",
|
||||
"\x1E",
|
||||
"\x1F",
|
||||
"\x7F",
|
||||
]
|
||||
|
||||
for char in control_chars:
|
||||
|
||||
@@ -5,13 +5,6 @@ from typing import Any, Type, TypeVar, Union, cast, get_args, get_origin, overlo
|
||||
from prisma import Json as PrismaJson
|
||||
|
||||
|
||||
def _is_type_or_subclass(origin: Any, target_type: type) -> bool:
|
||||
"""Check if origin is exactly the target type or a subclass of it."""
|
||||
return origin is target_type or (
|
||||
isinstance(origin, type) and issubclass(origin, target_type)
|
||||
)
|
||||
|
||||
|
||||
class ConversionError(ValueError):
|
||||
pass
|
||||
|
||||
@@ -145,11 +138,7 @@ def _try_convert(value: Any, target_type: Any, raise_on_mismatch: bool) -> Any:
|
||||
|
||||
if origin is None:
|
||||
origin = target_type
|
||||
# Early return for unsupported types (skip subclasses of supported types)
|
||||
supported_types = [list, dict, tuple, str, set, int, float, bool]
|
||||
if origin not in supported_types and not (
|
||||
isinstance(origin, type) and any(issubclass(origin, t) for t in supported_types)
|
||||
):
|
||||
if origin not in [list, dict, tuple, str, set, int, float, bool]:
|
||||
return value
|
||||
|
||||
# Handle the case when value is already of the target type
|
||||
@@ -179,47 +168,44 @@ def _try_convert(value: Any, target_type: Any, raise_on_mismatch: bool) -> Any:
|
||||
raise TypeError(f"Value {value} is not of expected type {target_type}")
|
||||
else:
|
||||
# Need to convert value to the origin type
|
||||
if _is_type_or_subclass(origin, list):
|
||||
converted_list = __convert_list(value)
|
||||
if origin is list:
|
||||
value = __convert_list(value)
|
||||
if args:
|
||||
converted_list = [convert(v, args[0]) for v in converted_list]
|
||||
return origin(converted_list) if origin is not list else converted_list
|
||||
elif _is_type_or_subclass(origin, dict):
|
||||
converted_dict = __convert_dict(value)
|
||||
return [convert(v, args[0]) for v in value]
|
||||
else:
|
||||
return value
|
||||
elif origin is dict:
|
||||
value = __convert_dict(value)
|
||||
if args:
|
||||
key_type, val_type = args
|
||||
converted_dict = {
|
||||
convert(k, key_type): convert(v, val_type)
|
||||
for k, v in converted_dict.items()
|
||||
return {
|
||||
convert(k, key_type): convert(v, val_type) for k, v in value.items()
|
||||
}
|
||||
return origin(converted_dict) if origin is not dict else converted_dict
|
||||
elif _is_type_or_subclass(origin, tuple):
|
||||
converted_tuple = __convert_tuple(value)
|
||||
else:
|
||||
return value
|
||||
elif origin is tuple:
|
||||
value = __convert_tuple(value)
|
||||
if args:
|
||||
if len(args) == 1:
|
||||
converted_tuple = tuple(
|
||||
convert(v, args[0]) for v in converted_tuple
|
||||
)
|
||||
return tuple(convert(v, args[0]) for v in value)
|
||||
else:
|
||||
converted_tuple = tuple(
|
||||
convert(v, t) for v, t in zip(converted_tuple, args)
|
||||
)
|
||||
return origin(converted_tuple) if origin is not tuple else converted_tuple
|
||||
elif _is_type_or_subclass(origin, str):
|
||||
converted_str = __convert_str(value)
|
||||
return origin(converted_str) if origin is not str else converted_str
|
||||
elif _is_type_or_subclass(origin, set):
|
||||
return tuple(convert(v, t) for v, t in zip(value, args))
|
||||
else:
|
||||
return value
|
||||
elif origin is str:
|
||||
return __convert_str(value)
|
||||
elif origin is set:
|
||||
value = __convert_set(value)
|
||||
if args:
|
||||
return {convert(v, args[0]) for v in value}
|
||||
else:
|
||||
return value
|
||||
elif _is_type_or_subclass(origin, bool):
|
||||
return __convert_bool(value)
|
||||
elif _is_type_or_subclass(origin, int):
|
||||
elif origin is int:
|
||||
return __convert_num(value, int)
|
||||
elif _is_type_or_subclass(origin, float):
|
||||
elif origin is float:
|
||||
return __convert_num(value, float)
|
||||
elif origin is bool:
|
||||
return __convert_bool(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
@@ -32,17 +32,3 @@ def test_type_conversion():
|
||||
assert convert("5", List[int]) == [5]
|
||||
assert convert("[5,4,2]", List[int]) == [5, 4, 2]
|
||||
assert convert([5, 4, 2], List[str]) == ["5", "4", "2"]
|
||||
|
||||
# Test the specific case that was failing: empty list to Optional[str]
|
||||
assert convert([], Optional[str]) == "[]"
|
||||
assert convert([], str) == "[]"
|
||||
|
||||
# Test the actual failing case: empty list to ShortTextType
|
||||
from backend.util.type import ShortTextType
|
||||
|
||||
assert convert([], Optional[ShortTextType]) == "[]"
|
||||
assert convert([], ShortTextType) == "[]"
|
||||
|
||||
# Test other empty list conversions
|
||||
assert convert([], int) == 0 # len([]) = 0
|
||||
assert convert([], Optional[int]) == 0
|
||||
|
||||
@@ -5,7 +5,7 @@ networks:
|
||||
name: shared-network
|
||||
|
||||
volumes:
|
||||
clamav-data:
|
||||
supabase-config:
|
||||
|
||||
x-agpt-services:
|
||||
&agpt-services
|
||||
@@ -13,18 +13,28 @@ x-agpt-services:
|
||||
- app-network
|
||||
- shared-network
|
||||
|
||||
services:
|
||||
x-supabase-services:
|
||||
&supabase-services
|
||||
networks:
|
||||
- app-network
|
||||
- shared-network
|
||||
|
||||
|
||||
volumes:
|
||||
clamav-data:
|
||||
|
||||
services:
|
||||
|
||||
db:
|
||||
<<: *agpt-services
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ../db/docker/docker-compose.yml
|
||||
service: db
|
||||
ports:
|
||||
- ${POSTGRES_PORT}:5432
|
||||
- ${POSTGRES_PORT}:5432 # We don't use Supavisor locally, so we expose the db directly.
|
||||
|
||||
vector:
|
||||
<<: *agpt-services
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ../db/docker/docker-compose.yml
|
||||
service: vector
|
||||
@@ -57,7 +67,6 @@ services:
|
||||
ports:
|
||||
- "5672:5672"
|
||||
- "15672:15672"
|
||||
|
||||
clamav:
|
||||
image: clamav/clamav-debian:latest
|
||||
ports:
|
||||
@@ -76,7 +85,6 @@ services:
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
networks:
|
||||
app-network-test:
|
||||
driver: bridge
|
||||
app-network-test:
|
||||
driver: bridge
|
||||
|
||||
@@ -5,14 +5,14 @@ Clean, streamlined load testing infrastructure for the AutoGPT Platform using k6
|
||||
## 🚀 Quick Start
|
||||
|
||||
```bash
|
||||
# 1. Set up API base URL (optional, defaults to local)
|
||||
export API_BASE_URL="http://localhost:8006"
|
||||
# 1. Set up Supabase service key (required for token generation)
|
||||
export SUPABASE_SERVICE_KEY="your-supabase-service-key"
|
||||
|
||||
# 2. Generate pre-authenticated tokens (first time setup - creates 160+ tokens with 24-hour expiry)
|
||||
# 2. Generate pre-authenticated tokens (first time setup - creates 160+ tokens with 24-hour expiry)
|
||||
node generate-tokens.js --count=160
|
||||
|
||||
# 3. Set up k6 cloud credentials (for cloud testing - see Credential Setup section below)
|
||||
export K6_CLOUD_TOKEN="your-k6-cloud-token"
|
||||
export K6_CLOUD_TOKEN="your-k6-cloud-token"
|
||||
export K6_CLOUD_PROJECT_ID="4254406"
|
||||
|
||||
# 4. Run orchestrated load tests locally
|
||||
@@ -85,11 +85,11 @@ npm run cloud
|
||||
### Pre-Authenticated Tokens
|
||||
|
||||
- **Generation**: Run `node generate-tokens.js --count=160` to create tokens
|
||||
- **File**: `configs/pre-authenticated-tokens.js` (gitignored for security)
|
||||
- **File**: `configs/pre-authenticated-tokens.js` (gitignored for security)
|
||||
- **Capacity**: 160+ tokens supporting high-concurrency testing
|
||||
- **Expiry**: Based on JWT token expiry settings (default: 15 min access, 7 day refresh)
|
||||
- **Benefit**: Eliminates auth rate limiting at scale
|
||||
- **Regeneration**: Run `node generate-tokens.js --count=160` when tokens expire
|
||||
- **Expiry**: 24 hours (86400 seconds) - extended for long-duration testing
|
||||
- **Benefit**: Eliminates Supabase auth rate limiting at scale
|
||||
- **Regeneration**: Run `node generate-tokens.js --count=160` when tokens expire after 24 hours
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
@@ -182,29 +182,29 @@ npm run cloud
|
||||
|
||||
### Required Setup
|
||||
|
||||
**1. API Base URL (Optional):**
|
||||
**1. Supabase Service Key (Required for all testing):**
|
||||
|
||||
```bash
|
||||
# For local testing (default)
|
||||
export API_BASE_URL="http://localhost:8006"
|
||||
# Option 1: From your local environment (if available)
|
||||
export SUPABASE_SERVICE_KEY="your-supabase-service-key"
|
||||
|
||||
# For dev environment
|
||||
export API_BASE_URL="https://dev-server.agpt.co"
|
||||
# Option 2: From Kubernetes secret (for platform developers)
|
||||
kubectl get secret supabase-service-key -o jsonpath='{.data.service-key}' | base64 -d
|
||||
|
||||
# For production (coordinate with team!)
|
||||
export API_BASE_URL="https://api.agpt.co"
|
||||
# Option 3: From Supabase dashboard
|
||||
# Go to Project Settings > API > service_role key (never commit this!)
|
||||
```
|
||||
|
||||
**2. Generate Pre-Authenticated Tokens (Required):**
|
||||
|
||||
```bash
|
||||
# Creates 160 tokens - prevents auth rate limiting
|
||||
# Creates 160 tokens with 24-hour expiry - prevents auth rate limiting
|
||||
node generate-tokens.js --count=160
|
||||
|
||||
# Generate fewer tokens for smaller tests (minimum 10)
|
||||
node generate-tokens.js --count=50
|
||||
|
||||
# Regenerate when tokens expire
|
||||
# Regenerate when tokens expire (every 24 hours)
|
||||
node generate-tokens.js --count=160
|
||||
```
|
||||
|
||||
|
||||
@@ -4,16 +4,25 @@ export const ENV_CONFIG = {
|
||||
API_BASE_URL: "https://dev-server.agpt.co",
|
||||
BUILDER_BASE_URL: "https://dev-builder.agpt.co",
|
||||
WS_BASE_URL: "wss://dev-ws-server.agpt.co",
|
||||
SUPABASE_URL: "https://adfjtextkuilwuhzdjpf.supabase.co",
|
||||
SUPABASE_ANON_KEY:
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImFkZmp0ZXh0a3VpbHd1aHpkanBmIiwicm9sZSI6ImFub24iLCJpYXQiOjE3MzAyNTE3MDIsImV4cCI6MjA0NTgyNzcwMn0.IuQNXsHEKJNxtS9nyFeqO0BGMYN8sPiObQhuJLSK9xk",
|
||||
},
|
||||
LOCAL: {
|
||||
API_BASE_URL: "http://localhost:8006",
|
||||
BUILDER_BASE_URL: "http://localhost:3000",
|
||||
WS_BASE_URL: "ws://localhost:8001",
|
||||
SUPABASE_URL: "http://localhost:8000",
|
||||
SUPABASE_ANON_KEY:
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE",
|
||||
},
|
||||
PROD: {
|
||||
API_BASE_URL: "https://api.agpt.co",
|
||||
BUILDER_BASE_URL: "https://builder.agpt.co",
|
||||
WS_BASE_URL: "wss://ws-server.agpt.co",
|
||||
SUPABASE_URL: "https://supabase.agpt.co",
|
||||
SUPABASE_ANON_KEY:
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImJnd3B3ZHN4YmxyeWloaW51dGJ4Iiwicm9sZSI6ImFub24iLCJpYXQiOjE3MzAyODYzMDUsImV4cCI6MjA0NTg2MjMwNX0.ISa2IofTdQIJmmX5JwKGGNajqjsD8bjaGBzK90SubE0",
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -4,19 +4,22 @@
|
||||
* Generate Pre-Authenticated Tokens for Load Testing
|
||||
* Creates configs/pre-authenticated-tokens.js with 350+ tokens
|
||||
*
|
||||
* This uses the native auth API to generate tokens
|
||||
* This replaces the old token generation scripts with a clean, single script
|
||||
*/
|
||||
|
||||
import https from "https";
|
||||
import http from "http";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
|
||||
// Get API base URL from environment (default to local)
|
||||
const API_BASE_URL = process.env.API_BASE_URL || "http://localhost:8006";
|
||||
const parsedUrl = new URL(API_BASE_URL);
|
||||
const isHttps = parsedUrl.protocol === "https:";
|
||||
const httpModule = isHttps ? https : http;
|
||||
// Get Supabase service key from environment (REQUIRED for token generation)
|
||||
const SUPABASE_SERVICE_KEY = process.env.SUPABASE_SERVICE_KEY;
|
||||
|
||||
if (!SUPABASE_SERVICE_KEY) {
|
||||
console.error("❌ SUPABASE_SERVICE_KEY environment variable is required");
|
||||
console.error("Get service key from kubectl or environment:");
|
||||
console.error('export SUPABASE_SERVICE_KEY="your-service-key"');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// Generate test users (loadtest4-50 are known to work)
|
||||
const TEST_USERS = [];
|
||||
@@ -28,7 +31,7 @@ for (let i = 4; i <= 50; i++) {
|
||||
}
|
||||
|
||||
console.log(
|
||||
`Generating pre-authenticated tokens from ${TEST_USERS.length} users...`,
|
||||
`🔐 Generating pre-authenticated tokens from ${TEST_USERS.length} users...`,
|
||||
);
|
||||
|
||||
async function authenticateUser(user, attempt = 1) {
|
||||
@@ -36,20 +39,22 @@ async function authenticateUser(user, attempt = 1) {
|
||||
const postData = JSON.stringify({
|
||||
email: user.email,
|
||||
password: user.password,
|
||||
expires_in: 86400, // 24 hours in seconds (24 * 60 * 60)
|
||||
});
|
||||
|
||||
const options = {
|
||||
hostname: parsedUrl.hostname,
|
||||
port: parsedUrl.port || (isHttps ? 443 : 80),
|
||||
path: "/api/auth/login",
|
||||
hostname: "adfjtextkuilwuhzdjpf.supabase.co",
|
||||
path: "/auth/v1/token?grant_type=password",
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${SUPABASE_SERVICE_KEY}`,
|
||||
apikey: SUPABASE_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"Content-Length": postData.length,
|
||||
},
|
||||
};
|
||||
|
||||
const req = httpModule.request(options, (res) => {
|
||||
const req = https.request(options, (res) => {
|
||||
let data = "";
|
||||
res.on("data", (chunk) => (data += chunk));
|
||||
res.on("end", () => {
|
||||
@@ -60,29 +65,29 @@ async function authenticateUser(user, attempt = 1) {
|
||||
} else if (res.statusCode === 429) {
|
||||
// Rate limited - wait and retry
|
||||
console.log(
|
||||
`Rate limited for ${user.email}, waiting 5s (attempt ${attempt}/3)...`,
|
||||
`⏳ Rate limited for ${user.email}, waiting 5s (attempt ${attempt}/3)...`,
|
||||
);
|
||||
setTimeout(() => {
|
||||
if (attempt < 3) {
|
||||
authenticateUser(user, attempt + 1).then(resolve);
|
||||
} else {
|
||||
console.log(`Max retries exceeded for ${user.email}`);
|
||||
console.log(`❌ Max retries exceeded for ${user.email}`);
|
||||
resolve(null);
|
||||
}
|
||||
}, 5000);
|
||||
} else {
|
||||
console.log(`Auth failed for ${user.email}: ${res.statusCode}`);
|
||||
console.log(`❌ Auth failed for ${user.email}: ${res.statusCode}`);
|
||||
resolve(null);
|
||||
}
|
||||
} catch (e) {
|
||||
console.log(`Parse error for ${user.email}:`, e.message);
|
||||
console.log(`❌ Parse error for ${user.email}:`, e.message);
|
||||
resolve(null);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
req.on("error", (err) => {
|
||||
console.log(`Request error for ${user.email}:`, err.message);
|
||||
console.log(`❌ Request error for ${user.email}:`, err.message);
|
||||
resolve(null);
|
||||
});
|
||||
|
||||
@@ -92,8 +97,7 @@ async function authenticateUser(user, attempt = 1) {
|
||||
}
|
||||
|
||||
async function generateTokens() {
|
||||
console.log("Starting token generation...");
|
||||
console.log(`Using API: ${API_BASE_URL}`);
|
||||
console.log("🚀 Starting token generation...");
|
||||
console.log("Rate limit aware - this will take ~10-15 minutes");
|
||||
console.log("===========================================\n");
|
||||
|
||||
@@ -109,11 +113,11 @@ async function generateTokens() {
|
||||
150;
|
||||
const tokensPerUser = Math.ceil(targetTokens / TEST_USERS.length);
|
||||
console.log(
|
||||
`Generating ${tokensPerUser} tokens per user (${TEST_USERS.length} users) - Target: ${targetTokens}\n`,
|
||||
`📊 Generating ${tokensPerUser} tokens per user (${TEST_USERS.length} users) - Target: ${targetTokens}\n`,
|
||||
);
|
||||
|
||||
for (let round = 1; round <= tokensPerUser; round++) {
|
||||
console.log(`Round ${round}/${tokensPerUser}:`);
|
||||
console.log(`🔄 Round ${round}/${tokensPerUser}:`);
|
||||
|
||||
for (
|
||||
let i = 0;
|
||||
@@ -133,9 +137,9 @@ async function generateTokens() {
|
||||
generated: new Date().toISOString(),
|
||||
round: round,
|
||||
});
|
||||
console.log(`OK (${tokens.length}/${targetTokens})`);
|
||||
console.log(`✅ (${tokens.length}/${targetTokens})`);
|
||||
} else {
|
||||
console.log(`FAILED`);
|
||||
console.log(`❌`);
|
||||
}
|
||||
|
||||
// Respect rate limits - wait 500ms between requests
|
||||
@@ -148,13 +152,13 @@ async function generateTokens() {
|
||||
|
||||
// Wait longer between rounds
|
||||
if (round < tokensPerUser) {
|
||||
console.log(` Waiting 3s before next round...\n`);
|
||||
console.log(` ⏸️ Waiting 3s before next round...\n`);
|
||||
await new Promise((resolve) => setTimeout(resolve, 3000));
|
||||
}
|
||||
}
|
||||
|
||||
const duration = Math.round((Date.now() - startTime) / 1000);
|
||||
console.log(`\nGenerated ${tokens.length} tokens in ${duration}s`);
|
||||
console.log(`\n✅ Generated ${tokens.length} tokens in ${duration}s`);
|
||||
|
||||
// Create configs directory if it doesn't exist
|
||||
const configsDir = path.join(process.cwd(), "configs");
|
||||
@@ -167,9 +171,9 @@ async function generateTokens() {
|
||||
// Generated: ${new Date().toISOString()}
|
||||
// Total tokens: ${tokens.length}
|
||||
// Generation time: ${duration} seconds
|
||||
//
|
||||
// SECURITY: This file contains real authentication tokens
|
||||
// DO NOT COMMIT TO GIT - File is gitignored
|
||||
//
|
||||
// ⚠️ SECURITY: This file contains real authentication tokens
|
||||
// ⚠️ DO NOT COMMIT TO GIT - File is gitignored
|
||||
|
||||
export const PRE_AUTHENTICATED_TOKENS = ${JSON.stringify(tokens, null, 2)};
|
||||
|
||||
@@ -177,10 +181,10 @@ export function getPreAuthenticatedToken(vuId = 1) {
|
||||
if (PRE_AUTHENTICATED_TOKENS.length === 0) {
|
||||
throw new Error('No pre-authenticated tokens available');
|
||||
}
|
||||
|
||||
|
||||
const tokenIndex = (vuId - 1) % PRE_AUTHENTICATED_TOKENS.length;
|
||||
const tokenData = PRE_AUTHENTICATED_TOKENS[tokenIndex];
|
||||
|
||||
|
||||
return {
|
||||
access_token: tokenData.token,
|
||||
user: { email: tokenData.user },
|
||||
@@ -193,7 +197,7 @@ const LOAD_TEST_SESSION_ID = '${new Date().toISOString().slice(0, 16).replace(/:
|
||||
|
||||
export function getPreAuthenticatedHeaders(vuId = 1) {
|
||||
const authData = getPreAuthenticatedToken(vuId);
|
||||
|
||||
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': \`Bearer \${authData.access_token}\`,
|
||||
@@ -209,16 +213,16 @@ export const TOKEN_STATS = {
|
||||
generated: PRE_AUTHENTICATED_TOKENS[0]?.generated || 'unknown'
|
||||
};
|
||||
|
||||
console.log(\`Loaded \${TOKEN_STATS.total} pre-authenticated tokens from \${TOKEN_STATS.users} users\`);
|
||||
console.log(\`🔐 Loaded \${TOKEN_STATS.total} pre-authenticated tokens from \${TOKEN_STATS.users} users\`);
|
||||
`;
|
||||
|
||||
const tokenFile = path.join(configsDir, "pre-authenticated-tokens.js");
|
||||
fs.writeFileSync(tokenFile, jsContent);
|
||||
|
||||
console.log(`Saved to configs/pre-authenticated-tokens.js`);
|
||||
console.log(`Ready for ${tokens.length} concurrent VU load testing!`);
|
||||
console.log(`💾 Saved to configs/pre-authenticated-tokens.js`);
|
||||
console.log(`🚀 Ready for ${tokens.length} concurrent VU load testing!`);
|
||||
console.log(
|
||||
`\nSecurity Note: Token file is gitignored and will not be committed`,
|
||||
`\n🔒 Security Note: Token file is gitignored and will not be committed`,
|
||||
);
|
||||
|
||||
return tokens.length;
|
||||
|
||||
@@ -45,7 +45,7 @@ export default function () {
|
||||
// Handle authentication failure gracefully
|
||||
if (!headers || !headers.Authorization) {
|
||||
console.log(
|
||||
`VU ${__VU} has no valid pre-authentication token - skipping iteration`,
|
||||
`⚠️ VU ${__VU} has no valid pre-authentication token - skipping iteration`,
|
||||
);
|
||||
check(null, {
|
||||
"Authentication: Failed gracefully without crashing VU": () => true,
|
||||
@@ -53,57 +53,56 @@ export default function () {
|
||||
return; // Exit iteration gracefully without crashing
|
||||
}
|
||||
|
||||
console.log(`VU ${__VU} making ${requestsPerVU} concurrent requests...`);
|
||||
console.log(`🚀 VU ${__VU} making ${requestsPerVU} concurrent requests...`);
|
||||
|
||||
// Create array of request functions to run concurrently
|
||||
const requests = [];
|
||||
|
||||
for (let i = 0; i < requestsPerVU; i++) {
|
||||
// Health check endpoint
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.SUPABASE_URL}/rest/v1/`,
|
||||
params: { headers: { apikey: config.SUPABASE_ANON_KEY } },
|
||||
});
|
||||
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.API_BASE_URL}/health`,
|
||||
params: { headers },
|
||||
});
|
||||
|
||||
// API endpoint check
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.API_BASE_URL}/api`,
|
||||
params: { headers },
|
||||
});
|
||||
}
|
||||
|
||||
// Execute all requests concurrently
|
||||
const responses = http.batch(requests);
|
||||
|
||||
// Validate results
|
||||
let healthSuccesses = 0;
|
||||
let apiSuccesses = 0;
|
||||
let supabaseSuccesses = 0;
|
||||
let backendSuccesses = 0;
|
||||
|
||||
for (let i = 0; i < responses.length; i++) {
|
||||
const response = responses[i];
|
||||
|
||||
if (i % 2 === 0) {
|
||||
// Health check request
|
||||
const healthCheck = check(response, {
|
||||
"Health endpoint: Status is not 500": (r) => r.status !== 500,
|
||||
"Health endpoint: Response time < 5s": (r) =>
|
||||
// Supabase request
|
||||
const connectivityCheck = check(response, {
|
||||
"Supabase connectivity: Status is not 500": (r) => r.status !== 500,
|
||||
"Supabase connectivity: Response time < 5s": (r) =>
|
||||
r.timings.duration < 5000,
|
||||
});
|
||||
if (healthCheck) healthSuccesses++;
|
||||
if (connectivityCheck) supabaseSuccesses++;
|
||||
} else {
|
||||
// API request
|
||||
const apiCheck = check(response, {
|
||||
"API server: Responds (any status)": (r) => r.status > 0,
|
||||
"API server: Response time < 5s": (r) => r.timings.duration < 5000,
|
||||
// Backend request
|
||||
const backendCheck = check(response, {
|
||||
"Backend server: Responds (any status)": (r) => r.status > 0,
|
||||
"Backend server: Response time < 5s": (r) =>
|
||||
r.timings.duration < 5000,
|
||||
});
|
||||
if (apiCheck) apiSuccesses++;
|
||||
if (backendCheck) backendSuccesses++;
|
||||
}
|
||||
}
|
||||
|
||||
console.log(
|
||||
`VU ${__VU} completed: ${healthSuccesses}/${requestsPerVU} health, ${apiSuccesses}/${requestsPerVU} API requests successful`,
|
||||
`✅ VU ${__VU} completed: ${supabaseSuccesses}/${requestsPerVU} Supabase, ${backendSuccesses}/${requestsPerVU} backend requests successful`,
|
||||
);
|
||||
|
||||
// Basic auth validation (once per iteration)
|
||||
@@ -126,7 +125,7 @@ export default function () {
|
||||
parts[2] && parts[2].length > 10,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`Test failed: ${error.message}`);
|
||||
console.error(`💥 Test failed: ${error.message}`);
|
||||
check(null, {
|
||||
"Test execution: No errors": () => false,
|
||||
});
|
||||
@@ -134,5 +133,5 @@ export default function () {
|
||||
}
|
||||
|
||||
export function teardown(data) {
|
||||
console.log(`Basic connectivity test completed`);
|
||||
console.log(`🏁 Basic connectivity test completed`);
|
||||
}
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "OAuthApplication" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"clientId" TEXT NOT NULL,
|
||||
"clientSecret" TEXT NOT NULL,
|
||||
"clientSecretSalt" TEXT NOT NULL,
|
||||
"redirectUris" TEXT[],
|
||||
"grantTypes" TEXT[] DEFAULT ARRAY['authorization_code', 'refresh_token']::TEXT[],
|
||||
"scopes" "APIKeyPermission"[],
|
||||
"ownerId" TEXT NOT NULL,
|
||||
"isActive" BOOLEAN NOT NULL DEFAULT true,
|
||||
|
||||
CONSTRAINT "OAuthApplication_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "OAuthAuthorizationCode" (
|
||||
"id" TEXT NOT NULL,
|
||||
"code" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"applicationId" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"scopes" "APIKeyPermission"[],
|
||||
"redirectUri" TEXT NOT NULL,
|
||||
"codeChallenge" TEXT,
|
||||
"codeChallengeMethod" TEXT,
|
||||
"usedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "OAuthAuthorizationCode_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "OAuthAccessToken" (
|
||||
"id" TEXT NOT NULL,
|
||||
"token" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"applicationId" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"scopes" "APIKeyPermission"[],
|
||||
"revokedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "OAuthAccessToken_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "OAuthRefreshToken" (
|
||||
"id" TEXT NOT NULL,
|
||||
"token" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"applicationId" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"scopes" "APIKeyPermission"[],
|
||||
"revokedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "OAuthRefreshToken_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OAuthApplication_clientId_key" ON "OAuthApplication"("clientId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthApplication_clientId_idx" ON "OAuthApplication"("clientId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthApplication_ownerId_idx" ON "OAuthApplication"("ownerId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OAuthAuthorizationCode_code_key" ON "OAuthAuthorizationCode"("code");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAuthorizationCode_code_idx" ON "OAuthAuthorizationCode"("code");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAuthorizationCode_applicationId_userId_idx" ON "OAuthAuthorizationCode"("applicationId", "userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAuthorizationCode_expiresAt_idx" ON "OAuthAuthorizationCode"("expiresAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OAuthAccessToken_token_key" ON "OAuthAccessToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAccessToken_token_idx" ON "OAuthAccessToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAccessToken_userId_applicationId_idx" ON "OAuthAccessToken"("userId", "applicationId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAccessToken_expiresAt_idx" ON "OAuthAccessToken"("expiresAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OAuthRefreshToken_token_key" ON "OAuthRefreshToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthRefreshToken_token_idx" ON "OAuthRefreshToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthRefreshToken_userId_applicationId_idx" ON "OAuthRefreshToken"("userId", "applicationId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthRefreshToken_expiresAt_idx" ON "OAuthRefreshToken"("expiresAt");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthApplication" ADD CONSTRAINT "OAuthApplication_ownerId_fkey" FOREIGN KEY ("ownerId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthAuthorizationCode" ADD CONSTRAINT "OAuthAuthorizationCode_applicationId_fkey" FOREIGN KEY ("applicationId") REFERENCES "OAuthApplication"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthAuthorizationCode" ADD CONSTRAINT "OAuthAuthorizationCode_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthAccessToken" ADD CONSTRAINT "OAuthAccessToken_applicationId_fkey" FOREIGN KEY ("applicationId") REFERENCES "OAuthApplication"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthAccessToken" ADD CONSTRAINT "OAuthAccessToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthRefreshToken" ADD CONSTRAINT "OAuthRefreshToken_applicationId_fkey" FOREIGN KEY ("applicationId") REFERENCES "OAuthApplication"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthRefreshToken" ADD CONSTRAINT "OAuthRefreshToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -1,5 +0,0 @@
|
||||
-- AlterEnum
|
||||
ALTER TYPE "APIKeyPermission" ADD VALUE 'IDENTITY';
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "OAuthApplication" ADD COLUMN "logoUrl" TEXT;
|
||||
@@ -1,65 +0,0 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- A unique constraint covering the columns `[googleId]` on the table `User` will be added. If there are existing duplicate values, this will fail.
|
||||
|
||||
*/
|
||||
-- AlterTable
|
||||
ALTER TABLE "User" ADD COLUMN "googleId" TEXT,
|
||||
ADD COLUMN "passwordHash" TEXT,
|
||||
ADD COLUMN "role" TEXT NOT NULL DEFAULT 'authenticated',
|
||||
ALTER COLUMN "emailVerified" SET DEFAULT false;
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "RefreshToken" (
|
||||
"id" TEXT NOT NULL,
|
||||
"token" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"revokedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "RefreshToken_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "PasswordResetToken" (
|
||||
"id" TEXT NOT NULL,
|
||||
"token" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"usedAt" TIMESTAMP(3),
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "PasswordResetToken_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "RefreshToken_token_key" ON "RefreshToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "RefreshToken_userId_idx" ON "RefreshToken"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "RefreshToken_expiresAt_idx" ON "RefreshToken"("expiresAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "RefreshToken_token_idx" ON "RefreshToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "PasswordResetToken_token_key" ON "PasswordResetToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PasswordResetToken_userId_idx" ON "PasswordResetToken"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PasswordResetToken_token_idx" ON "PasswordResetToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "User_googleId_key" ON "User"("googleId");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "RefreshToken" ADD CONSTRAINT "RefreshToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "PasswordResetToken" ADD CONSTRAINT "PasswordResetToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -1,23 +0,0 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "EmailVerificationToken" (
|
||||
"id" TEXT NOT NULL,
|
||||
"token" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"usedAt" TIMESTAMP(3),
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "EmailVerificationToken_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "EmailVerificationToken_token_key" ON "EmailVerificationToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "EmailVerificationToken_userId_idx" ON "EmailVerificationToken"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "EmailVerificationToken_token_idx" ON "EmailVerificationToken"("token");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "EmailVerificationToken" ADD CONSTRAINT "EmailVerificationToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
205
autogpt_platform/backend/poetry.lock
generated
205
autogpt_platform/backend/poetry.lock
generated
@@ -391,21 +391,6 @@ files = [
|
||||
{file = "audioop_lts-0.2.2.tar.gz", hash = "sha256:64d0c62d88e67b98a1a5e71987b7aa7b5bcffc7dcee65b635823dbdd0a8dbbd0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "authlib"
|
||||
version = "1.6.6"
|
||||
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd"},
|
||||
{file = "authlib-1.6.6.tar.gz", hash = "sha256:45770e8e056d0f283451d9996fbb59b70d45722b45d854d58f32878d0a40c38e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cryptography = "*"
|
||||
|
||||
[[package]]
|
||||
name = "autogpt-libs"
|
||||
version = "0.2.0"
|
||||
@@ -417,8 +402,6 @@ files = []
|
||||
develop = true
|
||||
|
||||
[package.dependencies]
|
||||
authlib = "^1.3.0"
|
||||
bcrypt = "^4.1.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
@@ -429,6 +412,7 @@ pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = {version = "^2.10.1", extras = ["crypto"]}
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[package.source]
|
||||
@@ -477,71 +461,6 @@ files = [
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
testing = ["jaraco.test", "pytest (!=8.0.*)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "bcrypt"
|
||||
version = "4.3.0"
|
||||
description = "Modern password hashing for your software and your servers"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-macosx_10_12_universal2.whl", hash = "sha256:f01e060f14b6b57bbb72fc5b4a83ac21c443c9a2ee708e04a10e9192f90a6281"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5eeac541cefd0bb887a371ef73c62c3cd78535e4887b310626036a7c0a817bb"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59e1aa0e2cd871b08ca146ed08445038f42ff75968c7ae50d2fdd7860ade2180"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:0042b2e342e9ae3d2ed22727c1262f76cc4f345683b5c1715f0250cf4277294f"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74a8d21a09f5e025a9a23e7c0fd2c7fe8e7503e4d356c0a2c1486ba010619f09"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:0142b2cb84a009f8452c8c5a33ace5e3dfec4159e7735f5afe9a4d50a8ea722d"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_34_aarch64.whl", hash = "sha256:12fa6ce40cde3f0b899729dbd7d5e8811cb892d31b6f7d0334a1f37748b789fd"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_34_x86_64.whl", hash = "sha256:5bd3cca1f2aa5dbcf39e2aa13dd094ea181f48959e1071265de49cc2b82525af"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:335a420cfd63fc5bc27308e929bee231c15c85cc4c496610ffb17923abf7f231"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:0e30e5e67aed0187a1764911af023043b4542e70a7461ad20e837e94d23e1d6c"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:3b8d62290ebefd49ee0b3ce7500f5dbdcf13b81402c05f6dafab9a1e1b27212f"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2ef6630e0ec01376f59a006dc72918b1bf436c3b571b80fa1968d775fa02fe7d"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-win32.whl", hash = "sha256:7a4be4cbf241afee43f1c3969b9103a41b40bcb3a3f467ab19f891d9bc4642e4"},
|
||||
{file = "bcrypt-4.3.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5c1949bf259a388863ced887c7861da1df681cb2388645766c89fdfd9004c669"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-macosx_10_12_universal2.whl", hash = "sha256:f81b0ed2639568bf14749112298f9e4e2b28853dab50a8b357e31798686a036d"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:864f8f19adbe13b7de11ba15d85d4a428c7e2f344bac110f667676a0ff84924b"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e36506d001e93bffe59754397572f21bb5dc7c83f54454c990c74a468cd589e"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:842d08d75d9fe9fb94b18b071090220697f9f184d4547179b60734846461ed59"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7c03296b85cb87db865d91da79bf63d5609284fc0cab9472fdd8367bbd830753"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:62f26585e8b219cdc909b6a0069efc5e4267e25d4a3770a364ac58024f62a761"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:beeefe437218a65322fbd0069eb437e7c98137e08f22c4660ac2dc795c31f8bb"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:97eea7408db3a5bcce4a55d13245ab3fa566e23b4c67cd227062bb49e26c585d"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:191354ebfe305e84f344c5964c7cd5f924a3bfc5d405c75ad07f232b6dffb49f"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:41261d64150858eeb5ff43c753c4b216991e0ae16614a308a15d909503617732"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:33752b1ba962ee793fa2b6321404bf20011fe45b9afd2a842139de3011898fef"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:50e6e80a4bfd23a25f5c05b90167c19030cf9f87930f7cb2eacb99f45d1c3304"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-win32.whl", hash = "sha256:67a561c4d9fb9465ec866177e7aebcad08fe23aaf6fbd692a6fab69088abfc51"},
|
||||
{file = "bcrypt-4.3.0-cp38-abi3-win_amd64.whl", hash = "sha256:584027857bc2843772114717a7490a37f68da563b3620f78a849bcb54dc11e62"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:0d3efb1157edebfd9128e4e46e2ac1a64e0c1fe46fb023158a407c7892b0f8c3"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08bacc884fd302b611226c01014eca277d48f0a05187666bca23aac0dad6fe24"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6746e6fec103fcd509b96bacdfdaa2fbde9a553245dbada284435173a6f1aef"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:afe327968aaf13fc143a56a3360cb27d4ad0345e34da12c7290f1b00b8fe9a8b"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d9af79d322e735b1fc33404b5765108ae0ff232d4b54666d46730f8ac1a43676"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f1e3ffa1365e8702dc48c8b360fef8d7afeca482809c5e45e653af82ccd088c1"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:3004df1b323d10021fda07a813fd33e0fd57bef0e9a480bb143877f6cba996fe"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:531457e5c839d8caea9b589a1bcfe3756b0547d7814e9ce3d437f17da75c32b0"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:17a854d9a7a476a89dcef6c8bd119ad23e0f82557afbd2c442777a16408e614f"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:6fb1fd3ab08c0cbc6826a2e0447610c6f09e983a281b919ed721ad32236b8b23"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e965a9c1e9a393b8005031ff52583cedc15b7884fce7deb8b0346388837d6cfe"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:79e70b8342a33b52b55d93b3a59223a844962bef479f6a0ea318ebbcadf71505"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-win32.whl", hash = "sha256:b4d4e57f0a63fd0b358eb765063ff661328f69a04494427265950c71b992a39a"},
|
||||
{file = "bcrypt-4.3.0-cp39-abi3-win_amd64.whl", hash = "sha256:e53e074b120f2877a35cc6c736b8eb161377caae8925c17688bd46ba56daaa5b"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c950d682f0952bafcceaf709761da0a32a942272fad381081b51096ffa46cea1"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:107d53b5c67e0bbc3f03ebf5b030e0403d24dda980f8e244795335ba7b4a027d"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:b693dbb82b3c27a1604a3dff5bfc5418a7e6a781bb795288141e5f80cf3a3492"},
|
||||
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:b6354d3760fcd31994a14c89659dee887f1351a06e5dac3c1142307172a79f90"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a839320bf27d474e52ef8cb16449bb2ce0ba03ca9f44daba6d93fa1d8828e48a"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:bdc6a24e754a555d7316fa4774e64c6c3997d27ed2d1964d55920c7c227bc4ce"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:55a935b8e9a1d2def0626c4269db3fcd26728cbff1e84f0341465c31c4ee56d8"},
|
||||
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:57967b7a28d855313a963aaea51bf6df89f833db4320da458e5b3c5ab6d4c938"},
|
||||
{file = "bcrypt-4.3.0.tar.gz", hash = "sha256:3a3fd2204178b6d2adcf09cb4f6426ffef54762577a7c9b54c159008cb288c18"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["pytest (>=3.2.1,!=3.3.0)"]
|
||||
typecheck = ["mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "black"
|
||||
version = "24.10.0"
|
||||
@@ -1062,6 +981,21 @@ files = [
|
||||
{file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deprecation"
|
||||
version = "2.1.0"
|
||||
description = "A library to handle automated deprecations"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
|
||||
{file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = "*"
|
||||
|
||||
[[package]]
|
||||
name = "discord-py"
|
||||
version = "2.5.2"
|
||||
@@ -1955,6 +1889,23 @@ files = [
|
||||
[package.dependencies]
|
||||
requests = ">=2.20.0,<3.0"
|
||||
|
||||
[[package]]
|
||||
name = "gotrue"
|
||||
version = "2.12.3"
|
||||
description = "Python Client Library for Supabase Auth"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "gotrue-2.12.3-py3-none-any.whl", hash = "sha256:b1a3c6a5fe3f92e854a026c4c19de58706a96fd5fbdcc3d620b2802f6a46a26b"},
|
||||
{file = "gotrue-2.12.3.tar.gz", hash = "sha256:f874cf9d0b2f0335bfbd0d6e29e3f7aff79998cd1c14d2ad814db8c06cee3852"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
pydantic = ">=1.10,<3"
|
||||
pyjwt = ">=2.10.1,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "gravitasml"
|
||||
version = "0.1.3"
|
||||
@@ -4109,6 +4060,24 @@ docs = ["sphinx (>=1.7.1)"]
|
||||
redis = ["redis"]
|
||||
tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)", "types-redis"]
|
||||
|
||||
[[package]]
|
||||
name = "postgrest"
|
||||
version = "1.1.1"
|
||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "postgrest-1.1.1-py3-none-any.whl", hash = "sha256:98a6035ee1d14288484bfe36235942c5fb2d26af6d8120dfe3efbe007859251a"},
|
||||
{file = "postgrest-1.1.1.tar.gz", hash = "sha256:f3bb3e8c4602775c75c844a31f565f5f3dd584df4d36d683f0b67d01a86be322"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2.1.0,<3.0.0"
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
pydantic = ">=1.9,<3.0"
|
||||
strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "posthog"
|
||||
version = "6.1.1"
|
||||
@@ -5353,6 +5322,23 @@ files = [
|
||||
[package.extras]
|
||||
all = ["numpy"]
|
||||
|
||||
[[package]]
|
||||
name = "realtime"
|
||||
version = "2.6.0"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "realtime-2.6.0-py3-none-any.whl", hash = "sha256:a0512d71044c2621455bc87d1c171739967edc161381994de54e0989ca6c348e"},
|
||||
{file = "realtime-2.6.0.tar.gz", hash = "sha256:f68743cff85d3113659fa19835a868674e720465649bf833e1cd47d7da0f7bbd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pydantic = ">=2.11.7,<3.0.0"
|
||||
typing-extensions = ">=4.14.0"
|
||||
websockets = ">=11,<16"
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "6.2.0"
|
||||
@@ -6114,6 +6100,23 @@ typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""
|
||||
[package.extras]
|
||||
full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"]
|
||||
|
||||
[[package]]
|
||||
name = "storage3"
|
||||
version = "0.12.0"
|
||||
description = "Supabase Storage client for Python."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "storage3-0.12.0-py3-none-any.whl", hash = "sha256:1c4585693ca42243ded1512b58e54c697111e91a20916cd14783eebc37e7c87d"},
|
||||
{file = "storage3-0.12.0.tar.gz", hash = "sha256:94243f20922d57738bf42e96b9f5582b4d166e8bf209eccf20b146909f3f71b0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2.1.0,<3.0.0"
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
python-dateutil = ">=2.8.2,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "strenum"
|
||||
version = "0.4.15"
|
||||
@@ -6147,6 +6150,42 @@ files = [
|
||||
requests = {version = ">=2.20", markers = "python_version >= \"3.0\""}
|
||||
typing-extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""}
|
||||
|
||||
[[package]]
|
||||
name = "supabase"
|
||||
version = "2.17.0"
|
||||
description = "Supabase client for Python."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase-2.17.0-py3-none-any.whl", hash = "sha256:2dd804fae8850cebccc9ab8711c2ee9e2f009e847f4c95c092a4423778e3c3f6"},
|
||||
{file = "supabase-2.17.0.tar.gz", hash = "sha256:3207314b540db7e3339fa2500bd977541517afb4d20b7ff93a89b97a05f9df38"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
gotrue = "2.12.3"
|
||||
httpx = ">=0.26,<0.29"
|
||||
postgrest = "1.1.1"
|
||||
realtime = "2.6.0"
|
||||
storage3 = "0.12.0"
|
||||
supafunc = "0.10.1"
|
||||
|
||||
[[package]]
|
||||
name = "supafunc"
|
||||
version = "0.10.1"
|
||||
description = "Library for Supabase Functions"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supafunc-0.10.1-py3-none-any.whl", hash = "sha256:26df9bd25ff2ef56cb5bfb8962de98f43331f7f8ff69572bac3ed9c3a9672040"},
|
||||
{file = "supafunc-0.10.1.tar.gz", hash = "sha256:a5b33c8baecb6b5297d25da29a2503e2ec67ee6986f3d44c137e651b8a59a17d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
strenum = ">=0.4.15,<0.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "tenacity"
|
||||
version = "9.1.2"
|
||||
@@ -7240,4 +7279,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "d0beae09baf94b9a5e7ec787f7da14c9268da37b1dcde7f582b948f2ff121843"
|
||||
content-hash = "13b191b2a1989d3321ff713c66ff6f5f4f3b82d15df4d407e0e5dbf87d7522c4"
|
||||
|
||||
@@ -62,6 +62,7 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
|
||||
sqlalchemy = "^2.0.40"
|
||||
strenum = "^0.4.9"
|
||||
stripe = "^11.5.0"
|
||||
supabase = "2.17.0"
|
||||
tenacity = "^9.1.2"
|
||||
todoist-api-python = "^2.1.7"
|
||||
tweepy = "^4.16.0"
|
||||
@@ -81,7 +82,6 @@ firecrawl-py = "^4.3.6"
|
||||
exa-py = "^1.14.20"
|
||||
croniter = "^6.0.0"
|
||||
stagehand = "^0.5.1"
|
||||
bcrypt = ">=4.1.0,<5.0.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
@@ -115,8 +115,6 @@ format = "linter:format"
|
||||
lint = "linter:lint"
|
||||
test = "run_tests:test"
|
||||
load-store-agents = "test.load_store_agents:run"
|
||||
export-api-schema = "backend.cli.generate_openapi_json:main"
|
||||
oauth-tool = "backend.cli.oauth_tool:cli"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
@@ -12,11 +12,11 @@ generator client {
|
||||
partial_type_generator = "backend/data/partial_types.py"
|
||||
}
|
||||
|
||||
// User model for authentication and platform data
|
||||
// User model to mirror Auth provider users
|
||||
model User {
|
||||
id String @id @default(uuid())
|
||||
id String @id // This should match the Supabase user ID
|
||||
email String @unique
|
||||
emailVerified Boolean @default(false)
|
||||
emailVerified Boolean @default(true)
|
||||
name String?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
@@ -25,11 +25,6 @@ model User {
|
||||
stripeCustomerId String?
|
||||
topUpConfig Json?
|
||||
|
||||
// Authentication fields
|
||||
passwordHash String? // bcrypt hash (nullable for OAuth-only users)
|
||||
googleId String? @unique // Google OAuth user ID
|
||||
role String @default("authenticated") // user role
|
||||
|
||||
maxEmailsPerDay Int @default(3)
|
||||
notifyOnAgentRun Boolean @default(true)
|
||||
notifyOnZeroBalance Boolean @default(true)
|
||||
@@ -44,11 +39,6 @@ model User {
|
||||
|
||||
timezone String @default("not-set")
|
||||
|
||||
// Auth token relations
|
||||
RefreshTokens RefreshToken[]
|
||||
PasswordResetTokens PasswordResetToken[]
|
||||
EmailVerificationTokens EmailVerificationToken[]
|
||||
|
||||
// Relations
|
||||
|
||||
AgentGraphs AgentGraph[]
|
||||
@@ -71,55 +61,6 @@ model User {
|
||||
IntegrationWebhooks IntegrationWebhook[]
|
||||
NotificationBatches UserNotificationBatch[]
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthApplications OAuthApplication[]
|
||||
OAuthAuthorizationCodes OAuthAuthorizationCode[]
|
||||
OAuthAccessTokens OAuthAccessToken[]
|
||||
OAuthRefreshTokens OAuthRefreshToken[]
|
||||
}
|
||||
|
||||
// Refresh tokens for JWT authentication
|
||||
model RefreshToken {
|
||||
id String @id @default(uuid())
|
||||
token String @unique // SHA-256 hashed refresh token
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
expiresAt DateTime
|
||||
createdAt DateTime @default(now())
|
||||
revokedAt DateTime?
|
||||
|
||||
@@index([userId])
|
||||
@@index([expiresAt])
|
||||
@@index([token])
|
||||
}
|
||||
|
||||
// Password reset tokens
|
||||
model PasswordResetToken {
|
||||
id String @id @default(uuid())
|
||||
token String @unique // SHA-256 hashed token
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
expiresAt DateTime
|
||||
usedAt DateTime?
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
@@index([userId])
|
||||
@@index([token])
|
||||
}
|
||||
|
||||
// Email verification tokens
|
||||
model EmailVerificationToken {
|
||||
id String @id @default(uuid())
|
||||
token String @unique // SHA-256 hashed token
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
expiresAt DateTime
|
||||
usedAt DateTime?
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
@@index([userId])
|
||||
@@index([token])
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
@@ -983,7 +924,6 @@ enum SubmissionStatus {
|
||||
}
|
||||
|
||||
enum APIKeyPermission {
|
||||
IDENTITY // Info about the authenticated user
|
||||
EXECUTE_GRAPH // Can execute agent graphs
|
||||
READ_GRAPH // Can get graph versions and details
|
||||
EXECUTE_BLOCK // Can execute individual blocks
|
||||
@@ -1035,113 +975,3 @@ enum APIKeyStatus {
|
||||
REVOKED
|
||||
SUSPENDED
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////// OAUTH PROVIDER TABLES //////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// OAuth2 applications that can access AutoGPT on behalf of users
|
||||
model OAuthApplication {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Application metadata
|
||||
name String
|
||||
description String?
|
||||
logoUrl String? // URL to app logo stored in GCS
|
||||
clientId String @unique
|
||||
clientSecret String // Hashed with Scrypt (same as API keys)
|
||||
clientSecretSalt String // Salt for Scrypt hashing
|
||||
|
||||
// OAuth configuration
|
||||
redirectUris String[] // Allowed callback URLs
|
||||
grantTypes String[] @default(["authorization_code", "refresh_token"])
|
||||
scopes APIKeyPermission[] // Which permissions the app can request
|
||||
|
||||
// Application management
|
||||
ownerId String
|
||||
Owner User @relation(fields: [ownerId], references: [id], onDelete: Cascade)
|
||||
isActive Boolean @default(true)
|
||||
|
||||
// Relations
|
||||
AuthorizationCodes OAuthAuthorizationCode[]
|
||||
AccessTokens OAuthAccessToken[]
|
||||
RefreshTokens OAuthRefreshToken[]
|
||||
|
||||
@@index([clientId])
|
||||
@@index([ownerId])
|
||||
}
|
||||
|
||||
// Temporary authorization codes (10 min TTL)
|
||||
model OAuthAuthorizationCode {
|
||||
id String @id @default(uuid())
|
||||
code String @unique
|
||||
createdAt DateTime @default(now())
|
||||
expiresAt DateTime // Now + 10 minutes
|
||||
|
||||
applicationId String
|
||||
Application OAuthApplication @relation(fields: [applicationId], references: [id], onDelete: Cascade)
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
scopes APIKeyPermission[]
|
||||
redirectUri String // Must match one from application
|
||||
|
||||
// PKCE (Proof Key for Code Exchange) support
|
||||
codeChallenge String?
|
||||
codeChallengeMethod String? // "S256" or "plain"
|
||||
|
||||
usedAt DateTime? // Set when code is consumed
|
||||
|
||||
@@index([code])
|
||||
@@index([applicationId, userId])
|
||||
@@index([expiresAt]) // For cleanup
|
||||
}
|
||||
|
||||
// Access tokens (1 hour TTL)
|
||||
model OAuthAccessToken {
|
||||
id String @id @default(uuid())
|
||||
token String @unique // SHA256 hash of plaintext token
|
||||
createdAt DateTime @default(now())
|
||||
expiresAt DateTime // Now + 1 hour
|
||||
|
||||
applicationId String
|
||||
Application OAuthApplication @relation(fields: [applicationId], references: [id], onDelete: Cascade)
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
scopes APIKeyPermission[]
|
||||
|
||||
revokedAt DateTime? // Set when token is revoked
|
||||
|
||||
@@index([token]) // For token lookup
|
||||
@@index([userId, applicationId])
|
||||
@@index([expiresAt]) // For cleanup
|
||||
}
|
||||
|
||||
// Refresh tokens (30 days TTL)
|
||||
model OAuthRefreshToken {
|
||||
id String @id @default(uuid())
|
||||
token String @unique // SHA256 hash of plaintext token
|
||||
createdAt DateTime @default(now())
|
||||
expiresAt DateTime // Now + 30 days
|
||||
|
||||
applicationId String
|
||||
Application OAuthApplication @relation(fields: [applicationId], references: [id], onDelete: Cascade)
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
scopes APIKeyPermission[]
|
||||
|
||||
revokedAt DateTime? // Set when token is revoked
|
||||
|
||||
@@index([token]) // For token lookup
|
||||
@@index([userId, applicationId])
|
||||
@@index([expiresAt]) // For cleanup
|
||||
}
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Migrate Large Tables: Stream execution history from source to destination
|
||||
#
|
||||
# This script streams the large execution tables that were excluded from
|
||||
# the initial migration. Run this AFTER migrate_to_gcp.sh completes.
|
||||
#
|
||||
# Tables migrated (in order of size):
|
||||
# - NotificationEvent (94 MB)
|
||||
# - AgentNodeExecutionKeyValueData (792 KB)
|
||||
# - AgentGraphExecution (1.3 GB)
|
||||
# - AgentNodeExecution (6 GB)
|
||||
# - AgentNodeExecutionInputOutput (30 GB)
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/migrate_big_tables.sh \
|
||||
# --source 'postgresql://user:pass@host:5432/db?schema=platform' \
|
||||
# --dest 'postgresql://user:pass@host:5432/db?schema=platform'
|
||||
#
|
||||
# Options:
|
||||
# --table <name> Migrate only a specific table
|
||||
# --dry-run Show what would be done without migrating
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Colors
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m'
|
||||
|
||||
log_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
|
||||
log_success() { echo -e "${GREEN}[SUCCESS]${NC} $1"; }
|
||||
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
||||
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
||||
|
||||
# Arguments
|
||||
SOURCE_URL=""
|
||||
DEST_URL=""
|
||||
DRY_RUN=false
|
||||
SINGLE_TABLE=""
|
||||
|
||||
# Tables to migrate (ordered smallest to largest)
|
||||
TABLES=(
|
||||
"NotificationEvent"
|
||||
"AgentNodeExecutionKeyValueData"
|
||||
"AgentGraphExecution"
|
||||
"AgentNodeExecution"
|
||||
"AgentNodeExecutionInputOutput"
|
||||
)
|
||||
|
||||
usage() {
|
||||
cat << EOF
|
||||
Usage: $(basename "$0") --source <url> --dest <url> [options]
|
||||
|
||||
Required:
|
||||
--source <url> Source database URL with ?schema=platform
|
||||
--dest <url> Destination database URL with ?schema=platform
|
||||
|
||||
Options:
|
||||
--table <name> Migrate only a specific table (e.g., AgentGraphExecution)
|
||||
--dry-run Show what would be done without migrating
|
||||
--help Show this help
|
||||
|
||||
Tables migrated (in order):
|
||||
1. NotificationEvent (94 MB)
|
||||
2. AgentNodeExecutionKeyValueData (792 KB)
|
||||
3. AgentGraphExecution (1.3 GB)
|
||||
4. AgentNodeExecution (6 GB)
|
||||
5. AgentNodeExecutionInputOutput (30 GB)
|
||||
|
||||
EOF
|
||||
exit 1
|
||||
}
|
||||
|
||||
parse_args() {
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--source) SOURCE_URL="$2"; shift 2 ;;
|
||||
--dest) DEST_URL="$2"; shift 2 ;;
|
||||
--table) SINGLE_TABLE="$2"; shift 2 ;;
|
||||
--dry-run) DRY_RUN=true; shift ;;
|
||||
--help|-h) usage ;;
|
||||
*) log_error "Unknown option: $1"; usage ;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ -z "$SOURCE_URL" ]]; then
|
||||
log_error "Missing --source"
|
||||
usage
|
||||
fi
|
||||
|
||||
if [[ -z "$DEST_URL" ]]; then
|
||||
log_error "Missing --dest"
|
||||
usage
|
||||
fi
|
||||
}
|
||||
|
||||
get_schema_from_url() {
|
||||
local url="$1"
|
||||
local schema=$(echo "$url" | sed -n 's/.*schema=\([^&]*\).*/\1/p')
|
||||
echo "${schema:-platform}"
|
||||
}
|
||||
|
||||
get_base_url() {
|
||||
local url="$1"
|
||||
echo "${url%%\?*}"
|
||||
}
|
||||
|
||||
get_table_size() {
|
||||
local base_url="$1"
|
||||
local schema="$2"
|
||||
local table="$3"
|
||||
|
||||
psql "${base_url}" -t -c "
|
||||
SELECT pg_size_pretty(pg_total_relation_size('${schema}.\"${table}\"'))
|
||||
" 2>/dev/null | tr -d ' ' || echo "unknown"
|
||||
}
|
||||
|
||||
get_table_count() {
|
||||
local base_url="$1"
|
||||
local schema="$2"
|
||||
local table="$3"
|
||||
|
||||
psql "${base_url}" -t -c "
|
||||
SELECT COUNT(*) FROM ${schema}.\"${table}\"
|
||||
" 2>/dev/null | tr -d ' ' || echo "0"
|
||||
}
|
||||
|
||||
migrate_table() {
|
||||
local table="$1"
|
||||
local source_base=$(get_base_url "$SOURCE_URL")
|
||||
local dest_base=$(get_base_url "$DEST_URL")
|
||||
local schema=$(get_schema_from_url "$SOURCE_URL")
|
||||
|
||||
log_info "=== Migrating ${table} ==="
|
||||
|
||||
# Get source stats
|
||||
local size=$(get_table_size "$source_base" "$schema" "$table")
|
||||
local count=$(get_table_count "$source_base" "$schema" "$table")
|
||||
log_info "Source: ${count} rows (${size})"
|
||||
|
||||
if [[ "$DRY_RUN" == true ]]; then
|
||||
log_info "DRY RUN: Would stream ${table} from source to destination"
|
||||
return
|
||||
fi
|
||||
|
||||
# Check if destination already has data
|
||||
local dest_count=$(get_table_count "$dest_base" "$schema" "$table")
|
||||
if [[ "$dest_count" != "0" ]]; then
|
||||
log_warn "Destination already has ${dest_count} rows in ${table}"
|
||||
read -p "Continue and add more rows? (y/N) " -n 1 -r
|
||||
echo ""
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
log_info "Skipping ${table}"
|
||||
return
|
||||
fi
|
||||
fi
|
||||
|
||||
log_info "Streaming ${table} (this may take a while for large tables)..."
|
||||
local start_time=$(date +%s)
|
||||
|
||||
# Stream directly from source to destination
|
||||
pg_dump "${source_base}" \
|
||||
--table="${schema}.\"${table}\"" \
|
||||
--data-only \
|
||||
--no-owner \
|
||||
--no-privileges \
|
||||
2>/dev/null \
|
||||
| grep -v '\\restrict' \
|
||||
| psql "${dest_base}" -q
|
||||
|
||||
local end_time=$(date +%s)
|
||||
local duration=$((end_time - start_time))
|
||||
|
||||
# Verify
|
||||
local new_dest_count=$(get_table_count "$dest_base" "$schema" "$table")
|
||||
log_success "${table}: ${new_dest_count} rows migrated in ${duration}s"
|
||||
}
|
||||
|
||||
main() {
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " Migrate Large Tables"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
|
||||
parse_args "$@"
|
||||
|
||||
local source_base=$(get_base_url "$SOURCE_URL")
|
||||
local dest_base=$(get_base_url "$DEST_URL")
|
||||
|
||||
log_info "Source: ${source_base}"
|
||||
log_info "Destination: ${dest_base}"
|
||||
[[ "$DRY_RUN" == true ]] && log_warn "DRY RUN MODE"
|
||||
echo ""
|
||||
|
||||
# Test connections
|
||||
log_info "Testing connections..."
|
||||
if ! psql "${source_base}" -c "SELECT 1" > /dev/null 2>&1; then
|
||||
log_error "Cannot connect to source"
|
||||
exit 1
|
||||
fi
|
||||
if ! psql "${dest_base}" -c "SELECT 1" > /dev/null 2>&1; then
|
||||
log_error "Cannot connect to destination"
|
||||
exit 1
|
||||
fi
|
||||
log_success "Connections OK"
|
||||
echo ""
|
||||
|
||||
# Determine which tables to migrate
|
||||
local tables_to_migrate=()
|
||||
if [[ -n "$SINGLE_TABLE" ]]; then
|
||||
tables_to_migrate=("$SINGLE_TABLE")
|
||||
else
|
||||
tables_to_migrate=("${TABLES[@]}")
|
||||
fi
|
||||
|
||||
# Show plan
|
||||
log_info "Tables to migrate:"
|
||||
local schema=$(get_schema_from_url "$SOURCE_URL")
|
||||
for table in "${tables_to_migrate[@]}"; do
|
||||
local size=$(get_table_size "$source_base" "$schema" "$table")
|
||||
echo " - ${table} (${size})"
|
||||
done
|
||||
echo ""
|
||||
|
||||
if [[ "$DRY_RUN" != true ]]; then
|
||||
log_warn "This will stream large amounts of data to the destination."
|
||||
read -p "Continue? (y/N) " -n 1 -r
|
||||
echo ""
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
log_info "Cancelled"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
echo ""
|
||||
log_info "Starting migration at $(date)"
|
||||
echo ""
|
||||
|
||||
# Migrate each table
|
||||
for table in "${tables_to_migrate[@]}"; do
|
||||
migrate_table "$table"
|
||||
echo ""
|
||||
done
|
||||
|
||||
log_success "Migration completed at $(date)"
|
||||
echo ""
|
||||
}
|
||||
|
||||
main "$@"
|
||||
@@ -1,271 +0,0 @@
|
||||
"""
|
||||
Migration script to copy password hashes from Supabase auth.users to platform.User.
|
||||
|
||||
This script should be run BEFORE removing Supabase services to preserve user credentials.
|
||||
It copies bcrypt password hashes from Supabase's auth.users table to the platform.User table,
|
||||
allowing users to continue using their existing passwords after the migration.
|
||||
|
||||
Usage:
|
||||
cd backend
|
||||
poetry run python scripts/migrate_supabase_users.py [options]
|
||||
|
||||
Options:
|
||||
--dry-run Preview what would be migrated without making changes
|
||||
--database-url <url> Database URL (overrides DATABASE_URL env var)
|
||||
|
||||
Examples:
|
||||
# Using environment variable
|
||||
poetry run python scripts/migrate_supabase_users.py --dry-run
|
||||
|
||||
# Using explicit database URL
|
||||
poetry run python scripts/migrate_supabase_users.py \
|
||||
--database-url "postgresql://user:pass@host:5432/db?schema=platform"
|
||||
|
||||
Prerequisites:
|
||||
- Supabase services must be running (auth.users table must exist)
|
||||
- Database migration 'add_native_auth' must be applied first
|
||||
- Either DATABASE_URL env var or --database-url must be provided
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from prisma import Prisma
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def migrate_credentials(db: Prisma) -> int:
|
||||
"""
|
||||
Copy bcrypt password hashes from auth.users to platform.User.
|
||||
|
||||
Returns the number of users updated.
|
||||
"""
|
||||
logger.info("Migrating user credentials from auth.users to platform.User...")
|
||||
|
||||
result = await db.execute_raw(
|
||||
"""
|
||||
UPDATE platform."User" u
|
||||
SET
|
||||
"passwordHash" = a.encrypted_password,
|
||||
"emailVerified" = (a.email_confirmed_at IS NOT NULL)
|
||||
FROM auth.users a
|
||||
WHERE u.id::text = a.id::text
|
||||
AND a.encrypted_password IS NOT NULL
|
||||
AND u."passwordHash" IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
logger.info(f"Updated {result} users with credentials")
|
||||
return result
|
||||
|
||||
|
||||
async def migrate_google_oauth_users(db: Prisma) -> int:
|
||||
"""
|
||||
Copy Google OAuth user IDs from auth.users to platform.User.
|
||||
|
||||
Returns the number of users updated.
|
||||
"""
|
||||
logger.info("Migrating Google OAuth users from auth.users to platform.User...")
|
||||
|
||||
result = await db.execute_raw(
|
||||
"""
|
||||
UPDATE platform."User" u
|
||||
SET "googleId" = (a.raw_app_meta_data->>'provider_id')::text
|
||||
FROM auth.users a
|
||||
WHERE u.id::text = a.id::text
|
||||
AND a.raw_app_meta_data->>'provider' = 'google'
|
||||
AND a.raw_app_meta_data->>'provider_id' IS NOT NULL
|
||||
AND u."googleId" IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
logger.info(f"Updated {result} users with Google OAuth IDs")
|
||||
return result
|
||||
|
||||
|
||||
async def get_migration_stats(db: Prisma) -> dict:
|
||||
"""Get statistics about the migration."""
|
||||
# Count users in platform.User
|
||||
platform_users = await db.user.count()
|
||||
|
||||
# Count users with credentials (not null)
|
||||
users_with_credentials = await db.user.count(
|
||||
where={"passwordHash": {"not": None}} # type: ignore
|
||||
)
|
||||
|
||||
# Count users with Google OAuth (not null)
|
||||
users_with_google = await db.user.count(
|
||||
where={"googleId": {"not": None}} # type: ignore
|
||||
)
|
||||
|
||||
# Count users without any auth method
|
||||
users_without_auth = await db.user.count(
|
||||
where={"passwordHash": None, "googleId": None}
|
||||
)
|
||||
|
||||
return {
|
||||
"total_platform_users": platform_users,
|
||||
"users_with_credentials": users_with_credentials,
|
||||
"users_with_google_oauth": users_with_google,
|
||||
"users_without_auth": users_without_auth,
|
||||
}
|
||||
|
||||
|
||||
async def verify_auth_users_exist(db: Prisma) -> bool:
|
||||
"""Check if auth.users table exists and has data."""
|
||||
try:
|
||||
result = await db.query_raw("SELECT COUNT(*) as count FROM auth.users")
|
||||
count = result[0]["count"] if result else 0
|
||||
logger.info(f"Found {count} users in auth.users table")
|
||||
return count > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Cannot access auth.users table: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def preview_migration(db: Prisma) -> dict:
|
||||
"""Preview what would be migrated without making changes."""
|
||||
logger.info("Previewing migration (dry-run mode)...")
|
||||
|
||||
# Count users that would have credentials migrated
|
||||
credentials_preview = await db.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM platform."User" u
|
||||
JOIN auth.users a ON u.id::text = a.id::text
|
||||
WHERE a.encrypted_password IS NOT NULL
|
||||
AND u."passwordHash" IS NULL
|
||||
"""
|
||||
)
|
||||
credentials_to_migrate = (
|
||||
credentials_preview[0]["count"] if credentials_preview else 0
|
||||
)
|
||||
|
||||
# Count users that would have Google OAuth migrated
|
||||
google_preview = await db.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM platform."User" u
|
||||
JOIN auth.users a ON u.id::text = a.id::text
|
||||
WHERE a.raw_app_meta_data->>'provider' = 'google'
|
||||
AND a.raw_app_meta_data->>'provider_id' IS NOT NULL
|
||||
AND u."googleId" IS NULL
|
||||
"""
|
||||
)
|
||||
google_to_migrate = google_preview[0]["count"] if google_preview else 0
|
||||
|
||||
return {
|
||||
"credentials_to_migrate": credentials_to_migrate,
|
||||
"google_oauth_to_migrate": google_to_migrate,
|
||||
}
|
||||
|
||||
|
||||
async def main(dry_run: bool = False):
|
||||
"""Run the migration."""
|
||||
logger.info("=" * 60)
|
||||
logger.info("Supabase User Migration Script")
|
||||
if dry_run:
|
||||
logger.info(">>> DRY RUN MODE - No changes will be made <<<")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Started at: {datetime.now().isoformat()}")
|
||||
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
try:
|
||||
# Check if auth.users exists
|
||||
if not await verify_auth_users_exist(db):
|
||||
logger.error(
|
||||
"Cannot find auth.users table or it's empty. "
|
||||
"Make sure Supabase is running and has users."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Get stats before migration
|
||||
logger.info("\n--- Current State ---")
|
||||
stats_before = await get_migration_stats(db)
|
||||
for key, value in stats_before.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
if dry_run:
|
||||
# Preview mode - show what would be migrated
|
||||
logger.info("\n--- Preview (would be migrated) ---")
|
||||
preview = await preview_migration(db)
|
||||
logger.info(
|
||||
f" Credentials to migrate: {preview['credentials_to_migrate']}"
|
||||
)
|
||||
logger.info(
|
||||
f" Google OAuth IDs to migrate: {preview['google_oauth_to_migrate']}"
|
||||
)
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("Dry run complete. Run without --dry-run to perform migration.")
|
||||
logger.info("=" * 60)
|
||||
else:
|
||||
# Run actual migrations
|
||||
logger.info("\n--- Running Migration ---")
|
||||
credentials_migrated = await migrate_credentials(db)
|
||||
google_migrated = await migrate_google_oauth_users(db)
|
||||
|
||||
# Get stats after migration
|
||||
logger.info("\n--- After Migration ---")
|
||||
stats_after = await get_migration_stats(db)
|
||||
for key, value in stats_after.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
# Summary
|
||||
logger.info("\n--- Summary ---")
|
||||
logger.info(f"Credentials migrated: {credentials_migrated}")
|
||||
logger.info(f"Google OAuth IDs migrated: {google_migrated}")
|
||||
logger.info(
|
||||
f"Users still without auth: {stats_after['users_without_auth']} "
|
||||
"(these may be OAuth users from other providers)"
|
||||
)
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("Migration completed successfully!")
|
||||
logger.info("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Migrate user auth data from Supabase to native auth"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Preview what would be migrated without making changes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--database-url",
|
||||
type=str,
|
||||
help="Database URL (overrides DATABASE_URL env var)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
args = parse_args()
|
||||
|
||||
# Override DATABASE_URL if provided via command line
|
||||
if args.database_url:
|
||||
os.environ["DATABASE_URL"] = args.database_url
|
||||
os.environ["DIRECT_URL"] = args.database_url
|
||||
|
||||
asyncio.run(main(dry_run=args.dry_run))
|
||||
@@ -1,482 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Database Migration Script: Supabase to GCP Cloud SQL
|
||||
#
|
||||
# This script migrates the AutoGPT Platform database from Supabase to a new PostgreSQL instance.
|
||||
#
|
||||
# Migration Steps:
|
||||
# 0. Nuke destination database (drop schema, recreate, apply migrations)
|
||||
# 1. Export platform schema data from source
|
||||
# 2. Export auth.users data from source (for password hashes, OAuth IDs)
|
||||
# 3. Import platform schema data to destination
|
||||
# 4. Update User table in destination with auth data
|
||||
# 5. Refresh materialized views
|
||||
#
|
||||
# Prerequisites:
|
||||
# - pg_dump and psql (PostgreSQL 15+)
|
||||
# - poetry installed (for Prisma migrations)
|
||||
# - Source and destination databases accessible
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/migrate_to_gcp.sh \
|
||||
# --source 'postgresql://user:pass@host:5432/db?schema=platform' \
|
||||
# --dest 'postgresql://user:pass@host:5432/db?schema=platform'
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m'
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
BACKEND_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
BACKUP_DIR="${BACKEND_DIR}/migration_backups"
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
# Command line arguments
|
||||
SOURCE_URL=""
|
||||
DEST_URL=""
|
||||
DRY_RUN=false
|
||||
|
||||
log_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
|
||||
log_success() { echo -e "${GREEN}[SUCCESS]${NC} $1"; }
|
||||
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
||||
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
||||
|
||||
usage() {
|
||||
cat << EOF
|
||||
Usage: $(basename "$0") --source <url> --dest <url> [options]
|
||||
|
||||
Required:
|
||||
--source <url> Source database URL with ?schema=platform
|
||||
--dest <url> Destination database URL with ?schema=platform
|
||||
|
||||
Options:
|
||||
--dry-run Preview without making changes
|
||||
--help Show this help
|
||||
|
||||
Migration Steps:
|
||||
0. Nuke destination database (DROP SCHEMA, recreate, apply Prisma migrations)
|
||||
1. Export platform schema data from source (READ-ONLY)
|
||||
2. Export auth.users data from source (READ-ONLY)
|
||||
3. Import platform data to destination
|
||||
4. Update User table with auth data (passwords, OAuth IDs)
|
||||
5. Refresh materialized views
|
||||
|
||||
EOF
|
||||
exit 1
|
||||
}
|
||||
|
||||
parse_args() {
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--source) SOURCE_URL="$2"; shift 2 ;;
|
||||
--dest) DEST_URL="$2"; shift 2 ;;
|
||||
--dry-run) DRY_RUN=true; shift ;;
|
||||
--help|-h) usage ;;
|
||||
*) log_error "Unknown option: $1"; usage ;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ -z "$SOURCE_URL" ]]; then
|
||||
log_error "Missing --source"
|
||||
usage
|
||||
fi
|
||||
|
||||
if [[ -z "$DEST_URL" ]]; then
|
||||
log_error "Missing --dest"
|
||||
usage
|
||||
fi
|
||||
}
|
||||
|
||||
get_schema_from_url() {
|
||||
local url="$1"
|
||||
local schema=$(echo "$url" | sed -n 's/.*schema=\([^&]*\).*/\1/p')
|
||||
echo "${schema:-platform}"
|
||||
}
|
||||
|
||||
get_base_url() {
|
||||
local url="$1"
|
||||
echo "${url%%\?*}"
|
||||
}
|
||||
|
||||
test_connections() {
|
||||
local source_base=$(get_base_url "$SOURCE_URL")
|
||||
local dest_base=$(get_base_url "$DEST_URL")
|
||||
|
||||
log_info "Testing source connection..."
|
||||
if ! psql "${source_base}" -c "SELECT 1" > /dev/null 2>&1; then
|
||||
log_error "Cannot connect to source database"
|
||||
psql "${source_base}" -c "SELECT 1" 2>&1 || true
|
||||
exit 1
|
||||
fi
|
||||
log_success "Source connection OK"
|
||||
|
||||
log_info "Testing destination connection..."
|
||||
if ! psql "${dest_base}" -c "SELECT 1" > /dev/null 2>&1; then
|
||||
log_error "Cannot connect to destination database"
|
||||
psql "${dest_base}" -c "SELECT 1" 2>&1 || true
|
||||
exit 1
|
||||
fi
|
||||
log_success "Destination connection OK"
|
||||
}
|
||||
|
||||
# ============================================
|
||||
# STEP 0: Nuke destination database
|
||||
# ============================================
|
||||
nuke_destination() {
|
||||
local schema=$(get_schema_from_url "$DEST_URL")
|
||||
local dest_base=$(get_base_url "$DEST_URL")
|
||||
|
||||
log_info "=== STEP 0: Nuking destination database ==="
|
||||
|
||||
if [[ "$DRY_RUN" == true ]]; then
|
||||
log_info "DRY RUN: Would drop and recreate schema '${schema}' in destination"
|
||||
return
|
||||
fi
|
||||
|
||||
# Show what exists in destination
|
||||
log_info "Current destination state:"
|
||||
local user_count=$(psql "${dest_base}" -t -c "SELECT COUNT(*) FROM ${schema}.\"User\"" 2>/dev/null | tr -d ' ' || echo "0")
|
||||
local graph_count=$(psql "${dest_base}" -t -c "SELECT COUNT(*) FROM ${schema}.\"AgentGraph\"" 2>/dev/null | tr -d ' ' || echo "0")
|
||||
echo " - Users: ${user_count}"
|
||||
echo " - AgentGraphs: ${graph_count}"
|
||||
|
||||
echo ""
|
||||
log_warn "⚠️ WARNING: This will PERMANENTLY DELETE all data in the destination database!"
|
||||
log_warn "Schema '${schema}' will be dropped and recreated."
|
||||
echo ""
|
||||
read -p "Type 'NUKE' to confirm deletion: " -r
|
||||
echo ""
|
||||
|
||||
if [[ "$REPLY" != "NUKE" ]]; then
|
||||
log_info "Cancelled - destination not modified"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
log_info "Dropping schema '${schema}'..."
|
||||
psql "${dest_base}" -c "DROP SCHEMA IF EXISTS ${schema} CASCADE;"
|
||||
|
||||
log_info "Recreating schema '${schema}'..."
|
||||
psql "${dest_base}" -c "CREATE SCHEMA ${schema};"
|
||||
|
||||
log_info "Applying Prisma migrations..."
|
||||
cd "${BACKEND_DIR}"
|
||||
DATABASE_URL="${DEST_URL}" DIRECT_URL="${DEST_URL}" poetry run prisma migrate deploy
|
||||
|
||||
log_success "Destination database reset complete"
|
||||
}
|
||||
|
||||
# ============================================
|
||||
# STEP 1: Export platform schema data
|
||||
# ============================================
|
||||
export_platform_data() {
|
||||
local schema=$(get_schema_from_url "$SOURCE_URL")
|
||||
local base_url=$(get_base_url "$SOURCE_URL")
|
||||
local output_file="${BACKUP_DIR}/platform_data_${TIMESTAMP}.sql"
|
||||
|
||||
log_info "=== STEP 1: Exporting platform schema data ==="
|
||||
mkdir -p "${BACKUP_DIR}"
|
||||
|
||||
if [[ "$DRY_RUN" == true ]]; then
|
||||
log_info "DRY RUN: Would export schema '${schema}' to ${output_file}"
|
||||
log_info "DRY RUN: Excluding large execution tables"
|
||||
touch "$output_file"
|
||||
echo "$output_file"
|
||||
return
|
||||
fi
|
||||
|
||||
log_info "Exporting from schema: ${schema}"
|
||||
log_info "EXCLUDING: AgentGraphExecution, AgentNodeExecution, AgentNodeExecutionInputOutput, AgentNodeExecutionKeyValueData, NotificationEvent"
|
||||
|
||||
pg_dump "${base_url}" \
|
||||
--schema="${schema}" \
|
||||
--format=plain \
|
||||
--no-owner \
|
||||
--no-privileges \
|
||||
--data-only \
|
||||
--exclude-table="${schema}.AgentGraphExecution" \
|
||||
--exclude-table="${schema}.AgentNodeExecution" \
|
||||
--exclude-table="${schema}.AgentNodeExecutionInputOutput" \
|
||||
--exclude-table="${schema}.AgentNodeExecutionKeyValueData" \
|
||||
--exclude-table="${schema}.NotificationEvent" \
|
||||
--file="${output_file}" 2>&1
|
||||
|
||||
# Remove Supabase-specific commands that break import
|
||||
sed -i.bak '/\\restrict/d' "${output_file}"
|
||||
rm -f "${output_file}.bak"
|
||||
|
||||
local size=$(du -h "${output_file}" | cut -f1)
|
||||
log_success "Platform data exported: ${output_file} (${size})"
|
||||
echo "$output_file"
|
||||
}
|
||||
|
||||
# ============================================
|
||||
# STEP 2: Export auth.users data
|
||||
# ============================================
|
||||
export_auth_data() {
|
||||
local base_url=$(get_base_url "$SOURCE_URL")
|
||||
local output_file="${BACKUP_DIR}/auth_users_${TIMESTAMP}.csv"
|
||||
|
||||
log_info "=== STEP 2: Exporting auth.users data ==="
|
||||
|
||||
# Check if auth.users exists
|
||||
local auth_exists=$(psql "${base_url}" -t -c "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'auth' AND table_name = 'users')" 2>/dev/null | tr -d ' ')
|
||||
|
||||
if [[ "$auth_exists" != "t" ]]; then
|
||||
log_warn "No auth.users table found - skipping auth export"
|
||||
echo ""
|
||||
return
|
||||
fi
|
||||
|
||||
if [[ "$DRY_RUN" == true ]]; then
|
||||
log_info "DRY RUN: Would export auth.users to ${output_file}"
|
||||
touch "$output_file"
|
||||
echo "$output_file"
|
||||
return
|
||||
fi
|
||||
|
||||
log_info "Extracting auth data (passwords, OAuth IDs, email verification)..."
|
||||
|
||||
psql "${base_url}" -c "\COPY (
|
||||
SELECT
|
||||
id,
|
||||
encrypted_password,
|
||||
(email_confirmed_at IS NOT NULL) as email_verified,
|
||||
CASE
|
||||
WHEN raw_app_meta_data->>'provider' = 'google'
|
||||
THEN raw_app_meta_data->>'provider_id'
|
||||
ELSE NULL
|
||||
END as google_id
|
||||
FROM auth.users
|
||||
WHERE encrypted_password IS NOT NULL
|
||||
OR raw_app_meta_data->>'provider' = 'google'
|
||||
) TO '${output_file}' WITH CSV HEADER"
|
||||
|
||||
local count=$(wc -l < "${output_file}" | tr -d ' ')
|
||||
log_success "Auth data exported: ${output_file} (${count} rows including header)"
|
||||
echo "$output_file"
|
||||
}
|
||||
|
||||
# ============================================
|
||||
# STEP 3: Import platform data to destination
|
||||
# ============================================
|
||||
import_platform_data() {
|
||||
local platform_file="$1"
|
||||
local dest_base=$(get_base_url "$DEST_URL")
|
||||
|
||||
log_info "=== STEP 3: Importing platform data to destination ==="
|
||||
|
||||
if [[ "$DRY_RUN" == true ]]; then
|
||||
log_info "DRY RUN: Would import ${platform_file} to destination"
|
||||
return
|
||||
fi
|
||||
|
||||
if [[ ! -f "$platform_file" ]]; then
|
||||
log_error "Platform data file not found: ${platform_file}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_info "Importing platform data (this may take a while)..."
|
||||
|
||||
# Import with error logging
|
||||
psql "${dest_base}" -f "${platform_file}" 2>&1 | tee "${BACKUP_DIR}/import_log_${TIMESTAMP}.txt" | head -100
|
||||
|
||||
log_success "Platform data import completed"
|
||||
}
|
||||
|
||||
# ============================================
|
||||
# STEP 4: Update User table with auth data
|
||||
# ============================================
|
||||
update_user_auth_data() {
|
||||
local auth_file="$1"
|
||||
local schema=$(get_schema_from_url "$DEST_URL")
|
||||
local dest_base=$(get_base_url "$DEST_URL")
|
||||
|
||||
log_info "=== STEP 4: Updating User table with auth data ==="
|
||||
|
||||
if [[ -z "$auth_file" || ! -f "$auth_file" ]]; then
|
||||
log_warn "No auth data file - skipping User auth update"
|
||||
return
|
||||
fi
|
||||
|
||||
if [[ "$DRY_RUN" == true ]]; then
|
||||
log_info "DRY RUN: Would update User table with auth data"
|
||||
return
|
||||
fi
|
||||
|
||||
log_info "Creating temporary table for auth data..."
|
||||
|
||||
psql "${dest_base}" << EOF
|
||||
-- Create temp table for auth data
|
||||
CREATE TEMP TABLE temp_auth_users (
|
||||
id UUID,
|
||||
encrypted_password TEXT,
|
||||
email_verified BOOLEAN,
|
||||
google_id TEXT
|
||||
);
|
||||
|
||||
-- Import CSV
|
||||
\COPY temp_auth_users FROM '${auth_file}' WITH CSV HEADER;
|
||||
|
||||
-- Update User table with password hashes
|
||||
UPDATE ${schema}."User" u
|
||||
SET "passwordHash" = t.encrypted_password
|
||||
FROM temp_auth_users t
|
||||
WHERE u.id = t.id
|
||||
AND t.encrypted_password IS NOT NULL
|
||||
AND u."passwordHash" IS NULL;
|
||||
|
||||
-- Update User table with email verification
|
||||
UPDATE ${schema}."User" u
|
||||
SET "emailVerified" = t.email_verified
|
||||
FROM temp_auth_users t
|
||||
WHERE u.id = t.id
|
||||
AND t.email_verified = true;
|
||||
|
||||
-- Update User table with Google OAuth IDs
|
||||
UPDATE ${schema}."User" u
|
||||
SET "googleId" = t.google_id
|
||||
FROM temp_auth_users t
|
||||
WHERE u.id = t.id
|
||||
AND t.google_id IS NOT NULL
|
||||
AND u."googleId" IS NULL;
|
||||
|
||||
-- Show results
|
||||
SELECT
|
||||
'Total Users' as metric, COUNT(*)::text as value FROM ${schema}."User"
|
||||
UNION ALL
|
||||
SELECT 'With Password', COUNT(*)::text FROM ${schema}."User" WHERE "passwordHash" IS NOT NULL
|
||||
UNION ALL
|
||||
SELECT 'With Google OAuth', COUNT(*)::text FROM ${schema}."User" WHERE "googleId" IS NOT NULL
|
||||
UNION ALL
|
||||
SELECT 'Email Verified', COUNT(*)::text FROM ${schema}."User" WHERE "emailVerified" = true;
|
||||
|
||||
DROP TABLE temp_auth_users;
|
||||
EOF
|
||||
|
||||
log_success "User auth data updated"
|
||||
}
|
||||
|
||||
# ============================================
|
||||
# STEP 5: Refresh materialized views
|
||||
# ============================================
|
||||
refresh_views() {
|
||||
local schema=$(get_schema_from_url "$DEST_URL")
|
||||
local dest_base=$(get_base_url "$DEST_URL")
|
||||
|
||||
log_info "=== STEP 5: Refreshing materialized views ==="
|
||||
|
||||
if [[ "$DRY_RUN" == true ]]; then
|
||||
log_info "DRY RUN: Would refresh materialized views"
|
||||
return
|
||||
fi
|
||||
|
||||
psql "${dest_base}" << EOF
|
||||
SET search_path TO ${schema};
|
||||
REFRESH MATERIALIZED VIEW "mv_agent_run_counts";
|
||||
REFRESH MATERIALIZED VIEW "mv_review_stats";
|
||||
|
||||
-- Reset sequences
|
||||
SELECT setval(
|
||||
pg_get_serial_sequence('${schema}."SearchTerms"', 'id'),
|
||||
COALESCE((SELECT MAX(id) FROM ${schema}."SearchTerms"), 0) + 1,
|
||||
false
|
||||
);
|
||||
EOF
|
||||
|
||||
log_success "Materialized views refreshed"
|
||||
}
|
||||
|
||||
# ============================================
|
||||
# Verification
|
||||
# ============================================
|
||||
verify_migration() {
|
||||
local source_base=$(get_base_url "$SOURCE_URL")
|
||||
local dest_base=$(get_base_url "$DEST_URL")
|
||||
local schema=$(get_schema_from_url "$SOURCE_URL")
|
||||
|
||||
log_info "=== VERIFICATION ==="
|
||||
|
||||
echo ""
|
||||
echo "Source counts:"
|
||||
psql "${source_base}" -c "SELECT 'User' as table_name, COUNT(*) FROM ${schema}.\"User\" UNION ALL SELECT 'AgentGraph', COUNT(*) FROM ${schema}.\"AgentGraph\" UNION ALL SELECT 'Profile', COUNT(*) FROM ${schema}.\"Profile\""
|
||||
|
||||
echo ""
|
||||
echo "Destination counts:"
|
||||
psql "${dest_base}" -c "SELECT 'User' as table_name, COUNT(*) FROM ${schema}.\"User\" UNION ALL SELECT 'AgentGraph', COUNT(*) FROM ${schema}.\"AgentGraph\" UNION ALL SELECT 'Profile', COUNT(*) FROM ${schema}.\"Profile\""
|
||||
}
|
||||
|
||||
# ============================================
|
||||
# Main
|
||||
# ============================================
|
||||
main() {
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " Database Migration Script"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
|
||||
parse_args "$@"
|
||||
|
||||
log_info "Source: $(get_base_url "$SOURCE_URL")"
|
||||
log_info "Destination: $(get_base_url "$DEST_URL")"
|
||||
[[ "$DRY_RUN" == true ]] && log_warn "DRY RUN MODE"
|
||||
echo ""
|
||||
|
||||
test_connections
|
||||
|
||||
echo ""
|
||||
|
||||
# Step 0: Nuke destination database (with confirmation)
|
||||
nuke_destination
|
||||
echo ""
|
||||
|
||||
if [[ "$DRY_RUN" != true ]]; then
|
||||
log_warn "This will migrate data to the destination database."
|
||||
read -p "Continue with migration? (y/N) " -n 1 -r
|
||||
echo ""
|
||||
[[ ! $REPLY =~ ^[Yy]$ ]] && { log_info "Cancelled"; exit 0; }
|
||||
fi
|
||||
|
||||
echo ""
|
||||
log_info "Starting migration at $(date)"
|
||||
echo ""
|
||||
|
||||
# Step 1: Export platform data (READ-ONLY on source)
|
||||
platform_file=$(export_platform_data)
|
||||
echo ""
|
||||
|
||||
# Step 2: Export auth data (READ-ONLY on source)
|
||||
auth_file=$(export_auth_data)
|
||||
echo ""
|
||||
|
||||
# Step 3: Import platform data to destination
|
||||
import_platform_data "$platform_file"
|
||||
echo ""
|
||||
|
||||
# Step 4: Update User table with auth data
|
||||
update_user_auth_data "$auth_file"
|
||||
echo ""
|
||||
|
||||
# Step 5: Refresh materialized views
|
||||
refresh_views
|
||||
echo ""
|
||||
|
||||
# Verification
|
||||
verify_migration
|
||||
|
||||
echo ""
|
||||
log_success "Migration completed at $(date)"
|
||||
echo ""
|
||||
echo "Files created:"
|
||||
echo " - Platform data: ${platform_file}"
|
||||
[[ -n "$auth_file" ]] && echo " - Auth data: ${auth_file}"
|
||||
echo ""
|
||||
}
|
||||
|
||||
main "$@"
|
||||
@@ -1,141 +0,0 @@
|
||||
-- Database Migration Verification Script
|
||||
-- Run this on both source (Supabase) and target (GCP) databases to compare
|
||||
|
||||
SET search_path TO platform;
|
||||
|
||||
-- ============================================
|
||||
-- TABLE ROW COUNTS
|
||||
-- ============================================
|
||||
|
||||
SELECT '=== TABLE ROW COUNTS ===' as section;
|
||||
|
||||
SELECT 'User' as table_name, COUNT(*) as row_count FROM "User"
|
||||
UNION ALL SELECT 'Profile', COUNT(*) FROM "Profile"
|
||||
UNION ALL SELECT 'UserOnboarding', COUNT(*) FROM "UserOnboarding"
|
||||
UNION ALL SELECT 'UserBalance', COUNT(*) FROM "UserBalance"
|
||||
UNION ALL SELECT 'AgentGraph', COUNT(*) FROM "AgentGraph"
|
||||
UNION ALL SELECT 'AgentNode', COUNT(*) FROM "AgentNode"
|
||||
UNION ALL SELECT 'AgentBlock', COUNT(*) FROM "AgentBlock"
|
||||
UNION ALL SELECT 'AgentNodeLink', COUNT(*) FROM "AgentNodeLink"
|
||||
UNION ALL SELECT 'AgentGraphExecution', COUNT(*) FROM "AgentGraphExecution"
|
||||
UNION ALL SELECT 'AgentNodeExecution', COUNT(*) FROM "AgentNodeExecution"
|
||||
UNION ALL SELECT 'AgentNodeExecutionInputOutput', COUNT(*) FROM "AgentNodeExecutionInputOutput"
|
||||
UNION ALL SELECT 'AgentNodeExecutionKeyValueData', COUNT(*) FROM "AgentNodeExecutionKeyValueData"
|
||||
UNION ALL SELECT 'AgentPreset', COUNT(*) FROM "AgentPreset"
|
||||
UNION ALL SELECT 'LibraryAgent', COUNT(*) FROM "LibraryAgent"
|
||||
UNION ALL SELECT 'StoreListing', COUNT(*) FROM "StoreListing"
|
||||
UNION ALL SELECT 'StoreListingVersion', COUNT(*) FROM "StoreListingVersion"
|
||||
UNION ALL SELECT 'StoreListingReview', COUNT(*) FROM "StoreListingReview"
|
||||
UNION ALL SELECT 'IntegrationWebhook', COUNT(*) FROM "IntegrationWebhook"
|
||||
UNION ALL SELECT 'APIKey', COUNT(*) FROM "APIKey"
|
||||
UNION ALL SELECT 'CreditTransaction', COUNT(*) FROM "CreditTransaction"
|
||||
UNION ALL SELECT 'CreditRefundRequest', COUNT(*) FROM "CreditRefundRequest"
|
||||
UNION ALL SELECT 'AnalyticsDetails', COUNT(*) FROM "AnalyticsDetails"
|
||||
UNION ALL SELECT 'AnalyticsMetrics', COUNT(*) FROM "AnalyticsMetrics"
|
||||
UNION ALL SELECT 'SearchTerms', COUNT(*) FROM "SearchTerms"
|
||||
UNION ALL SELECT 'NotificationEvent', COUNT(*) FROM "NotificationEvent"
|
||||
UNION ALL SELECT 'UserNotificationBatch', COUNT(*) FROM "UserNotificationBatch"
|
||||
UNION ALL SELECT 'BuilderSearchHistory', COUNT(*) FROM "BuilderSearchHistory"
|
||||
UNION ALL SELECT 'PendingHumanReview', COUNT(*) FROM "PendingHumanReview"
|
||||
UNION ALL SELECT 'RefreshToken', COUNT(*) FROM "RefreshToken"
|
||||
UNION ALL SELECT 'PasswordResetToken', COUNT(*) FROM "PasswordResetToken"
|
||||
ORDER BY table_name;
|
||||
|
||||
-- ============================================
|
||||
-- AUTH DATA VERIFICATION
|
||||
-- ============================================
|
||||
|
||||
SELECT '=== AUTH DATA VERIFICATION ===' as section;
|
||||
|
||||
SELECT
|
||||
COUNT(*) as total_users,
|
||||
COUNT("passwordHash") as users_with_password,
|
||||
COUNT("googleId") as users_with_google,
|
||||
COUNT(CASE WHEN "emailVerified" = true THEN 1 END) as verified_emails,
|
||||
COUNT(CASE WHEN "passwordHash" IS NULL AND "googleId" IS NULL THEN 1 END) as users_without_auth
|
||||
FROM "User";
|
||||
|
||||
-- ============================================
|
||||
-- VIEW VERIFICATION
|
||||
-- ============================================
|
||||
|
||||
SELECT '=== VIEW VERIFICATION ===' as section;
|
||||
|
||||
SELECT 'StoreAgent' as view_name, COUNT(*) as row_count FROM "StoreAgent"
|
||||
UNION ALL SELECT 'Creator', COUNT(*) FROM "Creator"
|
||||
UNION ALL SELECT 'StoreSubmission', COUNT(*) FROM "StoreSubmission";
|
||||
|
||||
-- ============================================
|
||||
-- MATERIALIZED VIEW VERIFICATION
|
||||
-- ============================================
|
||||
|
||||
SELECT '=== MATERIALIZED VIEW VERIFICATION ===' as section;
|
||||
|
||||
SELECT 'mv_agent_run_counts' as view_name, COUNT(*) as row_count FROM "mv_agent_run_counts"
|
||||
UNION ALL SELECT 'mv_review_stats', COUNT(*) FROM "mv_review_stats";
|
||||
|
||||
-- ============================================
|
||||
-- FOREIGN KEY INTEGRITY CHECKS
|
||||
-- ============================================
|
||||
|
||||
SELECT '=== FOREIGN KEY INTEGRITY (should all be 0) ===' as section;
|
||||
|
||||
SELECT 'Orphaned Profiles' as check_name,
|
||||
COUNT(*) as count
|
||||
FROM "Profile" p
|
||||
WHERE p."userId" IS NOT NULL
|
||||
AND NOT EXISTS (SELECT 1 FROM "User" u WHERE u.id = p."userId");
|
||||
|
||||
SELECT 'Orphaned AgentGraphs' as check_name,
|
||||
COUNT(*) as count
|
||||
FROM "AgentGraph" g
|
||||
WHERE NOT EXISTS (SELECT 1 FROM "User" u WHERE u.id = g."userId");
|
||||
|
||||
SELECT 'Orphaned AgentNodes' as check_name,
|
||||
COUNT(*) as count
|
||||
FROM "AgentNode" n
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM "AgentGraph" g
|
||||
WHERE g.id = n."agentGraphId" AND g.version = n."agentGraphVersion"
|
||||
);
|
||||
|
||||
SELECT 'Orphaned Executions' as check_name,
|
||||
COUNT(*) as count
|
||||
FROM "AgentGraphExecution" e
|
||||
WHERE NOT EXISTS (SELECT 1 FROM "User" u WHERE u.id = e."userId");
|
||||
|
||||
SELECT 'Orphaned LibraryAgents' as check_name,
|
||||
COUNT(*) as count
|
||||
FROM "LibraryAgent" l
|
||||
WHERE NOT EXISTS (SELECT 1 FROM "User" u WHERE u.id = l."userId");
|
||||
|
||||
-- ============================================
|
||||
-- SAMPLE DATA VERIFICATION
|
||||
-- ============================================
|
||||
|
||||
SELECT '=== SAMPLE USERS (first 5) ===' as section;
|
||||
|
||||
SELECT
|
||||
id,
|
||||
email,
|
||||
"emailVerified",
|
||||
CASE WHEN "passwordHash" IS NOT NULL THEN 'YES' ELSE 'NO' END as has_password,
|
||||
CASE WHEN "googleId" IS NOT NULL THEN 'YES' ELSE 'NO' END as has_google,
|
||||
"createdAt"
|
||||
FROM "User"
|
||||
ORDER BY "createdAt" DESC
|
||||
LIMIT 5;
|
||||
|
||||
-- ============================================
|
||||
-- STORE LISTINGS SAMPLE
|
||||
-- ============================================
|
||||
|
||||
SELECT '=== SAMPLE STORE LISTINGS (first 5) ===' as section;
|
||||
|
||||
SELECT
|
||||
id,
|
||||
slug,
|
||||
"isDeleted",
|
||||
"hasApprovedVersion"
|
||||
FROM "StoreListing"
|
||||
LIMIT 5;
|
||||
@@ -1,286 +0,0 @@
|
||||
"""
|
||||
Verification script to check scheduler data integrity after native auth migration.
|
||||
|
||||
This script verifies that all scheduled jobs reference valid users in the platform.User table.
|
||||
It can also clean up orphaned schedules (schedules for users that no longer exist).
|
||||
|
||||
Usage:
|
||||
cd backend
|
||||
poetry run python scripts/verify_scheduler_data.py [options]
|
||||
|
||||
Options:
|
||||
--dry-run Preview what would be cleaned up without making changes
|
||||
--cleanup Actually remove orphaned schedules
|
||||
--database-url <url> Database URL (overrides DATABASE_URL env var)
|
||||
|
||||
Examples:
|
||||
# Check for orphaned schedules (read-only)
|
||||
poetry run python scripts/verify_scheduler_data.py
|
||||
|
||||
# Preview cleanup
|
||||
poetry run python scripts/verify_scheduler_data.py --dry-run
|
||||
|
||||
# Actually clean up orphaned schedules
|
||||
poetry run python scripts/verify_scheduler_data.py --cleanup
|
||||
|
||||
Prerequisites:
|
||||
- Database must be accessible
|
||||
- Scheduler service must be running (for cleanup operations)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from urllib.parse import parse_qs, urlparse, urlunparse, urlencode
|
||||
|
||||
from prisma import Prisma
|
||||
from sqlalchemy import create_engine, text, MetaData
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_schema_from_url(database_url: str) -> tuple[str, str]:
|
||||
"""Extract schema from DATABASE_URL and return (schema, clean_url)."""
|
||||
parsed_url = urlparse(database_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
schema_list = query_params.pop("schema", None)
|
||||
schema = schema_list[0] if schema_list else "public"
|
||||
new_query = urlencode(query_params, doseq=True)
|
||||
new_parsed_url = parsed_url._replace(query=new_query)
|
||||
database_url_clean = str(urlunparse(new_parsed_url))
|
||||
return schema, database_url_clean
|
||||
|
||||
|
||||
async def get_all_user_ids(db: Prisma) -> set[str]:
|
||||
"""Get all user IDs from the platform.User table."""
|
||||
users = await db.user.find_many(select={"id": True})
|
||||
return {user.id for user in users}
|
||||
|
||||
|
||||
def get_scheduler_jobs(db_url: str, schema: str) -> list[dict]:
|
||||
"""Get all jobs from the apscheduler_jobs table."""
|
||||
engine = create_engine(db_url)
|
||||
jobs = []
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Check if table exists
|
||||
result = conn.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = :schema
|
||||
AND table_name = 'apscheduler_jobs'
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"schema": schema},
|
||||
)
|
||||
if not result.scalar():
|
||||
logger.warning(
|
||||
f"Table {schema}.apscheduler_jobs does not exist. "
|
||||
"Scheduler may not have been initialized yet."
|
||||
)
|
||||
return []
|
||||
|
||||
# Get all jobs
|
||||
result = conn.execute(
|
||||
text(f'SELECT id, job_state FROM {schema}."apscheduler_jobs"')
|
||||
)
|
||||
|
||||
for row in result:
|
||||
job_id = row[0]
|
||||
job_state = row[1]
|
||||
|
||||
try:
|
||||
# APScheduler stores job state as pickled data
|
||||
job_data = pickle.loads(job_state)
|
||||
kwargs = job_data.get("kwargs", {})
|
||||
|
||||
# Only process graph execution jobs (have user_id)
|
||||
if "user_id" in kwargs:
|
||||
jobs.append(
|
||||
{
|
||||
"id": job_id,
|
||||
"user_id": kwargs.get("user_id"),
|
||||
"graph_id": kwargs.get("graph_id"),
|
||||
"graph_version": kwargs.get("graph_version"),
|
||||
"cron": kwargs.get("cron"),
|
||||
"agent_name": kwargs.get("agent_name"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse job {job_id}: {e}")
|
||||
|
||||
return jobs
|
||||
|
||||
|
||||
async def verify_scheduler_data(
|
||||
db: Prisma, db_url: str, schema: str
|
||||
) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
Verify scheduler data integrity.
|
||||
|
||||
Returns:
|
||||
Tuple of (valid_jobs, orphaned_jobs)
|
||||
"""
|
||||
logger.info("Fetching all users from platform.User...")
|
||||
user_ids = await get_all_user_ids(db)
|
||||
logger.info(f"Found {len(user_ids)} users in platform.User")
|
||||
|
||||
logger.info("Fetching scheduled jobs from apscheduler_jobs...")
|
||||
jobs = get_scheduler_jobs(db_url, schema)
|
||||
logger.info(f"Found {len(jobs)} scheduled graph execution jobs")
|
||||
|
||||
valid_jobs = []
|
||||
orphaned_jobs = []
|
||||
|
||||
for job in jobs:
|
||||
if job["user_id"] in user_ids:
|
||||
valid_jobs.append(job)
|
||||
else:
|
||||
orphaned_jobs.append(job)
|
||||
|
||||
return valid_jobs, orphaned_jobs
|
||||
|
||||
|
||||
async def cleanup_orphaned_schedules(orphaned_jobs: list[dict], db_url: str, schema: str):
|
||||
"""Remove orphaned schedules from the database."""
|
||||
if not orphaned_jobs:
|
||||
logger.info("No orphaned schedules to clean up")
|
||||
return
|
||||
|
||||
engine = create_engine(db_url)
|
||||
|
||||
with engine.connect() as conn:
|
||||
for job in orphaned_jobs:
|
||||
try:
|
||||
conn.execute(
|
||||
text(f'DELETE FROM {schema}."apscheduler_jobs" WHERE id = :job_id'),
|
||||
{"job_id": job["id"]},
|
||||
)
|
||||
logger.info(
|
||||
f"Deleted orphaned schedule {job['id']} "
|
||||
f"(user: {job['user_id']}, graph: {job['graph_id']})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete schedule {job['id']}: {e}")
|
||||
|
||||
conn.commit()
|
||||
|
||||
logger.info(f"Cleaned up {len(orphaned_jobs)} orphaned schedules")
|
||||
|
||||
|
||||
async def main(dry_run: bool = False, cleanup: bool = False):
|
||||
"""Run the verification."""
|
||||
logger.info("=" * 60)
|
||||
logger.info("Scheduler Data Verification Script")
|
||||
if dry_run:
|
||||
logger.info(">>> DRY RUN MODE - No changes will be made <<<")
|
||||
elif cleanup:
|
||||
logger.info(">>> CLEANUP MODE - Orphaned schedules will be removed <<<")
|
||||
else:
|
||||
logger.info(">>> VERIFY MODE - Read-only check <<<")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Started at: {datetime.now().isoformat()}")
|
||||
|
||||
# Get database URL
|
||||
db_url = os.getenv("DIRECT_URL") or os.getenv("DATABASE_URL")
|
||||
if not db_url:
|
||||
logger.error("DATABASE_URL or DIRECT_URL environment variable not set")
|
||||
sys.exit(1)
|
||||
|
||||
schema, clean_db_url = _extract_schema_from_url(db_url)
|
||||
logger.info(f"Using schema: {schema}")
|
||||
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
try:
|
||||
valid_jobs, orphaned_jobs = await verify_scheduler_data(db, clean_db_url, schema)
|
||||
|
||||
# Report results
|
||||
logger.info("\n--- Verification Results ---")
|
||||
logger.info(f"Valid scheduled jobs: {len(valid_jobs)}")
|
||||
logger.info(f"Orphaned scheduled jobs: {len(orphaned_jobs)}")
|
||||
|
||||
if orphaned_jobs:
|
||||
logger.warning("\n--- Orphaned Schedules (users not in platform.User) ---")
|
||||
for job in orphaned_jobs:
|
||||
logger.warning(
|
||||
f" Schedule ID: {job['id']}\n"
|
||||
f" User ID: {job['user_id']}\n"
|
||||
f" Graph ID: {job['graph_id']}\n"
|
||||
f" Cron: {job['cron']}\n"
|
||||
f" Agent: {job['agent_name'] or 'N/A'}"
|
||||
)
|
||||
|
||||
if cleanup and not dry_run:
|
||||
logger.info("\n--- Cleaning up orphaned schedules ---")
|
||||
await cleanup_orphaned_schedules(orphaned_jobs, clean_db_url, schema)
|
||||
elif dry_run:
|
||||
logger.info(
|
||||
f"\n[DRY RUN] Would delete {len(orphaned_jobs)} orphaned schedules"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"\nTo clean up orphaned schedules, run with --cleanup flag"
|
||||
)
|
||||
else:
|
||||
logger.info("\n✅ All scheduled jobs reference valid users!")
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "=" * 60)
|
||||
if orphaned_jobs and cleanup and not dry_run:
|
||||
logger.info("Cleanup completed successfully!")
|
||||
else:
|
||||
logger.info("Verification completed!")
|
||||
logger.info("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Verification failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Verify scheduler data integrity after native auth migration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Preview what would be cleaned up without making changes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cleanup",
|
||||
action="store_true",
|
||||
help="Actually remove orphaned schedules",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--database-url",
|
||||
type=str,
|
||||
help="Database URL (overrides DATABASE_URL env var)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# Override DATABASE_URL if provided via command line
|
||||
if args.database_url:
|
||||
os.environ["DATABASE_URL"] = args.database_url
|
||||
os.environ["DIRECT_URL"] = args.database_url
|
||||
|
||||
asyncio.run(main(dry_run=args.dry_run, cleanup=args.cleanup))
|
||||
@@ -19,21 +19,21 @@ images: {
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from typing import Any, Dict, List, cast
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from faker import Faker
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
|
||||
from backend.data.auth.api_key import create_api_key
|
||||
from backend.data.api_key import create_api_key
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
|
||||
# Import API functions from the backend
|
||||
from backend.server.auth.service import AuthService
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.server.v2.library.db import create_library_agent, create_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPresetCreatable
|
||||
from backend.server.v2.store.db import create_store_submission, review_store_submission
|
||||
from backend.util.clients import get_supabase
|
||||
|
||||
faker = Faker()
|
||||
|
||||
@@ -107,10 +107,10 @@ class TestDataCreator:
|
||||
self.profiles: List[Dict[str, Any]] = []
|
||||
|
||||
async def create_test_users(self) -> List[Dict[str, Any]]:
|
||||
"""Create test users using native auth service."""
|
||||
"""Create test users using Supabase client."""
|
||||
print(f"Creating {NUM_USERS} test users...")
|
||||
|
||||
auth_service = AuthService()
|
||||
supabase = get_supabase()
|
||||
users = []
|
||||
|
||||
for i in range(NUM_USERS):
|
||||
@@ -122,35 +122,30 @@ class TestDataCreator:
|
||||
else:
|
||||
email = faker.unique.email()
|
||||
password = "testpassword123" # Standard test password
|
||||
user_id = f"test-user-{i}-{faker.uuid4()}"
|
||||
|
||||
# Try to create user with password using AuthService
|
||||
# Create user in Supabase Auth (if needed)
|
||||
try:
|
||||
user = await auth_service.register_user(
|
||||
email=email,
|
||||
password=password,
|
||||
name=faker.name(),
|
||||
auth_response = supabase.auth.admin.create_user(
|
||||
{"email": email, "password": password, "email_confirm": True}
|
||||
)
|
||||
users.append(
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
if auth_response.user:
|
||||
user_id = auth_response.user.id
|
||||
except Exception as supabase_error:
|
||||
print(
|
||||
f"Supabase user creation failed for {email}, using fallback: {supabase_error}"
|
||||
)
|
||||
except ValueError as e:
|
||||
# User already exists, get them instead
|
||||
print(f"User {email} already exists, fetching: {e}")
|
||||
existing_user = await auth_service.get_user_by_email(email)
|
||||
if existing_user:
|
||||
users.append(
|
||||
{
|
||||
"id": existing_user.id,
|
||||
"email": existing_user.email,
|
||||
"name": existing_user.name,
|
||||
"role": existing_user.role,
|
||||
}
|
||||
)
|
||||
# Fall back to direct database creation
|
||||
|
||||
# Create mock user data similar to what auth middleware would provide
|
||||
user_data = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
}
|
||||
|
||||
# Use the API function to create user in local database
|
||||
user = await get_or_create_user(user_data)
|
||||
users.append(user.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating user {i}: {e}")
|
||||
@@ -182,15 +177,12 @@ class TestDataCreator:
|
||||
for block in blocks_to_create:
|
||||
try:
|
||||
await prisma.agentblock.create(
|
||||
data=cast(
|
||||
AgentBlockCreateInput,
|
||||
{
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"inputSchema": "{}",
|
||||
"outputSchema": "{}",
|
||||
},
|
||||
)
|
||||
data={
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"inputSchema": "{}",
|
||||
"outputSchema": "{}",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error creating block {block.name}: {e}")
|
||||
@@ -472,7 +464,7 @@ class TestDataCreator:
|
||||
|
||||
api_keys = []
|
||||
for user in self.users:
|
||||
from backend.data.auth.api_key import APIKeyPermission
|
||||
from backend.data.api_key import APIKeyPermission
|
||||
|
||||
try:
|
||||
# Use the API function to create API key
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user