Compare commits

..

7 Commits

Author SHA1 Message Date
Swifty
a86d750cf5 Merge branch 'fix/integrations-credential-type' into swiftyos/dev 2025-12-04 16:14:51 +01:00
Swifty
13bd648731 Merge branch 'swiftyos/vector-search' into swiftyos/dev 2025-12-04 16:14:47 +01:00
Swifty
3d7ee7cc29 Merge branch 'swiftyos/add-default-agents' into swiftyos/dev 2025-12-04 16:14:44 +01:00
Swifty
1ea52934cd add store agents for seeding test databases 2025-12-04 16:07:58 +01:00
Swifty
7b6db6e260 add vector search 2025-12-04 16:05:47 +01:00
Swifty
2c9563353e formatting 2025-12-04 09:35:53 +01:00
Swifty
fb2a70e2d8 pass credential type 2025-12-04 09:21:12 +01:00
434 changed files with 8723 additions and 35775 deletions

View File

@@ -142,7 +142,7 @@ pnpm storybook # Start component development server
### Security & Middleware
**Cache Protection**: Backend includes middleware preventing sensitive data caching in browsers/proxies
**Authentication**: JWT-based with native authentication
**Authentication**: JWT-based with Supabase integration
**User ID Validation**: All data access requires user ID checks - verify this for any `data/*.py` changes
### Development Workflow
@@ -168,9 +168,9 @@ pnpm storybook # Start component development server
- `frontend/src/app/layout.tsx` - Root application layout
- `frontend/src/app/page.tsx` - Home page
- `frontend/src/lib/auth/` - Authentication client
- `frontend/src/lib/supabase/` - Authentication and database client
**Protected Routes**: Update `frontend/middleware.ts` when adding protected routes
**Protected Routes**: Update `frontend/lib/supabase/middleware.ts` when adding protected routes
### Agent Block System
@@ -194,7 +194,7 @@ Agents are built using a visual block-based system where each block performs a s
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
3. **Platform**: `/.env.default` (shared) → `/.env` (user overrides)
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
4. Docker Compose `environment:` sections override file-based config
5. Shell environment variables have highest precedence

View File

@@ -144,7 +144,11 @@ jobs:
"rabbitmq:management"
"clamav/clamav-debian:latest"
"busybox:latest"
"pgvector/pgvector:pg18"
"kong:2.8.1"
"supabase/gotrue:v2.170.0"
"supabase/postgres:15.8.1.049"
"supabase/postgres-meta:v0.86.1"
"supabase/studio:20250224-d10db0f"
)
# Check if any cached tar files exist (more reliable than cache-hit)

View File

@@ -160,7 +160,11 @@ jobs:
"rabbitmq:management"
"clamav/clamav-debian:latest"
"busybox:latest"
"pgvector/pgvector:pg18"
"kong:2.8.1"
"supabase/gotrue:v2.170.0"
"supabase/postgres:15.8.1.049"
"supabase/postgres-meta:v0.86.1"
"supabase/studio:20250224-d10db0f"
)
# Check if any cached tar files exist (more reliable than cache-hit)

View File

@@ -142,7 +142,11 @@ jobs:
"rabbitmq:management"
"clamav/clamav-debian:latest"
"busybox:latest"
"pgvector/pgvector:pg18"
"kong:2.8.1"
"supabase/gotrue:v2.170.0"
"supabase/postgres:15.8.1.049"
"supabase/postgres-meta:v0.86.1"
"supabase/studio:20250224-d10db0f"
)
# Check if any cached tar files exist (more reliable than cache-hit)

View File

@@ -2,13 +2,13 @@ name: AutoGPT Platform - Backend CI
on:
push:
branches: [master, dev, ci-test*, native-auth]
branches: [master, dev, ci-test*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- "autogpt_platform/backend/**"
- "autogpt_platform/autogpt_libs/**"
pull_request:
branches: [master, dev, release-*, native-auth]
branches: [master, dev, release-*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- "autogpt_platform/backend/**"
@@ -36,19 +36,6 @@ jobs:
runs-on: ubuntu-latest
services:
postgres:
image: pgvector/pgvector:pg18
ports:
- 5432:5432
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: your-super-secret-and-long-postgres-password
POSTGRES_DB: postgres
options: >-
--health-cmd "pg_isready -U postgres"
--health-interval 5s
--health-timeout 5s
--health-retries 10
redis:
image: redis:latest
ports:
@@ -91,6 +78,11 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Setup Supabase
uses: supabase/setup-cli@v1
with:
version: 1.178.1
- id: get_date
name: Get date
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
@@ -144,6 +136,16 @@ jobs:
- name: Generate Prisma Client
run: poetry run prisma generate
- id: supabase
name: Start Supabase
working-directory: .
run: |
supabase init
supabase start --exclude postgres-meta,realtime,storage-api,imgproxy,inbucket,studio,edge-runtime,logflare,vector,supavisor
supabase status -o env | sed 's/="/=/; s/"$//' >> $GITHUB_OUTPUT
# outputs:
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
- name: Wait for ClamAV to be ready
run: |
echo "Waiting for ClamAV daemon to start..."
@@ -176,8 +178,8 @@ jobs:
- name: Run Database Migrations
run: poetry run prisma migrate dev --name updates
env:
DATABASE_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
DIRECT_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
- id: lint
name: Run Linter
@@ -193,9 +195,11 @@ jobs:
if: success() || (failure() && steps.lint.outcome == 'failure')
env:
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
DATABASE_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
DIRECT_URL: postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/postgres
JWT_SECRET: your-super-secret-jwt-token-with-at-least-32-characters-long
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!

View File

@@ -2,12 +2,11 @@ name: AutoGPT Platform - Frontend CI
on:
push:
branches: [master, dev, native-auth]
branches: [master, dev]
paths:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
pull_request:
branches: [master, dev, native-auth]
paths:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
@@ -148,7 +147,7 @@ jobs:
- name: Enable corepack
run: corepack enable
- name: Copy default platform .env
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env

View File

@@ -1,13 +1,12 @@
name: AutoGPT Platform - Fullstack CI
name: AutoGPT Platform - Frontend CI
on:
push:
branches: [master, dev, native-auth]
branches: [master, dev]
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- "autogpt_platform/**"
pull_request:
branches: [master, dev, native-auth]
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- "autogpt_platform/**"
@@ -59,11 +58,14 @@ jobs:
types:
runs-on: ubuntu-latest
needs: setup
timeout-minutes: 10
strategy:
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v4
@@ -73,6 +75,18 @@ jobs:
- name: Enable corepack
run: corepack enable
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Copy backend .env
run: |
cp ../backend/.env.default ../backend/.env
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
- name: Restore dependencies cache
uses: actions/cache@v4
with:
@@ -87,12 +101,36 @@ jobs:
- name: Setup .env
run: cp .env.default .env
- name: Wait for services to be ready
run: |
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Generate API queries
run: pnpm generate:api
run: pnpm generate:api:force
- name: Check for API schema changes
run: |
if ! git diff --exit-code src/app/api/openapi.json; then
echo "❌ API schema changes detected in src/app/api/openapi.json"
echo ""
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
echo "The API schema is now out of sync with the Front-end queries."
echo ""
echo "To fix this:"
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
echo "2. Run 'pnpm generate:api' locally"
echo "3. Run 'pnpm types' locally"
echo "4. Fix any TypeScript errors that may have been introduced"
echo "5. Commit and push your changes"
echo ""
exit 1
else
echo "✅ No API schema changes detected"
fi
- name: Run Typescript checks
run: pnpm types
env:
CI: true
PLAIN_OUTPUT: True

View File

@@ -11,7 +11,7 @@ jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v10
- uses: actions/stale@v9
with:
# operations-per-run: 5000
stale-issue-message: >

View File

@@ -61,6 +61,6 @@ jobs:
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v6
- uses: actions/labeler@v5
with:
sync-labels: true

View File

@@ -49,5 +49,5 @@ Use conventional commit messages for all commits (e.g. `feat(backend): add API`)
- Keep out-of-scope changes under 20% of the PR.
- Ensure PR descriptions are complete.
- For changes touching `data/*.py`, validate user ID checks or explain why not needed.
- If adding protected frontend routes, update `frontend/lib/auth/helpers.ts`.
- If adding protected frontend routes, update `frontend/lib/supabase/middleware.ts`.
- Use the linear ticket branch structure if given codex/open-1668-resume-dropped-runs

View File

@@ -5,6 +5,12 @@
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
DASHBOARD_USERNAME=supabase
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
VAULT_ENC_KEY=your-encryption-key-32-chars-min
############
@@ -18,31 +24,100 @@ POSTGRES_PORT=5432
############
# Auth - Native authentication configuration
# Supavisor -- Database pooler
############
POOLER_PROXY_PORT_TRANSACTION=6543
POOLER_DEFAULT_POOL_SIZE=20
POOLER_MAX_CLIENT_CONN=100
POOLER_TENANT_ID=your-tenant-id
############
# API Proxy - Configuration for the Kong Reverse proxy.
############
KONG_HTTP_PORT=8000
KONG_HTTPS_PORT=8443
############
# API - Configuration for PostgREST.
############
PGRST_DB_SCHEMAS=public,storage,graphql_public
############
# Auth - Configuration for the GoTrue authentication server.
############
## General
SITE_URL=http://localhost:3000
ADDITIONAL_REDIRECT_URLS=
JWT_EXPIRY=3600
DISABLE_SIGNUP=false
API_EXTERNAL_URL=http://localhost:8000
# JWT token configuration
ACCESS_TOKEN_EXPIRE_MINUTES=15
REFRESH_TOKEN_EXPIRE_DAYS=7
JWT_ISSUER=autogpt-platform
## Mailer Config
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
MAILER_URLPATHS_INVITE="/auth/v1/verify"
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
# Google OAuth (optional)
GOOGLE_CLIENT_ID=
GOOGLE_CLIENT_SECRET=
## Email auth
ENABLE_EMAIL_SIGNUP=true
ENABLE_EMAIL_AUTOCONFIRM=false
SMTP_ADMIN_EMAIL=admin@example.com
SMTP_HOST=supabase-mail
SMTP_PORT=2500
SMTP_USER=fake_mail_user
SMTP_PASS=fake_mail_password
SMTP_SENDER_NAME=fake_sender
ENABLE_ANONYMOUS_USERS=false
## Phone auth
ENABLE_PHONE_SIGNUP=true
ENABLE_PHONE_AUTOCONFIRM=true
############
# Email configuration (optional)
# Studio - Configuration for the Dashboard
############
SMTP_HOST=
SMTP_PORT=587
SMTP_USER=
SMTP_PASS=
SMTP_FROM_EMAIL=noreply@example.com
STUDIO_DEFAULT_ORGANIZATION=Default Organization
STUDIO_DEFAULT_PROJECT=Default Project
STUDIO_PORT=3000
# replace if you intend to use Studio outside of localhost
SUPABASE_PUBLIC_URL=http://localhost:8000
# Enable webp support
IMGPROXY_ENABLE_WEBP_DETECTION=true
# Add your OpenAI API key to enable SQL Editor Assistant
OPENAI_API_KEY=
############
# Functions - Configuration for Functions
############
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
FUNCTIONS_VERIFY_JWT=false
############
# Logs - Configuration for Logflare
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
############
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
# Change vector.toml sinks to reflect this change
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
# Docker socket location - this value will differ depending on your OS
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
# Google Cloud Project details
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER

View File

@@ -1,6 +1,6 @@
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend load-store-agents
# Run just PostgreSQL + Redis + RabbitMQ + ClamAV
# Run just Supabase + Redis + RabbitMQ
start-core:
docker compose up -d deps
@@ -44,12 +44,12 @@ test-data:
cd backend && poetry run python test/test_data_creator.py
load-store-agents:
cd backend && poetry run load-store-agents
cd backend && poetry run python test/load_store_agents.py
help:
@echo "Usage: make <target>"
@echo "Targets:"
@echo " start-core - Start just the core services (PostgreSQL, Redis, RabbitMQ, ClamAV) in background"
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
@echo " stop-core - Stop the core services"
@echo " reset-db - Reset the database by deleting the volume"
@echo " logs-core - Tail the logs for core services"

View File

@@ -57,9 +57,6 @@ class APIKeySmith:
def hash_key(self, raw_key: str) -> tuple[str, str]:
"""Migrate a legacy hash to secure hash format."""
if not raw_key.startswith(self.PREFIX):
raise ValueError("Key without 'agpt_' prefix would fail validation")
salt = self._generate_salt()
hash = self._hash_key_with_salt(raw_key, salt)
return hash, salt.hex()

View File

@@ -16,37 +16,17 @@ ALGO_RECOMMENDATION = (
"We highly recommend using an asymmetric algorithm such as ES256, "
"because when leaked, a shared secret would allow anyone to "
"forge valid tokens and impersonate users. "
"More info: https://pyjwt.readthedocs.io/en/stable/algorithms.html"
"More info: https://supabase.com/docs/guides/auth/signing-keys#choosing-the-right-signing-algorithm" # noqa
)
class Settings:
def __init__(self):
# JWT verification key (public key for asymmetric, shared secret for symmetric)
self.JWT_VERIFY_KEY: str = os.getenv(
"JWT_VERIFY_KEY", os.getenv("SUPABASE_JWT_SECRET", "")
).strip()
# JWT signing key (private key for asymmetric, shared secret for symmetric)
# Falls back to JWT_VERIFY_KEY for symmetric algorithms like HS256
self.JWT_SIGN_KEY: str = os.getenv("JWT_SIGN_KEY", self.JWT_VERIFY_KEY).strip()
self.JWT_ALGORITHM: str = os.getenv("JWT_SIGN_ALGORITHM", "HS256").strip()
# Token expiration settings
self.ACCESS_TOKEN_EXPIRE_MINUTES: int = int(
os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "15")
)
self.REFRESH_TOKEN_EXPIRE_DAYS: int = int(
os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7")
)
# JWT issuer claim
self.JWT_ISSUER: str = os.getenv("JWT_ISSUER", "autogpt-platform").strip()
# JWT audience claim
self.JWT_AUDIENCE: str = os.getenv("JWT_AUDIENCE", "authenticated").strip()
self.validate()
def validate(self):

View File

@@ -1,8 +1,4 @@
import hashlib
import logging
import secrets
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any
import jwt
@@ -20,57 +16,6 @@ bearer_jwt_auth = HTTPBearer(
)
def create_access_token(
user_id: str,
email: str,
role: str = "authenticated",
email_verified: bool = False,
) -> str:
"""
Generate a new JWT access token.
:param user_id: The user's unique identifier
:param email: The user's email address
:param role: The user's role (default: "authenticated")
:param email_verified: Whether the user's email is verified
:return: Encoded JWT token
"""
settings = get_settings()
now = datetime.now(timezone.utc)
payload = {
"sub": user_id,
"email": email,
"role": role,
"email_verified": email_verified,
"aud": settings.JWT_AUDIENCE,
"iss": settings.JWT_ISSUER,
"iat": now,
"exp": now + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
"jti": str(uuid.uuid4()), # Unique token ID
}
return jwt.encode(payload, settings.JWT_SIGN_KEY, algorithm=settings.JWT_ALGORITHM)
def create_refresh_token() -> tuple[str, str]:
"""
Generate a new refresh token.
Returns a tuple of (raw_token, hashed_token).
The raw token should be sent to the client.
The hashed token should be stored in the database.
"""
raw_token = secrets.token_urlsafe(64)
hashed_token = hashlib.sha256(raw_token.encode()).hexdigest()
return raw_token, hashed_token
def hash_token(token: str) -> str:
"""Hash a token using SHA-256."""
return hashlib.sha256(token.encode()).hexdigest()
async def get_jwt_payload(
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
) -> dict[str, Any]:
@@ -107,19 +52,11 @@ def parse_jwt_token(token: str) -> dict[str, Any]:
"""
settings = get_settings()
try:
# Build decode options
options = {
"verify_aud": True,
"verify_iss": bool(settings.JWT_ISSUER),
}
payload = jwt.decode(
token,
settings.JWT_VERIFY_KEY,
algorithms=[settings.JWT_ALGORITHM],
audience=settings.JWT_AUDIENCE,
issuer=settings.JWT_ISSUER if settings.JWT_ISSUER else None,
options=options,
audience="authenticated",
)
return payload
except jwt.ExpiredSignatureError:

View File

@@ -11,7 +11,6 @@ class User:
email: str
phone_number: str
role: str
email_verified: bool = False
@classmethod
def from_payload(cls, payload):
@@ -19,6 +18,5 @@ class User:
user_id=payload["sub"],
email=payload.get("email", ""),
phone_number=payload.get("phone", ""),
role=payload.get("role", "authenticated"),
email_verified=payload.get("email_verified", False),
role=payload["role"],
)

View File

@@ -48,21 +48,6 @@ files = [
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
]
[[package]]
name = "authlib"
version = "1.6.6"
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd"},
{file = "authlib-1.6.6.tar.gz", hash = "sha256:45770e8e056d0f283451d9996fbb59b70d45722b45d854d58f32878d0a40c38e"},
]
[package.dependencies]
cryptography = "*"
[[package]]
name = "backports-asyncio-runner"
version = "1.2.0"
@@ -76,71 +61,6 @@ files = [
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
]
[[package]]
name = "bcrypt"
version = "4.3.0"
description = "Modern password hashing for your software and your servers"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "bcrypt-4.3.0-cp313-cp313t-macosx_10_12_universal2.whl", hash = "sha256:f01e060f14b6b57bbb72fc5b4a83ac21c443c9a2ee708e04a10e9192f90a6281"},
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5eeac541cefd0bb887a371ef73c62c3cd78535e4887b310626036a7c0a817bb"},
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59e1aa0e2cd871b08ca146ed08445038f42ff75968c7ae50d2fdd7860ade2180"},
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:0042b2e342e9ae3d2ed22727c1262f76cc4f345683b5c1715f0250cf4277294f"},
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74a8d21a09f5e025a9a23e7c0fd2c7fe8e7503e4d356c0a2c1486ba010619f09"},
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:0142b2cb84a009f8452c8c5a33ace5e3dfec4159e7735f5afe9a4d50a8ea722d"},
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_34_aarch64.whl", hash = "sha256:12fa6ce40cde3f0b899729dbd7d5e8811cb892d31b6f7d0334a1f37748b789fd"},
{file = "bcrypt-4.3.0-cp313-cp313t-manylinux_2_34_x86_64.whl", hash = "sha256:5bd3cca1f2aa5dbcf39e2aa13dd094ea181f48959e1071265de49cc2b82525af"},
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:335a420cfd63fc5bc27308e929bee231c15c85cc4c496610ffb17923abf7f231"},
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:0e30e5e67aed0187a1764911af023043b4542e70a7461ad20e837e94d23e1d6c"},
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:3b8d62290ebefd49ee0b3ce7500f5dbdcf13b81402c05f6dafab9a1e1b27212f"},
{file = "bcrypt-4.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2ef6630e0ec01376f59a006dc72918b1bf436c3b571b80fa1968d775fa02fe7d"},
{file = "bcrypt-4.3.0-cp313-cp313t-win32.whl", hash = "sha256:7a4be4cbf241afee43f1c3969b9103a41b40bcb3a3f467ab19f891d9bc4642e4"},
{file = "bcrypt-4.3.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5c1949bf259a388863ced887c7861da1df681cb2388645766c89fdfd9004c669"},
{file = "bcrypt-4.3.0-cp38-abi3-macosx_10_12_universal2.whl", hash = "sha256:f81b0ed2639568bf14749112298f9e4e2b28853dab50a8b357e31798686a036d"},
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:864f8f19adbe13b7de11ba15d85d4a428c7e2f344bac110f667676a0ff84924b"},
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e36506d001e93bffe59754397572f21bb5dc7c83f54454c990c74a468cd589e"},
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:842d08d75d9fe9fb94b18b071090220697f9f184d4547179b60734846461ed59"},
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7c03296b85cb87db865d91da79bf63d5609284fc0cab9472fdd8367bbd830753"},
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:62f26585e8b219cdc909b6a0069efc5e4267e25d4a3770a364ac58024f62a761"},
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:beeefe437218a65322fbd0069eb437e7c98137e08f22c4660ac2dc795c31f8bb"},
{file = "bcrypt-4.3.0-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:97eea7408db3a5bcce4a55d13245ab3fa566e23b4c67cd227062bb49e26c585d"},
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:191354ebfe305e84f344c5964c7cd5f924a3bfc5d405c75ad07f232b6dffb49f"},
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:41261d64150858eeb5ff43c753c4b216991e0ae16614a308a15d909503617732"},
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:33752b1ba962ee793fa2b6321404bf20011fe45b9afd2a842139de3011898fef"},
{file = "bcrypt-4.3.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:50e6e80a4bfd23a25f5c05b90167c19030cf9f87930f7cb2eacb99f45d1c3304"},
{file = "bcrypt-4.3.0-cp38-abi3-win32.whl", hash = "sha256:67a561c4d9fb9465ec866177e7aebcad08fe23aaf6fbd692a6fab69088abfc51"},
{file = "bcrypt-4.3.0-cp38-abi3-win_amd64.whl", hash = "sha256:584027857bc2843772114717a7490a37f68da563b3620f78a849bcb54dc11e62"},
{file = "bcrypt-4.3.0-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:0d3efb1157edebfd9128e4e46e2ac1a64e0c1fe46fb023158a407c7892b0f8c3"},
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08bacc884fd302b611226c01014eca277d48f0a05187666bca23aac0dad6fe24"},
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6746e6fec103fcd509b96bacdfdaa2fbde9a553245dbada284435173a6f1aef"},
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:afe327968aaf13fc143a56a3360cb27d4ad0345e34da12c7290f1b00b8fe9a8b"},
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d9af79d322e735b1fc33404b5765108ae0ff232d4b54666d46730f8ac1a43676"},
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f1e3ffa1365e8702dc48c8b360fef8d7afeca482809c5e45e653af82ccd088c1"},
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:3004df1b323d10021fda07a813fd33e0fd57bef0e9a480bb143877f6cba996fe"},
{file = "bcrypt-4.3.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:531457e5c839d8caea9b589a1bcfe3756b0547d7814e9ce3d437f17da75c32b0"},
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:17a854d9a7a476a89dcef6c8bd119ad23e0f82557afbd2c442777a16408e614f"},
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:6fb1fd3ab08c0cbc6826a2e0447610c6f09e983a281b919ed721ad32236b8b23"},
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e965a9c1e9a393b8005031ff52583cedc15b7884fce7deb8b0346388837d6cfe"},
{file = "bcrypt-4.3.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:79e70b8342a33b52b55d93b3a59223a844962bef479f6a0ea318ebbcadf71505"},
{file = "bcrypt-4.3.0-cp39-abi3-win32.whl", hash = "sha256:b4d4e57f0a63fd0b358eb765063ff661328f69a04494427265950c71b992a39a"},
{file = "bcrypt-4.3.0-cp39-abi3-win_amd64.whl", hash = "sha256:e53e074b120f2877a35cc6c736b8eb161377caae8925c17688bd46ba56daaa5b"},
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c950d682f0952bafcceaf709761da0a32a942272fad381081b51096ffa46cea1"},
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:107d53b5c67e0bbc3f03ebf5b030e0403d24dda980f8e244795335ba7b4a027d"},
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:b693dbb82b3c27a1604a3dff5bfc5418a7e6a781bb795288141e5f80cf3a3492"},
{file = "bcrypt-4.3.0-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:b6354d3760fcd31994a14c89659dee887f1351a06e5dac3c1142307172a79f90"},
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a839320bf27d474e52ef8cb16449bb2ce0ba03ca9f44daba6d93fa1d8828e48a"},
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:bdc6a24e754a555d7316fa4774e64c6c3997d27ed2d1964d55920c7c227bc4ce"},
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:55a935b8e9a1d2def0626c4269db3fcd26728cbff1e84f0341465c31c4ee56d8"},
{file = "bcrypt-4.3.0-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:57967b7a28d855313a963aaea51bf6df89f833db4320da458e5b3c5ab6d4c938"},
{file = "bcrypt-4.3.0.tar.gz", hash = "sha256:3a3fd2204178b6d2adcf09cb4f6426ffef54762577a7c9b54c159008cb288c18"},
]
[package.extras]
tests = ["pytest (>=3.2.1,!=3.3.0)"]
typecheck = ["mypy"]
[[package]]
name = "cachetools"
version = "5.5.2"
@@ -539,6 +459,21 @@ ssh = ["bcrypt (>=3.1.5)"]
test = ["certifi (>=2024)", "cryptography-vectors (==45.0.6)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
test-randomorder = ["pytest-randomly"]
[[package]]
name = "deprecation"
version = "2.1.0"
description = "A library to handle automated deprecations"
optional = false
python-versions = "*"
groups = ["main"]
files = [
{file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
{file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
]
[package.dependencies]
packaging = "*"
[[package]]
name = "exceptiongroup"
version = "1.3.0"
@@ -760,6 +695,23 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4
[package.extras]
grpc = ["grpcio (>=1.44.0,<2.0.0)"]
[[package]]
name = "gotrue"
version = "2.12.3"
description = "Python Client Library for Supabase Auth"
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "gotrue-2.12.3-py3-none-any.whl", hash = "sha256:b1a3c6a5fe3f92e854a026c4c19de58706a96fd5fbdcc3d620b2802f6a46a26b"},
{file = "gotrue-2.12.3.tar.gz", hash = "sha256:f874cf9d0b2f0335bfbd0d6e29e3f7aff79998cd1c14d2ad814db8c06cee3852"},
]
[package.dependencies]
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
pydantic = ">=1.10,<3"
pyjwt = ">=2.10.1,<3.0.0"
[[package]]
name = "grpc-google-iam-v1"
version = "0.14.2"
@@ -870,6 +822,94 @@ files = [
{file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"},
]
[[package]]
name = "h2"
version = "4.2.0"
description = "Pure-Python HTTP/2 protocol implementation"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "h2-4.2.0-py3-none-any.whl", hash = "sha256:479a53ad425bb29af087f3458a61d30780bc818e4ebcf01f0b536ba916462ed0"},
{file = "h2-4.2.0.tar.gz", hash = "sha256:c8a52129695e88b1a0578d8d2cc6842bbd79128ac685463b887ee278126ad01f"},
]
[package.dependencies]
hpack = ">=4.1,<5"
hyperframe = ">=6.1,<7"
[[package]]
name = "hpack"
version = "4.1.0"
description = "Pure-Python HPACK header encoding"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496"},
{file = "hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca"},
]
[[package]]
name = "httpcore"
version = "1.0.9"
description = "A minimal low-level HTTP client."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"},
{file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"},
]
[package.dependencies]
certifi = "*"
h11 = ">=0.16"
[package.extras]
asyncio = ["anyio (>=4.0,<5.0)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<1.0)"]
[[package]]
name = "httpx"
version = "0.28.1"
description = "The next generation HTTP client."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"},
{file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"},
]
[package.dependencies]
anyio = "*"
certifi = "*"
h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""}
httpcore = "==1.*"
idna = "*"
[package.extras]
brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "hyperframe"
version = "6.1.0"
description = "Pure-Python HTTP/2 framing"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5"},
{file = "hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08"},
]
[[package]]
name = "idna"
version = "3.10"
@@ -996,7 +1036,7 @@ version = "25.0"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
groups = ["main", "dev"]
files = [
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
@@ -1018,6 +1058,24 @@ files = [
dev = ["pre-commit", "tox"]
testing = ["coverage", "pytest", "pytest-benchmark"]
[[package]]
name = "postgrest"
version = "1.1.1"
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "postgrest-1.1.1-py3-none-any.whl", hash = "sha256:98a6035ee1d14288484bfe36235942c5fb2d26af6d8120dfe3efbe007859251a"},
{file = "postgrest-1.1.1.tar.gz", hash = "sha256:f3bb3e8c4602775c75c844a31f565f5f3dd584df4d36d683f0b67d01a86be322"},
]
[package.dependencies]
deprecation = ">=2.1.0,<3.0.0"
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
pydantic = ">=1.9,<3.0"
strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""}
[[package]]
name = "proto-plus"
version = "1.26.1"
@@ -1404,6 +1462,21 @@ pytest = ">=6.2.5"
[package.extras]
dev = ["pre-commit", "pytest-asyncio", "tox"]
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
description = "Extensions to the standard Python datetime module"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
groups = ["main"]
files = [
{file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
{file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
]
[package.dependencies]
six = ">=1.5"
[[package]]
name = "python-dotenv"
version = "1.1.1"
@@ -1419,6 +1492,22 @@ files = [
[package.extras]
cli = ["click (>=5.0)"]
[[package]]
name = "realtime"
version = "2.5.3"
description = ""
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "realtime-2.5.3-py3-none-any.whl", hash = "sha256:eb0994636946eff04c4c7f044f980c8c633c7eb632994f549f61053a474ac970"},
{file = "realtime-2.5.3.tar.gz", hash = "sha256:0587594f3bc1c84bf007ff625075b86db6528843e03250dc84f4f2808be3d99a"},
]
[package.dependencies]
typing-extensions = ">=4.14.0,<5.0.0"
websockets = ">=11,<16"
[[package]]
name = "redis"
version = "6.2.0"
@@ -1517,6 +1606,18 @@ files = [
{file = "semver-3.0.4.tar.gz", hash = "sha256:afc7d8c584a5ed0a11033af086e8af226a9c0b206f313e0301f8dd7b6b589602"},
]
[[package]]
name = "six"
version = "1.17.0"
description = "Python 2 and 3 compatibility utilities"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
groups = ["main"]
files = [
{file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"},
{file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"},
]
[[package]]
name = "sniffio"
version = "1.3.1"
@@ -1548,6 +1649,76 @@ typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""
[package.extras]
full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"]
[[package]]
name = "storage3"
version = "0.12.0"
description = "Supabase Storage client for Python."
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "storage3-0.12.0-py3-none-any.whl", hash = "sha256:1c4585693ca42243ded1512b58e54c697111e91a20916cd14783eebc37e7c87d"},
{file = "storage3-0.12.0.tar.gz", hash = "sha256:94243f20922d57738bf42e96b9f5582b4d166e8bf209eccf20b146909f3f71b0"},
]
[package.dependencies]
deprecation = ">=2.1.0,<3.0.0"
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
python-dateutil = ">=2.8.2,<3.0.0"
[[package]]
name = "strenum"
version = "0.4.15"
description = "An Enum that inherits from str."
optional = false
python-versions = "*"
groups = ["main"]
files = [
{file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"},
{file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"},
]
[package.extras]
docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"]
release = ["twine"]
test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"]
[[package]]
name = "supabase"
version = "2.16.0"
description = "Supabase client for Python."
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "supabase-2.16.0-py3-none-any.whl", hash = "sha256:99065caab3d90a56650bf39fbd0e49740995da3738ab28706c61bd7f2401db55"},
{file = "supabase-2.16.0.tar.gz", hash = "sha256:98f3810158012d4ec0e3083f2e5515f5e10b32bd71e7d458662140e963c1d164"},
]
[package.dependencies]
gotrue = ">=2.11.0,<3.0.0"
httpx = ">=0.26,<0.29"
postgrest = ">0.19,<1.2"
realtime = ">=2.4.0,<2.6.0"
storage3 = ">=0.10,<0.13"
supafunc = ">=0.9,<0.11"
[[package]]
name = "supafunc"
version = "0.10.1"
description = "Library for Supabase Functions"
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "supafunc-0.10.1-py3-none-any.whl", hash = "sha256:26df9bd25ff2ef56cb5bfb8962de98f43331f7f8ff69572bac3ed9c3a9672040"},
{file = "supafunc-0.10.1.tar.gz", hash = "sha256:a5b33c8baecb6b5297d25da29a2503e2ec67ee6986f3d44c137e651b8a59a17d"},
]
[package.dependencies]
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
strenum = ">=0.4.15,<0.5.0"
[[package]]
name = "tomli"
version = "2.2.1"
@@ -1656,6 +1827,85 @@ typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""}
[package.extras]
standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"]
[[package]]
name = "websockets"
version = "15.0.1"
description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "websockets-15.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d63efaa0cd96cf0c5fe4d581521d9fa87744540d4bc999ae6e08595a1014b45b"},
{file = "websockets-15.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac60e3b188ec7574cb761b08d50fcedf9d77f1530352db4eef1707fe9dee7205"},
{file = "websockets-15.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5756779642579d902eed757b21b0164cd6fe338506a8083eb58af5c372e39d9a"},
{file = "websockets-15.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fdfe3e2a29e4db3659dbd5bbf04560cea53dd9610273917799f1cde46aa725e"},
{file = "websockets-15.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c2529b320eb9e35af0fa3016c187dffb84a3ecc572bcee7c3ce302bfeba52bf"},
{file = "websockets-15.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac1e5c9054fe23226fb11e05a6e630837f074174c4c2f0fe442996112a6de4fb"},
{file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5df592cd503496351d6dc14f7cdad49f268d8e618f80dce0cd5a36b93c3fc08d"},
{file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0a34631031a8f05657e8e90903e656959234f3a04552259458aac0b0f9ae6fd9"},
{file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3d00075aa65772e7ce9e990cab3ff1de702aa09be3940d1dc88d5abf1ab8a09c"},
{file = "websockets-15.0.1-cp310-cp310-win32.whl", hash = "sha256:1234d4ef35db82f5446dca8e35a7da7964d02c127b095e172e54397fb6a6c256"},
{file = "websockets-15.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:39c1fec2c11dc8d89bba6b2bf1556af381611a173ac2b511cf7231622058af41"},
{file = "websockets-15.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:823c248b690b2fd9303ba00c4f66cd5e2d8c3ba4aa968b2779be9532a4dad431"},
{file = "websockets-15.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678999709e68425ae2593acf2e3ebcbcf2e69885a5ee78f9eb80e6e371f1bf57"},
{file = "websockets-15.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d50fd1ee42388dcfb2b3676132c78116490976f1300da28eb629272d5d93e905"},
{file = "websockets-15.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d99e5546bf73dbad5bf3547174cd6cb8ba7273062a23808ffea025ecb1cf8562"},
{file = "websockets-15.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66dd88c918e3287efc22409d426c8f729688d89a0c587c88971a0faa2c2f3792"},
{file = "websockets-15.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8dd8327c795b3e3f219760fa603dcae1dcc148172290a8ab15158cf85a953413"},
{file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8fdc51055e6ff4adeb88d58a11042ec9a5eae317a0a53d12c062c8a8865909e8"},
{file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:693f0192126df6c2327cce3baa7c06f2a117575e32ab2308f7f8216c29d9e2e3"},
{file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:54479983bd5fb469c38f2f5c7e3a24f9a4e70594cd68cd1fa6b9340dadaff7cf"},
{file = "websockets-15.0.1-cp311-cp311-win32.whl", hash = "sha256:16b6c1b3e57799b9d38427dda63edcbe4926352c47cf88588c0be4ace18dac85"},
{file = "websockets-15.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:27ccee0071a0e75d22cb35849b1db43f2ecd3e161041ac1ee9d2352ddf72f065"},
{file = "websockets-15.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3e90baa811a5d73f3ca0bcbf32064d663ed81318ab225ee4f427ad4e26e5aff3"},
{file = "websockets-15.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:592f1a9fe869c778694f0aa806ba0374e97648ab57936f092fd9d87f8bc03665"},
{file = "websockets-15.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0701bc3cfcb9164d04a14b149fd74be7347a530ad3bbf15ab2c678a2cd3dd9a2"},
{file = "websockets-15.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8b56bdcdb4505c8078cb6c7157d9811a85790f2f2b3632c7d1462ab5783d215"},
{file = "websockets-15.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0af68c55afbd5f07986df82831c7bff04846928ea8d1fd7f30052638788bc9b5"},
{file = "websockets-15.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64dee438fed052b52e4f98f76c5790513235efaa1ef7f3f2192c392cd7c91b65"},
{file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d5f6b181bb38171a8ad1d6aa58a67a6aa9d4b38d0f8c5f496b9e42561dfc62fe"},
{file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5d54b09eba2bada6011aea5375542a157637b91029687eb4fdb2dab11059c1b4"},
{file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3be571a8b5afed347da347bfcf27ba12b069d9d7f42cb8c7028b5e98bbb12597"},
{file = "websockets-15.0.1-cp312-cp312-win32.whl", hash = "sha256:c338ffa0520bdb12fbc527265235639fb76e7bc7faafbb93f6ba80d9c06578a9"},
{file = "websockets-15.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcd5cf9e305d7b8338754470cf69cf81f420459dbae8a3b40cee57417f4614a7"},
{file = "websockets-15.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee443ef070bb3b6ed74514f5efaa37a252af57c90eb33b956d35c8e9c10a1931"},
{file = "websockets-15.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5a939de6b7b4e18ca683218320fc67ea886038265fd1ed30173f5ce3f8e85675"},
{file = "websockets-15.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:746ee8dba912cd6fc889a8147168991d50ed70447bf18bcda7039f7d2e3d9151"},
{file = "websockets-15.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:595b6c3969023ecf9041b2936ac3827e4623bfa3ccf007575f04c5a6aa318c22"},
{file = "websockets-15.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c714d2fc58b5ca3e285461a4cc0c9a66bd0e24c5da9911e30158286c9b5be7f"},
{file = "websockets-15.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f3c1e2ab208db911594ae5b4f79addeb3501604a165019dd221c0bdcabe4db8"},
{file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:229cf1d3ca6c1804400b0a9790dc66528e08a6a1feec0d5040e8b9eb14422375"},
{file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:756c56e867a90fb00177d530dca4b097dd753cde348448a1012ed6c5131f8b7d"},
{file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:558d023b3df0bffe50a04e710bc87742de35060580a293c2a984299ed83bc4e4"},
{file = "websockets-15.0.1-cp313-cp313-win32.whl", hash = "sha256:ba9e56e8ceeeedb2e080147ba85ffcd5cd0711b89576b83784d8605a7df455fa"},
{file = "websockets-15.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:e09473f095a819042ecb2ab9465aee615bd9c2028e4ef7d933600a8401c79561"},
{file = "websockets-15.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5f4c04ead5aed67c8a1a20491d54cdfba5884507a48dd798ecaf13c74c4489f5"},
{file = "websockets-15.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abdc0c6c8c648b4805c5eacd131910d2a7f6455dfd3becab248ef108e89ab16a"},
{file = "websockets-15.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a625e06551975f4b7ea7102bc43895b90742746797e2e14b70ed61c43a90f09b"},
{file = "websockets-15.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d591f8de75824cbb7acad4e05d2d710484f15f29d4a915092675ad3456f11770"},
{file = "websockets-15.0.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:47819cea040f31d670cc8d324bb6435c6f133b8c7a19ec3d61634e62f8d8f9eb"},
{file = "websockets-15.0.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac017dd64572e5c3bd01939121e4d16cf30e5d7e110a119399cf3133b63ad054"},
{file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4a9fac8e469d04ce6c25bb2610dc535235bd4aa14996b4e6dbebf5e007eba5ee"},
{file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:363c6f671b761efcb30608d24925a382497c12c506b51661883c3e22337265ed"},
{file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:2034693ad3097d5355bfdacfffcbd3ef5694f9718ab7f29c29689a9eae841880"},
{file = "websockets-15.0.1-cp39-cp39-win32.whl", hash = "sha256:3b1ac0d3e594bf121308112697cf4b32be538fb1444468fb0a6ae4feebc83411"},
{file = "websockets-15.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7643a03db5c95c799b89b31c036d5f27eeb4d259c798e878d6937d71832b1e4"},
{file = "websockets-15.0.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0c9e74d766f2818bb95f84c25be4dea09841ac0f734d1966f415e4edfc4ef1c3"},
{file = "websockets-15.0.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1009ee0c7739c08a0cd59de430d6de452a55e42d6b522de7aa15e6f67db0b8e1"},
{file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d1f20b1c7a2fa82367e04982e708723ba0e7b8d43aa643d3dcd404d74f1475"},
{file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f29d80eb9a9263b8d109135351caf568cc3f80b9928bccde535c235de55c22d9"},
{file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b359ed09954d7c18bbc1680f380c7301f92c60bf924171629c5db97febb12f04"},
{file = "websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122"},
{file = "websockets-15.0.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7f493881579c90fc262d9cdbaa05a6b54b3811c2f300766748db79f098db9940"},
{file = "websockets-15.0.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:47b099e1f4fbc95b701b6e85768e1fcdaf1630f3cbe4765fa216596f12310e2e"},
{file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67f2b6de947f8c757db2db9c71527933ad0019737ec374a8a6be9a956786aaf9"},
{file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d08eb4c2b7d6c41da6ca0600c077e93f5adcfd979cd777d747e9ee624556da4b"},
{file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b826973a4a2ae47ba357e4e82fa44a463b8f168e1ca775ac64521442b19e87f"},
{file = "websockets-15.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:21c1fa28a6a7e3cbdc171c694398b6df4744613ce9b36b1a498e816787e28123"},
{file = "websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f"},
{file = "websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee"},
]
[[package]]
name = "zipp"
version = "3.23.0"
@@ -1679,4 +1929,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
content-hash = "de209c97aa0feb29d669a20e4422d51bdf3a0872ec37e85ce9b88ce726fcee7a"
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"

View File

@@ -18,8 +18,7 @@ pydantic = "^2.11.7"
pydantic-settings = "^2.10.1"
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
redis = "^6.2.0"
bcrypt = "^4.1.0"
authlib = "^1.3.0"
supabase = "^2.16.0"
uvicorn = "^0.35.0"
[tool.poetry.group.dev.dependencies]

View File

@@ -27,15 +27,10 @@ REDIS_PORT=6379
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
# JWT Authentication
# Generate a secure random key: python -c "import secrets; print(secrets.token_urlsafe(32))"
JWT_SIGN_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
# Supabase Authentication
SUPABASE_URL=http://localhost:8000
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
JWT_SIGN_ALGORITHM=HS256
ACCESS_TOKEN_EXPIRE_MINUTES=15
REFRESH_TOKEN_EXPIRE_DAYS=7
JWT_ISSUER=autogpt-platform
JWT_AUDIENCE=authenticated
## ===== REQUIRED SECURITY KEYS ===== ##
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()

View File

@@ -18,6 +18,3 @@ load-tests/results/
load-tests/*.json
load-tests/*.log
load-tests/node_modules/*
# Migration backups (contain user data)
migration_backups/

View File

@@ -20,7 +20,6 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.exceptions import BlockExecutionError
from backend.util.request import Requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -247,11 +246,7 @@ class AIShortformVideoCreatorBlock(Block):
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise BlockExecutionError(
message="Video creation timed out",
block_name=self.name,
block_id=self.id,
)
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
@@ -427,11 +422,7 @@ class AIAdMakerVideoCreatorBlock(Block):
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise BlockExecutionError(
message="Video creation timed out",
block_name=self.name,
block_id=self.id,
)
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
@@ -608,11 +599,7 @@ class AIScreenshotToVideoAdBlock(Block):
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise BlockExecutionError(
message="Video creation timed out",
block_name=self.name,
block_id=self.id,
)
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(

View File

@@ -1371,7 +1371,7 @@ async def create_base(
if tables:
params["tables"] = tables
logger.debug(f"Creating Airtable base with params: {params}")
print(params)
response = await Requests().post(
"https://api.airtable.com/v0/meta/bases",

View File

@@ -106,10 +106,7 @@ class ConditionBlock(Block):
ComparisonOperator.LESS_THAN_OR_EQUAL: lambda a, b: a <= b,
}
try:
result = comparison_funcs[operator](value1, value2)
except Exception as e:
raise ValueError(f"Comparison failed: {e}") from e
result = comparison_funcs[operator](value1, value2)
yield "result", result

View File

@@ -319,7 +319,7 @@ class CostDollars(BaseModel):
# Helper functions for payload processing
def process_text_field(
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None],
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None]
) -> Optional[Union[bool, Dict[str, Any]]]:
"""Process text field for API payload."""
if text is None:
@@ -400,7 +400,7 @@ def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str,
def process_context_field(
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None],
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None]
) -> Optional[Union[bool, Dict[str, int]]]:
"""Process context field for API payload."""
if context is None:

View File

@@ -15,7 +15,6 @@ from backend.sdk import (
SchemaField,
cost,
)
from backend.util.exceptions import BlockExecutionError
from ._config import firecrawl
@@ -60,18 +59,11 @@ class FirecrawlExtractBlock(Block):
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
try:
extract_result = app.extract(
urls=input_data.urls,
prompt=input_data.prompt,
schema=input_data.output_schema,
enable_web_search=input_data.enable_web_search,
)
except Exception as e:
raise BlockExecutionError(
message=f"Extract failed: {e}",
block_name=self.name,
block_id=self.id,
) from e
extract_result = app.extract(
urls=input_data.urls,
prompt=input_data.prompt,
schema=input_data.output_schema,
enable_web_search=input_data.enable_web_search,
)
yield "data", extract_result.data

View File

@@ -19,7 +19,6 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.exceptions import ModerationError
from backend.util.file import MediaFileType, store_media_file
TEST_CREDENTIALS = APIKeyCredentials(
@@ -154,8 +153,6 @@ class AIImageEditorBlock(Block):
),
aspect_ratio=input_data.aspect_ratio.value,
seed=input_data.seed,
user_id=user_id,
graph_exec_id=graph_exec_id,
)
yield "output_image", result
@@ -167,8 +164,6 @@ class AIImageEditorBlock(Block):
input_image_b64: Optional[str],
aspect_ratio: str,
seed: Optional[int],
user_id: str,
graph_exec_id: str,
) -> MediaFileType:
client = ReplicateClient(api_token=api_key.get_secret_value())
input_params = {
@@ -178,21 +173,11 @@ class AIImageEditorBlock(Block):
**({"seed": seed} if seed is not None else {}),
}
try:
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
model_name,
input=input_params,
wait=False,
)
except Exception as e:
if "flagged as sensitive" in str(e).lower():
raise ModerationError(
message="Content was flagged as sensitive by the model provider",
user_id=user_id,
graph_exec_id=graph_exec_id,
moderation_type="model_provider",
)
raise ValueError(f"Model execution failed: {e}") from e
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
model_name,
input=input_params,
wait=False,
)
if isinstance(output, list) and output:
output = output[0]

View File

@@ -1,108 +0,0 @@
{
"action": "created",
"discussion": {
"repository_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"category": {
"id": 12345678,
"node_id": "DIC_kwDOJKSTjM4CXXXX",
"repository_id": 614765452,
"emoji": ":pray:",
"name": "Q&A",
"description": "Ask the community for help",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2023-03-16T09:21:07Z",
"slug": "q-a",
"is_answerable": true
},
"answer_html_url": null,
"answer_chosen_at": null,
"answer_chosen_by": null,
"html_url": "https://github.com/Significant-Gravitas/AutoGPT/discussions/9999",
"id": 5000000001,
"node_id": "D_kwDOJKSTjM4AYYYY",
"number": 9999,
"title": "How do I configure custom blocks?",
"user": {
"login": "curious-user",
"id": 22222222,
"node_id": "MDQ6VXNlcjIyMjIyMjIy",
"avatar_url": "https://avatars.githubusercontent.com/u/22222222?v=4",
"url": "https://api.github.com/users/curious-user",
"html_url": "https://github.com/curious-user",
"type": "User",
"site_admin": false
},
"state": "open",
"state_reason": null,
"locked": false,
"comments": 0,
"created_at": "2024-12-01T17:00:00Z",
"updated_at": "2024-12-01T17:00:00Z",
"author_association": "NONE",
"active_lock_reason": null,
"body": "## Question\n\nI'm trying to create a custom block for my specific use case. I've read the documentation but I'm not sure how to:\n\n1. Define the input/output schema\n2. Handle authentication\n3. Test my block locally\n\nCan someone point me to examples or provide guidance?\n\n## Environment\n\n- AutoGPT Platform version: latest\n- Python: 3.11",
"reactions": {
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/discussions/9999/reactions",
"total_count": 0,
"+1": 0,
"-1": 0,
"laugh": 0,
"hooray": 0,
"confused": 0,
"heart": 0,
"rocket": 0,
"eyes": 0
},
"timeline_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/discussions/9999/timeline"
},
"repository": {
"id": 614765452,
"node_id": "R_kgDOJKSTjA",
"name": "AutoGPT",
"full_name": "Significant-Gravitas/AutoGPT",
"private": false,
"owner": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"url": "https://api.github.com/users/Significant-Gravitas",
"html_url": "https://github.com/Significant-Gravitas",
"type": "Organization",
"site_admin": false
},
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
"fork": false,
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2024-12-01T17:00:00Z",
"pushed_at": "2024-12-01T12:00:00Z",
"stargazers_count": 170000,
"watchers_count": 170000,
"language": "Python",
"has_discussions": true,
"forks_count": 45000,
"visibility": "public",
"default_branch": "master"
},
"organization": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"url": "https://api.github.com/orgs/Significant-Gravitas",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"description": ""
},
"sender": {
"login": "curious-user",
"id": 22222222,
"node_id": "MDQ6VXNlcjIyMjIyMjIy",
"avatar_url": "https://avatars.githubusercontent.com/u/22222222?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/curious-user",
"html_url": "https://github.com/curious-user",
"type": "User",
"site_admin": false
}
}

View File

@@ -1,112 +0,0 @@
{
"action": "opened",
"issue": {
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345",
"repository_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"labels_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/labels{/name}",
"comments_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/comments",
"events_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/events",
"html_url": "https://github.com/Significant-Gravitas/AutoGPT/issues/12345",
"id": 2000000001,
"node_id": "I_kwDOJKSTjM5wXXXX",
"number": 12345,
"title": "Bug: Application crashes when processing large files",
"user": {
"login": "bug-reporter",
"id": 11111111,
"node_id": "MDQ6VXNlcjExMTExMTEx",
"avatar_url": "https://avatars.githubusercontent.com/u/11111111?v=4",
"url": "https://api.github.com/users/bug-reporter",
"html_url": "https://github.com/bug-reporter",
"type": "User",
"site_admin": false
},
"labels": [
{
"id": 5272676214,
"node_id": "LA_kwDOJKSTjM8AAAABOkandg",
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/labels/bug",
"name": "bug",
"color": "d73a4a",
"default": true,
"description": "Something isn't working"
}
],
"state": "open",
"locked": false,
"assignee": null,
"assignees": [],
"milestone": null,
"comments": 0,
"created_at": "2024-12-01T16:00:00Z",
"updated_at": "2024-12-01T16:00:00Z",
"closed_at": null,
"author_association": "NONE",
"active_lock_reason": null,
"body": "## Description\n\nWhen I try to process a file larger than 100MB, the application crashes with an out of memory error.\n\n## Steps to Reproduce\n\n1. Open the application\n2. Select a file larger than 100MB\n3. Click 'Process'\n4. Application crashes\n\n## Expected Behavior\n\nThe application should handle large files gracefully.\n\n## Environment\n\n- OS: Ubuntu 22.04\n- Python: 3.11\n- AutoGPT Version: 1.0.0",
"reactions": {
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/reactions",
"total_count": 0,
"+1": 0,
"-1": 0,
"laugh": 0,
"hooray": 0,
"confused": 0,
"heart": 0,
"rocket": 0,
"eyes": 0
},
"timeline_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/timeline",
"state_reason": null
},
"repository": {
"id": 614765452,
"node_id": "R_kgDOJKSTjA",
"name": "AutoGPT",
"full_name": "Significant-Gravitas/AutoGPT",
"private": false,
"owner": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"url": "https://api.github.com/users/Significant-Gravitas",
"html_url": "https://github.com/Significant-Gravitas",
"type": "Organization",
"site_admin": false
},
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
"fork": false,
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2024-12-01T16:00:00Z",
"pushed_at": "2024-12-01T12:00:00Z",
"stargazers_count": 170000,
"watchers_count": 170000,
"language": "Python",
"forks_count": 45000,
"open_issues_count": 190,
"visibility": "public",
"default_branch": "master"
},
"organization": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"url": "https://api.github.com/orgs/Significant-Gravitas",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"description": ""
},
"sender": {
"login": "bug-reporter",
"id": 11111111,
"node_id": "MDQ6VXNlcjExMTExMTEx",
"avatar_url": "https://avatars.githubusercontent.com/u/11111111?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/bug-reporter",
"html_url": "https://github.com/bug-reporter",
"type": "User",
"site_admin": false
}
}

View File

@@ -1,97 +0,0 @@
{
"action": "published",
"release": {
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/releases/123456789",
"assets_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/releases/123456789/assets",
"upload_url": "https://uploads.github.com/repos/Significant-Gravitas/AutoGPT/releases/123456789/assets{?name,label}",
"html_url": "https://github.com/Significant-Gravitas/AutoGPT/releases/tag/v1.0.0",
"id": 123456789,
"author": {
"login": "ntindle",
"id": 12345678,
"node_id": "MDQ6VXNlcjEyMzQ1Njc4",
"avatar_url": "https://avatars.githubusercontent.com/u/12345678?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/ntindle",
"html_url": "https://github.com/ntindle",
"type": "User",
"site_admin": false
},
"node_id": "RE_kwDOJKSTjM4HWwAA",
"tag_name": "v1.0.0",
"target_commitish": "master",
"name": "AutoGPT Platform v1.0.0",
"draft": false,
"prerelease": false,
"created_at": "2024-12-01T10:00:00Z",
"published_at": "2024-12-01T12:00:00Z",
"assets": [
{
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/releases/assets/987654321",
"id": 987654321,
"node_id": "RA_kwDOJKSTjM4HWwBB",
"name": "autogpt-v1.0.0.zip",
"label": "Release Package",
"content_type": "application/zip",
"state": "uploaded",
"size": 52428800,
"download_count": 0,
"created_at": "2024-12-01T11:30:00Z",
"updated_at": "2024-12-01T11:35:00Z",
"browser_download_url": "https://github.com/Significant-Gravitas/AutoGPT/releases/download/v1.0.0/autogpt-v1.0.0.zip"
}
],
"tarball_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/tarball/v1.0.0",
"zipball_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/zipball/v1.0.0",
"body": "## What's New\n\n- Feature 1: Amazing new capability\n- Feature 2: Performance improvements\n- Bug fixes and stability improvements\n\n## Breaking Changes\n\nNone\n\n## Contributors\n\nThanks to all our contributors!"
},
"repository": {
"id": 614765452,
"node_id": "R_kgDOJKSTjA",
"name": "AutoGPT",
"full_name": "Significant-Gravitas/AutoGPT",
"private": false,
"owner": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"url": "https://api.github.com/users/Significant-Gravitas",
"html_url": "https://github.com/Significant-Gravitas",
"type": "Organization",
"site_admin": false
},
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
"fork": false,
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2024-12-01T12:00:00Z",
"pushed_at": "2024-12-01T12:00:00Z",
"stargazers_count": 170000,
"watchers_count": 170000,
"language": "Python",
"forks_count": 45000,
"visibility": "public",
"default_branch": "master"
},
"organization": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"url": "https://api.github.com/orgs/Significant-Gravitas",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"description": ""
},
"sender": {
"login": "ntindle",
"id": 12345678,
"node_id": "MDQ6VXNlcjEyMzQ1Njc4",
"avatar_url": "https://avatars.githubusercontent.com/u/12345678?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/ntindle",
"html_url": "https://github.com/ntindle",
"type": "User",
"site_admin": false
}
}

View File

@@ -1,53 +0,0 @@
{
"action": "created",
"starred_at": "2024-12-01T15:30:00Z",
"repository": {
"id": 614765452,
"node_id": "R_kgDOJKSTjA",
"name": "AutoGPT",
"full_name": "Significant-Gravitas/AutoGPT",
"private": false,
"owner": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"url": "https://api.github.com/users/Significant-Gravitas",
"html_url": "https://github.com/Significant-Gravitas",
"type": "Organization",
"site_admin": false
},
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
"fork": false,
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2024-12-01T15:30:00Z",
"pushed_at": "2024-12-01T12:00:00Z",
"stargazers_count": 170001,
"watchers_count": 170001,
"language": "Python",
"forks_count": 45000,
"visibility": "public",
"default_branch": "master"
},
"organization": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"url": "https://api.github.com/orgs/Significant-Gravitas",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"description": ""
},
"sender": {
"login": "awesome-contributor",
"id": 98765432,
"node_id": "MDQ6VXNlcjk4NzY1NDMy",
"avatar_url": "https://avatars.githubusercontent.com/u/98765432?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/awesome-contributor",
"html_url": "https://github.com/awesome-contributor",
"type": "User",
"site_admin": false
}
}

View File

@@ -159,391 +159,3 @@ class GithubPullRequestTriggerBlock(GitHubTriggerBase, Block):
# --8<-- [end:GithubTriggerExample]
class GithubStarTriggerBlock(GitHubTriggerBase, Block):
"""Trigger block for GitHub star events - useful for milestone celebrations."""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "star.created.json"
)
class Input(GitHubTriggerBase.Input):
class EventsFilter(BaseModel):
"""
https://docs.github.com/en/webhooks/webhook-events-and-payloads#star
"""
created: bool = False
deleted: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The star events to subscribe to"
)
class Output(GitHubTriggerBase.Output):
event: str = SchemaField(
description="The star event that triggered the webhook ('created' or 'deleted')"
)
starred_at: str = SchemaField(
description="ISO timestamp when the repo was starred (empty if deleted)"
)
stargazers_count: int = SchemaField(
description="Current number of stars on the repository"
)
repository_name: str = SchemaField(
description="Full name of the repository (owner/repo)"
)
repository_url: str = SchemaField(description="URL to the repository")
def __init__(self):
from backend.integrations.webhooks.github import GithubWebhookType
example_payload = json.loads(
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
)
super().__init__(
id="551e0a35-100b-49b7-89b8-3031322239b6",
description="This block triggers on GitHub star events. "
"Useful for celebrating milestones (e.g., 1k, 10k stars) or tracking engagement.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=GithubStarTriggerBlock.Input,
output_schema=GithubStarTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",
event_format="star.{event}",
),
test_input={
"repo": "Significant-Gravitas/AutoGPT",
"events": {"created": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("triggered_by_user", example_payload["sender"]),
("event", example_payload["action"]),
("starred_at", example_payload.get("starred_at", "")),
("stargazers_count", example_payload["repository"]["stargazers_count"]),
("repository_name", example_payload["repository"]["full_name"]),
("repository_url", example_payload["repository"]["html_url"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
async for name, value in super().run(input_data, **kwargs):
yield name, value
yield "event", input_data.payload["action"]
yield "starred_at", input_data.payload.get("starred_at", "")
yield "stargazers_count", input_data.payload["repository"]["stargazers_count"]
yield "repository_name", input_data.payload["repository"]["full_name"]
yield "repository_url", input_data.payload["repository"]["html_url"]
class GithubReleaseTriggerBlock(GitHubTriggerBase, Block):
"""Trigger block for GitHub release events - ideal for announcing new versions."""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "release.published.json"
)
class Input(GitHubTriggerBase.Input):
class EventsFilter(BaseModel):
"""
https://docs.github.com/en/webhooks/webhook-events-and-payloads#release
"""
published: bool = False
unpublished: bool = False
created: bool = False
edited: bool = False
deleted: bool = False
prereleased: bool = False
released: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The release events to subscribe to"
)
class Output(GitHubTriggerBase.Output):
event: str = SchemaField(
description="The release event that triggered the webhook (e.g., 'published')"
)
release: dict = SchemaField(description="The full release object")
release_url: str = SchemaField(description="URL to the release page")
tag_name: str = SchemaField(description="The release tag name (e.g., 'v1.0.0')")
release_name: str = SchemaField(description="Human-readable release name")
body: str = SchemaField(description="Release notes/description")
prerelease: bool = SchemaField(description="Whether this is a prerelease")
draft: bool = SchemaField(description="Whether this is a draft release")
assets: list = SchemaField(description="List of release assets/files")
def __init__(self):
from backend.integrations.webhooks.github import GithubWebhookType
example_payload = json.loads(
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
)
super().__init__(
id="2052dd1b-74e1-46ac-9c87-c7a0e057b60b",
description="This block triggers on GitHub release events. "
"Perfect for automating announcements to Discord, Twitter, or other platforms.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=GithubReleaseTriggerBlock.Input,
output_schema=GithubReleaseTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",
event_format="release.{event}",
),
test_input={
"repo": "Significant-Gravitas/AutoGPT",
"events": {"published": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("triggered_by_user", example_payload["sender"]),
("event", example_payload["action"]),
("release", example_payload["release"]),
("release_url", example_payload["release"]["html_url"]),
("tag_name", example_payload["release"]["tag_name"]),
("release_name", example_payload["release"]["name"]),
("body", example_payload["release"]["body"]),
("prerelease", example_payload["release"]["prerelease"]),
("draft", example_payload["release"]["draft"]),
("assets", example_payload["release"]["assets"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
async for name, value in super().run(input_data, **kwargs):
yield name, value
release = input_data.payload["release"]
yield "event", input_data.payload["action"]
yield "release", release
yield "release_url", release["html_url"]
yield "tag_name", release["tag_name"]
yield "release_name", release.get("name", "")
yield "body", release.get("body", "")
yield "prerelease", release["prerelease"]
yield "draft", release["draft"]
yield "assets", release["assets"]
class GithubIssuesTriggerBlock(GitHubTriggerBase, Block):
"""Trigger block for GitHub issues events - great for triage and notifications."""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "issues.opened.json"
)
class Input(GitHubTriggerBase.Input):
class EventsFilter(BaseModel):
"""
https://docs.github.com/en/webhooks/webhook-events-and-payloads#issues
"""
opened: bool = False
edited: bool = False
deleted: bool = False
closed: bool = False
reopened: bool = False
assigned: bool = False
unassigned: bool = False
labeled: bool = False
unlabeled: bool = False
locked: bool = False
unlocked: bool = False
transferred: bool = False
milestoned: bool = False
demilestoned: bool = False
pinned: bool = False
unpinned: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The issue events to subscribe to"
)
class Output(GitHubTriggerBase.Output):
event: str = SchemaField(
description="The issue event that triggered the webhook (e.g., 'opened')"
)
number: int = SchemaField(description="The issue number")
issue: dict = SchemaField(description="The full issue object")
issue_url: str = SchemaField(description="URL to the issue")
issue_title: str = SchemaField(description="The issue title")
issue_body: str = SchemaField(description="The issue body/description")
labels: list = SchemaField(description="List of labels on the issue")
assignees: list = SchemaField(description="List of assignees")
state: str = SchemaField(description="Issue state ('open' or 'closed')")
def __init__(self):
from backend.integrations.webhooks.github import GithubWebhookType
example_payload = json.loads(
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
)
super().__init__(
id="b2605464-e486-4bf4-aad3-d8a213c8a48a",
description="This block triggers on GitHub issues events. "
"Useful for automated triage, notifications, and welcoming first-time contributors.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=GithubIssuesTriggerBlock.Input,
output_schema=GithubIssuesTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",
event_format="issues.{event}",
),
test_input={
"repo": "Significant-Gravitas/AutoGPT",
"events": {"opened": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("triggered_by_user", example_payload["sender"]),
("event", example_payload["action"]),
("number", example_payload["issue"]["number"]),
("issue", example_payload["issue"]),
("issue_url", example_payload["issue"]["html_url"]),
("issue_title", example_payload["issue"]["title"]),
("issue_body", example_payload["issue"]["body"]),
("labels", example_payload["issue"]["labels"]),
("assignees", example_payload["issue"]["assignees"]),
("state", example_payload["issue"]["state"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
async for name, value in super().run(input_data, **kwargs):
yield name, value
issue = input_data.payload["issue"]
yield "event", input_data.payload["action"]
yield "number", issue["number"]
yield "issue", issue
yield "issue_url", issue["html_url"]
yield "issue_title", issue["title"]
yield "issue_body", issue.get("body") or ""
yield "labels", issue["labels"]
yield "assignees", issue["assignees"]
yield "state", issue["state"]
class GithubDiscussionTriggerBlock(GitHubTriggerBase, Block):
"""Trigger block for GitHub discussion events - perfect for community Q&A sync."""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "discussion.created.json"
)
class Input(GitHubTriggerBase.Input):
class EventsFilter(BaseModel):
"""
https://docs.github.com/en/webhooks/webhook-events-and-payloads#discussion
"""
created: bool = False
edited: bool = False
deleted: bool = False
answered: bool = False
unanswered: bool = False
labeled: bool = False
unlabeled: bool = False
locked: bool = False
unlocked: bool = False
category_changed: bool = False
transferred: bool = False
pinned: bool = False
unpinned: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The discussion events to subscribe to"
)
class Output(GitHubTriggerBase.Output):
event: str = SchemaField(
description="The discussion event that triggered the webhook"
)
number: int = SchemaField(description="The discussion number")
discussion: dict = SchemaField(description="The full discussion object")
discussion_url: str = SchemaField(description="URL to the discussion")
title: str = SchemaField(description="The discussion title")
body: str = SchemaField(description="The discussion body")
category: dict = SchemaField(description="The discussion category object")
category_name: str = SchemaField(description="Name of the category")
state: str = SchemaField(description="Discussion state")
def __init__(self):
from backend.integrations.webhooks.github import GithubWebhookType
example_payload = json.loads(
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
)
super().__init__(
id="87f847b3-d81a-424e-8e89-acadb5c9d52b",
description="This block triggers on GitHub Discussions events. "
"Great for syncing Q&A to Discord or auto-responding to common questions. "
"Note: Discussions must be enabled on the repository.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=GithubDiscussionTriggerBlock.Input,
output_schema=GithubDiscussionTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",
event_format="discussion.{event}",
),
test_input={
"repo": "Significant-Gravitas/AutoGPT",
"events": {"created": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("triggered_by_user", example_payload["sender"]),
("event", example_payload["action"]),
("number", example_payload["discussion"]["number"]),
("discussion", example_payload["discussion"]),
("discussion_url", example_payload["discussion"]["html_url"]),
("title", example_payload["discussion"]["title"]),
("body", example_payload["discussion"]["body"]),
("category", example_payload["discussion"]["category"]),
("category_name", example_payload["discussion"]["category"]["name"]),
("state", example_payload["discussion"]["state"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
async for name, value in super().run(input_data, **kwargs):
yield name, value
discussion = input_data.payload["discussion"]
yield "event", input_data.payload["action"]
yield "number", discussion["number"]
yield "discussion", discussion
yield "discussion_url", discussion["html_url"]
yield "title", discussion["title"]
yield "body", discussion.get("body") or ""
yield "category", discussion["category"]
yield "category_name", discussion["category"]["name"]
yield "state", discussion["state"]

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any
from typing import Any, Literal
from prisma.enums import ReviewStatus
@@ -45,11 +45,11 @@ class HumanInTheLoopBlock(Block):
)
class Output(BlockSchemaOutput):
approved_data: Any = SchemaField(
description="The data when approved (may be modified by reviewer)"
reviewed_data: Any = SchemaField(
description="The data after human review (may be modified)"
)
rejected_data: Any = SchemaField(
description="The data when rejected (may be modified by reviewer)"
status: Literal["approved", "rejected"] = SchemaField(
description="Status of the review: 'approved' or 'rejected'"
)
review_message: str = SchemaField(
description="Any message provided by the reviewer", default=""
@@ -69,7 +69,8 @@ class HumanInTheLoopBlock(Block):
"editable": True,
},
test_output=[
("approved_data", {"name": "John Doe", "age": 30}),
("status", "approved"),
("reviewed_data", {"name": "John Doe", "age": 30}),
],
test_mock={
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
@@ -115,7 +116,8 @@ class HumanInTheLoopBlock(Block):
logger.info(
f"HITL block skipping review for node {node_exec_id} - safe mode disabled"
)
yield "approved_data", input_data.data
yield "status", "approved"
yield "reviewed_data", input_data.data
yield "review_message", "Auto-approved (safe mode disabled)"
return
@@ -156,11 +158,12 @@ class HumanInTheLoopBlock(Block):
)
if result.status == ReviewStatus.APPROVED:
yield "approved_data", result.data
yield "status", "approved"
yield "reviewed_data", result.data
if result.message:
yield "review_message", result.message
elif result.status == ReviewStatus.REJECTED:
yield "rejected_data", result.data
yield "status", "rejected"
if result.message:
yield "review_message", result.message

View File

@@ -2,6 +2,7 @@ from enum import Enum
from typing import Any, Dict, Literal, Optional
from pydantic import SecretStr
from requests.exceptions import RequestException
from backend.data.block import (
Block,
@@ -331,8 +332,8 @@ class IdeogramModelBlock(Block):
try:
response = await Requests().post(url, headers=headers, json=data)
return response.json()["data"][0]["url"]
except Exception as e:
raise ValueError(f"Failed to fetch image with V3 endpoint: {e}") from e
except RequestException as e:
raise Exception(f"Failed to fetch image with V3 endpoint: {str(e)}")
async def _run_model_legacy(
self,
@@ -384,8 +385,8 @@ class IdeogramModelBlock(Block):
try:
response = await Requests().post(url, headers=headers, json=data)
return response.json()["data"][0]["url"]
except Exception as e:
raise ValueError(f"Failed to fetch image with legacy endpoint: {e}") from e
except RequestException as e:
raise Exception(f"Failed to fetch image with legacy endpoint: {str(e)}")
async def upscale_image(self, api_key: SecretStr, image_url: str):
url = "https://api.ideogram.ai/upscale"
@@ -412,5 +413,5 @@ class IdeogramModelBlock(Block):
return (response.json())["data"][0]["url"]
except Exception as e:
raise ValueError(f"Failed to upscale image: {e}") from e
except RequestException as e:
raise Exception(f"Failed to upscale image: {str(e)}")

View File

@@ -16,7 +16,6 @@ from backend.data.block import (
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
class SearchTheWebBlock(Block, GetRequest):
@@ -57,17 +56,7 @@ class SearchTheWebBlock(Block, GetRequest):
# Prepend the Jina Search URL to the encoded query
jina_search_url = f"https://s.jina.ai/{encoded_query}"
try:
results = await self.get_request(
jina_search_url, headers=headers, json=False
)
except Exception as e:
raise BlockExecutionError(
message=f"Search failed: {e}",
block_name=self.name,
block_id=self.id,
) from e
results = await self.get_request(jina_search_url, headers=headers, json=False)
# Output the search results
yield "results", results

View File

@@ -1,4 +1,3 @@
import logging
from datetime import datetime, timezone
from typing import Iterator, Literal
@@ -65,7 +64,6 @@ class RedditComment(BaseModel):
settings = Settings()
logger = logging.getLogger(__name__)
def get_praw(creds: RedditCredentials) -> praw.Reddit:
@@ -79,7 +77,7 @@ def get_praw(creds: RedditCredentials) -> praw.Reddit:
me = client.user.me()
if not me:
raise ValueError("Invalid Reddit credentials.")
logger.info(f"Logged in as Reddit user: {me.name}")
print(f"Logged in as Reddit user: {me.name}")
return client

View File

@@ -18,7 +18,6 @@ from backend.data.block import (
BlockSchemaOutput,
)
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
from backend.util.exceptions import BlockExecutionError, BlockInputError
logger = logging.getLogger(__name__)
@@ -112,27 +111,9 @@ class ReplicateModelBlock(Block):
yield "status", "succeeded"
yield "model_name", input_data.model_name
except Exception as e:
error_msg = str(e)
logger.error(f"Error running Replicate model: {error_msg}")
# Input validation errors (422, 400) → BlockInputError
if (
"422" in error_msg
or "Input validation failed" in error_msg
or "400" in error_msg
):
raise BlockInputError(
message=f"Invalid model inputs: {error_msg}",
block_name=self.name,
block_id=self.id,
) from e
# Everything else → BlockExecutionError
else:
raise BlockExecutionError(
message=f"Replicate model error: {error_msg}",
block_name=self.name,
block_id=self.id,
) from e
error_msg = f"Unexpected error running Replicate model: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
async def run_model(self, model_ref: str, model_inputs: dict, api_key: SecretStr):
"""

View File

@@ -45,16 +45,10 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
topic = input_data.topic
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
# Note: User-Agent is now automatically set by the request library
# to comply with Wikimedia's robot policy (https://w.wiki/4wJS)
try:
response = await self.get_request(url, json=True)
if "extract" not in response:
raise ValueError(f"Unable to parse Wikipedia response: {response}")
yield "summary", response["extract"]
except Exception as e:
raise ValueError(f"Failed to fetch Wikipedia summary: {e}") from e
response = await self.get_request(url, json=True)
if "extract" not in response:
raise RuntimeError(f"Unable to parse Wikipedia response: {response}")
yield "summary", response["extract"]
TEST_CREDENTIALS = APIKeyCredentials(

View File

@@ -1,11 +1,8 @@
import logging
import re
from collections import Counter
from concurrent.futures import Future
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
import backend.blocks.llm as llm
from backend.blocks.agent import AgentExecutorBlock
from backend.data.block import (
@@ -23,41 +20,16 @@ from backend.data.dynamic_fields import (
is_dynamic_field,
is_tool_pin,
)
from backend.data.execution import ExecutionContext
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json
from backend.util.clients import get_database_manager_async_client
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
if TYPE_CHECKING:
from backend.data.graph import Link, Node
from backend.executor.manager import ExecutionProcessor
logger = logging.getLogger(__name__)
class ToolInfo(BaseModel):
"""Processed tool call information."""
tool_call: Any # The original tool call object from LLM response
tool_name: str # The function name
tool_def: dict[str, Any] # The tool definition from tool_functions
input_data: dict[str, Any] # Processed input data ready for tool execution
field_mapping: dict[str, str] # Field name mapping for the tool
class ExecutionParams(BaseModel):
"""Tool execution parameters."""
user_id: str
graph_id: str
node_id: str
graph_version: int
graph_exec_id: str
node_exec_id: str
execution_context: "ExecutionContext"
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool request.
@@ -133,50 +105,6 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
return {"role": "tool", "tool_call_id": call_id, "content": content}
def _combine_tool_responses(tool_outputs: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Combine multiple Anthropic tool responses into a single user message.
For non-Anthropic formats, returns the original list unchanged.
"""
if len(tool_outputs) <= 1:
return tool_outputs
# Anthropic responses have role="user", type="message", and content is a list with tool_result items
anthropic_responses = [
output
for output in tool_outputs
if (
output.get("role") == "user"
and output.get("type") == "message"
and isinstance(output.get("content"), list)
and any(
item.get("type") == "tool_result"
for item in output.get("content", [])
if isinstance(item, dict)
)
)
]
if len(anthropic_responses) > 1:
combined_content = [
item for response in anthropic_responses for item in response["content"]
]
combined_response = {
"role": "user",
"type": "message",
"content": combined_content,
}
non_anthropic_responses = [
output for output in tool_outputs if output not in anthropic_responses
]
return [combined_response] + non_anthropic_responses
return tool_outputs
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
"""
Safely convert raw_response to dictionary format for conversation history.
@@ -276,17 +204,6 @@ class SmartDecisionMakerBlock(Block):
default="localhost:11434",
description="Ollama host for local models",
)
agent_mode_max_iterations: int = SchemaField(
title="Agent Mode Max Iterations",
description="Maximum iterations for agent mode. 0 = traditional mode (single LLM call, yield tool calls for external execution), -1 = infinite agent mode (loop until finished), 1+ = agent mode with max iterations limit.",
advanced=True,
default=0,
)
conversation_compaction: bool = SchemaField(
default=True,
title="Context window auto-compaction",
description="Automatically compact the context window once it hits the limit",
)
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
@@ -589,7 +506,6 @@ class SmartDecisionMakerBlock(Block):
Returns the response if successful, raises ValueError if validation fails.
"""
resp = await llm.llm_call(
compress_prompt_to_fit=input_data.conversation_compaction,
credentials=credentials,
llm_model=input_data.model,
prompt=current_prompt,
@@ -677,291 +593,6 @@ class SmartDecisionMakerBlock(Block):
return resp
def _process_tool_calls(
self, response, tool_functions: list[dict[str, Any]]
) -> list[ToolInfo]:
"""Process tool calls and extract tool definitions, arguments, and input data.
Returns a list of tool info dicts with:
- tool_call: The original tool call object
- tool_name: The function name
- tool_def: The tool definition from tool_functions
- input_data: Processed input data dict (includes None values)
- field_mapping: Field name mapping for the tool
"""
if not response.tool_calls:
return []
processed_tools = []
for tool_call in response.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
tool_def = next(
(
tool
for tool in tool_functions
if tool["function"]["name"] == tool_name
),
None,
)
if not tool_def:
if len(tool_functions) == 1:
tool_def = tool_functions[0]
else:
continue
# Build input data for the tool
input_data = {}
field_mapping = tool_def["function"].get("_field_mapping", {})
if "function" in tool_def and "parameters" in tool_def["function"]:
expected_args = tool_def["function"]["parameters"].get("properties", {})
for clean_arg_name in expected_args:
original_field_name = field_mapping.get(
clean_arg_name, clean_arg_name
)
arg_value = tool_args.get(clean_arg_name)
# Include all expected parameters, even if None (for backward compatibility with tests)
input_data[original_field_name] = arg_value
processed_tools.append(
ToolInfo(
tool_call=tool_call,
tool_name=tool_name,
tool_def=tool_def,
input_data=input_data,
field_mapping=field_mapping,
)
)
return processed_tools
def _update_conversation(
self, prompt: list[dict], response, tool_outputs: list | None = None
):
"""Update conversation history with response and tool outputs."""
# Don't add separate reasoning message with tool calls (breaks Anthropic's tool_use->tool_result pairing)
assistant_message = _convert_raw_response_to_dict(response.raw_response)
has_tool_calls = isinstance(assistant_message.get("content"), list) and any(
item.get("type") == "tool_use"
for item in assistant_message.get("content", [])
)
if response.reasoning and not has_tool_calls:
prompt.append(
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
)
prompt.append(assistant_message)
if tool_outputs:
prompt.extend(tool_outputs)
async def _execute_single_tool_with_manager(
self,
tool_info: ToolInfo,
execution_params: ExecutionParams,
execution_processor: "ExecutionProcessor",
) -> dict:
"""Execute a single tool using the execution manager for proper integration."""
# Lazy imports to avoid circular dependencies
from backend.data.execution import NodeExecutionEntry
tool_call = tool_info.tool_call
tool_def = tool_info.tool_def
raw_input_data = tool_info.input_data
# Get sink node and field mapping
sink_node_id = tool_def["function"]["_sink_node_id"]
# Use proper database operations for tool execution
db_client = get_database_manager_async_client()
# Get target node
target_node = await db_client.get_node(sink_node_id)
if not target_node:
raise ValueError(f"Target node {sink_node_id} not found")
# Create proper node execution using upsert_execution_input
node_exec_result = None
final_input_data = None
# Add all inputs to the execution
if not raw_input_data:
raise ValueError(f"Tool call has no input data: {tool_call}")
for input_name, input_value in raw_input_data.items():
node_exec_result, final_input_data = await db_client.upsert_execution_input(
node_id=sink_node_id,
graph_exec_id=execution_params.graph_exec_id,
input_name=input_name,
input_data=input_value,
)
assert node_exec_result is not None, "node_exec_result should not be None"
# Create NodeExecutionEntry for execution manager
node_exec_entry = NodeExecutionEntry(
user_id=execution_params.user_id,
graph_exec_id=execution_params.graph_exec_id,
graph_id=execution_params.graph_id,
graph_version=execution_params.graph_version,
node_exec_id=node_exec_result.node_exec_id,
node_id=sink_node_id,
block_id=target_node.block_id,
inputs=final_input_data or {},
execution_context=execution_params.execution_context,
)
# Use the execution manager to execute the tool node
try:
# Get NodeExecutionProgress from the execution manager's running nodes
node_exec_progress = execution_processor.running_node_execution[
sink_node_id
]
# Use the execution manager's own graph stats
graph_stats_pair = (
execution_processor.execution_stats,
execution_processor.execution_stats_lock,
)
# Create a completed future for the task tracking system
node_exec_future = Future()
node_exec_progress.add_task(
node_exec_id=node_exec_result.node_exec_id,
task=node_exec_future,
)
# Execute the node directly since we're in the SmartDecisionMaker context
node_exec_future.set_result(
await execution_processor.on_node_execution(
node_exec=node_exec_entry,
node_exec_progress=node_exec_progress,
nodes_input_masks=None,
graph_stats_pair=graph_stats_pair,
)
)
# Get outputs from database after execution completes using database manager client
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
node_exec_result.node_exec_id
)
# Create tool response
tool_response_content = (
json.dumps(node_outputs)
if node_outputs
else "Tool executed successfully"
)
return _create_tool_response(tool_call.id, tool_response_content)
except Exception as e:
logger.error(f"Tool execution with manager failed: {e}")
# Return error response
return _create_tool_response(
tool_call.id, f"Tool execution failed: {str(e)}"
)
async def _execute_tools_agent_mode(
self,
input_data,
credentials,
tool_functions: list[dict[str, Any]],
prompt: list[dict],
graph_exec_id: str,
node_id: str,
node_exec_id: str,
user_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
execution_processor: "ExecutionProcessor",
):
"""Execute tools in agent mode with a loop until finished."""
max_iterations = input_data.agent_mode_max_iterations
iteration = 0
# Execution parameters for tool execution
execution_params = ExecutionParams(
user_id=user_id,
graph_id=graph_id,
node_id=node_id,
graph_version=graph_version,
graph_exec_id=graph_exec_id,
node_exec_id=node_exec_id,
execution_context=execution_context,
)
current_prompt = list(prompt)
while max_iterations < 0 or iteration < max_iterations:
iteration += 1
logger.debug(f"Agent mode iteration {iteration}")
# Prepare prompt for this iteration
iteration_prompt = list(current_prompt)
# On the last iteration, add a special system message to encourage completion
if max_iterations > 0 and iteration == max_iterations:
last_iteration_message = {
"role": "system",
"content": f"{MAIN_OBJECTIVE_PREFIX}This is your last iteration ({iteration}/{max_iterations}). "
"Try to complete the task with the information you have. If you cannot fully complete it, "
"provide a summary of what you've accomplished and what remains to be done. "
"Prefer finishing with a clear response rather than making additional tool calls.",
}
iteration_prompt.append(last_iteration_message)
# Get LLM response
try:
response = await self._attempt_llm_call_with_validation(
credentials, input_data, iteration_prompt, tool_functions
)
except Exception as e:
yield "error", f"LLM call failed in agent mode iteration {iteration}: {str(e)}"
return
# Process tool calls
processed_tools = self._process_tool_calls(response, tool_functions)
# If no tool calls, we're done
if not processed_tools:
yield "finished", response.response
self._update_conversation(current_prompt, response)
yield "conversations", current_prompt
return
# Execute tools and collect responses
tool_outputs = []
for tool_info in processed_tools:
try:
tool_response = await self._execute_single_tool_with_manager(
tool_info, execution_params, execution_processor
)
tool_outputs.append(tool_response)
except Exception as e:
logger.error(f"Tool execution failed: {e}")
# Create error response for the tool
error_response = _create_tool_response(
tool_info.tool_call.id, f"Error: {str(e)}"
)
tool_outputs.append(error_response)
tool_outputs = _combine_tool_responses(tool_outputs)
self._update_conversation(current_prompt, response, tool_outputs)
# Yield intermediate conversation state
yield "conversations", current_prompt
# If we reach max iterations, yield the current state
if max_iterations < 0:
yield "finished", f"Agent mode completed after {iteration} iterations"
else:
yield "finished", f"Agent mode completed after {max_iterations} iterations (limit reached)"
yield "conversations", current_prompt
async def run(
self,
input_data: Input,
@@ -972,12 +603,8 @@ class SmartDecisionMakerBlock(Block):
graph_exec_id: str,
node_exec_id: str,
user_id: str,
graph_version: int,
execution_context: ExecutionContext,
execution_processor: "ExecutionProcessor",
**kwargs,
) -> BlockOutput:
tool_functions = await self._create_tool_node_signatures(node_id)
yield "tool_functions", json.dumps(tool_functions)
@@ -1021,52 +648,24 @@ class SmartDecisionMakerBlock(Block):
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
prefix = "[Main Objective Prompt]: "
if input_data.sys_prompt and not any(
p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
for p in prompt
p["role"] == "system" and p["content"].startswith(prefix) for p in prompt
):
prompt.append(
{
"role": "system",
"content": MAIN_OBJECTIVE_PREFIX + input_data.sys_prompt,
}
)
prompt.append({"role": "system", "content": prefix + input_data.sys_prompt})
if input_data.prompt and not any(
p["role"] == "user" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
for p in prompt
p["role"] == "user" and p["content"].startswith(prefix) for p in prompt
):
prompt.append(
{"role": "user", "content": MAIN_OBJECTIVE_PREFIX + input_data.prompt}
)
prompt.append({"role": "user", "content": prefix + input_data.prompt})
# Execute tools based on the selected mode
if input_data.agent_mode_max_iterations != 0:
# In agent mode, execute tools directly in a loop until finished
async for result in self._execute_tools_agent_mode(
input_data=input_data,
credentials=credentials,
tool_functions=tool_functions,
prompt=prompt,
graph_exec_id=graph_exec_id,
node_id=node_id,
node_exec_id=node_exec_id,
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
execution_processor=execution_processor,
):
yield result
return
# One-off mode: single LLM call and yield tool calls for external execution
current_prompt = list(prompt)
max_attempts = max(1, int(input_data.retry))
response = None
last_error = None
for _ in range(max_attempts):
for attempt in range(max_attempts):
try:
response = await self._attempt_llm_call_with_validation(
credentials, input_data, current_prompt, tool_functions

View File

@@ -1,11 +1,7 @@
import logging
import threading
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.execution import ExecutionContext
from backend.data.model import ProviderName, User
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
@@ -21,10 +17,10 @@ async def create_graph(s: SpinTestServer, g, u: User):
async def create_credentials(s: SpinTestServer, u: User):
import backend.blocks.llm as llm_module
import backend.blocks.llm as llm
provider = ProviderName.OPENAI
credentials = llm_module.TEST_CREDENTIALS
credentials = llm.TEST_CREDENTIALS
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
@@ -200,6 +196,8 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
@pytest.mark.asyncio
async def test_smart_decision_maker_tracks_llm_stats():
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@@ -218,6 +216,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
}
# Mock the _create_tool_node_signatures method to avoid database calls
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
@@ -235,19 +234,10 @@ async def test_smart_decision_maker_tracks_llm_stats():
prompt="Should I continue with this task?",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
# Execute the block
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
@@ -256,9 +246,6 @@ async def test_smart_decision_maker_tracks_llm_stats():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -276,6 +263,8 @@ async def test_smart_decision_maker_tracks_llm_stats():
@pytest.mark.asyncio
async def test_smart_decision_maker_parameter_validation():
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@@ -322,6 +311,8 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_with_typo.reasoning = None
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
@@ -338,17 +329,8 @@ async def test_smart_decision_maker_parameter_validation():
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2, # Set retry to 2 for testing
agent_mode_max_iterations=0,
)
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
# Should raise ValueError after retries due to typo'd parameter name
with pytest.raises(ValueError) as exc_info:
outputs = {}
@@ -360,9 +342,6 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -389,6 +368,8 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_missing_required.reasoning = None
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
@@ -404,17 +385,8 @@ async def test_smart_decision_maker_parameter_validation():
prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
# Should raise ValueError due to missing required parameter
with pytest.raises(ValueError) as exc_info:
outputs = {}
@@ -426,9 +398,6 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -449,6 +418,8 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_valid.reasoning = None
mock_response_valid.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
@@ -464,19 +435,10 @@ async def test_smart_decision_maker_parameter_validation():
prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
# Should succeed - optional parameter missing is OK
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
@@ -485,9 +447,6 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -513,6 +472,8 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_all_params.reasoning = None
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
@@ -528,19 +489,10 @@ async def test_smart_decision_maker_parameter_validation():
prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
# Should succeed with all parameters
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
@@ -549,9 +501,6 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -564,6 +513,8 @@ async def test_smart_decision_maker_parameter_validation():
@pytest.mark.asyncio
async def test_smart_decision_maker_raw_response_conversion():
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@@ -633,6 +584,7 @@ async def test_smart_decision_maker_raw_response_conversion():
)
# Mock llm_call to return different responses on different calls
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call", new_callable=AsyncMock
@@ -651,19 +603,10 @@ async def test_smart_decision_maker_raw_response_conversion():
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
agent_mode_max_iterations=0,
)
# Should succeed after retry, demonstrating our helper function works
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
@@ -672,9 +615,6 @@ async def test_smart_decision_maker_raw_response_conversion():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -710,6 +650,8 @@ async def test_smart_decision_maker_raw_response_conversion():
"I'll help you with that." # Ollama returns string
)
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
@@ -724,18 +666,9 @@ async def test_smart_decision_maker_raw_response_conversion():
prompt="Simple prompt",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
@@ -744,9 +677,6 @@ async def test_smart_decision_maker_raw_response_conversion():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -766,6 +696,8 @@ async def test_smart_decision_maker_raw_response_conversion():
"content": "Test response",
} # Dict format
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
@@ -780,18 +712,9 @@ async def test_smart_decision_maker_raw_response_conversion():
prompt="Another test",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
@@ -800,260 +723,8 @@ async def test_smart_decision_maker_raw_response_conversion():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
assert "finished" in outputs
assert outputs["finished"] == "Test response"
@pytest.mark.asyncio
async def test_smart_decision_maker_agent_mode():
"""Test that agent mode executes tools directly and loops until finished."""
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = SmartDecisionMakerBlock()
# Mock tool call that requires multiple iterations
mock_tool_call_1 = MagicMock()
mock_tool_call_1.id = "call_1"
mock_tool_call_1.function.name = "search_keywords"
mock_tool_call_1.function.arguments = (
'{"query": "test", "max_keyword_difficulty": 50}'
)
mock_response_1 = MagicMock()
mock_response_1.response = None
mock_response_1.tool_calls = [mock_tool_call_1]
mock_response_1.prompt_tokens = 50
mock_response_1.completion_tokens = 25
mock_response_1.reasoning = "Using search tool"
mock_response_1.raw_response = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_1", "type": "function"}],
}
# Final response with no tool calls (finished)
mock_response_2 = MagicMock()
mock_response_2.response = "Task completed successfully"
mock_response_2.tool_calls = []
mock_response_2.prompt_tokens = 30
mock_response_2.completion_tokens = 15
mock_response_2.reasoning = None
mock_response_2.raw_response = {
"role": "assistant",
"content": "Task completed successfully",
}
# Mock the LLM call to return different responses on each iteration
llm_call_mock = AsyncMock()
llm_call_mock.side_effect = [mock_response_1, mock_response_2]
# Mock tool node signatures
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "search_keywords",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": {},
"parameters": {
"properties": {
"query": {"type": "string"},
"max_keyword_difficulty": {"type": "integer"},
},
"required": ["query", "max_keyword_difficulty"],
},
},
}
]
# Mock database and execution components
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block-id"
mock_db_client.get_node.return_value = mock_node
# Mock upsert_execution_input to return proper NodeExecutionResult and input data
mock_node_exec_result = MagicMock()
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
mock_input_data = {"query": "test", "max_keyword_difficulty": 50}
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
mock_input_data,
)
# No longer need mock_execute_node since we use execution_processor.on_node_execution
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
), patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
), patch(
"backend.executor.manager.async_update_node_execution_status",
new_callable=AsyncMock,
), patch(
"backend.integrations.creds_manager.IntegrationCredentialsManager"
):
# Create a mock execution context
mock_execution_context = ExecutionContext(
safe_mode=False,
)
# Create a mock execution processor for agent mode tests
mock_execution_processor = AsyncMock()
# Configure the execution processor mock with required attributes
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
# Mock the on_node_execution method to return successful stats
mock_node_stats = MagicMock()
mock_node_stats.error = None # No error
mock_execution_processor.on_node_execution = AsyncMock(
return_value=mock_node_stats
)
# Mock the get_execution_outputs_by_node_exec_id method
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
"result": {"status": "success", "data": "search completed"}
}
# Test agent mode with max_iterations = 3
input_data = SmartDecisionMakerBlock.Input(
prompt="Complete this task using tools",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
)
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
# Verify agent mode behavior
assert "tool_functions" in outputs # tool_functions is yielded in both modes
assert "finished" in outputs
assert outputs["finished"] == "Task completed successfully"
assert "conversations" in outputs
# Verify the conversation includes tool responses
conversations = outputs["conversations"]
assert len(conversations) > 2 # Should have multiple conversation entries
# Verify LLM was called twice (once for tool call, once for finish)
assert llm_call_mock.call_count == 2
# Verify tool was executed via execution processor
assert mock_execution_processor.on_node_execution.call_count == 1
@pytest.mark.asyncio
async def test_smart_decision_maker_traditional_mode_default():
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = SmartDecisionMakerBlock()
# Mock tool call
mock_tool_call = MagicMock()
mock_tool_call.function.name = "search_keywords"
mock_tool_call.function.arguments = (
'{"query": "test", "max_keyword_difficulty": 50}'
)
mock_response = MagicMock()
mock_response.response = None
mock_response.tool_calls = [mock_tool_call]
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {"role": "assistant", "content": None}
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "search_keywords",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": {},
"parameters": {
"properties": {
"query": {"type": "string"},
"max_keyword_difficulty": {"type": "integer"},
},
"required": ["query", "max_keyword_difficulty"],
},
},
}
]
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response,
), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
):
# Test default behavior (traditional mode)
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0, # Traditional mode
)
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
# Verify traditional mode behavior
assert (
"tool_functions" in outputs
) # Should yield tool_functions in traditional mode
assert (
"tools_^_test-sink-node-id_~_query" in outputs
) # Should yield individual tool parameters
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
assert "conversations" in outputs

View File

@@ -1,7 +1,7 @@
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
import json
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
@@ -308,47 +308,10 @@ async def test_output_yielding_with_dynamic_fields():
) as mock_llm:
mock_llm.return_value = mock_response
# Mock the database manager to avoid HTTP calls during tool execution
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db_manager, patch.object(
# Mock the function signature creation
with patch.object(
block, "_create_tool_node_signatures", new_callable=AsyncMock
) as mock_sig:
# Set up the mock database manager
mock_db_client = AsyncMock()
mock_db_manager.return_value = mock_db_client
# Mock the node retrieval
mock_target_node = Mock()
mock_target_node.id = "test-sink-node-id"
mock_target_node.block_id = "CreateDictionaryBlock"
mock_target_node.block = Mock()
mock_target_node.block.name = "Create Dictionary"
mock_db_client.get_node.return_value = mock_target_node
# Mock the execution result creation
mock_node_exec_result = Mock()
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
mock_final_input_data = {
"values_#_name": "Alice",
"values_#_age": 30,
"values_#_email": "alice@example.com",
}
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
mock_final_input_data,
)
# Mock the output retrieval
mock_outputs = {
"values_#_name": "Alice",
"values_#_age": 30,
"values_#_email": "alice@example.com",
}
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
mock_outputs
)
mock_sig.return_value = [
{
"type": "function",
@@ -374,16 +337,10 @@ async def test_output_yielding_with_dynamic_fields():
prompt="Create a user dictionary",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.LlmModel.GPT4O,
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
)
# Run the block
outputs = {}
from backend.data.execution import ExecutionContext
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
async for output_name, output_value in block.run(
input_data,
credentials=llm.TEST_CREDENTIALS,
@@ -392,9 +349,6 @@ async def test_output_yielding_with_dynamic_fields():
graph_exec_id="test_exec",
node_exec_id="test_node_exec",
user_id="test_user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_value
@@ -557,108 +511,45 @@ async def test_validation_errors_dont_pollute_conversation():
}
]
# Mock the database manager to avoid HTTP calls during tool execution
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db_manager:
# Set up the mock database manager for agent mode
mock_db_client = AsyncMock()
mock_db_manager.return_value = mock_db_client
# Create input data
from backend.blocks import llm
# Mock the node retrieval
mock_target_node = Mock()
mock_target_node.id = "test-sink-node-id"
mock_target_node.block_id = "TestBlock"
mock_target_node.block = Mock()
mock_target_node.block.name = "Test Block"
mock_db_client.get_node.return_value = mock_target_node
input_data = block.input_schema(
prompt="Test prompt",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.LlmModel.GPT4O,
retry=3, # Allow retries
)
# Mock the execution result creation
mock_node_exec_result = Mock()
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
mock_final_input_data = {"correct_param": "value"}
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
mock_final_input_data,
)
# Run the block
outputs = {}
async for output_name, output_value in block.run(
input_data,
credentials=llm.TEST_CREDENTIALS,
graph_id="test_graph",
node_id="test_node",
graph_exec_id="test_exec",
node_exec_id="test_node_exec",
user_id="test_user",
):
outputs[output_name] = output_value
# Mock the output retrieval
mock_outputs = {"correct_param": "value"}
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
mock_outputs
)
# Verify we had 2 LLM calls (initial + retry)
assert call_count == 2
# Create input data
from backend.blocks import llm
# Check the final conversation output
final_conversation = outputs.get("conversations", [])
input_data = block.input_schema(
prompt="Test prompt",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.LlmModel.GPT4O,
retry=3, # Allow retries
agent_mode_max_iterations=1,
)
# The final conversation should NOT contain the validation error message
error_messages = [
msg
for msg in final_conversation
if msg.get("role") == "user"
and "parameter errors" in msg.get("content", "")
]
assert (
len(error_messages) == 0
), "Validation error leaked into final conversation"
# Run the block
outputs = {}
from backend.data.execution import ExecutionContext
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a proper mock execution processor for agent mode
from collections import defaultdict
mock_execution_processor = AsyncMock()
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = MagicMock()
# Create a mock NodeExecutionProgress for the sink node
mock_node_exec_progress = MagicMock()
mock_node_exec_progress.add_task = MagicMock()
mock_node_exec_progress.pop_output = MagicMock(
return_value=None
) # No outputs to process
# Set up running_node_execution as a defaultdict that returns our mock for any key
mock_execution_processor.running_node_execution = defaultdict(
lambda: mock_node_exec_progress
)
# Mock the on_node_execution method that gets called during tool execution
mock_node_stats = MagicMock()
mock_node_stats.error = None
mock_execution_processor.on_node_execution.return_value = (
mock_node_stats
)
async for output_name, output_value in block.run(
input_data,
credentials=llm.TEST_CREDENTIALS,
graph_id="test_graph",
node_id="test_node",
graph_exec_id="test_exec",
node_exec_id="test_node_exec",
user_id="test_user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_value
# Verify we had at least 1 LLM call
assert call_count >= 1
# Check the final conversation output
final_conversation = outputs.get("conversations", [])
# The final conversation should NOT contain validation error messages
# Even if retries don't happen in agent mode, we should not leak errors
error_messages = [
msg
for msg in final_conversation
if msg.get("role") == "user"
and "parameter errors" in msg.get("content", "")
]
assert (
len(error_messages) == 0
), "Validation error leaked into final conversation"
# The final conversation should only have the successful response
assert final_conversation[-1]["content"] == "valid"

View File

@@ -1 +0,0 @@
"""CLI utilities for backend development & administration"""

View File

@@ -1,57 +0,0 @@
#!/usr/bin/env python3
"""
Script to generate OpenAPI JSON specification for the FastAPI app.
This script imports the FastAPI app from backend.server.rest_api and outputs
the OpenAPI specification as JSON to stdout or a specified file.
Usage:
`poetry run python generate_openapi_json.py`
`poetry run python generate_openapi_json.py --output openapi.json`
`poetry run python generate_openapi_json.py --indent 4 --output openapi.json`
"""
import json
import os
from pathlib import Path
import click
@click.command()
@click.option(
"--output",
type=click.Path(dir_okay=False, path_type=Path),
help="Output file path (default: stdout)",
)
@click.option(
"--pretty",
type=click.BOOL,
default=False,
help="Pretty-print JSON output (indented 2 spaces)",
)
def main(output: Path, pretty: bool):
"""Generate and output the OpenAPI JSON specification."""
openapi_schema = get_openapi_schema()
json_output = json.dumps(openapi_schema, indent=2 if pretty else None)
if output:
output.write_text(json_output)
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
click.echo(f"\n{json_output[:500]} ...")
else:
print(json_output)
def get_openapi_schema():
"""Get the OpenAPI schema from the FastAPI app"""
from backend.server.rest_api import app
return app.openapi()
if __name__ == "__main__":
os.environ["LOG_LEVEL"] = "ERROR" # disable stdout log output
main()

File diff suppressed because it is too large Load Diff

View File

@@ -1,45 +1,12 @@
import logging
from datetime import datetime, timedelta, timezone
from typing import Optional
import prisma.types
from pydantic import BaseModel
from backend.data.db import query_raw_with_schema
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
class AccuracyAlertData(BaseModel):
"""Alert data when accuracy drops significantly."""
graph_id: str
user_id: Optional[str]
drop_percent: float
three_day_avg: float
seven_day_avg: float
detected_at: datetime
class AccuracyLatestData(BaseModel):
"""Latest execution accuracy data point."""
date: datetime
daily_score: Optional[float]
three_day_avg: Optional[float]
seven_day_avg: Optional[float]
fourteen_day_avg: Optional[float]
class AccuracyTrendsResponse(BaseModel):
"""Response model for accuracy trends and alerts."""
latest_data: AccuracyLatestData
alert: Optional[AccuracyAlertData]
historical_data: Optional[list[AccuracyLatestData]] = None
async def log_raw_analytics(
user_id: str,
type: str,
@@ -76,217 +43,3 @@ async def log_raw_metric(
)
return result
async def get_accuracy_trends_and_alerts(
graph_id: str,
days_back: int = 30,
user_id: Optional[str] = None,
drop_threshold: float = 10.0,
include_historical: bool = False,
) -> AccuracyTrendsResponse:
"""Get accuracy trends and detect alerts for a specific graph."""
query_template = """
WITH daily_scores AS (
SELECT
DATE(e."createdAt") as execution_date,
AVG(CASE
WHEN e.stats IS NOT NULL
AND e.stats::json->>'correctness_score' IS NOT NULL
AND e.stats::json->>'correctness_score' != 'null'
THEN (e.stats::json->>'correctness_score')::float * 100
ELSE NULL
END) as daily_score
FROM {schema_prefix}"AgentGraphExecution" e
WHERE e."agentGraphId" = $1::text
AND e."isDeleted" = false
AND e."createdAt" >= $2::timestamp
AND e."executionStatus" IN ('COMPLETED', 'FAILED', 'TERMINATED')
{user_filter}
GROUP BY DATE(e."createdAt")
HAVING COUNT(*) >= 3 -- Need at least 3 executions per day
),
trends AS (
SELECT
execution_date,
daily_score,
AVG(daily_score) OVER (
ORDER BY execution_date
ROWS BETWEEN 2 PRECEDING AND CURRENT ROW
) as three_day_avg,
AVG(daily_score) OVER (
ORDER BY execution_date
ROWS BETWEEN 6 PRECEDING AND CURRENT ROW
) as seven_day_avg,
AVG(daily_score) OVER (
ORDER BY execution_date
ROWS BETWEEN 13 PRECEDING AND CURRENT ROW
) as fourteen_day_avg
FROM daily_scores
)
SELECT *,
CASE
WHEN three_day_avg IS NOT NULL AND seven_day_avg IS NOT NULL AND seven_day_avg > 0
THEN ((seven_day_avg - three_day_avg) / seven_day_avg * 100)
ELSE NULL
END as drop_percent
FROM trends
ORDER BY execution_date DESC
{limit_clause}
"""
start_date = datetime.now(timezone.utc) - timedelta(days=days_back)
params = [graph_id, start_date]
user_filter = ""
if user_id:
user_filter = 'AND e."userId" = $3::text'
params.append(user_id)
# Determine limit clause
limit_clause = "" if include_historical else "LIMIT 1"
final_query = query_template.format(
schema_prefix="{schema_prefix}",
user_filter=user_filter,
limit_clause=limit_clause,
)
result = await query_raw_with_schema(final_query, *params)
if not result:
return AccuracyTrendsResponse(
latest_data=AccuracyLatestData(
date=datetime.now(timezone.utc),
daily_score=None,
three_day_avg=None,
seven_day_avg=None,
fourteen_day_avg=None,
),
alert=None,
)
latest = result[0]
alert = None
if (
latest["drop_percent"] is not None
and latest["drop_percent"] >= drop_threshold
and latest["three_day_avg"] is not None
and latest["seven_day_avg"] is not None
):
alert = AccuracyAlertData(
graph_id=graph_id,
user_id=user_id,
drop_percent=float(latest["drop_percent"]),
three_day_avg=float(latest["three_day_avg"]),
seven_day_avg=float(latest["seven_day_avg"]),
detected_at=datetime.now(timezone.utc),
)
# Prepare historical data if requested
historical_data = None
if include_historical:
historical_data = []
for row in result:
historical_data.append(
AccuracyLatestData(
date=row["execution_date"],
daily_score=(
float(row["daily_score"])
if row["daily_score"] is not None
else None
),
three_day_avg=(
float(row["three_day_avg"])
if row["three_day_avg"] is not None
else None
),
seven_day_avg=(
float(row["seven_day_avg"])
if row["seven_day_avg"] is not None
else None
),
fourteen_day_avg=(
float(row["fourteen_day_avg"])
if row["fourteen_day_avg"] is not None
else None
),
)
)
return AccuracyTrendsResponse(
latest_data=AccuracyLatestData(
date=latest["execution_date"],
daily_score=(
float(latest["daily_score"])
if latest["daily_score"] is not None
else None
),
three_day_avg=(
float(latest["three_day_avg"])
if latest["three_day_avg"] is not None
else None
),
seven_day_avg=(
float(latest["seven_day_avg"])
if latest["seven_day_avg"] is not None
else None
),
fourteen_day_avg=(
float(latest["fourteen_day_avg"])
if latest["fourteen_day_avg"] is not None
else None
),
),
alert=alert,
historical_data=historical_data,
)
class MarketplaceGraphData(BaseModel):
"""Data structure for marketplace graph monitoring."""
graph_id: str
user_id: Optional[str]
execution_count: int
async def get_marketplace_graphs_for_monitoring(
days_back: int = 30,
min_executions: int = 10,
) -> list[MarketplaceGraphData]:
"""Get published marketplace graphs with recent executions for monitoring."""
query_template = """
WITH marketplace_graphs AS (
SELECT DISTINCT
slv."agentGraphId" as graph_id,
slv."agentGraphVersion" as graph_version
FROM {schema_prefix}"StoreListing" sl
JOIN {schema_prefix}"StoreListingVersion" slv ON sl."activeVersionId" = slv."id"
WHERE sl."hasApprovedVersion" = true
AND sl."isDeleted" = false
)
SELECT DISTINCT
mg.graph_id,
NULL as user_id, -- Marketplace graphs don't have a specific user_id for monitoring
COUNT(*) as execution_count
FROM marketplace_graphs mg
JOIN {schema_prefix}"AgentGraphExecution" e ON e."agentGraphId" = mg.graph_id
WHERE e."createdAt" >= $1::timestamp
AND e."isDeleted" = false
AND e."executionStatus" IN ('COMPLETED', 'FAILED', 'TERMINATED')
GROUP BY mg.graph_id
HAVING COUNT(*) >= $2
ORDER BY execution_count DESC
"""
start_date = datetime.now(timezone.utc) - timedelta(days=days_back)
result = await query_raw_with_schema(query_template, start_date, min_executions)
return [
MarketplaceGraphData(
graph_id=row["graph_id"],
user_id=row["user_id"],
execution_count=int(row["execution_count"]),
)
for row in result
]

View File

@@ -1,24 +1,22 @@
import logging
import uuid
from datetime import datetime, timezone
from typing import Literal, Optional, cast
from typing import Optional
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission, APIKeyStatus
from prisma.models import APIKey as PrismaAPIKey
from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput
from pydantic import Field
from prisma.types import APIKeyWhereUniqueInput
from pydantic import BaseModel, Field
from backend.data.includes import MAX_USER_API_KEYS_FETCH
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from .base import APIAuthorizationInfo
logger = logging.getLogger(__name__)
keysmith = APIKeySmith()
class APIKeyInfo(APIAuthorizationInfo):
class APIKeyInfo(BaseModel):
id: str
name: str
head: str = Field(
@@ -28,9 +26,12 @@ class APIKeyInfo(APIAuthorizationInfo):
description=f"The last {APIKeySmith.TAIL_LENGTH} characters of the key"
)
status: APIKeyStatus
permissions: list[APIKeyPermission]
created_at: datetime
last_used_at: Optional[datetime] = None
revoked_at: Optional[datetime] = None
description: Optional[str] = None
type: Literal["api_key"] = "api_key" # type: ignore
user_id: str
@staticmethod
def from_db(api_key: PrismaAPIKey):
@@ -40,7 +41,7 @@ class APIKeyInfo(APIAuthorizationInfo):
head=api_key.head,
tail=api_key.tail,
status=APIKeyStatus(api_key.status),
scopes=[APIKeyPermission(p) for p in api_key.permissions],
permissions=[APIKeyPermission(p) for p in api_key.permissions],
created_at=api_key.createdAt,
last_used_at=api_key.lastUsedAt,
revoked_at=api_key.revokedAt,
@@ -82,20 +83,17 @@ async def create_api_key(
generated_key = keysmith.generate_key()
saved_key_obj = await PrismaAPIKey.prisma().create(
data=cast(
APIKeyCreateInput,
{
"id": str(uuid.uuid4()),
"name": name,
"head": generated_key.head,
"tail": generated_key.tail,
"hash": generated_key.hash,
"salt": generated_key.salt,
"permissions": [p for p in permissions],
"description": description,
"userId": user_id,
},
)
data={
"id": str(uuid.uuid4()),
"name": name,
"head": generated_key.head,
"tail": generated_key.tail,
"hash": generated_key.hash,
"salt": generated_key.salt,
"permissions": [p for p in permissions],
"description": description,
"userId": user_id,
}
)
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
@@ -213,7 +211,7 @@ async def suspend_api_key(key_id: str, user_id: str) -> APIKeyInfo:
def has_permission(api_key: APIKeyInfo, required_permission: APIKeyPermission) -> bool:
return required_permission in api_key.scopes
return required_permission in api_key.permissions
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyInfo]:

View File

@@ -1,15 +0,0 @@
from datetime import datetime
from typing import Literal, Optional
from prisma.enums import APIKeyPermission
from pydantic import BaseModel
class APIAuthorizationInfo(BaseModel):
user_id: str
scopes: list[APIKeyPermission]
type: Literal["oauth", "api_key"]
created_at: datetime
expires_at: Optional[datetime] = None
last_used_at: Optional[datetime] = None
revoked_at: Optional[datetime] = None

View File

@@ -1,886 +0,0 @@
"""
OAuth 2.0 Provider Data Layer
Handles management of OAuth applications, authorization codes,
access tokens, and refresh tokens.
Hashing strategy:
- Access tokens & Refresh tokens: SHA256 (deterministic, allows direct lookup by hash)
- Client secrets: Scrypt with salt (lookup by client_id, then verify with salt)
"""
import hashlib
import logging
import secrets
import uuid
from datetime import datetime, timedelta, timezone
from typing import Literal, Optional, cast
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission as APIPermission
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
from prisma.models import OAuthApplication as PrismaOAuthApplication
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
from prisma.types import (
OAuthAccessTokenCreateInput,
OAuthApplicationUpdateInput,
OAuthAuthorizationCodeCreateInput,
OAuthRefreshTokenCreateInput,
)
from pydantic import BaseModel, Field, SecretStr
from .base import APIAuthorizationInfo
logger = logging.getLogger(__name__)
keysmith = APIKeySmith() # Only used for client secret hashing (Scrypt)
def _generate_token() -> str:
"""Generate a cryptographically secure random token."""
return secrets.token_urlsafe(32)
def _hash_token(token: str) -> str:
"""Hash a token using SHA256 (deterministic, for direct lookup)."""
return hashlib.sha256(token.encode()).hexdigest()
# Token TTLs
AUTHORIZATION_CODE_TTL = timedelta(minutes=10)
ACCESS_TOKEN_TTL = timedelta(hours=1)
REFRESH_TOKEN_TTL = timedelta(days=30)
ACCESS_TOKEN_PREFIX = "agpt_xt_"
REFRESH_TOKEN_PREFIX = "agpt_rt_"
# ============================================================================
# Exception Classes
# ============================================================================
class OAuthError(Exception):
"""Base OAuth error"""
pass
class InvalidClientError(OAuthError):
"""Invalid client_id or client_secret"""
pass
class InvalidGrantError(OAuthError):
"""Invalid or expired authorization code/refresh token"""
def __init__(self, reason: str):
self.reason = reason
super().__init__(f"Invalid grant: {reason}")
class InvalidTokenError(OAuthError):
"""Invalid, expired, or revoked token"""
def __init__(self, reason: str):
self.reason = reason
super().__init__(f"Invalid token: {reason}")
# ============================================================================
# Data Models
# ============================================================================
class OAuthApplicationInfo(BaseModel):
"""OAuth application information (without client secret hash)"""
id: str
name: str
description: Optional[str] = None
logo_url: Optional[str] = None
client_id: str
redirect_uris: list[str]
grant_types: list[str]
scopes: list[APIPermission]
owner_id: str
is_active: bool
created_at: datetime
updated_at: datetime
@staticmethod
def from_db(app: PrismaOAuthApplication):
return OAuthApplicationInfo(
id=app.id,
name=app.name,
description=app.description,
logo_url=app.logoUrl,
client_id=app.clientId,
redirect_uris=app.redirectUris,
grant_types=app.grantTypes,
scopes=[APIPermission(s) for s in app.scopes],
owner_id=app.ownerId,
is_active=app.isActive,
created_at=app.createdAt,
updated_at=app.updatedAt,
)
class OAuthApplicationInfoWithSecret(OAuthApplicationInfo):
"""OAuth application with client secret hash (for validation)"""
client_secret_hash: str
client_secret_salt: str
@staticmethod
def from_db(app: PrismaOAuthApplication):
return OAuthApplicationInfoWithSecret(
**OAuthApplicationInfo.from_db(app).model_dump(),
client_secret_hash=app.clientSecret,
client_secret_salt=app.clientSecretSalt,
)
def verify_secret(self, plaintext_secret: str) -> bool:
"""Verify a plaintext client secret against the stored hash"""
# Use keysmith.verify_key() with stored salt
return keysmith.verify_key(
plaintext_secret, self.client_secret_hash, self.client_secret_salt
)
class OAuthAuthorizationCodeInfo(BaseModel):
"""Authorization code information"""
id: str
code: str
created_at: datetime
expires_at: datetime
application_id: str
user_id: str
scopes: list[APIPermission]
redirect_uri: str
code_challenge: Optional[str] = None
code_challenge_method: Optional[str] = None
used_at: Optional[datetime] = None
@property
def is_used(self) -> bool:
return self.used_at is not None
@staticmethod
def from_db(code: PrismaOAuthAuthorizationCode):
return OAuthAuthorizationCodeInfo(
id=code.id,
code=code.code,
created_at=code.createdAt,
expires_at=code.expiresAt,
application_id=code.applicationId,
user_id=code.userId,
scopes=[APIPermission(s) for s in code.scopes],
redirect_uri=code.redirectUri,
code_challenge=code.codeChallenge,
code_challenge_method=code.codeChallengeMethod,
used_at=code.usedAt,
)
class OAuthAccessTokenInfo(APIAuthorizationInfo):
"""Access token information"""
id: str
expires_at: datetime # type: ignore
application_id: str
type: Literal["oauth"] = "oauth" # type: ignore
@staticmethod
def from_db(token: PrismaOAuthAccessToken):
return OAuthAccessTokenInfo(
id=token.id,
user_id=token.userId,
scopes=[APIPermission(s) for s in token.scopes],
created_at=token.createdAt,
expires_at=token.expiresAt,
last_used_at=None,
revoked_at=token.revokedAt,
application_id=token.applicationId,
)
class OAuthAccessToken(OAuthAccessTokenInfo):
"""Access token with plaintext token included (sensitive)"""
token: SecretStr = Field(description="Plaintext token (sensitive)")
@staticmethod
def from_db(token: PrismaOAuthAccessToken, plaintext_token: str): # type: ignore
return OAuthAccessToken(
**OAuthAccessTokenInfo.from_db(token).model_dump(),
token=SecretStr(plaintext_token),
)
class OAuthRefreshTokenInfo(BaseModel):
"""Refresh token information"""
id: str
user_id: str
scopes: list[APIPermission]
created_at: datetime
expires_at: datetime
application_id: str
revoked_at: Optional[datetime] = None
@property
def is_revoked(self) -> bool:
return self.revoked_at is not None
@staticmethod
def from_db(token: PrismaOAuthRefreshToken):
return OAuthRefreshTokenInfo(
id=token.id,
user_id=token.userId,
scopes=[APIPermission(s) for s in token.scopes],
created_at=token.createdAt,
expires_at=token.expiresAt,
application_id=token.applicationId,
revoked_at=token.revokedAt,
)
class OAuthRefreshToken(OAuthRefreshTokenInfo):
"""Refresh token with plaintext token included (sensitive)"""
token: SecretStr = Field(description="Plaintext token (sensitive)")
@staticmethod
def from_db(token: PrismaOAuthRefreshToken, plaintext_token: str): # type: ignore
return OAuthRefreshToken(
**OAuthRefreshTokenInfo.from_db(token).model_dump(),
token=SecretStr(plaintext_token),
)
class TokenIntrospectionResult(BaseModel):
"""Result of token introspection (RFC 7662)"""
active: bool
scopes: Optional[list[str]] = None
client_id: Optional[str] = None
user_id: Optional[str] = None
exp: Optional[int] = None # Unix timestamp
token_type: Optional[Literal["access_token", "refresh_token"]] = None
# ============================================================================
# OAuth Application Management
# ============================================================================
async def get_oauth_application(client_id: str) -> Optional[OAuthApplicationInfo]:
"""Get OAuth application by client ID (without secret)"""
app = await PrismaOAuthApplication.prisma().find_unique(
where={"clientId": client_id}
)
if not app:
return None
return OAuthApplicationInfo.from_db(app)
async def get_oauth_application_with_secret(
client_id: str,
) -> Optional[OAuthApplicationInfoWithSecret]:
"""Get OAuth application by client ID (with secret hash for validation)"""
app = await PrismaOAuthApplication.prisma().find_unique(
where={"clientId": client_id}
)
if not app:
return None
return OAuthApplicationInfoWithSecret.from_db(app)
async def validate_client_credentials(
client_id: str, client_secret: str
) -> OAuthApplicationInfo:
"""
Validate client credentials and return application info.
Raises:
InvalidClientError: If client_id or client_secret is invalid, or app is inactive
"""
app = await get_oauth_application_with_secret(client_id)
if not app:
raise InvalidClientError("Invalid client_id")
if not app.is_active:
raise InvalidClientError("Application is not active")
# Verify client secret
if not app.verify_secret(client_secret):
raise InvalidClientError("Invalid client_secret")
# Return without secret hash
return OAuthApplicationInfo(**app.model_dump(exclude={"client_secret_hash"}))
def validate_redirect_uri(app: OAuthApplicationInfo, redirect_uri: str) -> bool:
"""Validate that redirect URI is registered for the application"""
return redirect_uri in app.redirect_uris
def validate_scopes(
app: OAuthApplicationInfo, requested_scopes: list[APIPermission]
) -> bool:
"""Validate that all requested scopes are allowed for the application"""
return all(scope in app.scopes for scope in requested_scopes)
# ============================================================================
# Authorization Code Flow
# ============================================================================
def _generate_authorization_code() -> str:
"""Generate a cryptographically secure authorization code"""
# 32 bytes = 256 bits of entropy
return secrets.token_urlsafe(32)
async def create_authorization_code(
application_id: str,
user_id: str,
scopes: list[APIPermission],
redirect_uri: str,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[Literal["S256", "plain"]] = None,
) -> OAuthAuthorizationCodeInfo:
"""
Create a new authorization code.
Expires in 10 minutes and can only be used once.
"""
code = _generate_authorization_code()
now = datetime.now(timezone.utc)
expires_at = now + AUTHORIZATION_CODE_TTL
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
data=cast(
OAuthAuthorizationCodeCreateInput,
{
"id": str(uuid.uuid4()),
"code": code,
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
"redirectUri": redirect_uri,
"codeChallenge": code_challenge,
"codeChallengeMethod": code_challenge_method,
},
)
)
return OAuthAuthorizationCodeInfo.from_db(saved_code)
async def consume_authorization_code(
code: str,
application_id: str,
redirect_uri: str,
code_verifier: Optional[str] = None,
) -> tuple[str, list[APIPermission]]:
"""
Consume an authorization code and return (user_id, scopes).
This marks the code as used and validates:
- Code exists and matches application
- Code is not expired
- Code has not been used
- Redirect URI matches
- PKCE code verifier matches (if code challenge was provided)
Raises:
InvalidGrantError: If code is invalid, expired, used, or PKCE fails
"""
auth_code = await PrismaOAuthAuthorizationCode.prisma().find_unique(
where={"code": code}
)
if not auth_code:
raise InvalidGrantError("authorization code not found")
# Validate application
if auth_code.applicationId != application_id:
raise InvalidGrantError(
"authorization code does not belong to this application"
)
# Check if already used
if auth_code.usedAt is not None:
raise InvalidGrantError(
f"authorization code already used at {auth_code.usedAt}"
)
# Check expiration
now = datetime.now(timezone.utc)
if auth_code.expiresAt < now:
raise InvalidGrantError("authorization code expired")
# Validate redirect URI
if auth_code.redirectUri != redirect_uri:
raise InvalidGrantError("redirect_uri mismatch")
# Validate PKCE if code challenge was provided
if auth_code.codeChallenge:
if not code_verifier:
raise InvalidGrantError("code_verifier required but not provided")
if not _verify_pkce(
code_verifier, auth_code.codeChallenge, auth_code.codeChallengeMethod
):
raise InvalidGrantError("PKCE verification failed")
# Mark code as used
await PrismaOAuthAuthorizationCode.prisma().update(
where={"code": code},
data={"usedAt": now},
)
return auth_code.userId, [APIPermission(s) for s in auth_code.scopes]
def _verify_pkce(
code_verifier: str, code_challenge: str, code_challenge_method: Optional[str]
) -> bool:
"""
Verify PKCE code verifier against code challenge.
Supports:
- S256: SHA256(code_verifier) == code_challenge
- plain: code_verifier == code_challenge
"""
if code_challenge_method == "S256":
# Hash the verifier with SHA256 and base64url encode
hashed = hashlib.sha256(code_verifier.encode("ascii")).digest()
computed_challenge = (
secrets.token_urlsafe(len(hashed)).encode("ascii").decode("ascii")
)
# For proper base64url encoding
import base64
computed_challenge = (
base64.urlsafe_b64encode(hashed).decode("ascii").rstrip("=")
)
return secrets.compare_digest(computed_challenge, code_challenge)
elif code_challenge_method == "plain" or code_challenge_method is None:
# Plain comparison
return secrets.compare_digest(code_verifier, code_challenge)
else:
logger.warning(f"Unsupported code challenge method: {code_challenge_method}")
return False
# ============================================================================
# Access Token Management
# ============================================================================
async def create_access_token(
application_id: str, user_id: str, scopes: list[APIPermission]
) -> OAuthAccessToken:
"""
Create a new access token.
Returns OAuthAccessToken (with plaintext token).
"""
plaintext_token = ACCESS_TOKEN_PREFIX + _generate_token()
token_hash = _hash_token(plaintext_token)
now = datetime.now(timezone.utc)
expires_at = now + ACCESS_TOKEN_TTL
saved_token = await PrismaOAuthAccessToken.prisma().create(
data=cast(
OAuthAccessTokenCreateInput,
{
"id": str(uuid.uuid4()),
"token": token_hash, # SHA256 hash for direct lookup
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
},
)
)
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
async def validate_access_token(
token: str,
) -> tuple[OAuthAccessTokenInfo, OAuthApplicationInfo]:
"""
Validate an access token and return token info.
Raises:
InvalidTokenError: If token is invalid, expired, or revoked
InvalidClientError: If the client application is not marked as active
"""
token_hash = _hash_token(token)
# Direct lookup by hash
access_token = await PrismaOAuthAccessToken.prisma().find_unique(
where={"token": token_hash}, include={"Application": True}
)
if not access_token:
raise InvalidTokenError("access token not found")
if not access_token.Application: # should be impossible
raise InvalidClientError("Client application not found")
if not access_token.Application.isActive:
raise InvalidClientError("Client application is disabled")
if access_token.revokedAt is not None:
raise InvalidTokenError("access token has been revoked")
# Check expiration
now = datetime.now(timezone.utc)
if access_token.expiresAt < now:
raise InvalidTokenError("access token expired")
return (
OAuthAccessTokenInfo.from_db(access_token),
OAuthApplicationInfo.from_db(access_token.Application),
)
async def revoke_access_token(
token: str, application_id: str
) -> OAuthAccessTokenInfo | None:
"""
Revoke an access token.
Args:
token: The plaintext access token to revoke
application_id: The application ID making the revocation request.
Only tokens belonging to this application will be revoked.
Returns:
OAuthAccessTokenInfo if token was found and revoked, None otherwise.
Note:
Always performs exactly 2 DB queries regardless of outcome to prevent
timing side-channel attacks that could reveal token existence.
"""
try:
token_hash = _hash_token(token)
# Use update_many to filter by both token and applicationId
updated_count = await PrismaOAuthAccessToken.prisma().update_many(
where={
"token": token_hash,
"applicationId": application_id,
"revokedAt": None,
},
data={"revokedAt": datetime.now(timezone.utc)},
)
# Always perform second query to ensure constant time
result = await PrismaOAuthAccessToken.prisma().find_unique(
where={"token": token_hash}
)
# Only return result if we actually revoked something
if updated_count == 0:
return None
return OAuthAccessTokenInfo.from_db(result) if result else None
except Exception as e:
logger.exception(f"Error revoking access token: {e}")
return None
# ============================================================================
# Refresh Token Management
# ============================================================================
async def create_refresh_token(
application_id: str, user_id: str, scopes: list[APIPermission]
) -> OAuthRefreshToken:
"""
Create a new refresh token.
Returns OAuthRefreshToken (with plaintext token).
"""
plaintext_token = REFRESH_TOKEN_PREFIX + _generate_token()
token_hash = _hash_token(plaintext_token)
now = datetime.now(timezone.utc)
expires_at = now + REFRESH_TOKEN_TTL
saved_token = await PrismaOAuthRefreshToken.prisma().create(
data=cast(
OAuthRefreshTokenCreateInput,
{
"id": str(uuid.uuid4()),
"token": token_hash, # SHA256 hash for direct lookup
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
},
)
)
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)
async def refresh_tokens(
refresh_token: str, application_id: str
) -> tuple[OAuthAccessToken, OAuthRefreshToken]:
"""
Use a refresh token to create new access and refresh tokens.
Returns (new_access_token, new_refresh_token) both with plaintext tokens included.
Raises:
InvalidGrantError: If refresh token is invalid, expired, or revoked
"""
token_hash = _hash_token(refresh_token)
# Direct lookup by hash
rt = await PrismaOAuthRefreshToken.prisma().find_unique(where={"token": token_hash})
if not rt:
raise InvalidGrantError("refresh token not found")
# NOTE: no need to check Application.isActive, this is checked by the token endpoint
if rt.revokedAt is not None:
raise InvalidGrantError("refresh token has been revoked")
# Validate application
if rt.applicationId != application_id:
raise InvalidGrantError("refresh token does not belong to this application")
# Check expiration
now = datetime.now(timezone.utc)
if rt.expiresAt < now:
raise InvalidGrantError("refresh token expired")
# Revoke old refresh token
await PrismaOAuthRefreshToken.prisma().update(
where={"token": token_hash},
data={"revokedAt": now},
)
# Create new access and refresh tokens with same scopes
scopes = [APIPermission(s) for s in rt.scopes]
new_access_token = await create_access_token(
rt.applicationId,
rt.userId,
scopes,
)
new_refresh_token = await create_refresh_token(
rt.applicationId,
rt.userId,
scopes,
)
return new_access_token, new_refresh_token
async def revoke_refresh_token(
token: str, application_id: str
) -> OAuthRefreshTokenInfo | None:
"""
Revoke a refresh token.
Args:
token: The plaintext refresh token to revoke
application_id: The application ID making the revocation request.
Only tokens belonging to this application will be revoked.
Returns:
OAuthRefreshTokenInfo if token was found and revoked, None otherwise.
Note:
Always performs exactly 2 DB queries regardless of outcome to prevent
timing side-channel attacks that could reveal token existence.
"""
try:
token_hash = _hash_token(token)
# Use update_many to filter by both token and applicationId
updated_count = await PrismaOAuthRefreshToken.prisma().update_many(
where={
"token": token_hash,
"applicationId": application_id,
"revokedAt": None,
},
data={"revokedAt": datetime.now(timezone.utc)},
)
# Always perform second query to ensure constant time
result = await PrismaOAuthRefreshToken.prisma().find_unique(
where={"token": token_hash}
)
# Only return result if we actually revoked something
if updated_count == 0:
return None
return OAuthRefreshTokenInfo.from_db(result) if result else None
except Exception as e:
logger.exception(f"Error revoking refresh token: {e}")
return None
# ============================================================================
# Token Introspection
# ============================================================================
async def introspect_token(
token: str,
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = None,
) -> TokenIntrospectionResult:
"""
Introspect a token and return its metadata (RFC 7662).
Returns TokenIntrospectionResult with active=True and metadata if valid,
or active=False if the token is invalid/expired/revoked.
"""
# Try as access token first (or if hint says "access_token")
if token_type_hint != "refresh_token":
try:
token_info, app = await validate_access_token(token)
return TokenIntrospectionResult(
active=True,
scopes=list(s.value for s in token_info.scopes),
client_id=app.client_id if app else None,
user_id=token_info.user_id,
exp=int(token_info.expires_at.timestamp()),
token_type="access_token",
)
except InvalidTokenError:
pass # Try as refresh token
# Try as refresh token
token_hash = _hash_token(token)
refresh_token = await PrismaOAuthRefreshToken.prisma().find_unique(
where={"token": token_hash}
)
if refresh_token and refresh_token.revokedAt is None:
# Check if valid (not expired)
now = datetime.now(timezone.utc)
if refresh_token.expiresAt > now:
app = await get_oauth_application_by_id(refresh_token.applicationId)
return TokenIntrospectionResult(
active=True,
scopes=list(s for s in refresh_token.scopes),
client_id=app.client_id if app else None,
user_id=refresh_token.userId,
exp=int(refresh_token.expiresAt.timestamp()),
token_type="refresh_token",
)
# Token not found or inactive
return TokenIntrospectionResult(active=False)
async def get_oauth_application_by_id(app_id: str) -> Optional[OAuthApplicationInfo]:
"""Get OAuth application by ID"""
app = await PrismaOAuthApplication.prisma().find_unique(where={"id": app_id})
if not app:
return None
return OAuthApplicationInfo.from_db(app)
async def list_user_oauth_applications(user_id: str) -> list[OAuthApplicationInfo]:
"""Get all OAuth applications owned by a user"""
apps = await PrismaOAuthApplication.prisma().find_many(
where={"ownerId": user_id},
order={"createdAt": "desc"},
)
return [OAuthApplicationInfo.from_db(app) for app in apps]
async def update_oauth_application(
app_id: str,
*,
owner_id: str,
is_active: Optional[bool] = None,
logo_url: Optional[str] = None,
) -> Optional[OAuthApplicationInfo]:
"""
Update OAuth application active status.
Only the owner can update their app's status.
Returns the updated app info, or None if app not found or not owned by user.
"""
# First verify ownership
app = await PrismaOAuthApplication.prisma().find_first(
where={"id": app_id, "ownerId": owner_id}
)
if not app:
return None
patch: OAuthApplicationUpdateInput = {}
if is_active is not None:
patch["isActive"] = is_active
if logo_url:
patch["logoUrl"] = logo_url
if not patch:
return OAuthApplicationInfo.from_db(app) # return unchanged
updated_app = await PrismaOAuthApplication.prisma().update(
where={"id": app_id},
data=patch,
)
return OAuthApplicationInfo.from_db(updated_app) if updated_app else None
# ============================================================================
# Token Cleanup
# ============================================================================
async def cleanup_expired_oauth_tokens() -> dict[str, int]:
"""
Delete expired OAuth tokens from the database.
This removes:
- Expired authorization codes (10 min TTL)
- Expired access tokens (1 hour TTL)
- Expired refresh tokens (30 day TTL)
Returns a dict with counts of deleted tokens by type.
"""
now = datetime.now(timezone.utc)
# Delete expired authorization codes
codes_result = await PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"expiresAt": {"lt": now}}
)
# Delete expired access tokens
access_result = await PrismaOAuthAccessToken.prisma().delete_many(
where={"expiresAt": {"lt": now}}
)
# Delete expired refresh tokens
refresh_result = await PrismaOAuthRefreshToken.prisma().delete_many(
where={"expiresAt": {"lt": now}}
)
deleted = {
"authorization_codes": codes_result,
"access_tokens": access_result,
"refresh_tokens": refresh_result,
}
total = sum(deleted.values())
if total > 0:
logger.info(f"Cleaned up {total} expired OAuth tokens: {deleted}")
return deleted

View File

@@ -601,18 +601,14 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
async for output_name, output_data in self._execute(input_data, **kwargs):
yield output_name, output_data
except Exception as ex:
if isinstance(ex, BlockError):
raise ex
else:
raise (
BlockExecutionError
if isinstance(ex, ValueError)
else BlockUnknownError
)(
if not isinstance(ex, BlockError):
raise BlockUnknownError(
message=str(ex),
block_name=self.name,
block_id=self.id,
) from ex
else:
raise ex
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
if error := self.input_schema.validate_data(input_data):

View File

@@ -5,14 +5,12 @@ This test was added to cover a previously untested code path that could lead to
incorrect balance capping behavior.
"""
from typing import cast
from uuid import uuid4
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
@@ -23,14 +21,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for ceiling tests."""
try:
await User.prisma().create(
data=cast(
UserCreateInput,
{
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
},
)
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
@@ -38,10 +33,7 @@ async def create_test_user(user_id: str) -> None:
await UserBalance.prisma().upsert(
where={"userId": user_id},
data=cast(
UserBalanceUpsertInput,
{"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
),
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
)

View File

@@ -7,7 +7,6 @@ without race conditions, deadlocks, or inconsistent state.
import asyncio
import random
from typing import cast
from uuid import uuid4
import prisma.enums
@@ -15,7 +14,6 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
from backend.util.exceptions import InsufficientBalanceError
@@ -30,14 +28,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user with initial balance."""
try:
await User.prisma().create(
data=cast(
UserCreateInput,
{
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
},
)
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
@@ -46,10 +41,7 @@ async def create_test_user(user_id: str) -> None:
# Ensure UserBalance record exists
await UserBalance.prisma().upsert(
where={"userId": user_id},
data=cast(
UserBalanceUpsertInput,
{"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
),
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
)
@@ -350,13 +342,10 @@ async def test_integer_overflow_protection(server: SpinTestServer):
# First, set balance near max
await UserBalance.prisma().upsert(
where={"userId": user_id},
data=cast(
UserBalanceUpsertInput,
{
"create": {"userId": user_id, "balance": max_int - 100},
"update": {"balance": max_int - 100},
},
),
data={
"create": {"userId": user_id, "balance": max_int - 100},
"update": {"balance": max_int - 100},
},
)
# Try to add more than possible - should clamp to POSTGRES_INT_MAX

View File

@@ -5,12 +5,9 @@ These tests run actual database operations to ensure SQL queries work correctly,
which would have caught the CreditTransactionType enum casting bug.
"""
from typing import cast
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserCreateInput
from backend.data.credit import (
AutoTopUpConfig,
@@ -32,15 +29,12 @@ async def cleanup_test_user():
# Create the user first
try:
await User.prisma().create(
data=cast(
UserCreateInput,
{
"id": user_id,
"email": f"test-{user_id}@example.com",
"topUpConfig": SafeJson({}),
"timezone": "UTC",
},
)
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"topUpConfig": SafeJson({}),
"timezone": "UTC",
}
)
except Exception:
# User might already exist, that's fine

View File

@@ -6,19 +6,12 @@ are atomic and maintain data consistency.
"""
from datetime import datetime, timezone
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
import stripe
from prisma.enums import CreditTransactionType
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
UserBalanceCreateInput,
UserCreateInput,
)
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
@@ -42,41 +35,32 @@ async def setup_test_user_with_topup():
# Create user
await User.prisma().create(
data=cast(
UserCreateInput,
{
"id": REFUND_TEST_USER_ID,
"email": f"{REFUND_TEST_USER_ID}@example.com",
"name": "Refund Test User",
},
)
data={
"id": REFUND_TEST_USER_ID,
"email": f"{REFUND_TEST_USER_ID}@example.com",
"name": "Refund Test User",
}
)
# Create user balance
await UserBalance.prisma().create(
data=cast(
UserBalanceCreateInput,
{
"userId": REFUND_TEST_USER_ID,
"balance": 1000, # $10
},
)
data={
"userId": REFUND_TEST_USER_ID,
"balance": 1000, # $10
}
)
# Create a top-up transaction that can be refunded
topup_tx = await CreditTransaction.prisma().create(
data=cast(
CreditTransactionCreateInput,
{
"userId": REFUND_TEST_USER_ID,
"amount": 1000,
"type": CreditTransactionType.TOP_UP,
"transactionKey": "pi_test_12345",
"runningBalance": 1000,
"isActive": True,
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
},
)
data={
"userId": REFUND_TEST_USER_ID,
"amount": 1000,
"type": CreditTransactionType.TOP_UP,
"transactionKey": "pi_test_12345",
"runningBalance": 1000,
"isActive": True,
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
}
)
return topup_tx
@@ -109,15 +93,12 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
# Create refund request record (simulating webhook flow)
await CreditRefundRequest.prisma().create(
data=cast(
CreditRefundRequestCreateInput,
{
"userId": REFUND_TEST_USER_ID,
"amount": 500,
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
"reason": "Test refund",
},
)
data={
"userId": REFUND_TEST_USER_ID,
"amount": 500,
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
"reason": "Test refund",
}
)
# Call deduct_credits
@@ -305,15 +286,12 @@ async def test_concurrent_refunds(server: SpinTestServer):
refund_requests = []
for i in range(5):
req = await CreditRefundRequest.prisma().create(
data=cast(
CreditRefundRequestCreateInput,
{
"userId": REFUND_TEST_USER_ID,
"amount": 100, # $1 each
"transactionKey": topup_tx.transactionKey,
"reason": f"Test refund {i}",
},
)
data={
"userId": REFUND_TEST_USER_ID,
"amount": 100, # $1 each
"transactionKey": topup_tx.transactionKey,
"reason": f"Test refund {i}",
}
)
refund_requests.append(req)

View File

@@ -1,10 +1,8 @@
from datetime import datetime, timedelta, timezone
from typing import cast
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, UserBalance
from prisma.types import CreditTransactionCreateInput, UserBalanceUpsertInput
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
@@ -25,13 +23,10 @@ async def disable_test_user_transactions():
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data=cast(
UserBalanceUpsertInput,
{
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
"update": {"balance": 0, "updatedAt": old_date},
},
),
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
"update": {"balance": 0, "updatedAt": old_date},
},
)
@@ -145,29 +140,23 @@ async def test_block_credit_reset(server: SpinTestServer):
# Manually create a transaction with month 1 timestamp to establish history
await CreditTransaction.prisma().create(
data=cast(
CreditTransactionCreateInput,
{
"userId": DEFAULT_USER_ID,
"amount": 100,
"type": CreditTransactionType.TOP_UP,
"runningBalance": 1100,
"isActive": True,
"createdAt": month1, # Set specific timestamp
},
)
data={
"userId": DEFAULT_USER_ID,
"amount": 100,
"type": CreditTransactionType.TOP_UP,
"runningBalance": 1100,
"isActive": True,
"createdAt": month1, # Set specific timestamp
}
)
# Update user balance to match
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data=cast(
UserBalanceUpsertInput,
{
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
"update": {"balance": 1100},
},
),
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
"update": {"balance": 1100},
},
)
# Now test month 2 behavior
@@ -186,17 +175,14 @@ async def test_block_credit_reset(server: SpinTestServer):
# Create a month 2 transaction to update the last transaction time
await CreditTransaction.prisma().create(
data=cast(
CreditTransactionCreateInput,
{
"userId": DEFAULT_USER_ID,
"amount": -700, # Spent 700 to get to 400
"type": CreditTransactionType.USAGE,
"runningBalance": 400,
"isActive": True,
"createdAt": month2,
},
)
data={
"userId": DEFAULT_USER_ID,
"amount": -700, # Spent 700 to get to 400
"type": CreditTransactionType.USAGE,
"runningBalance": 400,
"isActive": True,
"createdAt": month2,
}
)
# Move to month 3

View File

@@ -6,14 +6,12 @@ doesn't underflow below POSTGRES_INT_MIN, which could cause integer wraparound i
"""
import asyncio
from typing import cast
from uuid import uuid4
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
from backend.util.test import SpinTestServer
@@ -23,14 +21,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for underflow tests."""
try:
await User.prisma().create(
data=cast(
UserCreateInput,
{
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
},
)
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
@@ -38,10 +33,7 @@ async def create_test_user(user_id: str) -> None:
await UserBalance.prisma().upsert(
where={"userId": user_id},
data=cast(
UserBalanceUpsertInput,
{"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
),
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
)
@@ -78,13 +70,10 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
await UserBalance.prisma().upsert(
where={"userId": user_id},
data=cast(
UserBalanceUpsertInput,
{
"create": {"userId": user_id, "balance": initial_balance_target},
"update": {"balance": initial_balance_target},
},
),
data={
"create": {"userId": user_id, "balance": initial_balance_target},
"update": {"balance": initial_balance_target},
},
)
current_balance = await credit_system.get_credits(user_id)
@@ -121,13 +110,10 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
# Set balance to exactly POSTGRES_INT_MIN
await UserBalance.prisma().upsert(
where={"userId": user_id},
data=cast(
UserBalanceUpsertInput,
{
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
"update": {"balance": POSTGRES_INT_MIN},
},
),
data={
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
"update": {"balance": POSTGRES_INT_MIN},
},
)
edge_balance = await credit_system.get_credits(user_id)
@@ -166,13 +152,10 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
test_balance = POSTGRES_INT_MIN + 1000
await UserBalance.prisma().upsert(
where={"userId": user_id},
data=cast(
UserBalanceUpsertInput,
{
"create": {"userId": user_id, "balance": test_balance},
"update": {"balance": test_balance},
},
),
data={
"create": {"userId": user_id, "balance": test_balance},
"update": {"balance": test_balance},
},
)
current_balance = await credit_system.get_credits(user_id)
@@ -234,13 +217,10 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
await UserBalance.prisma().upsert(
where={"userId": user_id},
data=cast(
UserBalanceUpsertInput,
{
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
),
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
)
# Apply multiple refunds that would cumulatively underflow
@@ -315,13 +295,10 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
await UserBalance.prisma().upsert(
where={"userId": user_id},
data=cast(
UserBalanceUpsertInput,
{
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
),
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
)
async def large_refund(amount: int, label: str):

View File

@@ -9,13 +9,11 @@ This test ensures that:
import asyncio
from datetime import datetime
from typing import cast
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserCreateInput
from backend.data.credit import UsageTransactionMetadata, UserCredit
from backend.util.json import SafeJson
@@ -26,14 +24,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for migration tests."""
try:
await User.prisma().create(
data=cast(
UserCreateInput,
{
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
},
)
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
@@ -126,9 +121,7 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
try:
# Create UserBalance with specific value
await UserBalance.prisma().create(
data=cast(
UserBalanceCreateInput, {"userId": user_id, "balance": 5000}
) # $50
data={"userId": user_id, "balance": 5000} # $50
)
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
@@ -167,9 +160,7 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
try:
# Set initial balance in UserBalance
await UserBalance.prisma().create(
data=cast(UserBalanceCreateInput, {"userId": user_id, "balance": 1000})
)
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
# Run concurrent operations to ensure they all use UserBalance atomic operations
async def concurrent_spend(amount: int, label: str):

View File

@@ -5,7 +5,6 @@ from enum import Enum
from multiprocessing import Manager
from queue import Empty
from typing import (
TYPE_CHECKING,
Annotated,
Any,
AsyncGenerator,
@@ -28,7 +27,6 @@ from prisma.models import (
AgentNodeExecutionKeyValueData,
)
from prisma.types import (
AgentGraphExecutionCreateInput,
AgentGraphExecutionUpdateManyMutationInput,
AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput,
@@ -36,6 +34,7 @@ from prisma.types import (
AgentNodeExecutionKeyValueDataCreateInput,
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
AgentNodeExecutionWhereUniqueInput,
)
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
from pydantic.fields import Field
@@ -66,9 +65,6 @@ from .includes import (
)
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
if TYPE_CHECKING:
pass
T = TypeVar("T")
logger = logging.getLogger(__name__)
@@ -709,40 +705,37 @@ async def create_graph_execution(
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
"""
result = await AgentGraphExecution.prisma().create(
data=cast(
AgentGraphExecutionCreateInput,
{
"agentGraphId": graph_id,
"agentGraphVersion": graph_version,
"executionStatus": ExecutionStatus.INCOMPLETE,
"inputs": SafeJson(inputs),
"credentialInputs": (
SafeJson(credential_inputs) if credential_inputs else Json({})
),
"nodesInputMasks": (
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
),
"NodeExecutions": {
"create": [
AgentNodeExecutionCreateInput(
agentNodeId=node_id,
executionStatus=ExecutionStatus.QUEUED,
queuedTime=datetime.now(tz=timezone.utc),
Input={
"create": [
{"name": name, "data": SafeJson(data)}
for name, data in node_input.items()
]
},
)
for node_id, node_input in starting_nodes_input
]
},
"userId": user_id,
"agentPresetId": preset_id,
"parentGraphExecutionId": parent_graph_exec_id,
data={
"agentGraphId": graph_id,
"agentGraphVersion": graph_version,
"executionStatus": ExecutionStatus.INCOMPLETE,
"inputs": SafeJson(inputs),
"credentialInputs": (
SafeJson(credential_inputs) if credential_inputs else Json({})
),
"nodesInputMasks": (
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
),
"NodeExecutions": {
"create": [
AgentNodeExecutionCreateInput(
agentNodeId=node_id,
executionStatus=ExecutionStatus.QUEUED,
queuedTime=datetime.now(tz=timezone.utc),
Input={
"create": [
{"name": name, "data": SafeJson(data)}
for name, data in node_input.items()
]
},
)
for node_id, node_input in starting_nodes_input
]
},
),
"userId": user_id,
"agentPresetId": preset_id,
"parentGraphExecutionId": parent_graph_exec_id,
},
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
@@ -834,42 +827,15 @@ async def upsert_execution_output(
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
"""
data: AgentNodeExecutionInputOutputCreateInput = cast(
AgentNodeExecutionInputOutputCreateInput,
{
"name": output_name,
"referencedByOutputExecId": node_exec_id,
},
)
data: AgentNodeExecutionInputOutputCreateInput = {
"name": output_name,
"referencedByOutputExecId": node_exec_id,
}
if output_data is not None:
data["data"] = SafeJson(output_data)
await AgentNodeExecutionInputOutput.prisma().create(data=data)
async def get_execution_outputs_by_node_exec_id(
node_exec_id: str,
) -> dict[str, Any]:
"""
Get all execution outputs for a specific node execution ID.
Args:
node_exec_id: The node execution ID to get outputs for
Returns:
Dictionary mapping output names to their data values
"""
outputs = await AgentNodeExecutionInputOutput.prisma().find_many(
where={"referencedByOutputExecId": node_exec_id}
)
result = {}
for output in outputs:
if output.data is not None:
result[output.name] = type_utils.convert(output.data, JsonValue)
return result
async def update_graph_execution_start_time(
graph_exec_id: str,
) -> GraphExecution | None:
@@ -980,30 +946,25 @@ async def update_node_execution_status(
f"Invalid status transition: {status} has no valid source statuses"
)
# First verify the current status allows this transition
current_exec = await AgentNodeExecution.prisma().find_unique(
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
)
if not current_exec:
raise ValueError(f"Execution {node_exec_id} not found.")
# Check if current status allows the requested transition
if current_exec.executionStatus not in allowed_from:
# Status transition not allowed, return current state without updating
return NodeExecutionResult.from_db(current_exec)
# Status transition is valid, perform the update
updated_exec = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
if res := await AgentNodeExecution.prisma().update(
where=cast(
AgentNodeExecutionWhereUniqueInput,
{
"id": node_exec_id,
"executionStatus": {"in": [s.value for s in allowed_from]},
},
),
data=_get_update_status_data(status, execution_data, stats),
include=EXECUTION_RESULT_INCLUDE,
)
):
return NodeExecutionResult.from_db(res)
if not updated_exec:
raise ValueError(f"Failed to update execution {node_exec_id}.")
if res := await AgentNodeExecution.prisma().find_unique(
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
):
return NodeExecutionResult.from_db(res)
return NodeExecutionResult.from_db(updated_exec)
raise ValueError(f"Execution {node_exec_id} not found.")
def _get_update_status_data(
@@ -1504,35 +1465,3 @@ async def get_graph_execution_by_share_token(
created_at=execution.createdAt,
outputs=outputs,
)
async def get_frequently_executed_graphs(
days_back: int = 30,
min_executions: int = 10,
) -> list[dict]:
"""Get graphs that have been frequently executed for monitoring."""
query_template = """
SELECT DISTINCT
e."agentGraphId" as graph_id,
e."userId" as user_id,
COUNT(*) as execution_count
FROM {schema_prefix}"AgentGraphExecution" e
WHERE e."createdAt" >= $1::timestamp
AND e."isDeleted" = false
AND e."executionStatus" IN ('COMPLETED', 'FAILED', 'TERMINATED')
GROUP BY e."agentGraphId", e."userId"
HAVING COUNT(*) >= $2
ORDER BY execution_count DESC
"""
start_date = datetime.now(timezone.utc) - timedelta(days=days_back)
result = await query_raw_with_schema(query_template, start_date, min_executions)
return [
{
"graph_id": row["graph_id"],
"user_id": row["user_id"],
"execution_count": int(row["execution_count"]),
}
for row in result
]

View File

@@ -6,11 +6,11 @@ Handles all database operations for pending human reviews.
import asyncio
import logging
from datetime import datetime, timezone
from typing import Optional, cast
from typing import Optional
from prisma.enums import ReviewStatus
from prisma.models import PendingHumanReview
from prisma.types import PendingHumanReviewUpdateInput, PendingHumanReviewUpsertInput
from prisma.types import PendingHumanReviewUpdateInput
from pydantic import BaseModel
from backend.server.v2.executions.review.model import (
@@ -66,23 +66,20 @@ async def get_or_create_human_review(
# Upsert - get existing or create new review
review = await PendingHumanReview.prisma().upsert(
where={"nodeExecId": node_exec_id},
data=cast(
PendingHumanReviewUpsertInput,
{
"create": {
"userId": user_id,
"nodeExecId": node_exec_id,
"graphExecId": graph_exec_id,
"graphId": graph_id,
"graphVersion": graph_version,
"payload": SafeJson(input_data),
"instructions": message,
"editable": editable,
"status": ReviewStatus.WAITING,
},
"update": {}, # Do nothing on update - keep existing review as is
data={
"create": {
"userId": user_id,
"nodeExecId": node_exec_id,
"graphExecId": graph_exec_id,
"graphId": graph_id,
"graphVersion": graph_version,
"payload": SafeJson(input_data),
"instructions": message,
"editable": editable,
"status": ReviewStatus.WAITING,
},
),
"update": {}, # Do nothing on update - keep existing review as is
},
)
logger.info(
@@ -103,7 +100,7 @@ async def get_or_create_human_review(
return None
else:
return ReviewResult(
data=review.payload,
data=review.payload if review.status == ReviewStatus.APPROVED else None,
status=review.status,
message=review.reviewMessage or "",
processed=review.processed,

View File

@@ -22,7 +22,7 @@ from typing import (
from urllib.parse import urlparse
from uuid import uuid4
from prisma.enums import CreditTransactionType, OnboardingStep
from prisma.enums import CreditTransactionType
from pydantic import (
BaseModel,
ConfigDict,
@@ -868,20 +868,3 @@ class UserExecutionSummaryStats(BaseModel):
total_execution_time: float = Field(default=0)
average_execution_time: float = Field(default=0)
cost_breakdown: dict[str, float] = Field(default_factory=dict)
class UserOnboarding(BaseModel):
userId: str
completedSteps: list[OnboardingStep]
walletShown: bool
notified: list[OnboardingStep]
rewardedFor: list[OnboardingStep]
usageReason: Optional[str]
integrations: list[str]
otherIntegrations: Optional[str]
selectedStoreListingVersionId: Optional[str]
agentInput: Optional[dict[str, Any]]
onboardingAgentExecutionId: Optional[str]
agentRuns: int
lastRunAt: Optional[datetime]
consecutiveRunDays: int

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import AsyncGenerator
from pydantic import BaseModel, field_serializer
from pydantic import BaseModel
from backend.data.event_bus import AsyncRedisEventBus
from backend.server.model import NotificationPayload
@@ -15,11 +15,6 @@ class NotificationEvent(BaseModel):
user_id: str
payload: NotificationPayload
@field_serializer("payload")
def serialize_payload(self, payload: NotificationPayload):
"""Ensure extra fields survive Redis serialization."""
return payload.model_dump()
class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
Model = NotificationEvent # type: ignore

View File

@@ -1,30 +1,24 @@
import re
from datetime import datetime, timedelta, timezone
from typing import Any, Literal, Optional, cast
from zoneinfo import ZoneInfo
from datetime import datetime
from typing import Any, Optional
import prisma
import pydantic
from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import (
UserOnboardingCreateInput,
UserOnboardingUpdateInput,
UserOnboardingUpsertInput,
)
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from backend.data import execution as execution_db
from backend.data.block import get_blocks
from backend.data.credit import get_user_credit_model
from backend.data.model import CredentialsMetaInput
from backend.data.notification_bus import (
AsyncRedisNotificationEventBus,
NotificationEvent,
)
from backend.data.user import get_user_by_id
from backend.server.model import OnboardingNotificationPayload
from backend.server.v2.store.model import StoreAgentDetails
from backend.util.cache import cached
from backend.util.json import SafeJson
from backend.util.timezone_utils import get_user_timezone_or_utc
# Mapping from user reason id to categories to search for when choosing agent to show
REASON_MAPPING: dict[str, list[str]] = {
@@ -37,20 +31,9 @@ REASON_MAPPING: dict[str, list[str]] = {
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
FrontendOnboardingStep = Literal[
OnboardingStep.WELCOME,
OnboardingStep.USAGE_REASON,
OnboardingStep.INTEGRATIONS,
OnboardingStep.AGENT_CHOICE,
OnboardingStep.AGENT_NEW_RUN,
OnboardingStep.AGENT_INPUT,
OnboardingStep.CONGRATS,
OnboardingStep.MARKETPLACE_VISIT,
OnboardingStep.BUILDER_OPEN,
]
class UserOnboardingUpdate(pydantic.BaseModel):
completedSteps: Optional[list[OnboardingStep]] = None
walletShown: Optional[bool] = None
notified: Optional[list[OnboardingStep]] = None
usageReason: Optional[str] = None
@@ -59,6 +42,9 @@ class UserOnboardingUpdate(pydantic.BaseModel):
selectedStoreListingVersionId: Optional[str] = None
agentInput: Optional[dict[str, Any]] = None
onboardingAgentExecutionId: Optional[str] = None
agentRuns: Optional[int] = None
lastRunAt: Optional[datetime] = None
consecutiveRunDays: Optional[int] = None
async def get_user_onboarding(user_id: str):
@@ -97,6 +83,14 @@ async def reset_user_onboarding(user_id: str):
async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update: UserOnboardingUpdateInput = {}
onboarding = await get_user_onboarding(user_id)
if data.completedSteps is not None:
update["completedSteps"] = list(
set(data.completedSteps + onboarding.completedSteps)
)
for step in data.completedSteps:
if step not in onboarding.completedSteps:
await _reward_user(user_id, onboarding, step)
await _send_onboarding_notification(user_id, step)
if data.walletShown:
update["walletShown"] = data.walletShown
if data.notified is not None:
@@ -113,16 +107,19 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update["agentInput"] = SafeJson(data.agentInput)
if data.onboardingAgentExecutionId is not None:
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
if data.agentRuns is not None and data.agentRuns > onboarding.agentRuns:
update["agentRuns"] = data.agentRuns
if data.lastRunAt is not None:
update["lastRunAt"] = data.lastRunAt
if data.consecutiveRunDays is not None:
update["consecutiveRunDays"] = data.consecutiveRunDays
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data=cast(
UserOnboardingUpsertInput,
{
"create": {"userId": user_id, **update},
"update": update,
},
),
data={
"create": {"userId": user_id, **update},
"update": update,
},
)
@@ -164,12 +161,14 @@ async def _reward_user(user_id: str, onboarding: UserOnboarding, step: Onboardin
if step in onboarding.rewardedFor:
return
onboarding.rewardedFor.append(step)
user_credit_model = await get_user_credit_model(user_id)
await user_credit_model.onboarding_reward(user_id, reward, step)
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
"rewardedFor": list(set(onboarding.rewardedFor + [step])),
"completedSteps": list(set(onboarding.completedSteps + [step])),
"rewardedFor": onboarding.rewardedFor,
},
)
@@ -178,52 +177,31 @@ async def complete_onboarding_step(user_id: str, step: OnboardingStep):
"""
Completes the specified onboarding step for the user if not already completed.
"""
onboarding = await get_user_onboarding(user_id)
if step not in onboarding.completedSteps:
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
"completedSteps": list(set(onboarding.completedSteps + [step])),
},
await update_user_onboarding(
user_id,
UserOnboardingUpdate(completedSteps=onboarding.completedSteps + [step]),
)
await _reward_user(user_id, onboarding, step)
await _send_onboarding_notification(user_id, step)
async def _send_onboarding_notification(
user_id: str, step: OnboardingStep | None, event: str = "step_completed"
):
async def _send_onboarding_notification(user_id: str, step: OnboardingStep):
"""
Sends an onboarding notification to the user.
Sends an onboarding notification to the user for the specified step.
"""
payload = OnboardingNotificationPayload(
type="onboarding",
event=event,
step=step,
event="step_completed",
step=step.value,
)
await AsyncRedisNotificationEventBus().publish(
NotificationEvent(user_id=user_id, payload=payload)
)
async def complete_re_run_agent(user_id: str, graph_id: str) -> None:
"""
Complete RE_RUN_AGENT step when a user runs a graph they've run before.
Keeps overhead low by only counting executions if the step is still pending.
"""
onboarding = await get_user_onboarding(user_id)
if OnboardingStep.RE_RUN_AGENT in onboarding.completedSteps:
return
# Includes current execution, so count > 1 means there was at least one prior run.
previous_exec_count = await execution_db.get_graph_executions_count(
user_id=user_id, graph_id=graph_id
)
if previous_exec_count > 1:
await complete_onboarding_step(user_id, OnboardingStep.RE_RUN_AGENT)
def _clean_and_split(text: str) -> list[str]:
def clean_and_split(text: str) -> list[str]:
"""
Removes all special characters from a string, truncates it to 100 characters,
and splits it by whitespace and commas.
@@ -246,7 +224,7 @@ def _clean_and_split(text: str) -> list[str]:
return words
def _calculate_points(
def calculate_points(
agent, categories: list[str], custom: list[str], integrations: list[str]
) -> int:
"""
@@ -290,85 +268,18 @@ def _calculate_points(
return int(points)
def _normalize_datetime(value: datetime | None) -> datetime | None:
if value is None:
return None
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def get_credentials_blocks() -> dict[str, str]:
# Returns a dictionary of block id to credentials field name
creds: dict[str, str] = {}
blocks = get_blocks()
for id, block in blocks.items():
for field_name, field_info in block().input_schema.model_fields.items():
if field_info.annotation == CredentialsMetaInput:
creds[id] = field_name
return creds
def _calculate_consecutive_run_days(
last_run_at: datetime | None, current_consecutive_days: int, user_timezone: str
) -> tuple[datetime, int]:
tz = ZoneInfo(user_timezone)
local_now = datetime.now(tz)
normalized_last_run = _normalize_datetime(last_run_at)
if normalized_last_run is None:
return local_now.astimezone(timezone.utc), 1
last_run_local = normalized_last_run.astimezone(tz)
last_run_date = last_run_local.date()
today = local_now.date()
if last_run_date == today:
return local_now.astimezone(timezone.utc), current_consecutive_days
if last_run_date == today - timedelta(days=1):
return local_now.astimezone(timezone.utc), current_consecutive_days + 1
return local_now.astimezone(timezone.utc), 1
def _get_run_milestone_steps(
new_run_count: int, consecutive_days: int
) -> list[OnboardingStep]:
milestones: list[OnboardingStep] = []
if new_run_count >= 10:
milestones.append(OnboardingStep.RUN_AGENTS)
if new_run_count >= 100:
milestones.append(OnboardingStep.RUN_AGENTS_100)
if consecutive_days >= 3:
milestones.append(OnboardingStep.RUN_3_DAYS)
if consecutive_days >= 14:
milestones.append(OnboardingStep.RUN_14_DAYS)
return milestones
async def _get_user_timezone(user_id: str) -> str:
user = await get_user_by_id(user_id)
return get_user_timezone_or_utc(user.timezone if user else None)
async def increment_runs(user_id: str):
"""
Increment a user's run counters and trigger any onboarding milestones.
"""
user_timezone = await _get_user_timezone(user_id)
onboarding = await get_user_onboarding(user_id)
new_run_count = onboarding.agentRuns + 1
last_run_at, consecutive_run_days = _calculate_consecutive_run_days(
onboarding.lastRunAt, onboarding.consecutiveRunDays, user_timezone
)
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
"agentRuns": {"increment": 1},
"lastRunAt": last_run_at,
"consecutiveRunDays": consecutive_run_days,
},
)
milestones = _get_run_milestone_steps(new_run_count, consecutive_run_days)
new_steps = [step for step in milestones if step not in onboarding.completedSteps]
for step in new_steps:
await complete_onboarding_step(user_id, step)
# Send progress notification if no steps were completed, so client refetches onboarding state
if not new_steps:
await _send_onboarding_notification(user_id, None, event="increment_runs")
CREDENTIALS_FIELDS: dict[str, str] = get_credentials_blocks()
async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
@@ -377,7 +288,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
where_clause: dict[str, Any] = {}
custom = _clean_and_split((user_onboarding.usageReason or "").lower())
custom = clean_and_split((user_onboarding.usageReason or "").lower())
if categories:
where_clause["OR"] = [
@@ -425,7 +336,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
# Calculate points for the first X agents and choose the top 2
agent_points = []
for agent in storeAgents[:POINTS_AGENT_COUNT]:
points = _calculate_points(
points = calculate_points(
agent, categories, custom, user_onboarding.integrations
)
agent_points.append((agent, points))
@@ -439,7 +350,6 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_output_demo=agent.agent_output_demo or "",
agent_image=agent.agent_image,
creator=agent.creator_username,
creator_avatar=agent.creator_avatar,

View File

@@ -3,18 +3,12 @@ from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
from backend.data import db
from backend.data.analytics import (
get_accuracy_trends_and_alerts,
get_marketplace_graphs_for_monitoring,
)
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
create_graph_execution,
get_block_error_stats,
get_child_graph_executions,
get_execution_kv_data,
get_execution_outputs_by_node_exec_id,
get_frequently_executed_graphs,
get_graph_execution_meta,
get_graph_executions,
get_graph_executions_count,
@@ -148,13 +142,9 @@ class DatabaseManager(AppService):
update_graph_execution_stats = _(update_graph_execution_stats)
upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output)
get_execution_outputs_by_node_exec_id = _(get_execution_outputs_by_node_exec_id)
get_execution_kv_data = _(get_execution_kv_data)
set_execution_kv_data = _(set_execution_kv_data)
get_block_error_stats = _(get_block_error_stats)
get_accuracy_trends_and_alerts = _(get_accuracy_trends_and_alerts)
get_frequently_executed_graphs = _(get_frequently_executed_graphs)
get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring)
# Graphs
get_node = _(get_node)
@@ -236,10 +226,6 @@ class DatabaseManagerClient(AppServiceClient):
# Block error monitoring
get_block_error_stats = _(d.get_block_error_stats)
# Execution accuracy monitoring
get_accuracy_trends_and_alerts = _(d.get_accuracy_trends_and_alerts)
get_frequently_executed_graphs = _(d.get_frequently_executed_graphs)
get_marketplace_graphs_for_monitoring = _(d.get_marketplace_graphs_for_monitoring)
# Human In The Loop
has_pending_reviews_for_graph_exec = _(d.has_pending_reviews_for_graph_exec)
@@ -279,7 +265,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_user_integrations = d.get_user_integrations
upsert_execution_input = d.upsert_execution_input
upsert_execution_output = d.upsert_execution_output
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
update_graph_execution_stats = d.update_graph_execution_stats
update_node_execution_status = d.update_node_execution_status
update_node_execution_status_batch = d.update_node_execution_status_batch

View File

@@ -133,8 +133,9 @@ def execute_graph(
cluster_lock: ClusterLock,
):
"""Execute graph using thread-local ExecutionProcessor instance"""
processor: ExecutionProcessor = _tls.processor
return processor.on_graph_execution(graph_exec_entry, cancel_event, cluster_lock)
return _tls.processor.on_graph_execution(
graph_exec_entry, cancel_event, cluster_lock
)
T = TypeVar("T")
@@ -142,8 +143,8 @@ T = TypeVar("T")
async def execute_node(
node: Node,
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry,
execution_processor: "ExecutionProcessor",
execution_stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> BlockOutput:
@@ -168,7 +169,6 @@ async def execute_node(
node_id = data.node_id
node_block = node.block
execution_context = data.execution_context
creds_manager = execution_processor.creds_manager
log_metadata = LogMetadata(
logger=_logger,
@@ -212,7 +212,6 @@ async def execute_node(
"node_exec_id": node_exec_id,
"user_id": user_id,
"execution_context": execution_context,
"execution_processor": execution_processor,
}
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
@@ -609,8 +608,8 @@ class ExecutionProcessor:
async for output_name, output_data in execute_node(
node=node,
creds_manager=self.creds_manager,
data=node_exec,
execution_processor=self,
execution_stats=stats,
nodes_input_masks=nodes_input_masks,
):
@@ -861,17 +860,12 @@ class ExecutionProcessor:
execution_stats_lock = threading.Lock()
# State holders ----------------------------------------------------
self.running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
NodeExecutionProgress
)
self.running_node_evaluation: dict[str, Future] = {}
self.execution_stats = execution_stats
self.execution_stats_lock = execution_stats_lock
running_node_evaluation: dict[str, Future] = {}
execution_queue = ExecutionQueue[NodeExecutionEntry]()
running_node_execution = self.running_node_execution
running_node_evaluation = self.running_node_evaluation
try:
if db_client.get_credits(graph_exec.user_id) <= 0:
raise InsufficientBalanceError(

View File

@@ -23,18 +23,15 @@ from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.block import BlockInput
from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_runs
from backend.executor import utils as execution_utils
from backend.monitoring import (
NotificationJobArgs,
process_existing_batches,
process_weekly_summary,
report_block_error_rates,
report_execution_accuracy_alerts,
report_late_executions,
)
from backend.util.clients import get_scheduler_client
@@ -156,7 +153,6 @@ async def _execute_graph(**kwargs):
inputs=args.input_data,
graph_credentials_inputs=args.input_credentials,
)
await increment_runs(args.user_id)
elapsed = asyncio.get_event_loop().time() - start_time
logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
@@ -243,17 +239,6 @@ def cleanup_expired_files():
run_async(cleanup_expired_files_async())
def cleanup_oauth_tokens():
"""Clean up expired OAuth tokens from the database."""
# Wait for completion
run_async(cleanup_expired_oauth_tokens())
def execution_accuracy_alerts():
"""Check execution accuracy and send alerts if drops are detected."""
return report_execution_accuracy_alerts()
# Monitoring functions are now imported from monitoring module
@@ -453,28 +438,6 @@ class Scheduler(AppService):
jobstore=Jobstores.EXECUTION.value,
)
# OAuth Token Cleanup - configurable interval
self.scheduler.add_job(
cleanup_oauth_tokens,
id="cleanup_oauth_tokens",
trigger="interval",
replace_existing=True,
seconds=config.oauth_token_cleanup_interval_hours
* 3600, # Convert hours to seconds
jobstore=Jobstores.EXECUTION.value,
)
# Execution Accuracy Monitoring - configurable interval
self.scheduler.add_job(
execution_accuracy_alerts,
id="report_execution_accuracy_alerts",
trigger="interval",
replace_existing=True,
seconds=config.execution_accuracy_check_interval_hours
* 3600, # Convert hours to seconds
jobstore=Jobstores.EXECUTION.value,
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
@@ -622,16 +585,6 @@ class Scheduler(AppService):
"""Manually trigger cleanup of expired cloud storage files."""
return cleanup_expired_files()
@expose
def execute_cleanup_oauth_tokens(self):
"""Manually trigger cleanup of expired OAuth tokens."""
return cleanup_oauth_tokens()
@expose
def execute_report_execution_accuracy_alerts(self):
"""Manually trigger execution accuracy alert checking."""
return execution_accuracy_alerts()
class SchedulerClient(AppServiceClient):
@classmethod

View File

@@ -0,0 +1,156 @@
"""
Embedding service for generating text embeddings using OpenAI.
Used for vector-based semantic search in the store.
"""
import logging
from typing import Optional
import openai
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
# Model configuration
# Using text-embedding-3-small (1536 dimensions) for compatibility with pgvector indexes
# pgvector IVFFlat/HNSW indexes have dimension limits (2000 for IVFFlat, varies for HNSW)
EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_DIMENSIONS = 1536
# Input validation limits
# OpenAI text-embedding-3-large supports up to 8191 tokens (~32k chars)
# We set a conservative limit to prevent abuse
MAX_TEXT_LENGTH = 10000 # characters
MAX_BATCH_SIZE = 100 # maximum texts per batch request
class EmbeddingService:
"""Service for generating text embeddings using OpenAI."""
def __init__(self, api_key: Optional[str] = None):
settings = Settings()
self.api_key = (
api_key
or settings.secrets.openai_internal_api_key
or settings.secrets.openai_api_key
)
if not self.api_key:
raise ValueError(
"OpenAI API key not configured. "
"Set OPENAI_API_KEY or OPENAI_INTERNAL_API_KEY environment variable."
)
self.client = openai.AsyncOpenAI(api_key=self.api_key)
async def generate_embedding(self, text: str) -> list[float]:
"""
Generate embedding for a single text string.
Args:
text: The text to generate an embedding for.
Returns:
A list of floats representing the embedding vector.
Raises:
ValueError: If the text is empty or exceeds maximum length.
openai.APIError: If the OpenAI API call fails.
"""
# Input validation
if not text or not text.strip():
raise ValueError("Text cannot be empty")
if len(text) > MAX_TEXT_LENGTH:
raise ValueError(
f"Text exceeds maximum length of {MAX_TEXT_LENGTH} characters"
)
try:
response = await self.client.embeddings.create(
model=EMBEDDING_MODEL,
input=text,
dimensions=EMBEDDING_DIMENSIONS,
)
return response.data[0].embedding
except openai.APIError as e:
logger.error(f"OpenAI API error generating embedding: {e}")
raise
async def generate_embeddings(self, texts: list[str]) -> list[list[float]]:
"""
Generate embeddings for multiple texts (batch).
Args:
texts: List of texts to generate embeddings for.
Returns:
List of embedding vectors, one per input text.
Raises:
ValueError: If any text is invalid or batch size exceeds limit.
openai.APIError: If the OpenAI API call fails.
"""
# Input validation
if not texts:
raise ValueError("Texts list cannot be empty")
if len(texts) > MAX_BATCH_SIZE:
raise ValueError(f"Batch size exceeds maximum of {MAX_BATCH_SIZE} texts")
for i, text in enumerate(texts):
if not text or not text.strip():
raise ValueError(f"Text at index {i} cannot be empty")
if len(text) > MAX_TEXT_LENGTH:
raise ValueError(
f"Text at index {i} exceeds maximum length of {MAX_TEXT_LENGTH} characters"
)
try:
response = await self.client.embeddings.create(
model=EMBEDDING_MODEL,
input=texts,
dimensions=EMBEDDING_DIMENSIONS,
)
# Sort by index to ensure correct ordering
sorted_data = sorted(response.data, key=lambda x: x.index)
return [item.embedding for item in sorted_data]
except openai.APIError as e:
logger.error(f"OpenAI API error generating embeddings: {e}")
raise
def create_search_text(name: str, sub_heading: str, description: str) -> str:
"""
Combine fields into searchable text for embedding.
This creates a single text string from the agent's name, sub-heading,
and description, which is then converted to an embedding vector.
Args:
name: The agent name.
sub_heading: The agent sub-heading/tagline.
description: The agent description.
Returns:
A single string combining all non-empty fields.
"""
parts = [name or "", sub_heading or "", description or ""]
return " ".join(filter(None, parts)).strip()
# Singleton instance
_embedding_service: Optional[EmbeddingService] = None
async def get_embedding_service() -> EmbeddingService:
"""
Get or create the embedding service singleton.
Returns:
The shared EmbeddingService instance.
Raises:
ValueError: If OpenAI API key is not configured.
"""
global _embedding_service
if _embedding_service is None:
_embedding_service = EmbeddingService()
return _embedding_service

View File

@@ -0,0 +1,231 @@
"""Tests for the embedding service.
This module tests:
- create_search_text utility function
- EmbeddingService input validation
- EmbeddingService API interaction (mocked)
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.integrations.embeddings import (
EMBEDDING_DIMENSIONS,
MAX_BATCH_SIZE,
MAX_TEXT_LENGTH,
EmbeddingService,
create_search_text,
)
class TestCreateSearchText:
"""Tests for the create_search_text utility function."""
def test_combines_all_fields(self):
result = create_search_text("Agent Name", "A cool agent", "Does amazing things")
assert result == "Agent Name A cool agent Does amazing things"
def test_handles_empty_name(self):
result = create_search_text("", "Sub heading", "Description")
assert result == "Sub heading Description"
def test_handles_empty_sub_heading(self):
result = create_search_text("Name", "", "Description")
assert result == "Name Description"
def test_handles_empty_description(self):
result = create_search_text("Name", "Sub heading", "")
assert result == "Name Sub heading"
def test_handles_all_empty(self):
result = create_search_text("", "", "")
assert result == ""
def test_handles_none_values(self):
# The function expects strings but should handle None gracefully
result = create_search_text(None, None, None) # type: ignore
assert result == ""
def test_preserves_content_strips_outer_whitespace(self):
# The function joins parts and strips the outer result
# Internal whitespace in each part is preserved
result = create_search_text(" Name ", " Sub ", " Desc ")
# Each part is joined with space, then outer strip applied
assert result == "Name Sub Desc"
def test_handles_only_whitespace(self):
# Parts that are only whitespace become empty after filter
result = create_search_text(" ", " ", " ")
assert result == ""
class TestEmbeddingServiceValidation:
"""Tests for EmbeddingService input validation."""
@pytest.fixture
def mock_settings(self):
"""Mock settings with a test API key."""
with patch("backend.integrations.embeddings.Settings") as mock:
mock_instance = MagicMock()
mock_instance.secrets.openai_internal_api_key = "test-api-key"
mock_instance.secrets.openai_api_key = ""
mock.return_value = mock_instance
yield mock
@pytest.fixture
def service(self, mock_settings):
"""Create an EmbeddingService instance with mocked settings."""
with patch("backend.integrations.embeddings.openai.AsyncOpenAI"):
return EmbeddingService()
def test_init_requires_api_key(self):
"""Test that initialization fails without an API key."""
with patch("backend.integrations.embeddings.Settings") as mock:
mock_instance = MagicMock()
mock_instance.secrets.openai_internal_api_key = ""
mock_instance.secrets.openai_api_key = ""
mock.return_value = mock_instance
with pytest.raises(ValueError, match="OpenAI API key not configured"):
EmbeddingService()
def test_init_accepts_explicit_api_key(self):
"""Test that explicit API key overrides settings."""
with patch("backend.integrations.embeddings.Settings") as mock:
mock_instance = MagicMock()
mock_instance.secrets.openai_internal_api_key = ""
mock_instance.secrets.openai_api_key = ""
mock.return_value = mock_instance
with patch("backend.integrations.embeddings.openai.AsyncOpenAI"):
service = EmbeddingService(api_key="explicit-key")
assert service.api_key == "explicit-key"
@pytest.mark.asyncio
async def test_generate_embedding_empty_text(self, service):
"""Test that empty text raises ValueError."""
with pytest.raises(ValueError, match="Text cannot be empty"):
await service.generate_embedding("")
@pytest.mark.asyncio
async def test_generate_embedding_whitespace_only(self, service):
"""Test that whitespace-only text raises ValueError."""
with pytest.raises(ValueError, match="Text cannot be empty"):
await service.generate_embedding(" ")
@pytest.mark.asyncio
async def test_generate_embedding_exceeds_max_length(self, service):
"""Test that text exceeding max length raises ValueError."""
long_text = "a" * (MAX_TEXT_LENGTH + 1)
with pytest.raises(ValueError, match="exceeds maximum length"):
await service.generate_embedding(long_text)
@pytest.mark.asyncio
async def test_generate_embeddings_empty_list(self, service):
"""Test that empty list raises ValueError."""
with pytest.raises(ValueError, match="Texts list cannot be empty"):
await service.generate_embeddings([])
@pytest.mark.asyncio
async def test_generate_embeddings_exceeds_batch_size(self, service):
"""Test that batch exceeding max size raises ValueError."""
texts = ["text"] * (MAX_BATCH_SIZE + 1)
with pytest.raises(ValueError, match="Batch size exceeds maximum"):
await service.generate_embeddings(texts)
@pytest.mark.asyncio
async def test_generate_embeddings_empty_text_in_batch(self, service):
"""Test that empty text in batch raises ValueError with index."""
with pytest.raises(ValueError, match="Text at index 1 cannot be empty"):
await service.generate_embeddings(["valid", "", "also valid"])
@pytest.mark.asyncio
async def test_generate_embeddings_long_text_in_batch(self, service):
"""Test that long text in batch raises ValueError with index."""
long_text = "a" * (MAX_TEXT_LENGTH + 1)
with pytest.raises(ValueError, match="Text at index 2 exceeds maximum length"):
await service.generate_embeddings(["short", "also short", long_text])
class TestEmbeddingServiceAPI:
"""Tests for EmbeddingService API interaction."""
@pytest.fixture
def mock_openai_client(self):
"""Create a mock OpenAI client."""
mock_client = MagicMock()
mock_client.embeddings = MagicMock()
return mock_client
@pytest.fixture
def service_with_mock_client(self, mock_openai_client):
"""Create an EmbeddingService with a mocked OpenAI client."""
with patch("backend.integrations.embeddings.Settings") as mock_settings:
mock_instance = MagicMock()
mock_instance.secrets.openai_internal_api_key = "test-key"
mock_instance.secrets.openai_api_key = ""
mock_settings.return_value = mock_instance
with patch(
"backend.integrations.embeddings.openai.AsyncOpenAI"
) as mock_openai:
mock_openai.return_value = mock_openai_client
service = EmbeddingService()
return service, mock_openai_client
@pytest.mark.asyncio
async def test_generate_embedding_success(self, service_with_mock_client):
"""Test successful embedding generation."""
service, mock_client = service_with_mock_client
# Create mock response
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=mock_embedding)]
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
result = await service.generate_embedding("test text")
assert result == mock_embedding
mock_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_generate_embeddings_success(self, service_with_mock_client):
"""Test successful batch embedding generation."""
service, mock_client = service_with_mock_client
# Create mock response with multiple embeddings
mock_embeddings = [[0.1] * EMBEDDING_DIMENSIONS, [0.2] * EMBEDDING_DIMENSIONS]
mock_response = MagicMock()
mock_response.data = [
MagicMock(embedding=mock_embeddings[0], index=0),
MagicMock(embedding=mock_embeddings[1], index=1),
]
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
result = await service.generate_embeddings(["text1", "text2"])
assert result == mock_embeddings
mock_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_generate_embeddings_preserves_order(self, service_with_mock_client):
"""Test that batch embeddings are returned in correct order even if API returns out of order."""
service, mock_client = service_with_mock_client
# Create mock response with embeddings out of order
mock_embeddings = [[0.1] * EMBEDDING_DIMENSIONS, [0.2] * EMBEDDING_DIMENSIONS]
mock_response = MagicMock()
# Return in reverse order
mock_response.data = [
MagicMock(embedding=mock_embeddings[1], index=1),
MagicMock(embedding=mock_embeddings[0], index=0),
]
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
result = await service.generate_embeddings(["text1", "text2"])
# Should be sorted by index
assert result[0] == mock_embeddings[0]
assert result[1] == mock_embeddings[1]

View File

@@ -18,9 +18,7 @@ class ManualWebhookManagerBase(BaseWebhooksManager[WT]):
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
# TODO: pass ingress_url to user in frontend
# See: https://github.com/Significant-Gravitas/AutoGPT/issues/8537
logger.debug(f"Manual webhook registered with ingress URL: {ingress_url}")
print(ingress_url) # FIXME: pass URL to user in front end
return "", {}

View File

@@ -1,6 +1,5 @@
"""Monitoring module for platform health and alerting."""
from .accuracy_monitor import AccuracyMonitor, report_execution_accuracy_alerts
from .block_error_monitor import BlockErrorMonitor, report_block_error_rates
from .late_execution_monitor import (
LateExecutionException,
@@ -14,12 +13,10 @@ from .notification_monitor import (
)
__all__ = [
"AccuracyMonitor",
"BlockErrorMonitor",
"LateExecutionMonitor",
"LateExecutionException",
"NotificationJobArgs",
"report_execution_accuracy_alerts",
"report_block_error_rates",
"report_late_executions",
"process_existing_batches",

View File

@@ -1,107 +0,0 @@
"""Execution accuracy monitoring module."""
import logging
from backend.util.clients import (
get_database_manager_client,
get_notification_manager_client,
)
from backend.util.metrics import DiscordChannel, sentry_capture_error
from backend.util.settings import Config
logger = logging.getLogger(__name__)
config = Config()
class AccuracyMonitor:
"""Monitor execution accuracy trends and send alerts for drops."""
def __init__(self, drop_threshold: float = 10.0):
self.config = config
self.notification_client = get_notification_manager_client()
self.database_client = get_database_manager_client()
self.drop_threshold = drop_threshold
def check_execution_accuracy_alerts(self) -> str:
"""Check marketplace agents for accuracy drops and send alerts."""
try:
logger.info("Checking execution accuracy for marketplace agents")
# Get marketplace graphs using database client
graphs = self.database_client.get_marketplace_graphs_for_monitoring(
days_back=30, min_executions=10
)
alerts_found = 0
for graph_data in graphs:
result = self.database_client.get_accuracy_trends_and_alerts(
graph_id=graph_data.graph_id,
user_id=graph_data.user_id,
days_back=21, # 3 weeks
drop_threshold=self.drop_threshold,
)
if result.alert:
alert = result.alert
# Get graph details for better alert info
try:
graph_info = self.database_client.get_graph_metadata(
graph_id=alert.graph_id
)
graph_name = graph_info.name if graph_info else "Unknown Agent"
except Exception:
graph_name = "Unknown Agent"
# Create detailed alert message
alert_msg = (
f"🚨 **AGENT ACCURACY DROP DETECTED**\n\n"
f"**Agent:** {graph_name}\n"
f"**Graph ID:** `{alert.graph_id}`\n"
f"**Accuracy Drop:** {alert.drop_percent:.1f}%\n"
f"**Recent Performance:**\n"
f" • 3-day average: {alert.three_day_avg:.1f}%\n"
f" • 7-day average: {alert.seven_day_avg:.1f}%\n"
)
if alert.user_id:
alert_msg += f"**Owner:** {alert.user_id}\n"
# Send individual alert for each agent (not batched)
self.notification_client.discord_system_alert(
alert_msg, DiscordChannel.PRODUCT
)
alerts_found += 1
logger.warning(
f"Sent accuracy alert for agent: {graph_name} ({alert.graph_id})"
)
if alerts_found > 0:
return f"Alert sent for {alerts_found} agents with accuracy drops"
logger.info("No execution accuracy alerts detected")
return "No accuracy alerts detected"
except Exception as e:
logger.exception(f"Error checking execution accuracy alerts: {e}")
error = Exception(f"Error checking execution accuracy alerts: {e}")
msg = str(error)
sentry_capture_error(error)
self.notification_client.discord_system_alert(msg, DiscordChannel.PRODUCT)
return msg
def report_execution_accuracy_alerts(drop_threshold: float = 10.0) -> str:
"""
Check execution accuracy and send alerts if drops are detected.
Args:
drop_threshold: Percentage drop threshold to trigger alerts (default 10.0%)
Returns:
Status message indicating results of the check
"""
monitor = AccuracyMonitor(drop_threshold=drop_threshold)
return monitor.check_execution_accuracy_alerts()

View File

@@ -49,10 +49,11 @@
</p>
<ol style="margin-bottom: 10px;">
<li>
Connect to the database using your preferred database client.
Visit the Supabase Dashboard:
https://supabase.com/dashboard/project/bgwpwdsxblryihinutbx/editor
</li>
<li>
Navigate to the <strong>RefundRequest</strong> table in the <strong>platform</strong> schema.
Navigate to the <strong>RefundRequest</strong> table.
</li>
<li>
Filter the <code>transactionKey</code> column with the Transaction ID: <strong>{{ data.transaction_id }}</strong>.

View File

@@ -6,7 +6,7 @@ Usage: from backend.sdk import *
This module provides:
- All block base classes and types
- All credential and authentication components
- All credential and authentication components
- All cost tracking components
- All webhook components
- All utility functions

View File

@@ -1,7 +1,7 @@
"""
Integration between SDK provider costs and the execution cost system.
This module provides the glue between provider-defined base costs and the
This module provides the glue between provider-defined base costs and the
BLOCK_COSTS configuration used by the execution system.
"""

View File

@@ -1,13 +0,0 @@
"""
Authentication module for the AutoGPT Platform.
This module provides FastAPI-based authentication supporting:
- Email/password authentication with bcrypt hashing
- Google OAuth authentication
- JWT token management (access + refresh tokens)
"""
from .routes import router as auth_router
from .service import AuthService
__all__ = ["auth_router", "AuthService"]

View File

@@ -1,170 +0,0 @@
"""
Direct email sending for authentication flows.
This module bypasses the notification queue system to ensure auth emails
(password reset, email verification) are sent immediately in all environments.
"""
import logging
import pathlib
from typing import Optional
from jinja2 import Environment, FileSystemLoader
from postmarker.core import PostmarkClient
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
# Template directory
TEMPLATE_DIR = pathlib.Path(__file__).parent / "templates"
class AuthEmailSender:
"""Handles direct email sending for authentication flows."""
def __init__(self):
if settings.secrets.postmark_server_api_token:
self.postmark = PostmarkClient(
server_token=settings.secrets.postmark_server_api_token
)
else:
logger.warning(
"Postmark server API token not found, auth email sending disabled"
)
self.postmark = None
# Set up Jinja2 environment for templates
self.jinja_env: Optional[Environment] = None
if TEMPLATE_DIR.exists():
self.jinja_env = Environment(
loader=FileSystemLoader(str(TEMPLATE_DIR)),
autoescape=True,
)
else:
logger.warning(f"Auth email templates directory not found: {TEMPLATE_DIR}")
def _get_frontend_url(self) -> str:
"""Get the frontend base URL for email links."""
return (
settings.config.frontend_base_url
or settings.config.platform_base_url
or "http://localhost:3000"
)
def _render_template(
self, template_name: str, subject: str, **context
) -> tuple[str, str]:
"""Render an email template with the base template wrapper."""
if not self.jinja_env:
raise RuntimeError("Email templates not available")
# Render the content template
content_template = self.jinja_env.get_template(template_name)
content = content_template.render(**context)
# Render with base template
base_template = self.jinja_env.get_template("base.html.jinja2")
html_body = base_template.render(
data={"title": subject, "message": content, "unsubscribe_link": None}
)
return subject, html_body
def _send_email(self, to_email: str, subject: str, html_body: str) -> bool:
"""Send an email directly via Postmark."""
if not self.postmark:
logger.warning(
f"Postmark not configured. Would send email to {to_email}: {subject}"
)
return False
try:
self.postmark.emails.send( # type: ignore[attr-defined]
From=settings.config.postmark_sender_email,
To=to_email,
Subject=subject,
HtmlBody=html_body,
)
logger.info(f"Auth email sent to {to_email}: {subject}")
return True
except Exception as e:
logger.error(f"Failed to send auth email to {to_email}: {e}")
return False
def send_password_reset_email(
self, to_email: str, reset_token: str, user_name: Optional[str] = None
) -> bool:
"""
Send a password reset email.
Args:
to_email: Recipient email address
reset_token: The raw password reset token
user_name: Optional user name for personalization
Returns:
True if email was sent successfully, False otherwise
"""
try:
frontend_url = self._get_frontend_url()
reset_link = f"{frontend_url}/reset-password?token={reset_token}"
subject, html_body = self._render_template(
"password_reset.html.jinja2",
subject="Reset Your AutoGPT Password",
reset_link=reset_link,
user_name=user_name,
frontend_url=frontend_url,
)
return self._send_email(to_email, subject, html_body)
except Exception as e:
logger.error(f"Failed to send password reset email to {to_email}: {e}")
return False
def send_email_verification(
self, to_email: str, verification_token: str, user_name: Optional[str] = None
) -> bool:
"""
Send an email verification email.
Args:
to_email: Recipient email address
verification_token: The raw verification token
user_name: Optional user name for personalization
Returns:
True if email was sent successfully, False otherwise
"""
try:
frontend_url = self._get_frontend_url()
verification_link = (
f"{frontend_url}/verify-email?token={verification_token}"
)
subject, html_body = self._render_template(
"email_verification.html.jinja2",
subject="Verify Your AutoGPT Email",
verification_link=verification_link,
user_name=user_name,
frontend_url=frontend_url,
)
return self._send_email(to_email, subject, html_body)
except Exception as e:
logger.error(f"Failed to send verification email to {to_email}: {e}")
return False
# Singleton instance
_auth_email_sender: Optional[AuthEmailSender] = None
def get_auth_email_sender() -> AuthEmailSender:
"""Get or create the auth email sender singleton."""
global _auth_email_sender
if _auth_email_sender is None:
_auth_email_sender = AuthEmailSender()
return _auth_email_sender

View File

@@ -1,505 +0,0 @@
"""
Authentication API routes.
Provides endpoints for:
- User registration and login
- Token refresh and logout
- Password reset
- Email verification
- Google OAuth
"""
import logging
import secrets
import time
from typing import Optional
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
from pydantic import BaseModel, EmailStr, Field
from backend.util.settings import Settings
from .email import get_auth_email_sender
from .service import AuthService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"])
# Singleton auth service instance
_auth_service: Optional[AuthService] = None
# In-memory state storage for OAuth CSRF protection
# Format: {state_token: {"created_at": timestamp, "redirect_uri": optional_uri}}
# In production, use Redis for distributed state management
_oauth_states: dict[str, dict] = {}
_STATE_TTL_SECONDS = 600 # 10 minutes
def _cleanup_expired_states() -> None:
"""Remove expired OAuth states."""
now = time.time()
expired = [
k
for k, v in _oauth_states.items()
if now - v["created_at"] > _STATE_TTL_SECONDS
]
for k in expired:
del _oauth_states[k]
def _generate_state() -> str:
"""Generate a cryptographically secure state token."""
_cleanup_expired_states()
state = secrets.token_urlsafe(32)
_oauth_states[state] = {"created_at": time.time()}
return state
def _validate_state(state: str) -> bool:
"""Validate and consume a state token."""
if state not in _oauth_states:
return False
state_data = _oauth_states.pop(state)
if time.time() - state_data["created_at"] > _STATE_TTL_SECONDS:
return False
return True
def get_auth_service() -> AuthService:
"""Get or create the auth service singleton."""
global _auth_service
if _auth_service is None:
_auth_service = AuthService()
return _auth_service
# ============= Request/Response Models =============
class RegisterRequest(BaseModel):
"""Request model for user registration."""
email: EmailStr
password: str = Field(..., min_length=8)
name: Optional[str] = None
class LoginRequest(BaseModel):
"""Request model for user login."""
email: EmailStr
password: str
class TokenResponse(BaseModel):
"""Response model for authentication tokens."""
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int
class RefreshRequest(BaseModel):
"""Request model for token refresh."""
refresh_token: str
class LogoutRequest(BaseModel):
"""Request model for logout."""
refresh_token: str
class PasswordResetRequest(BaseModel):
"""Request model for password reset request."""
email: EmailStr
class PasswordResetConfirm(BaseModel):
"""Request model for password reset confirmation."""
token: str
new_password: str = Field(..., min_length=8)
class MessageResponse(BaseModel):
"""Generic message response."""
message: str
class UserResponse(BaseModel):
"""Response model for user info."""
id: str
email: str
name: Optional[str]
email_verified: bool
role: str
# ============= Auth Endpoints =============
@router.post("/register", response_model=TokenResponse)
async def register(request: RegisterRequest, background_tasks: BackgroundTasks):
"""
Register a new user with email and password.
Returns access and refresh tokens on successful registration.
Sends a verification email in the background.
"""
auth_service = get_auth_service()
try:
user = await auth_service.register_user(
email=request.email,
password=request.password,
name=request.name,
)
# Create verification token and send email in background
# This is non-critical - don't fail registration if email fails
try:
verification_token = await auth_service.create_email_verification_token(
user.id
)
email_sender = get_auth_email_sender()
background_tasks.add_task(
email_sender.send_email_verification,
to_email=user.email,
verification_token=verification_token,
user_name=user.name,
)
except Exception as e:
logger.warning(f"Failed to queue verification email for {user.email}: {e}")
tokens = await auth_service.create_tokens(user)
return TokenResponse(**tokens)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/login", response_model=TokenResponse)
async def login(request: LoginRequest):
"""
Login with email and password.
Returns access and refresh tokens on successful authentication.
"""
auth_service = get_auth_service()
user = await auth_service.authenticate_user(request.email, request.password)
if not user:
raise HTTPException(status_code=401, detail="Invalid email or password")
tokens = await auth_service.create_tokens(user)
return TokenResponse(**tokens)
@router.post("/logout", response_model=MessageResponse)
async def logout(request: LogoutRequest):
"""
Logout by revoking the refresh token.
This invalidates the refresh token so it cannot be used to get new access tokens.
"""
auth_service = get_auth_service()
revoked = await auth_service.revoke_refresh_token(request.refresh_token)
if not revoked:
raise HTTPException(status_code=400, detail="Invalid refresh token")
return MessageResponse(message="Successfully logged out")
@router.post("/refresh", response_model=TokenResponse)
async def refresh_tokens(request: RefreshRequest):
"""
Refresh access token using a refresh token.
The old refresh token is invalidated and a new one is returned (token rotation).
"""
auth_service = get_auth_service()
tokens = await auth_service.refresh_access_token(request.refresh_token)
if not tokens:
raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
return TokenResponse(**tokens)
@router.post("/password-reset/request", response_model=MessageResponse)
async def request_password_reset(
request: PasswordResetRequest, background_tasks: BackgroundTasks
):
"""
Request a password reset email.
Always returns success to prevent email enumeration attacks.
If the email exists, a password reset email will be sent.
"""
auth_service = get_auth_service()
user = await auth_service.get_user_by_email(request.email)
if user:
token = await auth_service.create_password_reset_token(user.id)
email_sender = get_auth_email_sender()
background_tasks.add_task(
email_sender.send_password_reset_email,
to_email=user.email,
reset_token=token,
user_name=user.name,
)
logger.info(f"Password reset email queued for user {user.id}")
# Always return success to prevent email enumeration
return MessageResponse(
message="If the email exists, a password reset link has been sent"
)
@router.post("/password-reset/confirm", response_model=MessageResponse)
async def confirm_password_reset(request: PasswordResetConfirm):
"""
Reset password using a password reset token.
All existing sessions (refresh tokens) will be invalidated.
"""
auth_service = get_auth_service()
success = await auth_service.reset_password(request.token, request.new_password)
if not success:
raise HTTPException(status_code=400, detail="Invalid or expired reset token")
return MessageResponse(message="Password has been reset successfully")
# ============= Email Verification Endpoints =============
class EmailVerificationRequest(BaseModel):
"""Request model for email verification."""
token: str
class ResendVerificationRequest(BaseModel):
"""Request model for resending verification email."""
email: EmailStr
@router.post("/email/verify", response_model=MessageResponse)
async def verify_email(request: EmailVerificationRequest):
"""
Verify email address using a verification token.
Marks the user's email as verified if the token is valid.
"""
auth_service = get_auth_service()
success = await auth_service.verify_email_token(request.token)
if not success:
raise HTTPException(
status_code=400, detail="Invalid or expired verification token"
)
return MessageResponse(message="Email verified successfully")
@router.post("/email/resend-verification", response_model=MessageResponse)
async def resend_verification_email(
request: ResendVerificationRequest, background_tasks: BackgroundTasks
):
"""
Resend email verification email.
Always returns success to prevent email enumeration attacks.
If the email exists and is not verified, a new verification email will be sent.
"""
auth_service = get_auth_service()
user = await auth_service.get_user_by_email(request.email)
if user and not user.emailVerified:
token = await auth_service.create_email_verification_token(user.id)
email_sender = get_auth_email_sender()
background_tasks.add_task(
email_sender.send_email_verification,
to_email=user.email,
verification_token=token,
user_name=user.name,
)
logger.info(f"Verification email queued for user {user.id}")
# Always return success to prevent email enumeration
return MessageResponse(
message="If the email exists and is not verified, a verification link has been sent"
)
# ============= Google OAuth Endpoints =============
# Google userinfo endpoint for fetching user profile
GOOGLE_USERINFO_ENDPOINT = "https://www.googleapis.com/oauth2/v2/userinfo"
class GoogleLoginResponse(BaseModel):
"""Response model for Google OAuth login initiation."""
url: str
def _get_google_oauth_handler():
"""Get a configured GoogleOAuthHandler instance."""
# Lazy import to avoid circular imports
from backend.integrations.oauth.google import GoogleOAuthHandler
settings = Settings()
client_id = settings.secrets.google_client_id
client_secret = settings.secrets.google_client_secret
if not client_id or not client_secret:
raise HTTPException(
status_code=500,
detail="Google OAuth is not configured. Set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET.",
)
# Construct the redirect URI - this should point to the frontend's callback
# which will then call our /auth/google/callback endpoint
frontend_base_url = settings.config.frontend_base_url or "http://localhost:3000"
redirect_uri = f"{frontend_base_url}/auth/callback"
return GoogleOAuthHandler(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
)
@router.get("/google/login", response_model=GoogleLoginResponse)
async def google_login(request: Request):
"""
Initiate Google OAuth flow.
Returns the Google OAuth authorization URL to redirect the user to.
"""
try:
handler = _get_google_oauth_handler()
state = _generate_state()
# Get the authorization URL with default scopes (email, profile, openid)
auth_url = handler.get_login_url(
scopes=[], # Will use DEFAULT_SCOPES from handler
state=state,
code_challenge=None, # Not using PKCE for server-side flow
)
logger.info(f"Generated Google OAuth URL for state: {state[:8]}...")
return GoogleLoginResponse(url=auth_url)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to initiate Google OAuth: {e}")
raise HTTPException(status_code=500, detail="Failed to initiate Google OAuth")
@router.get("/google/callback", response_model=TokenResponse)
async def google_callback(request: Request, code: str, state: Optional[str] = None):
"""
Handle Google OAuth callback.
Exchanges the authorization code for user info and creates/updates the user.
Returns access and refresh tokens.
"""
# Validate state to prevent CSRF attacks
if not state or not _validate_state(state):
logger.warning(
f"Invalid or missing OAuth state: {state[:8] if state else 'None'}..."
)
raise HTTPException(status_code=400, detail="Invalid or expired OAuth state")
try:
handler = _get_google_oauth_handler()
# Exchange the authorization code for Google credentials
logger.info("Exchanging authorization code for tokens...")
google_creds = await handler.exchange_code_for_tokens(
code=code,
scopes=[], # Will use the scopes from the initial request
code_verifier=None,
)
# The handler returns OAuth2Credentials with email in username field
email = google_creds.username
if not email:
raise HTTPException(
status_code=400, detail="Failed to retrieve email from Google"
)
# Fetch full user info to get Google user ID and name
# Lazy import to avoid circular imports
from google.auth.transport.requests import AuthorizedSession
from google.oauth2.credentials import Credentials
# We need to create Google Credentials object to use with AuthorizedSession
creds = Credentials(
token=google_creds.access_token.get_secret_value(),
refresh_token=(
google_creds.refresh_token.get_secret_value()
if google_creds.refresh_token
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=handler.client_id,
client_secret=handler.client_secret,
)
session = AuthorizedSession(creds)
userinfo_response = session.get(GOOGLE_USERINFO_ENDPOINT)
if not userinfo_response.ok:
logger.error(
f"Failed to fetch Google userinfo: {userinfo_response.status_code}"
)
raise HTTPException(
status_code=400, detail="Failed to fetch user info from Google"
)
userinfo = userinfo_response.json()
google_id = userinfo.get("id")
name = userinfo.get("name")
email_verified = userinfo.get("verified_email", False)
if not google_id:
raise HTTPException(
status_code=400, detail="Failed to retrieve Google user ID"
)
logger.info(f"Google OAuth successful for user: {email}")
# Create or update the user in our database
auth_service = get_auth_service()
user = await auth_service.create_or_update_google_user(
google_id=google_id,
email=email,
name=name,
email_verified=email_verified,
)
# Generate our JWT tokens
tokens = await auth_service.create_tokens(user)
return TokenResponse(**tokens)
except HTTPException:
raise
except Exception as e:
logger.error(f"Google OAuth callback failed: {e}")
raise HTTPException(status_code=500, detail="Failed to complete Google OAuth")

View File

@@ -1,499 +0,0 @@
"""
Core authentication service for password verification and token management.
"""
import logging
import re
from datetime import datetime, timedelta, timezone
from typing import Optional, cast
import bcrypt
from autogpt_libs.auth.config import get_settings
from autogpt_libs.auth.jwt_utils import (
create_access_token,
create_refresh_token,
hash_token,
)
from prisma.models import User as PrismaUser
from prisma.types import (
EmailVerificationTokenCreateInput,
PasswordResetTokenCreateInput,
ProfileCreateInput,
RefreshTokenCreateInput,
UserCreateInput,
)
from backend.data.db import prisma
logger = logging.getLogger(__name__)
class AuthService:
"""Handles authentication operations including password verification and token management."""
def __init__(self):
self.settings = get_settings()
def hash_password(self, password: str) -> str:
"""Hash a password using bcrypt."""
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(self, password: str, hashed: str) -> bool:
"""Verify a password against a bcrypt hash."""
try:
return bcrypt.checkpw(password.encode(), hashed.encode())
except Exception as e:
logger.warning(f"Password verification failed: {e}")
return False
async def register_user(
self,
email: str,
password: str,
name: Optional[str] = None,
) -> PrismaUser:
"""
Register a new user with email and password.
Creates both a User record and a Profile record.
:param email: User's email address
:param password: User's password (will be hashed)
:param name: Optional display name
:return: Created user record
:raises ValueError: If email is already registered
"""
# Check if user already exists
existing = await prisma.user.find_unique(where={"email": email})
if existing:
raise ValueError("Email already registered")
password_hash = self.hash_password(password)
# Generate a unique username from email
base_username = email.split("@")[0].lower()
# Remove any characters that aren't alphanumeric or underscore
base_username = re.sub(r"[^a-z0-9_]", "", base_username)
if not base_username:
base_username = "user"
# Check if username is unique, if not add a number suffix
username = base_username
counter = 1
while await prisma.profile.find_unique(where={"username": username}):
username = f"{base_username}{counter}"
counter += 1
user = await prisma.user.create(
data=cast(
UserCreateInput,
{
"email": email,
"passwordHash": password_hash,
"name": name,
"emailVerified": False,
"role": "authenticated",
},
)
)
# Create profile for the user
display_name = name or base_username
await prisma.profile.create(
data=cast(
ProfileCreateInput,
{
"userId": user.id,
"name": display_name,
"username": username,
"description": "",
"links": [],
},
)
)
logger.info(f"Registered new user: {user.id} with profile username: {username}")
return user
async def authenticate_user(
self, email: str, password: str
) -> Optional[PrismaUser]:
"""
Authenticate a user with email and password.
:param email: User's email address
:param password: User's password
:return: User record if authentication successful, None otherwise
"""
user = await prisma.user.find_unique(where={"email": email})
if not user:
logger.debug(f"Authentication failed: user not found for email {email}")
return None
if not user.passwordHash:
logger.debug(
f"Authentication failed: no password set for user {user.id} "
"(likely OAuth-only user)"
)
return None
if self.verify_password(password, user.passwordHash):
logger.debug(f"Authentication successful for user {user.id}")
return user
logger.debug(f"Authentication failed: invalid password for user {user.id}")
return None
async def create_tokens(self, user: PrismaUser) -> dict:
"""
Create access and refresh tokens for a user.
:param user: The user to create tokens for
:return: Dictionary with access_token, refresh_token, token_type, and expires_in
"""
# Create access token
access_token = create_access_token(
user_id=user.id,
email=user.email,
role=user.role or "authenticated",
email_verified=user.emailVerified,
)
# Create and store refresh token
raw_refresh_token, hashed_refresh_token = create_refresh_token()
expires_at = datetime.now(timezone.utc) + timedelta(
days=self.settings.REFRESH_TOKEN_EXPIRE_DAYS
)
await prisma.refreshtoken.create(
data=cast(
RefreshTokenCreateInput,
{
"token": hashed_refresh_token,
"userId": user.id,
"expiresAt": expires_at,
},
)
)
logger.debug(f"Created tokens for user {user.id}")
return {
"access_token": access_token,
"refresh_token": raw_refresh_token,
"token_type": "bearer",
"expires_in": self.settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
}
async def refresh_access_token(self, refresh_token: str) -> Optional[dict]:
"""
Refresh an access token using a refresh token.
Implements token rotation: the old refresh token is revoked and a new one is issued.
:param refresh_token: The refresh token
:return: New tokens if successful, None if refresh token is invalid/expired
"""
hashed_token = hash_token(refresh_token)
# Find the refresh token
stored_token = await prisma.refreshtoken.find_first(
where={
"token": hashed_token,
"revokedAt": None,
"expiresAt": {"gt": datetime.now(timezone.utc)},
},
include={"User": True},
)
if not stored_token or not stored_token.User:
logger.debug("Refresh token not found or expired")
return None
# Revoke the old token (token rotation)
await prisma.refreshtoken.update(
where={"id": stored_token.id},
data={"revokedAt": datetime.now(timezone.utc)},
)
logger.debug(f"Refreshed tokens for user {stored_token.User.id}")
# Create new tokens
return await self.create_tokens(stored_token.User)
async def revoke_refresh_token(self, refresh_token: str) -> bool:
"""
Revoke a refresh token (logout).
:param refresh_token: The refresh token to revoke
:return: True if token was found and revoked, False otherwise
"""
hashed_token = hash_token(refresh_token)
result = await prisma.refreshtoken.update_many(
where={"token": hashed_token, "revokedAt": None},
data={"revokedAt": datetime.now(timezone.utc)},
)
if result > 0:
logger.debug("Refresh token revoked")
return True
logger.debug("Refresh token not found or already revoked")
return False
async def revoke_all_user_tokens(self, user_id: str) -> int:
"""
Revoke all refresh tokens for a user (logout from all devices).
:param user_id: The user's ID
:return: Number of tokens revoked
"""
result = await prisma.refreshtoken.update_many(
where={"userId": user_id, "revokedAt": None},
data={"revokedAt": datetime.now(timezone.utc)},
)
logger.debug(f"Revoked {result} tokens for user {user_id}")
return result
async def get_user_by_google_id(self, google_id: str) -> Optional[PrismaUser]:
"""Get a user by their Google OAuth ID."""
return await prisma.user.find_unique(where={"googleId": google_id})
async def get_user_by_email(self, email: str) -> Optional[PrismaUser]:
"""Get a user by their email address."""
return await prisma.user.find_unique(where={"email": email})
async def create_or_update_google_user(
self,
google_id: str,
email: str,
name: Optional[str] = None,
email_verified: bool = False,
) -> PrismaUser:
"""
Create or update a user from Google OAuth.
If a user with the Google ID exists, return them.
If a user with the email exists but no Google ID, link the account.
Otherwise, create a new user.
:param google_id: Google's unique user ID
:param email: User's email from Google
:param name: User's name from Google
:param email_verified: Whether Google has verified the email
:return: The user record
"""
# Check if user exists with this Google ID
user = await self.get_user_by_google_id(google_id)
if user:
return user
# Check if user exists with this email
user = await self.get_user_by_email(email)
if user:
# Link Google account to existing user
updated_user = await prisma.user.update(
where={"id": user.id},
data={
"googleId": google_id,
"emailVerified": email_verified or user.emailVerified,
},
)
if updated_user:
logger.info(f"Linked Google account to existing user {updated_user.id}")
return updated_user
return user
# Create new user with profile
# Generate a unique username from email
base_username = email.split("@")[0].lower()
base_username = re.sub(r"[^a-z0-9_]", "", base_username)
if not base_username:
base_username = "user"
username = base_username
counter = 1
while await prisma.profile.find_unique(where={"username": username}):
username = f"{base_username}{counter}"
counter += 1
user = await prisma.user.create(
data=cast(
UserCreateInput,
{
"email": email,
"googleId": google_id,
"name": name,
"emailVerified": email_verified,
"role": "authenticated",
},
)
)
# Create profile for the user
display_name = name or base_username
await prisma.profile.create(
data=cast(
ProfileCreateInput,
{
"userId": user.id,
"name": display_name,
"username": username,
"description": "",
"links": [],
},
)
)
logger.info(
f"Created new user from Google OAuth: {user.id} with profile: {username}"
)
return user
async def create_password_reset_token(self, user_id: str) -> str:
"""
Create a password reset token for a user.
:param user_id: The user's ID
:return: The raw token to send to the user
"""
raw_token, hashed_token = create_refresh_token() # Reuse token generation
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await prisma.passwordresettoken.create(
data=cast(
PasswordResetTokenCreateInput,
{
"token": hashed_token,
"userId": user_id,
"expiresAt": expires_at,
},
)
)
return raw_token
async def create_email_verification_token(self, user_id: str) -> str:
"""
Create an email verification token for a user.
:param user_id: The user's ID
:return: The raw token to send to the user
"""
raw_token, hashed_token = create_refresh_token() # Reuse token generation
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
await prisma.emailverificationtoken.create(
data=cast(
EmailVerificationTokenCreateInput,
{
"token": hashed_token,
"userId": user_id,
"expiresAt": expires_at,
},
)
)
return raw_token
async def verify_email_token(self, token: str) -> bool:
"""
Verify an email verification token and mark the user's email as verified.
:param token: The raw token from the user
:return: True if successful, False if token is invalid
"""
hashed_token = hash_token(token)
# Find and validate token
stored_token = await prisma.emailverificationtoken.find_first(
where={
"token": hashed_token,
"usedAt": None,
"expiresAt": {"gt": datetime.now(timezone.utc)},
}
)
if not stored_token:
return False
# Mark email as verified
await prisma.user.update(
where={"id": stored_token.userId},
data={"emailVerified": True},
)
# Mark token as used
await prisma.emailverificationtoken.update(
where={"id": stored_token.id},
data={"usedAt": datetime.now(timezone.utc)},
)
logger.info(f"Email verified for user {stored_token.userId}")
return True
async def verify_password_reset_token(self, token: str) -> Optional[str]:
"""
Verify a password reset token and return the user ID.
:param token: The raw token from the user
:return: User ID if valid, None otherwise
"""
hashed_token = hash_token(token)
stored_token = await prisma.passwordresettoken.find_first(
where={
"token": hashed_token,
"usedAt": None,
"expiresAt": {"gt": datetime.now(timezone.utc)},
}
)
if not stored_token:
return None
return stored_token.userId
async def reset_password(self, token: str, new_password: str) -> bool:
"""
Reset a user's password using a password reset token.
:param token: The password reset token
:param new_password: The new password
:return: True if successful, False if token is invalid
"""
hashed_token = hash_token(token)
# Find and validate token
stored_token = await prisma.passwordresettoken.find_first(
where={
"token": hashed_token,
"usedAt": None,
"expiresAt": {"gt": datetime.now(timezone.utc)},
}
)
if not stored_token:
return False
# Update password
password_hash = self.hash_password(new_password)
await prisma.user.update(
where={"id": stored_token.userId},
data={"passwordHash": password_hash},
)
# Mark token as used
await prisma.passwordresettoken.update(
where={"id": stored_token.id},
data={"usedAt": datetime.now(timezone.utc)},
)
# Revoke all refresh tokens for security
await self.revoke_all_user_tokens(stored_token.userId)
logger.info(f"Password reset for user {stored_token.userId}")
return True

View File

@@ -1,302 +0,0 @@
{# Base Template for Auth Emails #}
{# Template variables:
data.message: the message to display in the email
data.title: the title of the email
data.unsubscribe_link: the link to unsubscribe from the email (optional for auth emails)
#}
<!doctype html>
<html lang="ltr" xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=yes">
<meta name="format-detection" content="telephone=no, date=no, address=no, email=no, url=no">
<meta name="x-apple-disable-message-reformatting">
<!--[if !mso]>
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<![endif]-->
<!--[if mso]>
<style>
* { font-family: sans-serif !important; }
</style>
<noscript>
<xml>
<o:OfficeDocumentSettings>
<o:PixelsPerInch>96</o:PixelsPerInch>
</o:OfficeDocumentSettings>
</xml>
</noscript>
<![endif]-->
<style type="text/css">
/* RESET STYLES */
html,
body {
margin: 0 !important;
padding: 0 !important;
width: 100% !important;
height: 100% !important;
}
body {
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
text-rendering: optimizeLegibility;
}
.document {
margin: 0 !important;
padding: 0 !important;
width: 100% !important;
}
img {
border: 0;
outline: none;
text-decoration: none;
-ms-interpolation-mode: bicubic;
}
table {
border-collapse: collapse;
}
table,
td {
mso-table-lspace: 0pt;
mso-table-rspace: 0pt;
}
body,
table,
td,
a {
-webkit-text-size-adjust: 100%;
-ms-text-size-adjust: 100%;
}
h1,
h2,
h3,
h4,
h5,
p {
margin: 0;
word-break: break-word;
}
/* iOS BLUE LINKS */
a[x-apple-data-detectors] {
color: inherit !important;
text-decoration: none !important;
font-size: inherit !important;
font-family: inherit !important;
font-weight: inherit !important;
line-height: inherit !important;
}
/* ANDROID CENTER FIX */
div[style*="margin: 16px 0;"] {
margin: 0 !important;
}
/* MEDIA QUERIES */
@media all and (max-width:639px) {
.wrapper {
width: 100% !important;
}
.container {
width: 100% !important;
min-width: 100% !important;
padding: 0 !important;
}
.row {
padding-left: 20px !important;
padding-right: 20px !important;
}
.col-mobile {
width: 20px !important;
}
.col {
display: block !important;
width: 100% !important;
}
.mobile-center {
text-align: center !important;
float: none !important;
}
.mobile-mx-auto {
margin: 0 auto !important;
float: none !important;
}
.mobile-left {
text-align: center !important;
float: left !important;
}
.mobile-hide {
display: none !important;
}
.img {
width: 100% !important;
height: auto !important;
}
.ml-btn {
width: 100% !important;
max-width: 100% !important;
}
.ml-btn-container {
width: 100% !important;
max-width: 100% !important;
}
}
</style>
<style type="text/css">
@import url("https://assets.mlcdn.com/fonts-v2.css?version=1729862");
</style>
<style type="text/css">
@media screen {
body {
font-family: 'Poppins', sans-serif;
}
}
</style>
<title>{{data.title}}</title>
</head>
<body style="margin: 0 !important; padding: 0 !important; background-color:#070629;">
<div class="document" role="article" aria-roledescription="email" aria-label lang dir="ltr"
style="background-color:#070629; line-height: 100%; font-size:medium; font-size:max(16px, 1rem);">
<!-- Main Content -->
<table width="100%" align="center" cellspacing="0" cellpadding="0" border="0">
<tr>
<td class="background" bgcolor="#070629" align="center" valign="top" style="padding: 0 8px;">
<!-- Email Content -->
<table class="container" align="center" width="640" cellpadding="0" cellspacing="0" border="0"
style="max-width: 640px;">
<tr>
<td align="center">
<!-- Logo Section -->
<table class="container ml-4 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
<tr>
<td class="ml-default-border container" height="40" style="line-height: 40px; min-width: 640px;">
</td>
</tr>
<tr>
<td>
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td class="row" align="center" style="padding: 0 50px;">
<img
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
border="0" alt="" width="120" class="logo"
style="max-width: 120px; display: inline-block;">
</td>
</tr>
</table>
</td>
</tr>
</table>
<!-- Main Content Section -->
<table class="container ml-6 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
<tr>
<td class="row" style="padding: 0 50px;">
{{data.message|safe}}
</td>
</tr>
</table>
<!-- Footer Section -->
<table class="container ml-10 ml-default-border" width="640" bgcolor="#ffffff" align="center" border="0"
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
<tr>
<td class="row" style="padding: 0 50px;">
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td height="20" style="line-height: 20px;"></td>
</tr>
<tr>
<td>
<!-- Footer Content -->
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td class="col" align="left" valign="middle" width="120">
<img
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
border="0" alt="" width="120" class="logo"
style="max-width: 120px; display: inline-block;">
</td>
<td class="col" width="40" height="30" style="line-height: 30px;"></td>
<td class="col mobile-left" align="right" valign="middle" width="250">
<table role="presentation" cellpadding="0" cellspacing="0" border="0">
<tr>
<td align="center" valign="middle" width="18" style="padding: 0 5px 0 0;">
<a href="https://x.com/auto_gpt" target="blank" style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/x.png"
width="18" alt="x">
</a>
</td>
<td align="center" valign="middle" width="18" style="padding: 0 5px;">
<a href="https://discord.gg/autogpt" target="blank"
style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/discord.png"
width="18" alt="discord">
</a>
</td>
<td align="center" valign="middle" width="18" style="padding: 0 0 0 5px;">
<a href="https://agpt.co/" target="blank" style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/website.png"
width="18" alt="website">
</a>
</td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
<tr>
<td height="15" style="line-height: 15px;"></td>
</tr>
<tr>
<td align="center" style="text-align: left!important;">
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 12px; line-height: 150%; display: inline-block; margin-bottom: 0;">
This is an automated security email from AutoGPT. If you did not request this action, please ignore this email or contact support if you have concerns.
</p>
</td>
</tr>
<tr>
<td height="20" style="line-height: 20px;"></td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
</table>
</div>
</body>
</html>

View File

@@ -1,65 +0,0 @@
{# Email Verification Template #}
{# Variables:
verification_link: URL for email verification
user_name: Optional user name for personalization
frontend_url: Base frontend URL
#}
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td height="30" style="line-height: 30px;"></td>
</tr>
<tr>
<td align="center">
<h1 style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 28px; line-height: 125%; font-weight: bold; margin-bottom: 20px;">
Verify Your Email
</h1>
</td>
</tr>
<tr>
<td align="left">
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
{% if user_name %}Hi {{ user_name }},{% else %}Hi,{% endif %}
</p>
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
Welcome to AutoGPT! Please verify your email address by clicking the button below:
</p>
</td>
</tr>
<tr>
<td align="center" style="padding: 20px 0;">
<table border="0" cellspacing="0" cellpadding="0">
<tr>
<td align="center" bgcolor="#4285F4" style="border-radius: 8px;">
<a href="{{ verification_link }}" target="_blank"
style="display: inline-block; padding: 16px 36px; font-family: 'Poppins', sans-serif; font-size: 16px; font-weight: 600; color: #ffffff; text-decoration: none; border-radius: 8px;">
Verify Email
</a>
</td>
</tr>
</table>
</td>
</tr>
<tr>
<td align="left">
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
This link will expire in <strong>24 hours</strong>.
</p>
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
If you didn't create an account with AutoGPT, you can safely ignore this email.
</p>
</td>
</tr>
<tr>
<td align="left">
<p style="font-family: 'Poppins', sans-serif; color: #888888; font-size: 14px; line-height: 165%; margin-bottom: 10px;">
If the button doesn't work, copy and paste this link into your browser:
</p>
<p style="font-family: 'Poppins', sans-serif; color: #4285F4; font-size: 14px; line-height: 165%; word-break: break-all;">
<a href="{{ verification_link }}" style="color: #4285F4; text-decoration: underline;">{{ verification_link }}</a>
</p>
</td>
</tr>
<tr>
<td height="30" style="line-height: 30px;"></td>
</tr>
</table>

View File

@@ -1,65 +0,0 @@
{# Password Reset Email Template #}
{# Variables:
reset_link: URL for password reset
user_name: Optional user name for personalization
frontend_url: Base frontend URL
#}
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td height="30" style="line-height: 30px;"></td>
</tr>
<tr>
<td align="center">
<h1 style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 28px; line-height: 125%; font-weight: bold; margin-bottom: 20px;">
Reset Your Password
</h1>
</td>
</tr>
<tr>
<td align="left">
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
{% if user_name %}Hi {{ user_name }},{% else %}Hi,{% endif %}
</p>
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
We received a request to reset your password for your AutoGPT account. Click the button below to create a new password:
</p>
</td>
</tr>
<tr>
<td align="center" style="padding: 20px 0;">
<table border="0" cellspacing="0" cellpadding="0">
<tr>
<td align="center" bgcolor="#4285F4" style="border-radius: 8px;">
<a href="{{ reset_link }}" target="_blank"
style="display: inline-block; padding: 16px 36px; font-family: 'Poppins', sans-serif; font-size: 16px; font-weight: 600; color: #ffffff; text-decoration: none; border-radius: 8px;">
Reset Password
</a>
</td>
</tr>
</table>
</td>
</tr>
<tr>
<td align="left">
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
This link will expire in <strong>1 hour</strong> for security reasons.
</p>
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-bottom: 20px;">
If you didn't request a password reset, you can safely ignore this email. Your password will remain unchanged.
</p>
</td>
</tr>
<tr>
<td align="left">
<p style="font-family: 'Poppins', sans-serif; color: #888888; font-size: 14px; line-height: 165%; margin-bottom: 10px;">
If the button doesn't work, copy and paste this link into your browser:
</p>
<p style="font-family: 'Poppins', sans-serif; color: #4285F4; font-size: 14px; line-height: 165%; word-break: break-all;">
<a href="{{ reset_link }}" style="color: #4285F4; text-decoration: underline;">{{ reset_link }}</a>
</p>
</td>
</tr>
<tr>
<td height="30" style="line-height: 30px;"></td>
</tr>
</table>

View File

@@ -1,107 +1,36 @@
from fastapi import HTTPException, Security, status
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
from fastapi import HTTPException, Security
from fastapi.security import APIKeyHeader
from prisma.enums import APIKeyPermission
from backend.data.auth.api_key import APIKeyInfo, validate_api_key
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.auth.oauth import (
InvalidClientError,
InvalidTokenError,
OAuthAccessTokenInfo,
validate_access_token,
)
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
bearer_auth = HTTPBearer(auto_error=False)
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
"""Middleware for API key authentication only"""
"""Base middleware for API key authentication"""
if api_key is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API key"
)
raise HTTPException(status_code=401, detail="Missing API key")
api_key_obj = await validate_api_key(api_key)
if not api_key_obj:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
)
raise HTTPException(status_code=401, detail="Invalid API key")
return api_key_obj
async def require_access_token(
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
) -> OAuthAccessTokenInfo:
"""Middleware for OAuth access token authentication only"""
if bearer is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Authorization header",
)
try:
token_info, _ = await validate_access_token(bearer.credentials)
except (InvalidClientError, InvalidTokenError) as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
return token_info
async def require_auth(
api_key: str | None = Security(api_key_header),
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
) -> APIAuthorizationInfo:
"""
Unified authentication middleware supporting both API keys and OAuth tokens.
Supports two authentication methods, which are checked in order:
1. X-API-Key header (existing API key authentication)
2. Authorization: Bearer <token> header (OAuth access token)
Returns:
APIAuthorizationInfo: base class of both APIKeyInfo and OAuthAccessTokenInfo.
"""
# Try API key first
if api_key is not None:
api_key_info = await validate_api_key(api_key)
if api_key_info:
return api_key_info
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
)
# Try OAuth bearer token
if bearer is not None:
try:
token_info, _ = await validate_access_token(bearer.credentials)
return token_info
except (InvalidClientError, InvalidTokenError) as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
# No credentials provided
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication. Provide API key or access token.",
)
def require_permission(permission: APIKeyPermission):
"""
Dependency function for checking specific permissions
(works with API keys and OAuth tokens)
"""
"""Dependency function for checking specific permissions"""
async def check_permission(
auth: APIAuthorizationInfo = Security(require_auth),
) -> APIAuthorizationInfo:
if permission not in auth.scopes:
api_key: APIKeyInfo = Security(require_api_key),
) -> APIKeyInfo:
if not has_permission(api_key, permission):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permission: {permission.value}",
status_code=403,
detail=f"API key lacks the required permission '{permission}'",
)
return auth
return api_key
return check_permission

View File

@@ -16,7 +16,7 @@ from fastapi import APIRouter, Body, HTTPException, Path, Security, status
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field, SecretStr
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.api_key import APIKeyInfo
from backend.data.model import (
APIKeyCredentials,
Credentials,
@@ -255,7 +255,7 @@ def _get_oauth_handler_for_external(
@integrations_router.get("/providers", response_model=list[ProviderInfo])
async def list_providers(
auth: APIAuthorizationInfo = Security(
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[ProviderInfo]:
@@ -319,7 +319,7 @@ async def list_providers(
async def initiate_oauth(
provider: Annotated[str, Path(title="The OAuth provider")],
request: OAuthInitiateRequest,
auth: APIAuthorizationInfo = Security(
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
),
) -> OAuthInitiateResponse:
@@ -337,10 +337,7 @@ async def initiate_oauth(
if not validate_callback_url(request.callback_url):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"Callback URL origin is not allowed. "
f"Allowed origins: {settings.config.external_oauth_callback_origins}",
),
detail=f"Callback URL origin is not allowed. Allowed origins: {settings.config.external_oauth_callback_origins}",
)
# Validate provider
@@ -362,15 +359,13 @@ async def initiate_oauth(
)
# Store state token with external flow metadata
# Note: initiated_by_api_key_id is only available for API key auth, not OAuth
api_key_id = getattr(auth, "id", None) if auth.type == "api_key" else None
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id=auth.user_id,
user_id=api_key.user_id,
provider=provider if isinstance(provider_name, str) else provider_name.value,
scopes=request.scopes,
callback_url=request.callback_url,
state_metadata=request.state_metadata,
initiated_by_api_key_id=api_key_id,
initiated_by_api_key_id=api_key.id,
)
# Build login URL
@@ -398,7 +393,7 @@ async def initiate_oauth(
async def complete_oauth(
provider: Annotated[str, Path(title="The OAuth provider")],
request: OAuthCompleteRequest,
auth: APIAuthorizationInfo = Security(
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
),
) -> OAuthCompleteResponse:
@@ -411,7 +406,7 @@ async def complete_oauth(
"""
# Verify state token
valid_state = await creds_manager.store.verify_state_token(
auth.user_id, request.state_token, provider
api_key.user_id, request.state_token, provider
)
if not valid_state:
@@ -458,7 +453,7 @@ async def complete_oauth(
)
# Store credentials
await creds_manager.create(auth.user_id, credentials)
await creds_manager.create(api_key.user_id, credentials)
logger.info(f"Successfully completed external OAuth for provider {provider}")
@@ -475,7 +470,7 @@ async def complete_oauth(
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
async def list_credentials(
auth: APIAuthorizationInfo = Security(
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[CredentialSummary]:
@@ -484,7 +479,7 @@ async def list_credentials(
Returns metadata about each credential without exposing sensitive tokens.
"""
credentials = await creds_manager.store.get_all_creds(auth.user_id)
credentials = await creds_manager.store.get_all_creds(api_key.user_id)
return [
CredentialSummary(
id=cred.id,
@@ -504,7 +499,7 @@ async def list_credentials(
)
async def list_credentials_by_provider(
provider: Annotated[str, Path(title="The provider to list credentials for")],
auth: APIAuthorizationInfo = Security(
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[CredentialSummary]:
@@ -512,7 +507,7 @@ async def list_credentials_by_provider(
List credentials for a specific provider.
"""
credentials = await creds_manager.store.get_creds_by_provider(
auth.user_id, provider
api_key.user_id, provider
)
return [
CredentialSummary(
@@ -541,7 +536,7 @@ async def create_credential(
CreateUserPasswordCredentialRequest,
CreateHostScopedCredentialRequest,
] = Body(..., discriminator="type"),
auth: APIAuthorizationInfo = Security(
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
),
) -> CreateCredentialResponse:
@@ -596,7 +591,7 @@ async def create_credential(
# Store credentials
try:
await creds_manager.create(auth.user_id, credentials)
await creds_manager.create(api_key.user_id, credentials)
except Exception as e:
logger.error(f"Failed to store credentials: {e}")
raise HTTPException(
@@ -628,7 +623,7 @@ class DeleteCredentialResponse(BaseModel):
async def delete_credential(
provider: Annotated[str, Path(title="The provider")],
cred_id: Annotated[str, Path(title="The credential ID to delete")],
auth: APIAuthorizationInfo = Security(
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.DELETE_INTEGRATIONS)
),
) -> DeleteCredentialResponse:
@@ -639,7 +634,7 @@ async def delete_credential(
use the main API's delete endpoint which handles webhook cleanup and
token revocation.
"""
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
creds = await creds_manager.store.get_creds_by_id(api_key.user_id, cred_id)
if not creds:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
@@ -650,6 +645,6 @@ async def delete_credential(
detail="Credentials do not match the specified provider",
)
await creds_manager.delete(auth.user_id, cred_id)
await creds_manager.delete(api_key.user_id, cred_id)
return DeleteCredentialResponse(deleted=True, credentials_id=cred_id)

View File

@@ -14,7 +14,7 @@ from fastapi import APIRouter, Security
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.api_key import APIKeyInfo
from backend.server.external.middleware import require_permission
from backend.server.v2.chat.model import ChatSession
from backend.server.v2.chat.tools import find_agent_tool, run_agent_tool
@@ -24,9 +24,9 @@ logger = logging.getLogger(__name__)
tools_router = APIRouter(prefix="/tools", tags=["tools"])
# Note: We use Security() as a function parameter dependency (auth: APIAuthorizationInfo = Security(...))
# Note: We use Security() as a function parameter dependency (api_key: APIKeyInfo = Security(...))
# rather than in the decorator's dependencies= list. This avoids duplicate permission checks
# while still enforcing auth AND giving us access to auth for extracting user_id.
# while still enforcing auth AND giving us access to the api_key for extracting user_id.
# Request models
@@ -80,9 +80,7 @@ def _create_ephemeral_session(user_id: str | None) -> ChatSession:
)
async def find_agent(
request: FindAgentRequest,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.USE_TOOLS)
),
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
) -> dict[str, Any]:
"""
Search for agents in the marketplace based on capabilities and user needs.
@@ -93,9 +91,9 @@ async def find_agent(
Returns:
List of matching agents or no results response
"""
session = _create_ephemeral_session(auth.user_id)
session = _create_ephemeral_session(api_key.user_id)
result = await find_agent_tool._execute(
user_id=auth.user_id,
user_id=api_key.user_id,
session=session,
query=request.query,
)
@@ -107,9 +105,7 @@ async def find_agent(
)
async def run_agent(
request: RunAgentRequest,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.USE_TOOLS)
),
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
) -> dict[str, Any]:
"""
Run or schedule an agent from the marketplace.
@@ -133,9 +129,9 @@ async def run_agent(
- execution_started: If agent was run or scheduled successfully
- error: If something went wrong
"""
session = _create_ephemeral_session(auth.user_id)
session = _create_ephemeral_session(api_key.user_id)
result = await run_agent_tool._execute(
user_id=auth.user_id,
user_id=api_key.user_id,
session=session,
username_agent_slug=request.username_agent_slug,
inputs=request.inputs,

View File

@@ -5,7 +5,6 @@ from typing import Annotated, Any, Literal, Optional, Sequence
from fastapi import APIRouter, Body, HTTPException, Security
from prisma.enums import AgentExecutionStatus, APIKeyPermission
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import backend.data.block
@@ -13,8 +12,7 @@ import backend.server.v2.store.cache as store_cache
import backend.server.v2.store.model as store_model
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import user as user_db
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.api_key import APIKeyInfo
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.executor.utils import add_graph_execution
from backend.server.external.middleware import require_permission
@@ -26,33 +24,27 @@ logger = logging.getLogger(__name__)
v1_router = APIRouter()
class UserInfoResponse(BaseModel):
id: str
name: Optional[str]
email: str
timezone: str = Field(
description="The user's last known timezone (e.g. 'Europe/Amsterdam'), "
"or 'not-set' if not set"
)
class NodeOutput(TypedDict):
key: str
value: Any
@v1_router.get(
path="/me",
tags=["user", "meta"],
)
async def get_user_info(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.IDENTITY)
),
) -> UserInfoResponse:
user = await user_db.get_user_by_id(auth.user_id)
class ExecutionNode(TypedDict):
node_id: str
input: Any
output: dict[str, Any]
return UserInfoResponse(
id=user.id,
name=user.name,
email=user.email,
timezone=user.timezone,
)
class ExecutionNodeOutput(TypedDict):
node_id: str
outputs: list[NodeOutput]
class GraphExecutionResult(TypedDict):
execution_id: str
status: str
nodes: list[ExecutionNode]
output: Optional[list[dict[str, str]]]
@v1_router.get(
@@ -73,9 +65,7 @@ async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
async def execute_graph_block(
block_id: str,
data: BlockInput,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.EXECUTE_BLOCK)
),
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
) -> CompletedBlockOutput:
obj = backend.data.block.get_block(block_id)
if not obj:
@@ -95,14 +85,12 @@ async def execute_graph(
graph_id: str,
graph_version: int,
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.EXECUTE_GRAPH)
),
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
) -> dict[str, Any]:
try:
graph_exec = await add_graph_execution(
graph_id=graph_id,
user_id=auth.user_id,
user_id=api_key.user_id,
inputs=node_input,
graph_version=graph_version,
)
@@ -112,19 +100,6 @@ async def execute_graph(
raise HTTPException(status_code=400, detail=msg)
class ExecutionNode(TypedDict):
node_id: str
input: Any
output: dict[str, Any]
class GraphExecutionResult(TypedDict):
execution_id: str
status: str
nodes: list[ExecutionNode]
output: Optional[list[dict[str, str]]]
@v1_router.get(
path="/graphs/{graph_id}/executions/{graph_exec_id}/results",
tags=["graphs"],
@@ -132,12 +107,10 @@ class GraphExecutionResult(TypedDict):
async def get_graph_execution_results(
graph_id: str,
graph_exec_id: str,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_GRAPH)
),
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
) -> GraphExecutionResult:
graph_exec = await execution_db.get_graph_execution(
user_id=auth.user_id,
user_id=api_key.user_id,
execution_id=graph_exec_id,
include_node_executions=True,
)
@@ -149,7 +122,7 @@ async def get_graph_execution_results(
if not await graph_db.get_graph(
graph_id=graph_exec.graph_id,
version=graph_exec.graph_version,
user_id=auth.user_id,
user_id=api_key.user_id,
):
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")

View File

@@ -33,11 +33,7 @@ from backend.data.model import (
OAuth2Credentials,
UserIntegrations,
)
from backend.data.onboarding import (
OnboardingStep,
complete_onboarding_step,
increment_runs,
)
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
@@ -381,7 +377,6 @@ async def webhook_ingress_generic(
return
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
await increment_runs(user_id)
# Execute all triggers concurrently for better performance
tasks = []

View File

@@ -1,10 +1,9 @@
import enum
from typing import Any, Literal, Optional
from typing import Any, Optional
import pydantic
from prisma.enums import OnboardingStep
from backend.data.auth.api_key import APIKeyInfo, APIKeyPermission
from backend.data.api_key import APIKeyInfo, APIKeyPermission
from backend.data.graph import Graph
from backend.util.timezone_name import TimeZoneName
@@ -36,13 +35,8 @@ class WSSubscribeGraphExecutionsRequest(pydantic.BaseModel):
graph_id: str
GraphCreationSource = Literal["builder", "upload"]
GraphExecutionSource = Literal["builder", "library", "onboarding"]
class CreateGraph(pydantic.BaseModel):
graph: Graph
source: GraphCreationSource | None = None
class CreateAPIKeyRequest(pydantic.BaseModel):
@@ -89,8 +83,6 @@ class NotificationPayload(pydantic.BaseModel):
type: str
event: str
model_config = pydantic.ConfigDict(extra="allow")
class OnboardingNotificationPayload(NotificationPayload):
step: OnboardingStep | None
step: str

View File

@@ -21,8 +21,6 @@ import backend.data.db
import backend.data.graph
import backend.data.user
import backend.integrations.webhooks.utils
import backend.server.auth
import backend.server.routers.oauth
import backend.server.routers.postmark.postmark
import backend.server.routers.v1
import backend.server.v2.admin.credit_admin_routes
@@ -256,7 +254,6 @@ app.add_exception_handler(ValueError, handle_internal_http_error(400))
app.add_exception_handler(Exception, handle_internal_http_error(500))
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/api")
app.include_router(backend.server.auth.auth_router, tags=["auth"], prefix="/api")
app.include_router(
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
)
@@ -300,11 +297,6 @@ app.include_router(
tags=["v2", "chat"],
prefix="/api/chat",
)
app.include_router(
backend.server.routers.oauth.router,
tags=["oauth"],
prefix="/api/oauth",
)
app.mount("/external-api", external_app)

View File

@@ -1,833 +0,0 @@
"""
OAuth 2.0 Provider Endpoints
Implements OAuth 2.0 Authorization Code flow with PKCE support.
Flow:
1. User clicks "Login with AutoGPT" in 3rd party app
2. App redirects user to /oauth/authorize with client_id, redirect_uri, scope, state
3. User sees consent screen (if not already logged in, redirects to login first)
4. User approves → backend creates authorization code
5. User redirected back to app with code
6. App exchanges code for access/refresh tokens at /oauth/token
7. App uses access token to call external API endpoints
"""
import io
import logging
import os
import uuid
from datetime import datetime
from typing import Literal, Optional
from urllib.parse import urlencode
from autogpt_libs.auth import get_user_id
from fastapi import APIRouter, Body, HTTPException, Security, UploadFile, status
from gcloud.aio import storage as async_storage
from PIL import Image
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.data.auth.oauth import (
InvalidClientError,
InvalidGrantError,
OAuthApplicationInfo,
TokenIntrospectionResult,
consume_authorization_code,
create_access_token,
create_authorization_code,
create_refresh_token,
get_oauth_application,
get_oauth_application_by_id,
introspect_token,
list_user_oauth_applications,
refresh_tokens,
revoke_access_token,
revoke_refresh_token,
update_oauth_application,
validate_client_credentials,
validate_redirect_uri,
validate_scopes,
)
from backend.util.settings import Settings
from backend.util.virus_scanner import scan_content_safe
settings = Settings()
logger = logging.getLogger(__name__)
router = APIRouter()
# ============================================================================
# Request/Response Models
# ============================================================================
class TokenResponse(BaseModel):
"""OAuth 2.0 token response"""
token_type: Literal["Bearer"] = "Bearer"
access_token: str
access_token_expires_at: datetime
refresh_token: str
refresh_token_expires_at: datetime
scopes: list[str]
class ErrorResponse(BaseModel):
"""OAuth 2.0 error response"""
error: str
error_description: Optional[str] = None
class OAuthApplicationPublicInfo(BaseModel):
"""Public information about an OAuth application (for consent screen)"""
name: str
description: Optional[str] = None
logo_url: Optional[str] = None
scopes: list[str]
# ============================================================================
# Application Info Endpoint
# ============================================================================
@router.get(
"/app/{client_id}",
responses={
404: {"description": "Application not found or disabled"},
},
)
async def get_oauth_app_info(
client_id: str, user_id: str = Security(get_user_id)
) -> OAuthApplicationPublicInfo:
"""
Get public information about an OAuth application.
This endpoint is used by the consent screen to display application details
to the user before they authorize access.
Returns:
- name: Application name
- description: Application description (if provided)
- scopes: List of scopes the application is allowed to request
"""
app = await get_oauth_application(client_id)
if not app or not app.is_active:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Application not found",
)
return OAuthApplicationPublicInfo(
name=app.name,
description=app.description,
logo_url=app.logo_url,
scopes=[s.value for s in app.scopes],
)
# ============================================================================
# Authorization Endpoint
# ============================================================================
class AuthorizeRequest(BaseModel):
"""OAuth 2.0 authorization request"""
client_id: str = Field(description="Client identifier")
redirect_uri: str = Field(description="Redirect URI")
scopes: list[str] = Field(description="List of scopes")
state: str = Field(description="Anti-CSRF token from client")
response_type: str = Field(
default="code", description="Must be 'code' for authorization code flow"
)
code_challenge: str = Field(description="PKCE code challenge (required)")
code_challenge_method: Literal["S256", "plain"] = Field(
default="S256", description="PKCE code challenge method (S256 recommended)"
)
class AuthorizeResponse(BaseModel):
"""OAuth 2.0 authorization response with redirect URL"""
redirect_url: str = Field(description="URL to redirect the user to")
@router.post("/authorize")
async def authorize(
request: AuthorizeRequest = Body(),
user_id: str = Security(get_user_id),
) -> AuthorizeResponse:
"""
OAuth 2.0 Authorization Endpoint
User must be logged in (authenticated with Supabase JWT).
This endpoint creates an authorization code and returns a redirect URL.
PKCE (Proof Key for Code Exchange) is REQUIRED for all authorization requests.
The frontend consent screen should call this endpoint after the user approves,
then redirect the user to the returned `redirect_url`.
Request Body:
- client_id: The OAuth application's client ID
- redirect_uri: Where to redirect after authorization (must match registered URI)
- scopes: List of permissions (e.g., "EXECUTE_GRAPH READ_GRAPH")
- state: Anti-CSRF token provided by client (will be returned in redirect)
- response_type: Must be "code" (for authorization code flow)
- code_challenge: PKCE code challenge (required)
- code_challenge_method: "S256" (recommended) or "plain"
Returns:
- redirect_url: The URL to redirect the user to (includes authorization code)
Error cases return a redirect_url with error parameters, or raise HTTPException
for critical errors (like invalid redirect_uri).
"""
try:
# Validate response_type
if request.response_type != "code":
return _error_redirect_url(
request.redirect_uri,
request.state,
"unsupported_response_type",
"Only 'code' response type is supported",
)
# Get application
app = await get_oauth_application(request.client_id)
if not app:
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_client",
"Unknown client_id",
)
if not app.is_active:
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_client",
"Application is not active",
)
# Validate redirect URI
if not validate_redirect_uri(app, request.redirect_uri):
# For invalid redirect_uri, we can't redirect safely
# Must return error instead
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"Invalid redirect_uri. "
f"Must be one of: {', '.join(app.redirect_uris)}"
),
)
# Parse and validate scopes
try:
requested_scopes = [APIKeyPermission(s.strip()) for s in request.scopes]
except ValueError as e:
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_scope",
f"Invalid scope: {e}",
)
if not requested_scopes:
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_scope",
"At least one scope is required",
)
if not validate_scopes(app, requested_scopes):
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_scope",
"Application is not authorized for all requested scopes. "
f"Allowed: {', '.join(s.value for s in app.scopes)}",
)
# Create authorization code
auth_code = await create_authorization_code(
application_id=app.id,
user_id=user_id,
scopes=requested_scopes,
redirect_uri=request.redirect_uri,
code_challenge=request.code_challenge,
code_challenge_method=request.code_challenge_method,
)
# Build redirect URL with authorization code
params = {
"code": auth_code.code,
"state": request.state,
}
redirect_url = f"{request.redirect_uri}?{urlencode(params)}"
logger.info(
f"Authorization code issued for user #{user_id} "
f"and app {app.name} (#{app.id})"
)
return AuthorizeResponse(redirect_url=redirect_url)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in authorization endpoint: {e}", exc_info=True)
return _error_redirect_url(
request.redirect_uri,
request.state,
"server_error",
"An unexpected error occurred",
)
def _error_redirect_url(
redirect_uri: str,
state: str,
error: str,
error_description: Optional[str] = None,
) -> AuthorizeResponse:
"""Helper to build redirect URL with OAuth error parameters"""
params = {
"error": error,
"state": state,
}
if error_description:
params["error_description"] = error_description
redirect_url = f"{redirect_uri}?{urlencode(params)}"
return AuthorizeResponse(redirect_url=redirect_url)
# ============================================================================
# Token Endpoint
# ============================================================================
class TokenRequestByCode(BaseModel):
grant_type: Literal["authorization_code"]
code: str = Field(description="Authorization code")
redirect_uri: str = Field(
description="Redirect URI (must match authorization request)"
)
client_id: str
client_secret: str
code_verifier: str = Field(description="PKCE code verifier")
class TokenRequestByRefreshToken(BaseModel):
grant_type: Literal["refresh_token"]
refresh_token: str
client_id: str
client_secret: str
@router.post("/token")
async def token(
request: TokenRequestByCode | TokenRequestByRefreshToken = Body(),
) -> TokenResponse:
"""
OAuth 2.0 Token Endpoint
Exchanges authorization code or refresh token for access token.
Grant Types:
1. authorization_code: Exchange authorization code for tokens
- Required: grant_type, code, redirect_uri, client_id, client_secret
- Optional: code_verifier (required if PKCE was used)
2. refresh_token: Exchange refresh token for new access token
- Required: grant_type, refresh_token, client_id, client_secret
Returns:
- access_token: Bearer token for API access (1 hour TTL)
- token_type: "Bearer"
- expires_in: Seconds until access token expires
- refresh_token: Token for refreshing access (30 days TTL)
- scopes: List of scopes
"""
# Validate client credentials
try:
app = await validate_client_credentials(
request.client_id, request.client_secret
)
except InvalidClientError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
)
# Handle authorization_code grant
if request.grant_type == "authorization_code":
# Consume authorization code
try:
user_id, scopes = await consume_authorization_code(
code=request.code,
application_id=app.id,
redirect_uri=request.redirect_uri,
code_verifier=request.code_verifier,
)
except InvalidGrantError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
# Create access and refresh tokens
access_token = await create_access_token(app.id, user_id, scopes)
refresh_token = await create_refresh_token(app.id, user_id, scopes)
logger.info(
f"Access token issued for user #{user_id} and app {app.name} (#{app.id})"
"via authorization code"
)
if not access_token.token or not refresh_token.token:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate tokens",
)
return TokenResponse(
token_type="Bearer",
access_token=access_token.token.get_secret_value(),
access_token_expires_at=access_token.expires_at,
refresh_token=refresh_token.token.get_secret_value(),
refresh_token_expires_at=refresh_token.expires_at,
scopes=list(s.value for s in scopes),
)
# Handle refresh_token grant
elif request.grant_type == "refresh_token":
# Refresh access token
try:
new_access_token, new_refresh_token = await refresh_tokens(
request.refresh_token, app.id
)
except InvalidGrantError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
logger.info(
f"Tokens refreshed for user #{new_access_token.user_id} "
f"by app {app.name} (#{app.id})"
)
if not new_access_token.token or not new_refresh_token.token:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate tokens",
)
return TokenResponse(
token_type="Bearer",
access_token=new_access_token.token.get_secret_value(),
access_token_expires_at=new_access_token.expires_at,
refresh_token=new_refresh_token.token.get_secret_value(),
refresh_token_expires_at=new_refresh_token.expires_at,
scopes=list(s.value for s in new_access_token.scopes),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant_type: {request.grant_type}. "
"Must be 'authorization_code' or 'refresh_token'",
)
# ============================================================================
# Token Introspection Endpoint
# ============================================================================
@router.post("/introspect")
async def introspect(
token: str = Body(description="Token to introspect"),
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
None, description="Hint about token type ('access_token' or 'refresh_token')"
),
client_id: str = Body(description="Client identifier"),
client_secret: str = Body(description="Client secret"),
) -> TokenIntrospectionResult:
"""
OAuth 2.0 Token Introspection Endpoint (RFC 7662)
Allows clients to check if a token is valid and get its metadata.
Returns:
- active: Whether the token is currently active
- scopes: List of authorized scopes (if active)
- client_id: The client the token was issued to (if active)
- user_id: The user the token represents (if active)
- exp: Expiration timestamp (if active)
- token_type: "access_token" or "refresh_token" (if active)
"""
# Validate client credentials
try:
await validate_client_credentials(client_id, client_secret)
except InvalidClientError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
)
# Introspect the token
return await introspect_token(token, token_type_hint)
# ============================================================================
# Token Revocation Endpoint
# ============================================================================
@router.post("/revoke")
async def revoke(
token: str = Body(description="Token to revoke"),
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
None, description="Hint about token type ('access_token' or 'refresh_token')"
),
client_id: str = Body(description="Client identifier"),
client_secret: str = Body(description="Client secret"),
):
"""
OAuth 2.0 Token Revocation Endpoint (RFC 7009)
Allows clients to revoke an access or refresh token.
Note: Revoking a refresh token does NOT revoke associated access tokens.
Revoking an access token does NOT revoke the associated refresh token.
"""
# Validate client credentials
try:
app = await validate_client_credentials(client_id, client_secret)
except InvalidClientError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
)
# Try to revoke as access token first
# Note: We pass app.id to ensure the token belongs to the authenticated app
if token_type_hint != "refresh_token":
revoked = await revoke_access_token(token, app.id)
if revoked:
logger.info(
f"Access token revoked for app {app.name} (#{app.id}); "
f"user #{revoked.user_id}"
)
return {"status": "ok"}
# Try to revoke as refresh token
revoked = await revoke_refresh_token(token, app.id)
if revoked:
logger.info(
f"Refresh token revoked for app {app.name} (#{app.id}); "
f"user #{revoked.user_id}"
)
return {"status": "ok"}
# Per RFC 7009, revocation endpoint returns 200 even if token not found
# or if token belongs to a different application.
# This prevents token scanning attacks.
logger.warning(f"Unsuccessful token revocation attempt by app {app.name} #{app.id}")
return {"status": "ok"}
# ============================================================================
# Application Management Endpoints (for app owners)
# ============================================================================
@router.get("/apps/mine")
async def list_my_oauth_apps(
user_id: str = Security(get_user_id),
) -> list[OAuthApplicationInfo]:
"""
List all OAuth applications owned by the current user.
Returns a list of OAuth applications with their details including:
- id, name, description, logo_url
- client_id (public identifier)
- redirect_uris, grant_types, scopes
- is_active status
- created_at, updated_at timestamps
Note: client_secret is never returned for security reasons.
"""
return await list_user_oauth_applications(user_id)
@router.patch("/apps/{app_id}/status")
async def update_app_status(
app_id: str,
user_id: str = Security(get_user_id),
is_active: bool = Body(description="Whether the app should be active", embed=True),
) -> OAuthApplicationInfo:
"""
Enable or disable an OAuth application.
Only the application owner can update the status.
When disabled, the application cannot be used for new authorizations
and existing access tokens will fail validation.
Returns the updated application info.
"""
updated_app = await update_oauth_application(
app_id=app_id,
owner_id=user_id,
is_active=is_active,
)
if not updated_app:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Application not found or you don't have permission to update it",
)
action = "enabled" if is_active else "disabled"
logger.info(f"OAuth app {updated_app.name} (#{app_id}) {action} by user #{user_id}")
return updated_app
class UpdateAppLogoRequest(BaseModel):
logo_url: str = Field(description="URL of the uploaded logo image")
@router.patch("/apps/{app_id}/logo")
async def update_app_logo(
app_id: str,
request: UpdateAppLogoRequest = Body(),
user_id: str = Security(get_user_id),
) -> OAuthApplicationInfo:
"""
Update the logo URL for an OAuth application.
Only the application owner can update the logo.
The logo should be uploaded first using the media upload endpoint,
then this endpoint is called with the resulting URL.
Logo requirements:
- Must be square (1:1 aspect ratio)
- Minimum 512x512 pixels
- Maximum 2048x2048 pixels
Returns the updated application info.
"""
if (
not (app := await get_oauth_application_by_id(app_id))
or app.owner_id != user_id
):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth App not found",
)
# Delete the current app logo file (if any and it's in our cloud storage)
await _delete_app_current_logo_file(app)
updated_app = await update_oauth_application(
app_id=app_id,
owner_id=user_id,
logo_url=request.logo_url,
)
if not updated_app:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Application not found or you don't have permission to update it",
)
logger.info(
f"OAuth app {updated_app.name} (#{app_id}) logo updated by user #{user_id}"
)
return updated_app
# Logo upload constraints
LOGO_MIN_SIZE = 512
LOGO_MAX_SIZE = 2048
LOGO_ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp"}
LOGO_MAX_FILE_SIZE = 3 * 1024 * 1024 # 3MB
@router.post("/apps/{app_id}/logo/upload")
async def upload_app_logo(
app_id: str,
file: UploadFile,
user_id: str = Security(get_user_id),
) -> OAuthApplicationInfo:
"""
Upload a logo image for an OAuth application.
Requirements:
- Image must be square (1:1 aspect ratio)
- Minimum 512x512 pixels
- Maximum 2048x2048 pixels
- Allowed formats: JPEG, PNG, WebP
- Maximum file size: 3MB
The image is uploaded to cloud storage and the app's logoUrl is updated.
Returns the updated application info.
"""
# Verify ownership to reduce vulnerability to DoS(torage) or DoM(oney) attacks
if (
not (app := await get_oauth_application_by_id(app_id))
or app.owner_id != user_id
):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth App not found",
)
# Check GCS configuration
if not settings.config.media_gcs_bucket_name:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Media storage is not configured",
)
# Validate content type
content_type = file.content_type
if content_type not in LOGO_ALLOWED_TYPES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid file type. Allowed: JPEG, PNG, WebP. Got: {content_type}",
)
# Read file content
try:
file_bytes = await file.read()
except Exception as e:
logger.error(f"Error reading logo file: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to read uploaded file",
)
# Check file size
if len(file_bytes) > LOGO_MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"File too large. "
f"Maximum size is {LOGO_MAX_FILE_SIZE // 1024 // 1024}MB"
),
)
# Validate image dimensions
try:
image = Image.open(io.BytesIO(file_bytes))
width, height = image.size
if width != height:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Logo must be square. Got {width}x{height}",
)
if width < LOGO_MIN_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Logo too small. Minimum {LOGO_MIN_SIZE}x{LOGO_MIN_SIZE}. "
f"Got {width}x{height}",
)
if width > LOGO_MAX_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Logo too large. Maximum {LOGO_MAX_SIZE}x{LOGO_MAX_SIZE}. "
f"Got {width}x{height}",
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error validating logo image: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid image file",
)
# Scan for viruses
filename = file.filename or "logo"
await scan_content_safe(file_bytes, filename=filename)
# Generate unique filename
file_ext = os.path.splitext(filename)[1].lower() or ".png"
unique_filename = f"{uuid.uuid4()}{file_ext}"
storage_path = f"oauth-apps/{app_id}/logo/{unique_filename}"
# Upload to GCS
try:
async with async_storage.Storage() as async_client:
bucket_name = settings.config.media_gcs_bucket_name
await async_client.upload(
bucket_name, storage_path, file_bytes, content_type=content_type
)
logo_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
except Exception as e:
logger.error(f"Error uploading logo to GCS: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to upload logo",
)
# Delete the current app logo file (if any and it's in our cloud storage)
await _delete_app_current_logo_file(app)
# Update the app with the new logo URL
updated_app = await update_oauth_application(
app_id=app_id,
owner_id=user_id,
logo_url=logo_url,
)
if not updated_app:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Application not found or you don't have permission to update it",
)
logger.info(
f"OAuth app {updated_app.name} (#{app_id}) logo uploaded by user #{user_id}"
)
return updated_app
async def _delete_app_current_logo_file(app: OAuthApplicationInfo):
"""
Delete the current logo file for the given app, if there is one in our cloud storage
"""
bucket_name = settings.config.media_gcs_bucket_name
storage_base_url = f"https://storage.googleapis.com/{bucket_name}/"
if app.logo_url and app.logo_url.startswith(storage_base_url):
# Parse blob path from URL: https://storage.googleapis.com/{bucket}/{path}
old_path = app.logo_url.replace(storage_base_url, "")
try:
async with async_storage.Storage() as async_client:
await async_client.delete(bucket_name, old_path)
logger.info(f"Deleted old logo for OAuth app #{app.id}: {old_path}")
except Exception as e:
# Log but don't fail - the new logo was uploaded successfully
logger.warning(
f"Failed to delete old logo for OAuth app #{app.id}: {e}", exc_info=e
)

File diff suppressed because it is too large Load Diff

View File

@@ -56,7 +56,7 @@ async def postmark_webhook_handler(
webhook: Annotated[
PostmarkWebhook,
Body(discriminator="RecordType"),
],
]
):
logger.info(f"Received webhook from Postmark: {webhook}")
match webhook:

View File

@@ -5,7 +5,7 @@ import time
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import Annotated, Any, Sequence, get_args
from typing import Annotated, Any, Sequence
import pydantic
import stripe
@@ -31,9 +31,9 @@ from typing_extensions import Optional, TypedDict
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.v2.library.db as library_db
from backend.data import api_key as api_key_db
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.auth import api_key as api_key_db
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
from backend.data.credit import (
AutoTopUpConfig,
@@ -45,17 +45,12 @@ from backend.data.credit import (
set_auto_top_up,
)
from backend.data.graph import GraphSettings
from backend.data.model import CredentialsMetaInput, UserOnboarding
from backend.data.model import CredentialsMetaInput
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
FrontendOnboardingStep,
OnboardingStep,
UserOnboardingUpdate,
complete_onboarding_step,
complete_re_run_agent,
get_recommended_agents,
get_user_onboarding,
increment_runs,
onboarding_enabled,
reset_user_onboarding,
update_user_onboarding,
@@ -83,7 +78,6 @@ from backend.server.model import (
CreateAPIKeyRequest,
CreateAPIKeyResponse,
CreateGraph,
GraphExecutionSource,
RequestTopUp,
SetGraphActiveVersion,
TimezoneResponse,
@@ -91,7 +85,6 @@ from backend.server.model import (
UpdateTimezoneRequest,
UploadFileResponse,
)
from backend.server.v2.store.model import StoreAgentDetails
from backend.util.cache import cached
from backend.util.clients import get_scheduler_client
from backend.util.cloud_storage import get_cloud_storage_handler
@@ -281,10 +274,9 @@ async def update_preferences(
@v1_router.get(
"/onboarding",
summary="Onboarding state",
summary="Get onboarding status",
tags=["onboarding"],
dependencies=[Security(requires_user)],
response_model=UserOnboarding,
)
async def get_onboarding(user_id: Annotated[str, Security(get_user_id)]):
return await get_user_onboarding(user_id)
@@ -292,10 +284,9 @@ async def get_onboarding(user_id: Annotated[str, Security(get_user_id)]):
@v1_router.patch(
"/onboarding",
summary="Update onboarding state",
summary="Update onboarding progress",
tags=["onboarding"],
dependencies=[Security(requires_user)],
response_model=UserOnboarding,
)
async def update_onboarding(
user_id: Annotated[str, Security(get_user_id)], data: UserOnboardingUpdate
@@ -303,39 +294,25 @@ async def update_onboarding(
return await update_user_onboarding(user_id, data)
@v1_router.post(
"/onboarding/step",
summary="Complete onboarding step",
tags=["onboarding"],
dependencies=[Security(requires_user)],
)
async def onboarding_complete_step(
user_id: Annotated[str, Security(get_user_id)], step: FrontendOnboardingStep
):
if step not in get_args(FrontendOnboardingStep):
raise HTTPException(status_code=400, detail="Invalid onboarding step")
return await complete_onboarding_step(user_id, step)
@v1_router.get(
"/onboarding/agents",
summary="Recommended onboarding agents",
summary="Get recommended agents",
tags=["onboarding"],
dependencies=[Security(requires_user)],
)
async def get_onboarding_agents(
user_id: Annotated[str, Security(get_user_id)],
) -> list[StoreAgentDetails]:
):
return await get_recommended_agents(user_id)
@v1_router.get(
"/onboarding/enabled",
summary="Is onboarding enabled",
summary="Check onboarding enabled",
tags=["onboarding", "public"],
dependencies=[Security(requires_user)],
)
async def is_onboarding_enabled() -> bool:
async def is_onboarding_enabled():
return await onboarding_enabled()
@@ -344,7 +321,6 @@ async def is_onboarding_enabled() -> bool:
summary="Reset onboarding progress",
tags=["onboarding"],
dependencies=[Security(requires_user)],
response_model=UserOnboarding,
)
async def reset_onboarding(user_id: Annotated[str, Security(get_user_id)]):
return await reset_user_onboarding(user_id)
@@ -833,12 +809,7 @@ async def create_new_graph(
# as the graph already valid and no sub-graphs are returned back.
await graph_db.create_graph(graph, user_id=user_id)
await library_db.create_library_agent(graph, user_id=user_id)
activated_graph = await on_graph_activate(graph, user_id=user_id)
if create_graph.source == "builder":
await complete_onboarding_step(user_id, OnboardingStep.BUILDER_SAVE_AGENT)
return activated_graph
return await on_graph_activate(graph, user_id=user_id)
@v1_router.delete(
@@ -996,7 +967,6 @@ async def execute_graph(
credentials_inputs: Annotated[
dict[str, CredentialsMetaInput], Body(..., embed=True, default_factory=dict)
],
source: Annotated[GraphExecutionSource | None, Body(embed=True)] = None,
graph_version: Optional[int] = None,
preset_id: Optional[str] = None,
) -> execution_db.GraphExecutionMeta:
@@ -1020,14 +990,6 @@ async def execute_graph(
# Record successful graph execution
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
record_graph_operation(operation="execute", status="success")
await increment_runs(user_id)
await complete_re_run_agent(user_id, graph_id)
if source == "library":
await complete_onboarding_step(
user_id, OnboardingStep.MARKETPLACE_RUN_AGENT
)
elif source == "builder":
await complete_onboarding_step(user_id, OnboardingStep.BUILDER_RUN_AGENT)
return result
except GraphValidationError as e:
# Record failed graph execution
@@ -1141,15 +1103,6 @@ async def list_graph_executions(
filtered_executions = await hide_activity_summaries_if_disabled(
paginated_result.executions, user_id
)
onboarding = await get_user_onboarding(user_id)
if (
onboarding.onboardingAgentExecutionId
and onboarding.onboardingAgentExecutionId
in [exec.id for exec in filtered_executions]
and OnboardingStep.GET_RESULTS not in onboarding.completedSteps
):
await complete_onboarding_step(user_id, OnboardingStep.GET_RESULTS)
return execution_db.GraphExecutionsPaginated(
executions=filtered_executions, pagination=paginated_result.pagination
)
@@ -1187,12 +1140,6 @@ async def get_graph_execution(
# Apply feature flags to filter out disabled features
result = await hide_activity_summary_if_disabled(result, user_id)
onboarding = await get_user_onboarding(user_id)
if (
onboarding.onboardingAgentExecutionId == graph_exec_id
and OnboardingStep.GET_RESULTS not in onboarding.completedSteps
):
await complete_onboarding_step(user_id, OnboardingStep.GET_RESULTS)
return result
@@ -1369,8 +1316,6 @@ async def create_graph_execution_schedule(
result.next_run_time, user_timezone
)
await complete_onboarding_step(user_id, OnboardingStep.SCHEDULE_AGENT)
return result

View File

@@ -522,8 +522,8 @@ async def test_api_keys_with_newline_variations(mock_request):
"valid\r\ntoken", # Windows newline
"valid\rtoken", # Mac newline
"valid\x85token", # NEL (Next Line)
"valid\x0btoken", # Vertical Tab
"valid\x0ctoken", # Form Feed
"valid\x0Btoken", # Vertical Tab
"valid\x0Ctoken", # Form Feed
]
for api_key in newline_variations:

View File

@@ -8,10 +8,6 @@ from fastapi import APIRouter, HTTPException, Security
from pydantic import BaseModel, Field
from backend.blocks.llm import LlmModel
from backend.data.analytics import (
AccuracyTrendsResponse,
get_accuracy_trends_and_alerts,
)
from backend.data.execution import (
ExecutionStatus,
GraphExecutionMeta,
@@ -87,18 +83,6 @@ class ExecutionAnalyticsConfig(BaseModel):
recommended_model: str
class AccuracyTrendsRequest(BaseModel):
graph_id: str = Field(..., description="Graph ID to analyze", min_length=1)
user_id: Optional[str] = Field(None, description="Optional user ID filter")
days_back: int = Field(30, description="Number of days to look back", ge=7, le=90)
drop_threshold: float = Field(
10.0, description="Alert threshold percentage", ge=1.0, le=50.0
)
include_historical: bool = Field(
False, description="Include historical data for charts"
)
router = APIRouter(
prefix="/admin",
tags=["admin", "execution_analytics"],
@@ -435,40 +419,3 @@ async def _process_batch(
return await asyncio.gather(
*[process_single_execution(execution) for execution in executions]
)
@router.get(
"/execution_accuracy_trends",
response_model=AccuracyTrendsResponse,
summary="Get Execution Accuracy Trends and Alerts",
)
async def get_execution_accuracy_trends(
graph_id: str,
user_id: Optional[str] = None,
days_back: int = 30,
drop_threshold: float = 10.0,
include_historical: bool = False,
admin_user_id: str = Security(get_user_id),
) -> AccuracyTrendsResponse:
"""
Get execution accuracy trends with moving averages and alert detection.
Simple single-query approach.
"""
logger.info(
f"Admin user {admin_user_id} requesting accuracy trends for graph {graph_id}"
)
try:
result = await get_accuracy_trends_and_alerts(
graph_id=graph_id,
days_back=days_back,
user_id=user_id,
drop_threshold=drop_threshold,
include_historical=include_historical,
)
return result
except Exception as e:
logger.exception(f"Error getting accuracy trends for graph {graph_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,16 +1,9 @@
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from difflib import SequenceMatcher
from typing import Sequence
import prisma
import backend.data.block
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.server.v2.store.db as store_db
import backend.server.v2.store.model as store_model
from backend.blocks import load_all_blocks
from backend.blocks.llm import LlmModel
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
@@ -21,36 +14,17 @@ from backend.server.v2.builder.model import (
BlockResponse,
BlockType,
CountResponse,
FilterType,
Provider,
ProviderResponse,
SearchEntry,
SearchBlocksResponse,
)
from backend.util.cache import cached
from backend.util.models import Pagination
logger = logging.getLogger(__name__)
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
@dataclass
class _ScoredItem:
item: SearchResultItem
filter_type: FilterType
score: float
sort_key: str
@dataclass
class _SearchCacheEntry:
items: list[SearchResultItem]
total_items: dict[FilterType, int]
_static_counts_cache: dict | None = None
_suggested_blocks: list[BlockInfo] | None = None
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
@@ -156,244 +130,71 @@ def get_block_by_id(block_id: str) -> BlockInfo | None:
return None
async def update_search(user_id: str, search: SearchEntry) -> str:
def search_blocks(
include_blocks: bool = True,
include_integrations: bool = True,
query: str = "",
page: int = 1,
page_size: int = 50,
) -> SearchBlocksResponse:
"""
Upsert a search request for the user and return the search ID.
Get blocks based on the filter and query.
`providers` only applies for `integrations` filter.
"""
if search.search_id:
# Update existing search
await prisma.models.BuilderSearchHistory.prisma().update(
where={
"id": search.search_id,
},
data={
"searchQuery": search.search_query or "",
"filter": search.filter or [], # type: ignore
"byCreator": search.by_creator or [],
},
)
return search.search_id
else:
# Create new search
new_search = await prisma.models.BuilderSearchHistory.prisma().create(
data={
"userId": user_id,
"searchQuery": search.search_query or "",
"filter": search.filter or [], # type: ignore
"byCreator": search.by_creator or [],
}
)
return new_search.id
blocks: list[AnyBlockSchema] = []
query = query.lower()
async def get_recent_searches(user_id: str, limit: int = 5) -> list[SearchEntry]:
"""
Get the user's most recent search requests.
"""
searches = await prisma.models.BuilderSearchHistory.prisma().find_many(
where={
"userId": user_id,
},
order={
"updatedAt": "desc",
},
take=limit,
)
return [
SearchEntry(
search_query=s.searchQuery,
filter=s.filter, # type: ignore
by_creator=s.byCreator,
search_id=s.id,
)
for s in searches
]
async def get_sorted_search_results(
*,
user_id: str,
search_query: str | None,
filters: Sequence[FilterType],
by_creator: Sequence[str] | None = None,
) -> _SearchCacheEntry:
normalized_filters: tuple[FilterType, ...] = tuple(sorted(set(filters or [])))
normalized_creators: tuple[str, ...] = tuple(sorted(set(by_creator or [])))
return await _build_cached_search_results(
user_id=user_id,
search_query=search_query or "",
filters=normalized_filters,
by_creator=normalized_creators,
)
@cached(ttl_seconds=300, shared_cache=True)
async def _build_cached_search_results(
user_id: str,
search_query: str,
filters: tuple[FilterType, ...],
by_creator: tuple[str, ...],
) -> _SearchCacheEntry:
normalized_query = (search_query or "").strip().lower()
include_blocks = "blocks" in filters
include_integrations = "integrations" in filters
include_library_agents = "my_agents" in filters
include_marketplace_agents = "marketplace_agents" in filters
scored_items: list[_ScoredItem] = []
total_items: dict[FilterType, int] = {
"blocks": 0,
"integrations": 0,
"marketplace_agents": 0,
"my_agents": 0,
}
block_results, block_total, integration_total = _collect_block_results(
normalized_query=normalized_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
if include_library_agents:
library_response = await library_db.list_library_agents(
user_id=user_id,
search_term=search_query or None,
page=1,
page_size=MAX_LIBRARY_AGENT_RESULTS,
)
total_items["my_agents"] = library_response.pagination.total_items
scored_items.extend(
_build_library_items(
agents=library_response.agents,
normalized_query=normalized_query,
)
)
if include_marketplace_agents:
marketplace_response = await store_db.get_store_agents(
creators=list(by_creator) or None,
search_query=search_query or None,
page=1,
page_size=MAX_MARKETPLACE_AGENT_RESULTS,
)
total_items["marketplace_agents"] = marketplace_response.pagination.total_items
scored_items.extend(
_build_marketplace_items(
agents=marketplace_response.agents,
normalized_query=normalized_query,
)
)
sorted_items = sorted(
scored_items,
key=lambda entry: (-entry.score, entry.sort_key, entry.filter_type),
)
return _SearchCacheEntry(
items=[entry.item for entry in sorted_items],
total_items=total_items,
)
def _collect_block_results(
*,
normalized_query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
results: list[_ScoredItem] = []
total = 0
skip = (page - 1) * page_size
take = page_size
block_count = 0
integration_count = 0
if not include_blocks and not include_integrations:
return results, block_count, integration_count
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
# Skip disabled blocks
if block.disabled:
continue
block_info = block.get_info()
# Skip blocks that don't match the query
if (
query not in block.name.lower()
and query not in block.description.lower()
and not _matches_llm_model(block.input_schema, query)
):
continue
keep = False
credentials = list(block.input_schema.get_credentials_fields().values())
is_integration = len(credentials) > 0
if is_integration and not include_integrations:
continue
if not is_integration and not include_blocks:
continue
score = _score_block(block, block_info, normalized_query)
if not _should_include_item(score, normalized_query):
continue
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
if include_integrations and len(credentials) > 0:
keep = True
integration_count += 1
else:
if include_blocks and len(credentials) == 0:
keep = True
block_count += 1
results.append(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=score,
sort_key=_get_item_name(block_info),
)
)
return results, block_count, integration_count
def _build_library_items(
*,
agents: list[library_model.LibraryAgent],
normalized_query: str,
) -> list[_ScoredItem]:
results: list[_ScoredItem] = []
for agent in agents:
score = _score_library_agent(agent, normalized_query)
if not _should_include_item(score, normalized_query):
if not keep:
continue
results.append(
_ScoredItem(
item=agent,
filter_type="my_agents",
score=score,
sort_key=_get_item_name(agent),
)
)
return results
def _build_marketplace_items(
*,
agents: list[store_model.StoreAgent],
normalized_query: str,
) -> list[_ScoredItem]:
results: list[_ScoredItem] = []
for agent in agents:
score = _score_store_agent(agent, normalized_query)
if not _should_include_item(score, normalized_query):
total += 1
if skip > 0:
skip -= 1
continue
if take > 0:
take -= 1
blocks.append(block)
results.append(
_ScoredItem(
item=agent,
filter_type="marketplace_agents",
score=score,
sort_key=_get_item_name(agent),
)
)
return results
return SearchBlocksResponse(
blocks=BlockResponse(
blocks=[b.get_info() for b in blocks],
pagination=Pagination(
total_items=total,
total_pages=(total + page_size - 1) // page_size,
current_page=page,
page_size=page_size,
),
),
total_block_count=block_count,
total_integration_count=integration_count,
)
def get_providers(
@@ -450,12 +251,16 @@ async def get_counts(user_id: str) -> CountResponse:
)
@cached(ttl_seconds=3600)
async def _get_static_counts():
"""
Get counts of blocks, integrations, and marketplace agents.
This is cached to avoid unnecessary database queries and calculations.
Can't use functools.cache here because the function is async.
"""
global _static_counts_cache
if _static_counts_cache is not None:
return _static_counts_cache
all_blocks = 0
input_blocks = 0
action_blocks = 0
@@ -482,7 +287,7 @@ async def _get_static_counts():
marketplace_agents = await prisma.models.StoreAgent.prisma().count()
return {
_static_counts_cache = {
"all_blocks": all_blocks,
"input_blocks": input_blocks,
"action_blocks": action_blocks,
@@ -491,6 +296,8 @@ async def _get_static_counts():
"marketplace_agents": marketplace_agents,
}
return _static_counts_cache
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
@@ -501,123 +308,6 @@ def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
return False
def _score_block(
block: AnyBlockSchema,
block_info: BlockInfo,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = block_info.name.lower()
description = block_info.description.lower()
score = _score_primary_fields(name, description, normalized_query)
category_text = " ".join(
category.get("category", "").lower() for category in block_info.categories
)
score += _score_additional_field(category_text, normalized_query, 12, 6)
credentials_info = block.input_schema.get_credentials_fields_info().values()
provider_names = [
provider.value.lower()
for info in credentials_info
for provider in info.provider
]
provider_text = " ".join(provider_names)
score += _score_additional_field(provider_text, normalized_query, 15, 6)
if _matches_llm_model(block.input_schema, normalized_query):
score += 20
return score
def _score_library_agent(
agent: library_model.LibraryAgent,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = agent.name.lower()
description = (agent.description or "").lower()
instructions = (agent.instructions or "").lower()
score = _score_primary_fields(name, description, normalized_query)
score += _score_additional_field(instructions, normalized_query, 15, 6)
score += _score_additional_field(
agent.creator_name.lower(), normalized_query, 10, 5
)
return score
def _score_store_agent(
agent: store_model.StoreAgent,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = agent.agent_name.lower()
description = agent.description.lower()
sub_heading = agent.sub_heading.lower()
score = _score_primary_fields(name, description, normalized_query)
score += _score_additional_field(sub_heading, normalized_query, 12, 6)
score += _score_additional_field(agent.creator.lower(), normalized_query, 10, 5)
return score
def _score_primary_fields(name: str, description: str, query: str) -> float:
score = 0.0
if name == query:
score += 120
elif name.startswith(query):
score += 90
elif query in name:
score += 60
score += SequenceMatcher(None, name, query).ratio() * 50
if description:
if query in description:
score += 30
score += SequenceMatcher(None, description, query).ratio() * 25
return score
def _score_additional_field(
value: str,
query: str,
contains_weight: float,
similarity_weight: float,
) -> float:
if not value or not query:
return 0.0
score = 0.0
if query in value:
score += contains_weight
score += SequenceMatcher(None, value, query).ratio() * similarity_weight
return score
def _should_include_item(score: float, normalized_query: str) -> bool:
if not normalized_query:
return True
return score >= MIN_SCORE_FOR_FILTERED_RESULTS
def _get_item_name(item: SearchResultItem) -> str:
if isinstance(item, BlockInfo):
return item.name.lower()
if isinstance(item, library_model.LibraryAgent):
return item.name.lower()
return item.agent_name.lower()
@cached(ttl_seconds=3600)
def _get_all_providers() -> dict[ProviderName, Provider]:
providers: dict[ProviderName, Provider] = {}
@@ -639,9 +329,13 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
return providers
@cached(ttl_seconds=3600)
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
suggested_blocks = []
global _suggested_blocks
if _suggested_blocks is not None and len(_suggested_blocks) >= count:
return _suggested_blocks[:count]
_suggested_blocks = []
# Sum the number of executions for each block type
# Prisma cannot group by nested relations, so we do a raw query
# Calculate the cutoff timestamp
@@ -682,7 +376,7 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
# Sort blocks by execution count
blocks.sort(key=lambda x: x[1], reverse=True)
suggested_blocks = [block[0] for block in blocks]
_suggested_blocks = [block[0] for block in blocks]
# Return the top blocks
return suggested_blocks[:count]
return _suggested_blocks[:count]

View File

@@ -18,17 +18,10 @@ FilterType = Literal[
BlockType = Literal["all", "input", "action", "output"]
class SearchEntry(BaseModel):
search_query: str | None = None
filter: list[FilterType] | None = None
by_creator: list[str] | None = None
search_id: str | None = None
# Suggestions
class SuggestionsResponse(BaseModel):
otto_suggestions: list[str]
recent_searches: list[SearchEntry]
recent_searches: list[str]
providers: list[ProviderName]
top_blocks: list[BlockInfo]
@@ -39,7 +32,7 @@ class BlockCategoryResponse(BaseModel):
total_blocks: int
blocks: list[BlockInfo]
model_config = {"use_enum_values": False} # Use enum names like "AI"
model_config = {"use_enum_values": False} # <== use enum names like "AI"
# Input/Action/Output and see all for block categories
@@ -60,11 +53,17 @@ class ProviderResponse(BaseModel):
pagination: Pagination
class SearchBlocksResponse(BaseModel):
blocks: BlockResponse
total_block_count: int
total_integration_count: int
class SearchResponse(BaseModel):
items: list[BlockInfo | library_model.LibraryAgent | store_model.StoreAgent]
search_id: str
total_items: dict[FilterType, int]
pagination: Pagination
page: int
more_pages: bool
class CountResponse(BaseModel):

View File

@@ -6,6 +6,10 @@ from autogpt_libs.auth.dependencies import get_user_id, requires_user
import backend.server.v2.builder.db as builder_db
import backend.server.v2.builder.model as builder_model
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.server.v2.store.db as store_db
import backend.server.v2.store.model as store_model
from backend.integrations.providers import ProviderName
from backend.util.models import Pagination
@@ -41,9 +45,7 @@ def sanitize_query(query: str | None) -> str | None:
summary="Get Builder suggestions",
response_model=builder_model.SuggestionsResponse,
)
async def get_suggestions(
user_id: Annotated[str, fastapi.Security(get_user_id)],
) -> builder_model.SuggestionsResponse:
async def get_suggestions() -> builder_model.SuggestionsResponse:
"""
Get all suggestions for the Blocks Menu.
"""
@@ -53,7 +55,11 @@ async def get_suggestions(
"Help me create a list",
"Help me feed my data to Google Maps",
],
recent_searches=await builder_db.get_recent_searches(user_id),
recent_searches=[
"image generation",
"deepfake",
"competitor analysis",
],
providers=[
ProviderName.TWITTER,
ProviderName.GITHUB,
@@ -141,6 +147,7 @@ async def get_providers(
)
# Not using post method because on frontend, orval doesn't support Infinite Query with POST method.
@router.get(
"/search",
summary="Builder search",
@@ -150,7 +157,7 @@ async def get_providers(
async def search(
user_id: Annotated[str, fastapi.Security(get_user_id)],
search_query: Annotated[str | None, fastapi.Query()] = None,
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
filter: Annotated[list[str] | None, fastapi.Query()] = None,
search_id: Annotated[str | None, fastapi.Query()] = None,
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1,
@@ -169,43 +176,69 @@ async def search(
]
search_query = sanitize_query(search_query)
# Get all possible results
cached_results = await builder_db.get_sorted_search_results(
user_id=user_id,
search_query=search_query,
filters=filter,
by_creator=by_creator,
)
# Paginate results
total_combined_items = len(cached_results.items)
pagination = Pagination(
total_items=total_combined_items,
total_pages=(total_combined_items + page_size - 1) // page_size,
current_page=page,
page_size=page_size,
)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated_items = cached_results.items[start_idx:end_idx]
# Update the search entry by id
search_id = await builder_db.update_search(
user_id,
builder_model.SearchEntry(
search_query=search_query,
filter=filter,
by_creator=by_creator,
search_id=search_id,
# Blocks&Integrations
blocks = builder_model.SearchBlocksResponse(
blocks=builder_model.BlockResponse(
blocks=[],
pagination=Pagination.empty(),
),
total_block_count=0,
total_integration_count=0,
)
if "blocks" in filter or "integrations" in filter:
blocks = builder_db.search_blocks(
include_blocks="blocks" in filter,
include_integrations="integrations" in filter,
query=search_query or "",
page=page,
page_size=page_size,
)
# Library Agents
my_agents = library_model.LibraryAgentResponse(
agents=[],
pagination=Pagination.empty(),
)
if "my_agents" in filter:
my_agents = await library_db.list_library_agents(
user_id=user_id,
search_term=search_query,
page=page,
page_size=page_size,
)
# Marketplace Agents
marketplace_agents = store_model.StoreAgentsResponse(
agents=[],
pagination=Pagination.empty(),
)
if "marketplace_agents" in filter:
marketplace_agents = await store_db.get_store_agents(
creators=by_creator,
search_query=search_query,
page=page,
page_size=page_size,
)
more_pages = False
if (
blocks.blocks.pagination.current_page < blocks.blocks.pagination.total_pages
or my_agents.pagination.current_page < my_agents.pagination.total_pages
or marketplace_agents.pagination.current_page
< marketplace_agents.pagination.total_pages
):
more_pages = True
return builder_model.SearchResponse(
items=paginated_items,
search_id=search_id,
total_items=cached_results.total_items,
pagination=pagination,
items=blocks.blocks.blocks + my_agents.agents + marketplace_agents.agents,
total_items={
"blocks": blocks.total_block_count,
"integrations": blocks.total_integration_count,
"marketplace_agents": marketplace_agents.pagination.total_items,
"my_agents": my_agents.pagination.total_items,
},
page=page,
more_pages=more_pages,
)

View File

@@ -1,10 +1,8 @@
import uuid
from datetime import UTC, datetime
from os import getenv
from typing import cast
import pytest
from prisma.types import ProfileCreateInput
from pydantic import SecretStr
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
@@ -51,16 +49,13 @@ async def setup_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data=cast(
ProfileCreateInput,
{
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile",
"links": [], # Required field - empty array for test profiles
},
)
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile",
"links": [], # Required field - empty array for test profiles
}
)
# 2. Create a test graph with agent input -> agent output
@@ -177,16 +172,13 @@ async def setup_llm_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data=cast(
ProfileCreateInput,
{
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for LLM tests",
"links": [], # Required field - empty array for test profiles
},
)
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for LLM tests",
"links": [], # Required field - empty array for test profiles
}
)
# 2. Create test OpenAI credentials for the user
@@ -340,16 +332,13 @@ async def setup_firecrawl_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data=cast(
ProfileCreateInput,
{
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for Firecrawl tests",
"links": [], # Required field - empty array for test profiles
},
)
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for Firecrawl tests",
"links": [], # Required field - empty array for test profiles
}
)
# NOTE: We deliberately do NOT create Firecrawl credentials for this user

View File

@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(
tags=["v2", "executions", "review"],
tags=["executions", "review", "private"],
dependencies=[Security(autogpt_auth_lib.requires_user)],
)
@@ -134,14 +134,18 @@ async def process_review_action(
# Build review decisions map
review_decisions = {}
for review in request.reviews:
review_status = (
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
)
review_decisions[review.node_exec_id] = (
review_status,
review.reviewed_data,
review.message,
)
if review.approved:
review_decisions[review.node_exec_id] = (
ReviewStatus.APPROVED,
review.reviewed_data,
review.message,
)
else:
review_decisions[review.node_exec_id] = (
ReviewStatus.REJECTED,
None,
review.message,
)
# Process all reviews
updated_reviews = await process_all_reviews_for_execution(

View File

@@ -1,13 +1,12 @@
import asyncio
import logging
from typing import Literal, Optional, cast
from typing import Literal, Optional
import fastapi
import prisma.errors
import prisma.fields
import prisma.models
import prisma.types
from prisma.types import LibraryAgentCreateInput
import backend.data.graph as graph_db
import backend.data.integrations as integrations_db
@@ -803,21 +802,18 @@ async def add_store_agent_to_library(
# Create LibraryAgent entry
added_agent = await prisma.models.LibraryAgent.prisma().create(
data=cast(
LibraryAgentCreateInput,
{
"User": {"connect": {"id": user_id}},
"AgentGraph": {
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
"isCreatedByUser": False,
"settings": SafeJson(
_initialize_graph_settings(graph_model).model_dump()
),
data={
"User": {"connect": {"id": user_id}},
"AgentGraph": {
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
),
"isCreatedByUser": False,
"settings": SafeJson(
_initialize_graph_settings(graph_model).model_dump()
),
},
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
),

View File

@@ -1,15 +1,13 @@
import logging
from typing import Literal, Optional
from typing import Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
from fastapi.responses import Response
from prisma.enums import OnboardingStep
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
from backend.data.onboarding import complete_onboarding_step
from backend.util.exceptions import DatabaseError, NotFoundError
logger = logging.getLogger(__name__)
@@ -202,9 +200,6 @@ async def get_library_agent_by_store_listing_version_id(
)
async def add_marketplace_agent_to_library(
store_listing_version_id: str = Body(embed=True),
source: Literal["onboarding", "marketplace"] = Body(
default="marketplace", embed=True
),
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryAgent:
"""
@@ -222,15 +217,10 @@ async def add_marketplace_agent_to_library(
HTTPException(500): If a server/database error occurs.
"""
try:
agent = await library_db.add_store_agent_to_library(
return await library_db.add_store_agent_to_library(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)
if source != "onboarding":
await complete_onboarding_step(
user_id, OnboardingStep.MARKETPLACE_ADD_AGENT
)
return agent
except store_exceptions.AgentNotFoundError as e:
logger.warning(

View File

@@ -10,7 +10,6 @@ from backend.data.execution import GraphExecutionMeta
from backend.data.graph import get_graph
from backend.data.integrations import get_webhook
from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_runs
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks import get_webhook_manager
@@ -402,8 +401,6 @@ async def execute_preset(
merged_node_input = preset.inputs | inputs
merged_credential_inputs = preset.credentials | credential_inputs
await increment_runs(user_id)
return await add_graph_execution(
user_id=user_id,
graph_id=preset.graph_id,

Some files were not shown because too many files have changed in this diff Show More