Compare commits

..

2 Commits

Author SHA1 Message Date
Nicholas Tindle
516ae38329 Merge branch 'dev' into gmail-reply-draft 2025-08-26 00:25:22 -05:00
Nicholas Tindle
d20ac49211 feat(backend): add deduplication for email addresses and add Gmail draft reply functionality 2025-08-25 08:15:09 -05:00
713 changed files with 7658 additions and 36852 deletions

View File

@@ -1,97 +0,0 @@
name: Auto Fix CI Failures
on:
workflow_run:
workflows: ["CI"]
types:
- completed
permissions:
contents: write
pull-requests: write
actions: read
issues: write
id-token: write # Required for OIDC token exchange
jobs:
auto-fix:
if: |
github.event.workflow_run.conclusion == 'failure' &&
github.event.workflow_run.pull_requests[0] &&
!startsWith(github.event.workflow_run.head_branch, 'claude-auto-fix-ci-')
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
ref: ${{ github.event.workflow_run.head_branch }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Setup git identity
run: |
git config --global user.email "claude[bot]@users.noreply.github.com"
git config --global user.name "claude[bot]"
- name: Create fix branch
id: branch
run: |
BRANCH_NAME="claude-auto-fix-ci-${{ github.event.workflow_run.head_branch }}-${{ github.run_id }}"
git checkout -b "$BRANCH_NAME"
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
- name: Get CI failure details
id: failure_details
uses: actions/github-script@v7
with:
script: |
const run = await github.rest.actions.getWorkflowRun({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: ${{ github.event.workflow_run.id }}
});
const jobs = await github.rest.actions.listJobsForWorkflowRun({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: ${{ github.event.workflow_run.id }}
});
const failedJobs = jobs.data.jobs.filter(job => job.conclusion === 'failure');
let errorLogs = [];
for (const job of failedJobs) {
const logs = await github.rest.actions.downloadJobLogsForWorkflowRun({
owner: context.repo.owner,
repo: context.repo.repo,
job_id: job.id
});
errorLogs.push({
jobName: job.name,
logs: logs.data
});
}
return {
runUrl: run.data.html_url,
failedJobs: failedJobs.map(j => j.name),
errorLogs: errorLogs
};
- name: Fix CI failures with Claude
id: claude
uses: anthropics/claude-code-action@v1
with:
prompt: |
/fix-ci
Failed CI Run: ${{ fromJSON(steps.failure_details.outputs.result).runUrl }}
Failed Jobs: ${{ join(fromJSON(steps.failure_details.outputs.result).failedJobs, ', ') }}
PR Number: ${{ github.event.workflow_run.pull_requests[0].number }}
Branch Name: ${{ steps.branch.outputs.branch_name }}
Base Branch: ${{ github.event.workflow_run.head_branch }}
Repository: ${{ github.repository }}
Error logs:
${{ toJSON(fromJSON(steps.failure_details.outputs.result).errorLogs) }}
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
claude_args: "--allowedTools 'Edit,MultiEdit,Write,Read,Glob,Grep,LS,Bash(git:*),Bash(bun:*),Bash(npm:*),Bash(npx:*),Bash(gh:*)'"

View File

@@ -1,379 +0,0 @@
# Claude Dependabot PR Review Workflow
#
# This workflow automatically runs Claude analysis on Dependabot PRs to:
# - Identify dependency changes and their versions
# - Look up changelogs for updated packages
# - Assess breaking changes and security impacts
# - Provide actionable recommendations for the development team
#
# Triggered on: Dependabot PRs (opened, synchronize)
# Requirements: ANTHROPIC_API_KEY secret must be configured
name: Claude Dependabot PR Review
on:
pull_request:
types: [opened, synchronize]
jobs:
dependabot-review:
# Only run on Dependabot PRs
if: github.actor == 'dependabot[bot]'
runs-on: ubuntu-latest
timeout-minutes: 30
permissions:
contents: write
pull-requests: read
issues: read
id-token: write
actions: read # Required for CI access
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 1
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Install Poetry
run: |
# Extract Poetry version from backend/poetry.lock (matches CI)
cd autogpt_platform/backend
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
# Install Poetry
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
# Add Poetry to PATH
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Check poetry.lock
working-directory: autogpt_platform/backend
run: |
poetry lock
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
echo "Warning: poetry.lock not up to date, but continuing for setup"
git checkout poetry.lock # Reset for clean setup
fi
- name: Install Python dependencies
working-directory: autogpt_platform/backend
run: poetry install
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Set pnpm store directory
run: |
pnpm config set store-dir ~/.pnpm-store
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend
run: pnpm install --frozen-lockfile
# Install Playwright browsers for frontend testing
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
# - name: Install Playwright browsers
# working-directory: autogpt_platform/frontend
# run: pnpm playwright install --with-deps chromium
# Docker setup for development environment
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Copy default environment files
working-directory: autogpt_platform
run: |
# Copy default environment files for development
cp .env.default .env
cp backend/.env.default backend/.env
cp frontend/.env.default frontend/.env
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
restore-keys: |
docker-images-v2-${{ runner.os }}-
docker-images-v1-${{ runner.os }}-
- name: Load or pull Docker images
working-directory: autogpt_platform
run: |
mkdir -p ~/docker-cache
# Define image list for easy maintenance
IMAGES=(
"redis:latest"
"rabbitmq:management"
"clamav/clamav-debian:latest"
"busybox:latest"
"kong:2.8.1"
"supabase/gotrue:v2.170.0"
"supabase/postgres:15.8.1.049"
"supabase/postgres-meta:v0.86.1"
"supabase/studio:20250224-d10db0f"
)
# Check if any cached tar files exist (more reliable than cache-hit)
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
echo "Docker cache found, loading images in parallel..."
for image in "${IMAGES[@]}"; do
# Convert image name to filename (replace : and / with -)
filename=$(echo "$image" | tr ':/' '--')
if [ -f ~/docker-cache/${filename}.tar ]; then
echo "Loading $image..."
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
fi
done
wait
echo "All cached images loaded"
else
echo "No Docker cache found, pulling images in parallel..."
# Pull all images in parallel
for image in "${IMAGES[@]}"; do
docker pull "$image" &
done
wait
# Only save cache on main branches (not PRs) to avoid cache pollution
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
echo "Saving Docker images to cache in parallel..."
for image in "${IMAGES[@]}"; do
# Convert image name to filename (replace : and / with -)
filename=$(echo "$image" | tr ':/' '--')
echo "Saving $image..."
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
done
wait
echo "Docker image cache saved"
else
echo "Skipping cache save for PR/feature branch"
fi
fi
echo "Docker images ready for use"
# Phase 2: Build migrate service with GitHub Actions cache
- name: Build migrate Docker image with cache
working-directory: autogpt_platform
run: |
# Build the migrate image with buildx for GHA caching
docker buildx build \
--cache-from type=gha \
--cache-to type=gha,mode=max \
--target migrate \
--tag autogpt_platform-migrate:latest \
--load \
-f backend/Dockerfile \
..
# Start services using pre-built images
- name: Start Docker services for development
working-directory: autogpt_platform
run: |
# Start essential services (migrate image already built with correct tag)
docker compose --profile local up deps --no-build --detach
echo "Waiting for services to be ready..."
# Wait for database to be ready
echo "Checking database readiness..."
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
echo " Waiting for database..."
sleep 2
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
# Check migrate service status
echo "Checking migration status..."
docker compose ps migrate || echo " Migrate service not visible in ps output"
# Wait for migrate service to complete
echo "Waiting for migrations to complete..."
timeout 30 bash -c '
ATTEMPTS=0
while [ $ATTEMPTS -lt 15 ]; do
ATTEMPTS=$((ATTEMPTS + 1))
# Check using docker directly (more reliable than docker compose ps)
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
if [ -z "$CONTAINER_STATUS" ]; then
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
echo "✅ Migrations completed successfully"
docker compose logs migrate --tail=5 2>/dev/null || true
exit 0
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
echo "❌ Migrations failed with exit code: $EXIT_CODE"
echo "Migration logs:"
docker compose logs migrate --tail=20 2>/dev/null || true
exit 1
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
else
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
fi
sleep 2
done
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
echo "Final container check:"
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
echo "Migration logs (if available):"
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
' || echo "⚠️ Migration check completed with warnings, continuing..."
# Brief wait for other services to stabilize
echo "Waiting 5 seconds for other services to stabilize..."
sleep 5
# Verify installations and provide environment info
- name: Verify setup and show environment info
run: |
echo "=== Python Setup ==="
python --version
poetry --version
echo "=== Node.js Setup ==="
node --version
pnpm --version
echo "=== Additional Tools ==="
docker --version
docker compose version
gh --version || true
echo "=== Services Status ==="
cd autogpt_platform
docker compose ps || true
echo "=== Backend Dependencies ==="
cd backend
poetry show | head -10 || true
echo "=== Frontend Dependencies ==="
cd ../frontend
pnpm list --depth=0 | head -10 || true
echo "=== Environment Files ==="
ls -la ../.env* || true
ls -la .env* || true
ls -la ../backend/.env* || true
echo "✅ AutoGPT Platform development environment setup complete!"
echo "🚀 Ready for development with Docker services running"
echo "📝 Backend server: poetry run serve (port 8000)"
echo "🌐 Frontend server: pnpm dev (port 3000)"
- name: Run Claude Dependabot Analysis
id: claude_review
uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
claude_args: |
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
prompt: |
You are Claude, an AI assistant specialized in reviewing Dependabot dependency update PRs.
Your primary tasks are:
1. **Analyze the dependency changes** in this Dependabot PR
2. **Look up changelogs** for all updated dependencies to understand what changed
3. **Identify breaking changes** and assess potential impact on the AutoGPT codebase
4. **Provide actionable recommendations** for the development team
## Analysis Process:
1. **Identify Changed Dependencies**:
- Use git diff to see what dependencies were updated
- Parse package.json, poetry.lock, requirements files, etc.
- List all package versions: old → new
2. **Changelog Research**:
- For each updated dependency, look up its changelog/release notes
- Use WebFetch to access GitHub releases, NPM package pages, PyPI project pages. The pr should also have some details
- Focus on versions between the old and new versions
- Identify: breaking changes, deprecations, security fixes, new features
3. **Breaking Change Assessment**:
- Categorize changes: BREAKING, MAJOR, MINOR, PATCH, SECURITY
- Assess impact on AutoGPT's usage patterns
- Check if AutoGPT uses affected APIs/features
- Look for migration guides or upgrade instructions
4. **Codebase Impact Analysis**:
- Search the AutoGPT codebase for usage of changed APIs
- Identify files that might be affected by breaking changes
- Check test files for deprecated usage patterns
- Look for configuration changes needed
## Output Format:
Provide a comprehensive review comment with:
### 🔍 Dependency Analysis Summary
- List of updated packages with version changes
- Overall risk assessment (LOW/MEDIUM/HIGH)
### 📋 Detailed Changelog Review
For each updated dependency:
- **Package**: name (old_version → new_version)
- **Changes**: Summary of key changes
- **Breaking Changes**: List any breaking changes
- **Security Fixes**: Note security improvements
- **Migration Notes**: Any upgrade steps needed
### ⚠️ Impact Assessment
- **Breaking Changes Found**: Yes/No with details
- **Affected Files**: List AutoGPT files that may need updates
- **Test Impact**: Any tests that may need updating
- **Configuration Changes**: Required config updates
### 🛠️ Recommendations
- **Action Required**: What the team should do
- **Testing Focus**: Areas to test thoroughly
- **Follow-up Tasks**: Any additional work needed
- **Merge Recommendation**: APPROVE/REVIEW_NEEDED/HOLD
### 📚 Useful Links
- Links to relevant changelogs, migration guides, documentation
Be thorough but concise. Focus on actionable insights that help the development team make informed decisions about the dependency updates.

View File

@@ -30,296 +30,18 @@ jobs:
github.event.issue.author_association == 'COLLABORATOR'
)
runs-on: ubuntu-latest
timeout-minutes: 45
permissions:
contents: write
contents: read
pull-requests: read
issues: read
id-token: write
actions: read # Required for CI access
steps:
- name: Checkout code
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Install Poetry
run: |
# Extract Poetry version from backend/poetry.lock (matches CI)
cd autogpt_platform/backend
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
# Install Poetry
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
# Add Poetry to PATH
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Check poetry.lock
working-directory: autogpt_platform/backend
run: |
poetry lock
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
echo "Warning: poetry.lock not up to date, but continuing for setup"
git checkout poetry.lock # Reset for clean setup
fi
- name: Install Python dependencies
working-directory: autogpt_platform/backend
run: poetry install
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Set pnpm store directory
run: |
pnpm config set store-dir ~/.pnpm-store
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend
run: pnpm install --frozen-lockfile
# Install Playwright browsers for frontend testing
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
# - name: Install Playwright browsers
# working-directory: autogpt_platform/frontend
# run: pnpm playwright install --with-deps chromium
# Docker setup for development environment
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Copy default environment files
working-directory: autogpt_platform
run: |
# Copy default environment files for development
cp .env.default .env
cp backend/.env.default backend/.env
cp frontend/.env.default frontend/.env
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
restore-keys: |
docker-images-v2-${{ runner.os }}-
docker-images-v1-${{ runner.os }}-
- name: Load or pull Docker images
working-directory: autogpt_platform
run: |
mkdir -p ~/docker-cache
# Define image list for easy maintenance
IMAGES=(
"redis:latest"
"rabbitmq:management"
"clamav/clamav-debian:latest"
"busybox:latest"
"kong:2.8.1"
"supabase/gotrue:v2.170.0"
"supabase/postgres:15.8.1.049"
"supabase/postgres-meta:v0.86.1"
"supabase/studio:20250224-d10db0f"
)
# Check if any cached tar files exist (more reliable than cache-hit)
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
echo "Docker cache found, loading images in parallel..."
for image in "${IMAGES[@]}"; do
# Convert image name to filename (replace : and / with -)
filename=$(echo "$image" | tr ':/' '--')
if [ -f ~/docker-cache/${filename}.tar ]; then
echo "Loading $image..."
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
fi
done
wait
echo "All cached images loaded"
else
echo "No Docker cache found, pulling images in parallel..."
# Pull all images in parallel
for image in "${IMAGES[@]}"; do
docker pull "$image" &
done
wait
# Only save cache on main branches (not PRs) to avoid cache pollution
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
echo "Saving Docker images to cache in parallel..."
for image in "${IMAGES[@]}"; do
# Convert image name to filename (replace : and / with -)
filename=$(echo "$image" | tr ':/' '--')
echo "Saving $image..."
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
done
wait
echo "Docker image cache saved"
else
echo "Skipping cache save for PR/feature branch"
fi
fi
echo "Docker images ready for use"
# Phase 2: Build migrate service with GitHub Actions cache
- name: Build migrate Docker image with cache
working-directory: autogpt_platform
run: |
# Build the migrate image with buildx for GHA caching
docker buildx build \
--cache-from type=gha \
--cache-to type=gha,mode=max \
--target migrate \
--tag autogpt_platform-migrate:latest \
--load \
-f backend/Dockerfile \
..
# Start services using pre-built images
- name: Start Docker services for development
working-directory: autogpt_platform
run: |
# Start essential services (migrate image already built with correct tag)
docker compose --profile local up deps --no-build --detach
echo "Waiting for services to be ready..."
# Wait for database to be ready
echo "Checking database readiness..."
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
echo " Waiting for database..."
sleep 2
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
# Check migrate service status
echo "Checking migration status..."
docker compose ps migrate || echo " Migrate service not visible in ps output"
# Wait for migrate service to complete
echo "Waiting for migrations to complete..."
timeout 30 bash -c '
ATTEMPTS=0
while [ $ATTEMPTS -lt 15 ]; do
ATTEMPTS=$((ATTEMPTS + 1))
# Check using docker directly (more reliable than docker compose ps)
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
if [ -z "$CONTAINER_STATUS" ]; then
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
echo "✅ Migrations completed successfully"
docker compose logs migrate --tail=5 2>/dev/null || true
exit 0
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
echo "❌ Migrations failed with exit code: $EXIT_CODE"
echo "Migration logs:"
docker compose logs migrate --tail=20 2>/dev/null || true
exit 1
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
else
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
fi
sleep 2
done
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
echo "Final container check:"
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
echo "Migration logs (if available):"
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
' || echo "⚠️ Migration check completed with warnings, continuing..."
# Brief wait for other services to stabilize
echo "Waiting 5 seconds for other services to stabilize..."
sleep 5
# Verify installations and provide environment info
- name: Verify setup and show environment info
run: |
echo "=== Python Setup ==="
python --version
poetry --version
echo "=== Node.js Setup ==="
node --version
pnpm --version
echo "=== Additional Tools ==="
docker --version
docker compose version
gh --version || true
echo "=== Services Status ==="
cd autogpt_platform
docker compose ps || true
echo "=== Backend Dependencies ==="
cd backend
poetry show | head -10 || true
echo "=== Frontend Dependencies ==="
cd ../frontend
pnpm list --depth=0 | head -10 || true
echo "=== Environment Files ==="
ls -la ../.env* || true
ls -la .env* || true
ls -la ../backend/.env* || true
echo "✅ AutoGPT Platform development environment setup complete!"
echo "🚀 Ready for development with Docker services running"
echo "📝 Backend server: poetry run serve (port 8000)"
echo "🌐 Frontend server: pnpm dev (port 3000)"
- name: Run Claude Code
id: claude
uses: anthropics/claude-code-action@v1
uses: anthropics/claude-code-action@beta
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
claude_args: |
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*), Bash(gh pr edit:*)"
--model opus
additional_permissions: |
actions: read

View File

@@ -3,7 +3,6 @@ name: AutoGPT Platform - Deploy Prod Environment
on:
release:
types: [published]
workflow_dispatch:
permissions:
contents: 'read'
@@ -18,8 +17,6 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name || 'master' }}
- name: Set up Python
uses: actions/setup-python@v5
@@ -39,7 +36,7 @@ jobs:
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
trigger:
needs: migrate
runs-on: ubuntu-latest
@@ -50,5 +47,4 @@ jobs:
token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
event-type: build_deploy_prod
client-payload: |
{"ref": "${{ github.ref_name || 'master' }}", "repository": "${{ github.repository }}"}
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'

View File

@@ -5,13 +5,6 @@ on:
branches: [ dev ]
paths:
- 'autogpt_platform/**'
workflow_dispatch:
inputs:
git_ref:
description: 'Git ref (branch/tag) of AutoGPT to deploy'
required: true
default: 'master'
type: string
permissions:
contents: 'read'
@@ -26,8 +19,6 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
- name: Set up Python
uses: actions/setup-python@v5
@@ -57,4 +48,4 @@ jobs:
token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
event-type: build_deploy_dev
client-payload: '{"ref": "${{ github.event.inputs.git_ref || github.ref }}", "repository": "${{ github.repository }}"}'
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'

View File

@@ -37,7 +37,9 @@ jobs:
services:
redis:
image: redis:latest
image: bitnami/redis:6.2
env:
REDIS_PASSWORD: testpassword
ports:
- 6379:6379
rabbitmq:
@@ -199,9 +201,10 @@ jobs:
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
REDIS_PASSWORD: "testpassword"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
env:

View File

@@ -160,7 +160,7 @@ jobs:
- name: Run docker compose
run: |
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
docker compose -f ../docker-compose.yml up -d
env:
DOCKER_BUILDKIT: 1
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache

View File

@@ -61,27 +61,24 @@ poetry run pytest path/to/test.py --snapshot-update
```bash
# Install dependencies
cd frontend && pnpm i
cd frontend && npm install
# Start development server
pnpm dev
npm run dev
# Run E2E tests
pnpm test
npm run test
# Run Storybook for component development
pnpm storybook
npm run storybook
# Build production
pnpm build
npm run build
# Type checking
pnpm types
npm run types
```
We have a components library in autogpt_platform/frontend/src/components/atoms that should be used when adding new pages and components.
## Architecture Overview
### Backend Architecture
@@ -152,21 +149,12 @@ Key models (defined in `/backend/schema.prisma`):
**Adding a new block:**
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `/backend/backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
2. Inherit from `Block` base class
3. Define input/output schemas
4. Implement `run` method
5. Register in block registry
6. Generate the block uuid using `uuid.uuid4()`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?

View File

@@ -0,0 +1,35 @@
import hashlib
import secrets
from typing import NamedTuple
class APIKeyContainer(NamedTuple):
"""Container for API key parts."""
raw: str
prefix: str
postfix: str
hash: str
class APIKeyManager:
PREFIX: str = "agpt_"
PREFIX_LENGTH: int = 8
POSTFIX_LENGTH: int = 8
def generate_api_key(self) -> APIKeyContainer:
"""Generate a new API key with all its parts."""
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
return APIKeyContainer(
raw=raw_key,
prefix=raw_key[: self.PREFIX_LENGTH],
postfix=raw_key[-self.POSTFIX_LENGTH :],
hash=hashlib.sha256(raw_key.encode()).hexdigest(),
)
def verify_api_key(self, provided_key: str, stored_hash: str) -> bool:
"""Verify if a provided API key matches the stored hash."""
if not provided_key.startswith(self.PREFIX):
return False
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
return secrets.compare_digest(provided_hash, stored_hash)

View File

@@ -1,78 +0,0 @@
import hashlib
import secrets
from typing import NamedTuple
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
class APIKeyContainer(NamedTuple):
"""Container for API key parts."""
key: str
head: str
tail: str
hash: str
salt: str
class APIKeySmith:
PREFIX: str = "agpt_"
HEAD_LENGTH: int = 8
TAIL_LENGTH: int = 8
def generate_key(self) -> APIKeyContainer:
"""Generate a new API key with secure hashing."""
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
hash, salt = self.hash_key(raw_key)
return APIKeyContainer(
key=raw_key,
head=raw_key[: self.HEAD_LENGTH],
tail=raw_key[-self.TAIL_LENGTH :],
hash=hash,
salt=salt,
)
def verify_key(
self, provided_key: str, known_hash: str, known_salt: str | None = None
) -> bool:
"""
Verify an API key against a known hash (+ salt).
Supports verifying both legacy SHA256 and secure Scrypt hashes.
"""
if not provided_key.startswith(self.PREFIX):
return False
# Handle legacy SHA256 hashes (migration support)
if known_salt is None:
legacy_hash = hashlib.sha256(provided_key.encode()).hexdigest()
return secrets.compare_digest(legacy_hash, known_hash)
try:
salt_bytes = bytes.fromhex(known_salt)
provided_hash = self._hash_key_with_salt(provided_key, salt_bytes)
return secrets.compare_digest(provided_hash, known_hash)
except (ValueError, TypeError):
return False
def hash_key(self, raw_key: str) -> tuple[str, str]:
"""Migrate a legacy hash to secure hash format."""
salt = self._generate_salt()
hash = self._hash_key_with_salt(raw_key, salt)
return hash, salt.hex()
def _generate_salt(self) -> bytes:
"""Generate a random salt for hashing."""
return secrets.token_bytes(32)
def _hash_key_with_salt(self, raw_key: str, salt: bytes) -> str:
"""Hash API key using Scrypt with salt."""
kdf = Scrypt(
length=32,
salt=salt,
n=2**14, # CPU/memory cost parameter
r=8, # Block size parameter
p=1, # Parallelization parameter
)
key_hash = kdf.derive(raw_key.encode())
return key_hash.hex()

View File

@@ -1,79 +0,0 @@
import hashlib
from autogpt_libs.api_key.keysmith import APIKeySmith
def test_generate_api_key():
keysmith = APIKeySmith()
key = keysmith.generate_key()
assert key.key.startswith(keysmith.PREFIX)
assert key.head == key.key[: keysmith.HEAD_LENGTH]
assert key.tail == key.key[-keysmith.TAIL_LENGTH :]
assert len(key.hash) == 64 # 32 bytes hex encoded
assert len(key.salt) == 64 # 32 bytes hex encoded
def test_verify_new_secure_key():
keysmith = APIKeySmith()
key = keysmith.generate_key()
# Test correct key validates
assert keysmith.verify_key(key.key, key.hash, key.salt) is True
# Test wrong key fails
wrong_key = f"{keysmith.PREFIX}wrongkey123"
assert keysmith.verify_key(wrong_key, key.hash, key.salt) is False
def test_verify_legacy_key():
keysmith = APIKeySmith()
legacy_key = f"{keysmith.PREFIX}legacykey123"
legacy_hash = hashlib.sha256(legacy_key.encode()).hexdigest()
# Test legacy key validates without salt
assert keysmith.verify_key(legacy_key, legacy_hash) is True
# Test wrong legacy key fails
wrong_key = f"{keysmith.PREFIX}wronglegacy"
assert keysmith.verify_key(wrong_key, legacy_hash) is False
def test_rehash_existing_key():
keysmith = APIKeySmith()
legacy_key = f"{keysmith.PREFIX}migratekey123"
# Migrate the legacy key
new_hash, new_salt = keysmith.hash_key(legacy_key)
# Verify migrated key works
assert keysmith.verify_key(legacy_key, new_hash, new_salt) is True
# Verify different key fails with migrated hash
wrong_key = f"{keysmith.PREFIX}wrongkey"
assert keysmith.verify_key(wrong_key, new_hash, new_salt) is False
def test_invalid_key_prefix():
keysmith = APIKeySmith()
key = keysmith.generate_key()
# Test key without proper prefix fails
invalid_key = "invalid_prefix_key"
assert keysmith.verify_key(invalid_key, key.hash, key.salt) is False
def test_secure_hash_requires_salt():
keysmith = APIKeySmith()
key = keysmith.generate_key()
# Secure hash without salt should fail
assert keysmith.verify_key(key.key, key.hash) is False
def test_invalid_salt_format():
keysmith = APIKeySmith()
key = keysmith.generate_key()
# Invalid salt format should fail gracefully
assert keysmith.verify_key(key.key, key.hash, "invalid_hex") is False

View File

@@ -1,13 +1,13 @@
from .config import verify_settings
from .dependencies import get_user_id, requires_admin_user, requires_user
from .helpers import add_auth_responses_to_openapi
from .depends import requires_admin_user, requires_user
from .jwt_utils import parse_jwt_token
from .middleware import APIKeyValidator, auth_middleware
from .models import User
__all__ = [
"verify_settings",
"get_user_id",
"requires_admin_user",
"parse_jwt_token",
"requires_user",
"add_auth_responses_to_openapi",
"requires_admin_user",
"APIKeyValidator",
"auth_middleware",
"User",
]

View File

@@ -1,90 +1,11 @@
import logging
import os
from jwt.algorithms import get_default_algorithms, has_crypto
logger = logging.getLogger(__name__)
class AuthConfigError(ValueError):
"""Raised when authentication configuration is invalid."""
pass
ALGO_RECOMMENDATION = (
"We highly recommend using an asymmetric algorithm such as ES256, "
"because when leaked, a shared secret would allow anyone to "
"forge valid tokens and impersonate users. "
"More info: https://supabase.com/docs/guides/auth/signing-keys#choosing-the-right-signing-algorithm" # noqa
)
class Settings:
def __init__(self):
self.JWT_VERIFY_KEY: str = os.getenv(
"JWT_VERIFY_KEY", os.getenv("SUPABASE_JWT_SECRET", "")
).strip()
self.JWT_ALGORITHM: str = os.getenv("JWT_SIGN_ALGORITHM", "HS256").strip()
self.validate()
def validate(self):
if not self.JWT_VERIFY_KEY:
raise AuthConfigError(
"JWT_VERIFY_KEY must be set. "
"An empty JWT secret would allow anyone to forge valid tokens."
)
if len(self.JWT_VERIFY_KEY) < 32:
logger.warning(
"⚠️ JWT_VERIFY_KEY appears weak (less than 32 characters). "
"Consider using a longer, cryptographically secure secret."
)
supported_algorithms = get_default_algorithms().keys()
if not has_crypto:
logger.warning(
"⚠️ Asymmetric JWT verification is not available "
"because the 'cryptography' package is not installed. "
+ ALGO_RECOMMENDATION
)
if (
self.JWT_ALGORITHM not in supported_algorithms
or self.JWT_ALGORITHM == "none"
):
raise AuthConfigError(
f"Invalid JWT_SIGN_ALGORITHM: '{self.JWT_ALGORITHM}'. "
"Supported algorithms are listed on "
"https://pyjwt.readthedocs.io/en/stable/algorithms.html"
)
if self.JWT_ALGORITHM.startswith("HS"):
logger.warning(
f"⚠️ JWT_SIGN_ALGORITHM is set to '{self.JWT_ALGORITHM}', "
"a symmetric shared-key signature algorithm. " + ALGO_RECOMMENDATION
)
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
self.JWT_ALGORITHM: str = "HS256"
_settings: Settings = None # type: ignore
def get_settings() -> Settings:
global _settings
if not _settings:
_settings = Settings()
return _settings
def verify_settings() -> None:
global _settings
if not _settings:
_settings = Settings() # calls validation indirectly
return
_settings.validate()
settings = Settings()

View File

@@ -1,306 +0,0 @@
"""
Comprehensive tests for auth configuration to ensure 100% line and branch coverage.
These tests verify critical security checks preventing JWT token forgery.
"""
import logging
import os
import pytest
from pytest_mock import MockerFixture
from autogpt_libs.auth.config import AuthConfigError, Settings
def test_environment_variable_precedence(mocker: MockerFixture):
"""Test that environment variables take precedence over defaults."""
secret = "environment-secret-key-with-proper-length-123456"
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == secret
def test_environment_variable_backwards_compatible(mocker: MockerFixture):
"""Test that SUPABASE_JWT_SECRET is read if JWT_VERIFY_KEY is not set."""
secret = "environment-secret-key-with-proper-length-123456"
mocker.patch.dict(os.environ, {"SUPABASE_JWT_SECRET": secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == secret
def test_auth_config_error_inheritance():
"""Test that AuthConfigError is properly defined as an Exception."""
assert issubclass(AuthConfigError, Exception)
error = AuthConfigError("test message")
assert str(error) == "test message"
def test_settings_static_after_creation(mocker: MockerFixture):
"""Test that settings maintain their values after creation."""
secret = "immutable-secret-key-with-proper-length-12345"
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
settings = Settings()
original_secret = settings.JWT_VERIFY_KEY
# Changing environment after creation shouldn't affect settings
os.environ["JWT_VERIFY_KEY"] = "different-secret"
assert settings.JWT_VERIFY_KEY == original_secret
def test_settings_load_with_valid_secret(mocker: MockerFixture):
"""Test auth enabled with a valid JWT secret."""
valid_secret = "a" * 32 # 32 character secret
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": valid_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == valid_secret
def test_settings_load_with_strong_secret(mocker: MockerFixture):
"""Test auth enabled with a cryptographically strong secret."""
strong_secret = "super-secret-jwt-token-with-at-least-32-characters-long"
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": strong_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == strong_secret
assert len(settings.JWT_VERIFY_KEY) >= 32
def test_secret_empty_raises_error(mocker: MockerFixture):
"""Test that auth enabled with empty secret raises AuthConfigError."""
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": ""}, clear=True)
with pytest.raises(Exception) as exc_info:
Settings()
assert "JWT_VERIFY_KEY" in str(exc_info.value)
def test_secret_missing_raises_error(mocker: MockerFixture):
"""Test that auth enabled without secret env var raises AuthConfigError."""
mocker.patch.dict(os.environ, {}, clear=True)
with pytest.raises(Exception) as exc_info:
Settings()
assert "JWT_VERIFY_KEY" in str(exc_info.value)
@pytest.mark.parametrize("secret", [" ", " ", "\t", "\n", " \t\n "])
def test_secret_only_whitespace_raises_error(mocker: MockerFixture, secret: str):
"""Test that auth enabled with whitespace-only secret raises error."""
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
with pytest.raises(ValueError):
Settings()
def test_secret_weak_logs_warning(
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
):
"""Test that weak JWT secret triggers warning log."""
weak_secret = "short" # Less than 32 characters
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": weak_secret}, clear=True)
with caplog.at_level(logging.WARNING):
settings = Settings()
assert settings.JWT_VERIFY_KEY == weak_secret
assert "key appears weak" in caplog.text.lower()
assert "less than 32 characters" in caplog.text
def test_secret_31_char_logs_warning(
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
):
"""Test that 31-character secret triggers warning (boundary test)."""
secret_31 = "a" * 31 # Exactly 31 characters
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret_31}, clear=True)
with caplog.at_level(logging.WARNING):
settings = Settings()
assert len(settings.JWT_VERIFY_KEY) == 31
assert "key appears weak" in caplog.text.lower()
def test_secret_32_char_no_warning(
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
):
"""Test that 32-character secret does not trigger warning (boundary test)."""
secret_32 = "a" * 32 # Exactly 32 characters
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret_32}, clear=True)
with caplog.at_level(logging.WARNING):
settings = Settings()
assert len(settings.JWT_VERIFY_KEY) == 32
assert "JWT secret appears weak" not in caplog.text
def test_secret_whitespace_stripped(mocker: MockerFixture):
"""Test that JWT secret whitespace is stripped."""
secret = "a" * 32
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": f" {secret} "}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == secret
def test_secret_with_special_characters(mocker: MockerFixture):
"""Test JWT secret with special characters."""
special_secret = "!@#$%^&*()_+-=[]{}|;:,.<>?`~" + "a" * 10 # 40 chars total
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": special_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == special_secret
def test_secret_with_unicode(mocker: MockerFixture):
"""Test JWT secret with unicode characters."""
unicode_secret = "秘密🔐キー" + "a" * 25 # Ensure >32 bytes
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": unicode_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == unicode_secret
def test_secret_very_long(mocker: MockerFixture):
"""Test JWT secret with excessive length."""
long_secret = "a" * 1000 # 1000 character secret
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": long_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == long_secret
assert len(settings.JWT_VERIFY_KEY) == 1000
def test_secret_with_newline(mocker: MockerFixture):
"""Test JWT secret containing newlines."""
multiline_secret = "secret\nwith\nnewlines" + "a" * 20
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": multiline_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == multiline_secret
def test_secret_base64_encoded(mocker: MockerFixture):
"""Test JWT secret that looks like base64."""
base64_secret = "dGhpc19pc19hX3NlY3JldF9rZXlfd2l0aF9wcm9wZXJfbGVuZ3Ro"
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": base64_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == base64_secret
def test_secret_numeric_only(mocker: MockerFixture):
"""Test JWT secret with only numbers."""
numeric_secret = "1234567890" * 4 # 40 character numeric secret
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": numeric_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == numeric_secret
def test_algorithm_default_hs256(mocker: MockerFixture):
"""Test that JWT algorithm defaults to HS256."""
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": "a" * 32}, clear=True)
settings = Settings()
assert settings.JWT_ALGORITHM == "HS256"
def test_algorithm_whitespace_stripped(mocker: MockerFixture):
"""Test that JWT algorithm whitespace is stripped."""
secret = "a" * 32
mocker.patch.dict(
os.environ,
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": " HS256 "},
clear=True,
)
settings = Settings()
assert settings.JWT_ALGORITHM == "HS256"
def test_no_crypto_warning(mocker: MockerFixture, caplog: pytest.LogCaptureFixture):
"""Test warning when crypto package is not available."""
secret = "a" * 32
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
# Mock has_crypto to return False
mocker.patch("autogpt_libs.auth.config.has_crypto", False)
with caplog.at_level(logging.WARNING):
Settings()
assert "Asymmetric JWT verification is not available" in caplog.text
assert "cryptography" in caplog.text
def test_algorithm_invalid_raises_error(mocker: MockerFixture):
"""Test that invalid JWT algorithm raises AuthConfigError."""
secret = "a" * 32
mocker.patch.dict(
os.environ,
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": "INVALID_ALG"},
clear=True,
)
with pytest.raises(AuthConfigError) as exc_info:
Settings()
assert "Invalid JWT_SIGN_ALGORITHM" in str(exc_info.value)
assert "INVALID_ALG" in str(exc_info.value)
def test_algorithm_none_raises_error(mocker: MockerFixture):
"""Test that 'none' algorithm raises AuthConfigError."""
secret = "a" * 32
mocker.patch.dict(
os.environ,
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": "none"},
clear=True,
)
with pytest.raises(AuthConfigError) as exc_info:
Settings()
assert "Invalid JWT_SIGN_ALGORITHM" in str(exc_info.value)
@pytest.mark.parametrize("algorithm", ["HS256", "HS384", "HS512"])
def test_algorithm_symmetric_warning(
mocker: MockerFixture, caplog: pytest.LogCaptureFixture, algorithm: str
):
"""Test warning for symmetric algorithms (HS256, HS384, HS512)."""
secret = "a" * 32
mocker.patch.dict(
os.environ,
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": algorithm},
clear=True,
)
with caplog.at_level(logging.WARNING):
settings = Settings()
assert algorithm in caplog.text
assert "symmetric shared-key signature algorithm" in caplog.text
assert settings.JWT_ALGORITHM == algorithm
@pytest.mark.parametrize(
"algorithm",
["ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512"],
)
def test_algorithm_asymmetric_no_warning(
mocker: MockerFixture, caplog: pytest.LogCaptureFixture, algorithm: str
):
"""Test that asymmetric algorithms do not trigger warning."""
secret = "a" * 32
mocker.patch.dict(
os.environ,
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": algorithm},
clear=True,
)
with caplog.at_level(logging.WARNING):
settings = Settings()
# Should not contain the symmetric algorithm warning
assert "symmetric shared-key signature algorithm" not in caplog.text
assert settings.JWT_ALGORITHM == algorithm

View File

@@ -1,45 +0,0 @@
"""
FastAPI dependency functions for JWT-based authentication and authorization.
These are the high-level dependency functions used in route definitions.
"""
import fastapi
from .jwt_utils import get_jwt_payload, verify_user
from .models import User
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
"""
FastAPI dependency that requires a valid authenticated user.
Raises:
HTTPException: 401 for authentication failures
"""
return verify_user(jwt_payload, admin_only=False)
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
"""
FastAPI dependency that requires a valid admin user.
Raises:
HTTPException: 401 for authentication failures, 403 for insufficient permissions
"""
return verify_user(jwt_payload, admin_only=True)
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
"""
FastAPI dependency that returns the ID of the authenticated user.
Raises:
HTTPException: 401 for authentication failures or missing user ID
"""
user_id = jwt_payload.get("sub")
if not user_id:
raise fastapi.HTTPException(
status_code=401, detail="User ID not found in token"
)
return user_id

View File

@@ -1,335 +0,0 @@
"""
Comprehensive integration tests for authentication dependencies.
Tests the full authentication flow from HTTP requests to user validation.
"""
import os
import pytest
from fastapi import FastAPI, HTTPException, Security
from fastapi.testclient import TestClient
from pytest_mock import MockerFixture
from autogpt_libs.auth.dependencies import (
get_user_id,
requires_admin_user,
requires_user,
)
from autogpt_libs.auth.models import User
class TestAuthDependencies:
"""Test suite for authentication dependency functions."""
@pytest.fixture
def app(self):
"""Create a test FastAPI application."""
app = FastAPI()
@app.get("/user")
def get_user_endpoint(user: User = Security(requires_user)):
return {"user_id": user.user_id, "role": user.role}
@app.get("/admin")
def get_admin_endpoint(user: User = Security(requires_admin_user)):
return {"user_id": user.user_id, "role": user.role}
@app.get("/user-id")
def get_user_id_endpoint(user_id: str = Security(get_user_id)):
return {"user_id": user_id}
return app
@pytest.fixture
def client(self, app):
"""Create a test client."""
return TestClient(app)
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
"""Test requires_user with valid JWT payload."""
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
# Mock get_jwt_payload to return our test payload
mocker.patch(
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user = requires_user(jwt_payload)
assert isinstance(user, User)
assert user.user_id == "user-123"
assert user.role == "user"
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
"""Test requires_user accepts admin users."""
jwt_payload = {
"sub": "admin-456",
"role": "admin",
"email": "admin@example.com",
}
mocker.patch(
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user = requires_user(jwt_payload)
assert user.user_id == "admin-456"
assert user.role == "admin"
def test_requires_user_missing_sub(self):
"""Test requires_user with missing user ID."""
jwt_payload = {"role": "user", "email": "user@example.com"}
with pytest.raises(HTTPException) as exc_info:
requires_user(jwt_payload)
assert exc_info.value.status_code == 401
assert "User ID not found" in exc_info.value.detail
def test_requires_user_empty_sub(self):
"""Test requires_user with empty user ID."""
jwt_payload = {"sub": "", "role": "user"}
with pytest.raises(HTTPException) as exc_info:
requires_user(jwt_payload)
assert exc_info.value.status_code == 401
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
"""Test requires_admin_user with admin role."""
jwt_payload = {
"sub": "admin-789",
"role": "admin",
"email": "admin@example.com",
}
mocker.patch(
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user = requires_admin_user(jwt_payload)
assert user.user_id == "admin-789"
assert user.role == "admin"
def test_requires_admin_user_with_regular_user(self):
"""Test requires_admin_user rejects regular users."""
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
with pytest.raises(HTTPException) as exc_info:
requires_admin_user(jwt_payload)
assert exc_info.value.status_code == 403
assert "Admin access required" in exc_info.value.detail
def test_requires_admin_user_missing_role(self):
"""Test requires_admin_user with missing role."""
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
with pytest.raises(KeyError):
requires_admin_user(jwt_payload)
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
"""Test get_user_id extracts user ID correctly."""
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
mocker.patch(
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user_id = get_user_id(jwt_payload)
assert user_id == "user-id-xyz"
def test_get_user_id_missing_sub(self):
"""Test get_user_id with missing user ID."""
jwt_payload = {"role": "user"}
with pytest.raises(HTTPException) as exc_info:
get_user_id(jwt_payload)
assert exc_info.value.status_code == 401
assert "User ID not found" in exc_info.value.detail
def test_get_user_id_none_sub(self):
"""Test get_user_id with None user ID."""
jwt_payload = {"sub": None, "role": "user"}
with pytest.raises(HTTPException) as exc_info:
get_user_id(jwt_payload)
assert exc_info.value.status_code == 401
class TestAuthDependenciesIntegration:
"""Integration tests for auth dependencies with FastAPI."""
acceptable_jwt_secret = "test-secret-with-proper-length-123456"
@pytest.fixture
def create_token(self, mocker: MockerFixture):
"""Helper to create JWT tokens."""
import jwt
mocker.patch.dict(
os.environ,
{"JWT_VERIFY_KEY": self.acceptable_jwt_secret},
clear=True,
)
def _create_token(payload, secret=self.acceptable_jwt_secret):
return jwt.encode(payload, secret, algorithm="HS256")
return _create_token
def test_endpoint_auth_enabled_no_token(self):
"""Test endpoints require token when auth is enabled."""
app = FastAPI()
@app.get("/test")
def test_endpoint(user: User = Security(requires_user)):
return {"user_id": user.user_id}
client = TestClient(app)
# Should fail without auth header
response = client.get("/test")
assert response.status_code == 401
def test_endpoint_with_valid_token(self, create_token):
"""Test endpoint with valid JWT token."""
app = FastAPI()
@app.get("/test")
def test_endpoint(user: User = Security(requires_user)):
return {"user_id": user.user_id, "role": user.role}
client = TestClient(app)
token = create_token(
{"sub": "test-user", "role": "user", "aud": "authenticated"},
secret=self.acceptable_jwt_secret,
)
response = client.get("/test", headers={"Authorization": f"Bearer {token}"})
assert response.status_code == 200
assert response.json()["user_id"] == "test-user"
def test_admin_endpoint_requires_admin_role(self, create_token):
"""Test admin endpoint rejects non-admin users."""
app = FastAPI()
@app.get("/admin")
def admin_endpoint(user: User = Security(requires_admin_user)):
return {"user_id": user.user_id}
client = TestClient(app)
# Regular user token
user_token = create_token(
{"sub": "regular-user", "role": "user", "aud": "authenticated"},
secret=self.acceptable_jwt_secret,
)
response = client.get(
"/admin", headers={"Authorization": f"Bearer {user_token}"}
)
assert response.status_code == 403
# Admin token
admin_token = create_token(
{"sub": "admin-user", "role": "admin", "aud": "authenticated"},
secret=self.acceptable_jwt_secret,
)
response = client.get(
"/admin", headers={"Authorization": f"Bearer {admin_token}"}
)
assert response.status_code == 200
assert response.json()["user_id"] == "admin-user"
class TestAuthDependenciesEdgeCases:
"""Edge case tests for authentication dependencies."""
def test_dependency_with_complex_payload(self):
"""Test dependencies handle complex JWT payloads."""
complex_payload = {
"sub": "user-123",
"role": "admin",
"email": "test@example.com",
"app_metadata": {"provider": "email", "providers": ["email"]},
"user_metadata": {
"full_name": "Test User",
"avatar_url": "https://example.com/avatar.jpg",
},
"aud": "authenticated",
"iat": 1234567890,
"exp": 9999999999,
}
user = requires_user(complex_payload)
assert user.user_id == "user-123"
assert user.email == "test@example.com"
admin = requires_admin_user(complex_payload)
assert admin.role == "admin"
def test_dependency_with_unicode_in_payload(self):
"""Test dependencies handle unicode in JWT payloads."""
unicode_payload = {
"sub": "user-😀-123",
"role": "user",
"email": "测试@example.com",
"name": "日本語",
}
user = requires_user(unicode_payload)
assert "😀" in user.user_id
assert user.email == "测试@example.com"
def test_dependency_with_null_values(self):
"""Test dependencies handle null values in payload."""
null_payload = {
"sub": "user-123",
"role": "user",
"email": None,
"phone": None,
"metadata": None,
}
user = requires_user(null_payload)
assert user.user_id == "user-123"
assert user.email is None
def test_concurrent_requests_isolation(self):
"""Test that concurrent requests don't interfere with each other."""
payload1 = {"sub": "user-1", "role": "user"}
payload2 = {"sub": "user-2", "role": "admin"}
# Simulate concurrent processing
user1 = requires_user(payload1)
user2 = requires_admin_user(payload2)
assert user1.user_id == "user-1"
assert user2.user_id == "user-2"
assert user1.role == "user"
assert user2.role == "admin"
@pytest.mark.parametrize(
"payload,expected_error,admin_only",
[
(None, "Authorization header is missing", False),
({}, "User ID not found", False),
({"sub": ""}, "User ID not found", False),
({"role": "user"}, "User ID not found", False),
({"sub": "user", "role": "user"}, "Admin access required", True),
],
)
def test_dependency_error_cases(
self, payload, expected_error: str, admin_only: bool
):
"""Test that errors propagate correctly through dependencies."""
# Import verify_user to test it directly since dependencies use FastAPI Security
from autogpt_libs.auth.jwt_utils import verify_user
with pytest.raises(HTTPException) as exc_info:
verify_user(payload, admin_only=admin_only)
assert expected_error in exc_info.value.detail
def test_dependency_valid_user(self):
"""Test valid user case for dependency."""
# Import verify_user to test it directly since dependencies use FastAPI Security
from autogpt_libs.auth.jwt_utils import verify_user
# Valid case
user = verify_user({"sub": "user", "role": "user"}, admin_only=False)
assert user.user_id == "user"

View File

@@ -0,0 +1,46 @@
import fastapi
from .config import settings
from .middleware import auth_middleware
from .models import DEFAULT_USER_ID, User
def requires_user(payload: dict = fastapi.Depends(auth_middleware)) -> User:
return verify_user(payload, admin_only=False)
def requires_admin_user(
payload: dict = fastapi.Depends(auth_middleware),
) -> User:
return verify_user(payload, admin_only=True)
def verify_user(payload: dict | None, admin_only: bool) -> User:
if not payload:
if settings.ENABLE_AUTH:
raise fastapi.HTTPException(
status_code=401, detail="Authorization header is missing"
)
# This handles the case when authentication is disabled
payload = {"sub": DEFAULT_USER_ID, "role": "admin"}
user_id = payload.get("sub")
if not user_id:
raise fastapi.HTTPException(
status_code=401, detail="User ID not found in token"
)
if admin_only and payload["role"] != "admin":
raise fastapi.HTTPException(status_code=403, detail="Admin access required")
return User.from_payload(payload)
def get_user_id(payload: dict = fastapi.Depends(auth_middleware)) -> str:
user_id = payload.get("sub")
if not user_id:
raise fastapi.HTTPException(
status_code=401, detail="User ID not found in token"
)
return user_id

View File

@@ -0,0 +1,68 @@
import pytest
from .depends import requires_admin_user, requires_user, verify_user
def test_verify_user_no_payload():
user = verify_user(None, admin_only=False)
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
assert user.role == "admin"
def test_verify_user_no_user_id():
with pytest.raises(Exception):
verify_user({"role": "admin"}, admin_only=False)
def test_verify_user_not_admin():
with pytest.raises(Exception):
verify_user(
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"},
admin_only=True,
)
def test_verify_user_with_admin_role():
user = verify_user(
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "admin"},
admin_only=True,
)
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
assert user.role == "admin"
def test_verify_user_with_user_role():
user = verify_user(
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"},
admin_only=False,
)
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
assert user.role == "user"
def test_requires_user():
user = requires_user(
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"}
)
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
assert user.role == "user"
def test_requires_user_no_user_id():
with pytest.raises(Exception):
requires_user({"role": "user"})
def test_requires_admin_user():
user = requires_admin_user(
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "admin"}
)
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
assert user.role == "admin"
def test_requires_admin_user_not_admin():
with pytest.raises(Exception):
requires_admin_user(
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"}
)

View File

@@ -1,68 +0,0 @@
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from .jwt_utils import bearer_jwt_auth
def add_auth_responses_to_openapi(app: FastAPI) -> None:
"""
Set up custom OpenAPI schema generation that adds 401 responses
to all authenticated endpoints.
This is needed when using HTTPBearer with auto_error=False to get proper
401 responses instead of 403, but FastAPI only automatically adds security
responses when auto_error=True.
"""
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
# Add 401 response to all endpoints that have security requirements
for path, methods in openapi_schema["paths"].items():
for method, details in methods.items():
security_schemas = [
schema
for auth_option in details.get("security", [])
for schema in auth_option.keys()
]
if bearer_jwt_auth.scheme_name not in security_schemas:
continue
if "responses" not in details:
details["responses"] = {}
details["responses"]["401"] = {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
# Ensure #/components/responses exists
if "components" not in openapi_schema:
openapi_schema["components"] = {}
if "responses" not in openapi_schema["components"]:
openapi_schema["components"]["responses"] = {}
# Define 401 response
openapi_schema["components"]["responses"]["HTTP401NotAuthenticatedError"] = {
"description": "Authentication required",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {"detail": {"type": "string"}},
}
}
},
}
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi

View File

@@ -1,435 +0,0 @@
"""
Comprehensive tests for auth helpers module to achieve 100% coverage.
Tests OpenAPI schema generation and authentication response handling.
"""
from unittest import mock
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from autogpt_libs.auth.helpers import add_auth_responses_to_openapi
from autogpt_libs.auth.jwt_utils import bearer_jwt_auth
def test_add_auth_responses_to_openapi_basic():
"""Test adding 401 responses to OpenAPI schema."""
app = FastAPI(title="Test App", version="1.0.0")
# Add some test endpoints with authentication
from fastapi import Depends
from autogpt_libs.auth.dependencies import requires_user
@app.get("/protected", dependencies=[Depends(requires_user)])
def protected_endpoint():
return {"message": "Protected"}
@app.get("/public")
def public_endpoint():
return {"message": "Public"}
# Apply the OpenAPI customization
add_auth_responses_to_openapi(app)
# Get the OpenAPI schema
schema = app.openapi()
# Verify basic schema properties
assert schema["info"]["title"] == "Test App"
assert schema["info"]["version"] == "1.0.0"
# Verify 401 response component is added
assert "components" in schema
assert "responses" in schema["components"]
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
# Verify 401 response structure
error_response = schema["components"]["responses"]["HTTP401NotAuthenticatedError"]
assert error_response["description"] == "Authentication required"
assert "application/json" in error_response["content"]
assert "schema" in error_response["content"]["application/json"]
# Verify schema properties
response_schema = error_response["content"]["application/json"]["schema"]
assert response_schema["type"] == "object"
assert "detail" in response_schema["properties"]
assert response_schema["properties"]["detail"]["type"] == "string"
def test_add_auth_responses_to_openapi_with_security():
"""Test that 401 responses are added only to secured endpoints."""
app = FastAPI()
# Mock endpoint with security
from fastapi import Security
from autogpt_libs.auth.dependencies import get_user_id
@app.get("/secured")
def secured_endpoint(user_id: str = Security(get_user_id)):
return {"user_id": user_id}
@app.post("/also-secured")
def another_secured(user_id: str = Security(get_user_id)):
return {"status": "ok"}
@app.get("/unsecured")
def unsecured_endpoint():
return {"public": True}
# Apply OpenAPI customization
add_auth_responses_to_openapi(app)
# Get schema
schema = app.openapi()
# Check that secured endpoints have 401 responses
if "/secured" in schema["paths"]:
if "get" in schema["paths"]["/secured"]:
secured_get = schema["paths"]["/secured"]["get"]
if "responses" in secured_get:
assert "401" in secured_get["responses"]
assert (
secured_get["responses"]["401"]["$ref"]
== "#/components/responses/HTTP401NotAuthenticatedError"
)
if "/also-secured" in schema["paths"]:
if "post" in schema["paths"]["/also-secured"]:
secured_post = schema["paths"]["/also-secured"]["post"]
if "responses" in secured_post:
assert "401" in secured_post["responses"]
# Check that unsecured endpoint does not have 401 response
if "/unsecured" in schema["paths"]:
if "get" in schema["paths"]["/unsecured"]:
unsecured_get = schema["paths"]["/unsecured"]["get"]
if "responses" in unsecured_get:
assert "401" not in unsecured_get.get("responses", {})
def test_add_auth_responses_to_openapi_cached_schema():
"""Test that OpenAPI schema is cached after first generation."""
app = FastAPI()
# Apply customization
add_auth_responses_to_openapi(app)
# Get schema twice
schema1 = app.openapi()
schema2 = app.openapi()
# Should return the same cached object
assert schema1 is schema2
def test_add_auth_responses_to_openapi_existing_responses():
"""Test handling endpoints that already have responses defined."""
app = FastAPI()
from fastapi import Security
from autogpt_libs.auth.jwt_utils import get_jwt_payload
@app.get(
"/with-responses",
responses={
200: {"description": "Success"},
404: {"description": "Not found"},
},
)
def endpoint_with_responses(jwt: dict = Security(get_jwt_payload)):
return {"data": "test"}
# Apply customization
add_auth_responses_to_openapi(app)
schema = app.openapi()
# Check that existing responses are preserved and 401 is added
if "/with-responses" in schema["paths"]:
if "get" in schema["paths"]["/with-responses"]:
responses = schema["paths"]["/with-responses"]["get"].get("responses", {})
# Original responses should be preserved
if "200" in responses:
assert responses["200"]["description"] == "Success"
if "404" in responses:
assert responses["404"]["description"] == "Not found"
# 401 should be added
if "401" in responses:
assert (
responses["401"]["$ref"]
== "#/components/responses/HTTP401NotAuthenticatedError"
)
def test_add_auth_responses_to_openapi_no_security_endpoints():
"""Test with app that has no secured endpoints."""
app = FastAPI()
@app.get("/public1")
def public1():
return {"message": "public1"}
@app.post("/public2")
def public2():
return {"message": "public2"}
# Apply customization
add_auth_responses_to_openapi(app)
schema = app.openapi()
# Component should still be added for consistency
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
# But no endpoints should have 401 responses
for path in schema["paths"].values():
for method in path.values():
if isinstance(method, dict) and "responses" in method:
assert "401" not in method["responses"]
def test_add_auth_responses_to_openapi_multiple_security_schemes():
"""Test endpoints with multiple security requirements."""
app = FastAPI()
from fastapi import Security
from autogpt_libs.auth.dependencies import requires_admin_user, requires_user
from autogpt_libs.auth.models import User
@app.get("/multi-auth")
def multi_auth(
user: User = Security(requires_user),
admin: User = Security(requires_admin_user),
):
return {"status": "super secure"}
# Apply customization
add_auth_responses_to_openapi(app)
schema = app.openapi()
# Should have 401 response
if "/multi-auth" in schema["paths"]:
if "get" in schema["paths"]["/multi-auth"]:
responses = schema["paths"]["/multi-auth"]["get"].get("responses", {})
if "401" in responses:
assert (
responses["401"]["$ref"]
== "#/components/responses/HTTP401NotAuthenticatedError"
)
def test_add_auth_responses_to_openapi_empty_components():
"""Test when OpenAPI schema has no components section initially."""
app = FastAPI()
# Mock get_openapi to return schema without components
original_get_openapi = get_openapi
def mock_get_openapi(*args, **kwargs):
schema = original_get_openapi(*args, **kwargs)
# Remove components if it exists
if "components" in schema:
del schema["components"]
return schema
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
# Apply customization
add_auth_responses_to_openapi(app)
schema = app.openapi()
# Components should be created
assert "components" in schema
assert "responses" in schema["components"]
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
def test_add_auth_responses_to_openapi_all_http_methods():
"""Test that all HTTP methods are handled correctly."""
app = FastAPI()
from fastapi import Security
from autogpt_libs.auth.jwt_utils import get_jwt_payload
@app.get("/resource")
def get_resource(jwt: dict = Security(get_jwt_payload)):
return {"method": "GET"}
@app.post("/resource")
def post_resource(jwt: dict = Security(get_jwt_payload)):
return {"method": "POST"}
@app.put("/resource")
def put_resource(jwt: dict = Security(get_jwt_payload)):
return {"method": "PUT"}
@app.patch("/resource")
def patch_resource(jwt: dict = Security(get_jwt_payload)):
return {"method": "PATCH"}
@app.delete("/resource")
def delete_resource(jwt: dict = Security(get_jwt_payload)):
return {"method": "DELETE"}
# Apply customization
add_auth_responses_to_openapi(app)
schema = app.openapi()
# All methods should have 401 response
if "/resource" in schema["paths"]:
for method in ["get", "post", "put", "patch", "delete"]:
if method in schema["paths"]["/resource"]:
method_spec = schema["paths"]["/resource"][method]
if "responses" in method_spec:
assert "401" in method_spec["responses"]
def test_bearer_jwt_auth_scheme_config():
"""Test that bearer_jwt_auth is configured correctly."""
assert bearer_jwt_auth.scheme_name == "HTTPBearerJWT"
assert bearer_jwt_auth.auto_error is False
def test_add_auth_responses_with_no_routes():
"""Test OpenAPI generation with app that has no routes."""
app = FastAPI(title="Empty App")
# Apply customization to empty app
add_auth_responses_to_openapi(app)
schema = app.openapi()
# Should still have basic structure
assert schema["info"]["title"] == "Empty App"
assert "components" in schema
assert "responses" in schema["components"]
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
def test_custom_openapi_function_replacement():
"""Test that the custom openapi function properly replaces the default."""
app = FastAPI()
# Store original function
original_openapi = app.openapi
# Apply customization
add_auth_responses_to_openapi(app)
# Function should be replaced
assert app.openapi != original_openapi
assert callable(app.openapi)
def test_endpoint_without_responses_section():
"""Test endpoint that has security but no responses section initially."""
app = FastAPI()
from fastapi import Security
from fastapi.openapi.utils import get_openapi as original_get_openapi
from autogpt_libs.auth.jwt_utils import get_jwt_payload
# Create endpoint
@app.get("/no-responses")
def endpoint_without_responses(jwt: dict = Security(get_jwt_payload)):
return {"data": "test"}
# Mock get_openapi to remove responses from the endpoint
def mock_get_openapi(*args, **kwargs):
schema = original_get_openapi(*args, **kwargs)
# Remove responses from our endpoint to trigger line 40
if "/no-responses" in schema.get("paths", {}):
if "get" in schema["paths"]["/no-responses"]:
# Delete responses to force the code to create it
if "responses" in schema["paths"]["/no-responses"]["get"]:
del schema["paths"]["/no-responses"]["get"]["responses"]
return schema
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
# Apply customization
add_auth_responses_to_openapi(app)
# Get schema and verify 401 was added
schema = app.openapi()
# The endpoint should now have 401 response
if "/no-responses" in schema["paths"]:
if "get" in schema["paths"]["/no-responses"]:
responses = schema["paths"]["/no-responses"]["get"].get("responses", {})
assert "401" in responses
assert (
responses["401"]["$ref"]
== "#/components/responses/HTTP401NotAuthenticatedError"
)
def test_components_with_existing_responses():
"""Test when components already has a responses section."""
app = FastAPI()
# Mock get_openapi to return schema with existing components/responses
from fastapi.openapi.utils import get_openapi as original_get_openapi
def mock_get_openapi(*args, **kwargs):
schema = original_get_openapi(*args, **kwargs)
# Add existing components/responses
if "components" not in schema:
schema["components"] = {}
schema["components"]["responses"] = {
"ExistingResponse": {"description": "An existing response"}
}
return schema
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
# Apply customization
add_auth_responses_to_openapi(app)
schema = app.openapi()
# Both responses should exist
assert "ExistingResponse" in schema["components"]["responses"]
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
# Verify our 401 response structure
error_response = schema["components"]["responses"][
"HTTP401NotAuthenticatedError"
]
assert error_response["description"] == "Authentication required"
def test_openapi_schema_persistence():
"""Test that modifications to OpenAPI schema persist correctly."""
app = FastAPI()
from fastapi import Security
from autogpt_libs.auth.jwt_utils import get_jwt_payload
@app.get("/test")
def test_endpoint(jwt: dict = Security(get_jwt_payload)):
return {"test": True}
# Apply customization
add_auth_responses_to_openapi(app)
# Get schema multiple times
schema1 = app.openapi()
# Modify the cached schema (shouldn't affect future calls)
schema1["info"]["title"] = "Modified Title"
# Clear cache and get again
app.openapi_schema = None
schema2 = app.openapi()
# Should regenerate with original title
assert schema2["info"]["title"] == app.title
assert schema2["info"]["title"] != "Modified Title"

View File

@@ -1,48 +1,11 @@
import logging
from typing import Any
from typing import Any, Dict
import jwt
from fastapi import HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from .config import get_settings
from .models import User
logger = logging.getLogger(__name__)
# Bearer token authentication scheme
bearer_jwt_auth = HTTPBearer(
bearerFormat="jwt", scheme_name="HTTPBearerJWT", auto_error=False
)
from .config import settings
def get_jwt_payload(
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
) -> dict[str, Any]:
"""
Extract and validate JWT payload from HTTP Authorization header.
This is the core authentication function that handles:
- Reading the `Authorization` header to obtain the JWT token
- Verifying the JWT token's signature
- Decoding the JWT token's payload
:param credentials: HTTP Authorization credentials from bearer token
:return: JWT payload dictionary
:raises HTTPException: 401 if authentication fails
"""
if not credentials:
raise HTTPException(status_code=401, detail="Authorization header is missing")
try:
payload = parse_jwt_token(credentials.credentials)
logger.debug("Token decoded successfully")
return payload
except ValueError as e:
raise HTTPException(status_code=401, detail=str(e))
def parse_jwt_token(token: str) -> dict[str, Any]:
def parse_jwt_token(token: str) -> Dict[str, Any]:
"""
Parse and validate a JWT token.
@@ -50,11 +13,10 @@ def parse_jwt_token(token: str) -> dict[str, Any]:
:return: The decoded payload
:raises ValueError: If the token is invalid or expired
"""
settings = get_settings()
try:
payload = jwt.decode(
token,
settings.JWT_VERIFY_KEY,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM],
audience="authenticated",
)
@@ -63,18 +25,3 @@ def parse_jwt_token(token: str) -> dict[str, Any]:
raise ValueError("Token has expired")
except jwt.InvalidTokenError as e:
raise ValueError(f"Invalid token: {str(e)}")
def verify_user(jwt_payload: dict | None, admin_only: bool) -> User:
if jwt_payload is None:
raise HTTPException(status_code=401, detail="Authorization header is missing")
user_id = jwt_payload.get("sub")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in token")
if admin_only and jwt_payload["role"] != "admin":
raise HTTPException(status_code=403, detail="Admin access required")
return User.from_payload(jwt_payload)

View File

@@ -1,308 +0,0 @@
"""
Comprehensive tests for JWT token parsing and validation.
Ensures 100% line and branch coverage for JWT security functions.
"""
import os
from datetime import datetime, timedelta, timezone
import jwt
import pytest
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from pytest_mock import MockerFixture
from autogpt_libs.auth import config, jwt_utils
from autogpt_libs.auth.config import Settings
from autogpt_libs.auth.models import User
MOCK_JWT_SECRET = "test-secret-key-with-at-least-32-characters"
TEST_USER_PAYLOAD = {
"sub": "test-user-id",
"role": "user",
"aud": "authenticated",
"email": "test@example.com",
}
TEST_ADMIN_PAYLOAD = {
"sub": "admin-user-id",
"role": "admin",
"aud": "authenticated",
"email": "admin@example.com",
}
@pytest.fixture(autouse=True)
def mock_config(mocker: MockerFixture):
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": MOCK_JWT_SECRET}, clear=True)
mocker.patch.object(config, "_settings", Settings())
yield
def create_token(payload, secret=None, algorithm="HS256"):
"""Helper to create JWT tokens."""
if secret is None:
secret = MOCK_JWT_SECRET
return jwt.encode(payload, secret, algorithm=algorithm)
def test_parse_jwt_token_valid():
"""Test parsing a valid JWT token."""
token = create_token(TEST_USER_PAYLOAD)
result = jwt_utils.parse_jwt_token(token)
assert result["sub"] == "test-user-id"
assert result["role"] == "user"
assert result["aud"] == "authenticated"
def test_parse_jwt_token_expired():
"""Test parsing an expired JWT token."""
expired_payload = {
**TEST_USER_PAYLOAD,
"exp": datetime.now(timezone.utc) - timedelta(hours=1),
}
token = create_token(expired_payload)
with pytest.raises(ValueError) as exc_info:
jwt_utils.parse_jwt_token(token)
assert "Token has expired" in str(exc_info.value)
def test_parse_jwt_token_invalid_signature():
"""Test parsing a token with invalid signature."""
# Create token with different secret
token = create_token(TEST_USER_PAYLOAD, secret="wrong-secret")
with pytest.raises(ValueError) as exc_info:
jwt_utils.parse_jwt_token(token)
assert "Invalid token" in str(exc_info.value)
def test_parse_jwt_token_malformed():
"""Test parsing a malformed token."""
malformed_tokens = [
"not.a.token",
"invalid",
"",
# Header only
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9",
# No signature
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0",
]
for token in malformed_tokens:
with pytest.raises(ValueError) as exc_info:
jwt_utils.parse_jwt_token(token)
assert "Invalid token" in str(exc_info.value)
def test_parse_jwt_token_wrong_audience():
"""Test parsing a token with wrong audience."""
wrong_aud_payload = {**TEST_USER_PAYLOAD, "aud": "wrong-audience"}
token = create_token(wrong_aud_payload)
with pytest.raises(ValueError) as exc_info:
jwt_utils.parse_jwt_token(token)
assert "Invalid token" in str(exc_info.value)
def test_parse_jwt_token_missing_audience():
"""Test parsing a token without audience claim."""
no_aud_payload = {k: v for k, v in TEST_USER_PAYLOAD.items() if k != "aud"}
token = create_token(no_aud_payload)
with pytest.raises(ValueError) as exc_info:
jwt_utils.parse_jwt_token(token)
assert "Invalid token" in str(exc_info.value)
def test_get_jwt_payload_with_valid_token():
"""Test extracting JWT payload with valid bearer token."""
token = create_token(TEST_USER_PAYLOAD)
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
result = jwt_utils.get_jwt_payload(credentials)
assert result["sub"] == "test-user-id"
assert result["role"] == "user"
def test_get_jwt_payload_no_credentials():
"""Test JWT payload when no credentials provided."""
with pytest.raises(HTTPException) as exc_info:
jwt_utils.get_jwt_payload(None)
assert exc_info.value.status_code == 401
assert "Authorization header is missing" in exc_info.value.detail
def test_get_jwt_payload_invalid_token():
"""Test JWT payload extraction with invalid token."""
credentials = HTTPAuthorizationCredentials(
scheme="Bearer", credentials="invalid.token.here"
)
with pytest.raises(HTTPException) as exc_info:
jwt_utils.get_jwt_payload(credentials)
assert exc_info.value.status_code == 401
assert "Invalid token" in exc_info.value.detail
def test_verify_user_with_valid_user():
"""Test verifying a valid user."""
user = jwt_utils.verify_user(TEST_USER_PAYLOAD, admin_only=False)
assert isinstance(user, User)
assert user.user_id == "test-user-id"
assert user.role == "user"
assert user.email == "test@example.com"
def test_verify_user_with_admin():
"""Test verifying an admin user."""
user = jwt_utils.verify_user(TEST_ADMIN_PAYLOAD, admin_only=True)
assert isinstance(user, User)
assert user.user_id == "admin-user-id"
assert user.role == "admin"
def test_verify_user_admin_only_with_regular_user():
"""Test verifying regular user when admin is required."""
with pytest.raises(HTTPException) as exc_info:
jwt_utils.verify_user(TEST_USER_PAYLOAD, admin_only=True)
assert exc_info.value.status_code == 403
assert "Admin access required" in exc_info.value.detail
def test_verify_user_no_payload():
"""Test verifying user with no payload."""
with pytest.raises(HTTPException) as exc_info:
jwt_utils.verify_user(None, admin_only=False)
assert exc_info.value.status_code == 401
assert "Authorization header is missing" in exc_info.value.detail
def test_verify_user_missing_sub():
"""Test verifying user with payload missing 'sub' field."""
invalid_payload = {"role": "user", "email": "test@example.com"}
with pytest.raises(HTTPException) as exc_info:
jwt_utils.verify_user(invalid_payload, admin_only=False)
assert exc_info.value.status_code == 401
assert "User ID not found in token" in exc_info.value.detail
def test_verify_user_empty_sub():
"""Test verifying user with empty 'sub' field."""
invalid_payload = {"sub": "", "role": "user"}
with pytest.raises(HTTPException) as exc_info:
jwt_utils.verify_user(invalid_payload, admin_only=False)
assert exc_info.value.status_code == 401
assert "User ID not found in token" in exc_info.value.detail
def test_verify_user_none_sub():
"""Test verifying user with None 'sub' field."""
invalid_payload = {"sub": None, "role": "user"}
with pytest.raises(HTTPException) as exc_info:
jwt_utils.verify_user(invalid_payload, admin_only=False)
assert exc_info.value.status_code == 401
assert "User ID not found in token" in exc_info.value.detail
def test_verify_user_missing_role_admin_check():
"""Test verifying admin when role field is missing."""
no_role_payload = {"sub": "user-id"}
with pytest.raises(KeyError):
# This will raise KeyError when checking payload["role"]
jwt_utils.verify_user(no_role_payload, admin_only=True)
# ======================== EDGE CASES ======================== #
def test_jwt_with_additional_claims():
"""Test JWT token with additional custom claims."""
extra_claims_payload = {
"sub": "user-id",
"role": "user",
"aud": "authenticated",
"custom_claim": "custom_value",
"permissions": ["read", "write"],
"metadata": {"key": "value"},
}
token = create_token(extra_claims_payload)
result = jwt_utils.parse_jwt_token(token)
assert result["sub"] == "user-id"
assert result["custom_claim"] == "custom_value"
assert result["permissions"] == ["read", "write"]
def test_jwt_with_numeric_sub():
"""Test JWT token with numeric user ID."""
payload = {
"sub": 12345, # Numeric ID
"role": "user",
"aud": "authenticated",
}
# Should convert to string internally
user = jwt_utils.verify_user(payload, admin_only=False)
assert user.user_id == 12345
def test_jwt_with_very_long_sub():
"""Test JWT token with very long user ID."""
long_id = "a" * 1000
payload = {
"sub": long_id,
"role": "user",
"aud": "authenticated",
}
user = jwt_utils.verify_user(payload, admin_only=False)
assert user.user_id == long_id
def test_jwt_with_special_characters_in_claims():
"""Test JWT token with special characters in claims."""
payload = {
"sub": "user@example.com/special-chars!@#$%",
"role": "admin",
"aud": "authenticated",
"email": "test+special@example.com",
}
user = jwt_utils.verify_user(payload, admin_only=True)
assert "special-chars!@#$%" in user.user_id
def test_jwt_with_future_iat():
"""Test JWT token with issued-at time in future."""
future_payload = {
"sub": "user-id",
"role": "user",
"aud": "authenticated",
"iat": datetime.now(timezone.utc) + timedelta(hours=1),
}
token = create_token(future_payload)
# PyJWT validates iat claim and should reject future tokens
with pytest.raises(ValueError, match="not yet valid"):
jwt_utils.parse_jwt_token(token)
def test_jwt_with_different_algorithms():
"""Test that only HS256 algorithm is accepted."""
payload = {
"sub": "user-id",
"role": "user",
"aud": "authenticated",
}
# Try different algorithms
algorithms = ["HS384", "HS512", "none"]
for algo in algorithms:
if algo == "none":
# Special case for 'none' algorithm (security vulnerability if accepted)
token = create_token(payload, "", algorithm="none")
else:
token = create_token(payload, algorithm=algo)
with pytest.raises(ValueError) as exc_info:
jwt_utils.parse_jwt_token(token)
assert "Invalid token" in str(exc_info.value)

View File

@@ -0,0 +1,139 @@
import inspect
import logging
import secrets
from typing import Any, Callable, Optional
from fastapi import HTTPException, Request, Security
from fastapi.security import APIKeyHeader, HTTPBearer
from starlette.status import HTTP_401_UNAUTHORIZED
from .config import settings
from .jwt_utils import parse_jwt_token
logger = logging.getLogger(__name__)
bearer_auth = HTTPBearer(auto_error=False)
async def auth_middleware(request: Request):
if not settings.ENABLE_AUTH:
# If authentication is disabled, allow the request to proceed
logger.warning("Auth disabled")
return {}
credentials = await bearer_auth(request)
if not credentials:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
payload = parse_jwt_token(credentials.credentials)
request.state.user = payload
logger.debug("Token decoded successfully")
except ValueError as e:
raise HTTPException(status_code=401, detail=str(e))
return payload
class APIKeyValidator:
"""
Configurable API key validator that supports custom validation functions
for FastAPI applications.
This class provides a flexible way to implement API key authentication with optional
custom validation logic. It can be used for simple token matching
or more complex validation scenarios like database lookups.
Examples:
Simple token validation:
```python
validator = APIKeyValidator(
header_name="X-API-Key",
expected_token="your-secret-token"
)
@app.get("/protected", dependencies=[Depends(validator.get_dependency())])
def protected_endpoint():
return {"message": "Access granted"}
```
Custom validation with database lookup:
```python
async def validate_with_db(api_key: str):
api_key_obj = await db.get_api_key(api_key)
return api_key_obj if api_key_obj and api_key_obj.is_active else None
validator = APIKeyValidator(
header_name="X-API-Key",
validate_fn=validate_with_db
)
```
Args:
header_name (str): The name of the header containing the API key
expected_token (Optional[str]): The expected API key value for simple token matching
validate_fn (Optional[Callable]): Custom validation function that takes an API key
string and returns a boolean or object. Can be async.
error_status (int): HTTP status code to use for validation errors
error_message (str): Error message to return when validation fails
"""
def __init__(
self,
header_name: str,
expected_token: Optional[str] = None,
validate_fn: Optional[Callable[[str], bool]] = None,
error_status: int = HTTP_401_UNAUTHORIZED,
error_message: str = "Invalid API key",
):
# Create the APIKeyHeader as a class property
self.security_scheme = APIKeyHeader(name=header_name)
self.expected_token = expected_token
self.custom_validate_fn = validate_fn
self.error_status = error_status
self.error_message = error_message
async def default_validator(self, api_key: str) -> bool:
if not self.expected_token:
raise ValueError(
"Expected Token Required to be set when uisng API Key Validator default validation"
)
return secrets.compare_digest(api_key, self.expected_token)
async def __call__(
self, request: Request, api_key: str = Security(APIKeyHeader)
) -> Any:
if api_key is None:
raise HTTPException(status_code=self.error_status, detail="Missing API key")
# Use custom validation if provided, otherwise use default equality check
validator = self.custom_validate_fn or self.default_validator
result = (
await validator(api_key)
if inspect.iscoroutinefunction(validator)
else validator(api_key)
)
if not result:
raise HTTPException(
status_code=self.error_status, detail=self.error_message
)
# Store validation result in request state if it's not just a boolean
if result is not True:
request.state.api_key = result
return result
def get_dependency(self):
"""
Returns a callable dependency that FastAPI will recognize as a security scheme
"""
async def validate_api_key(
request: Request, api_key: str = Security(self.security_scheme)
) -> Any:
return await self(request, api_key)
# This helps FastAPI recognize it as a security dependency
validate_api_key.__name__ = f"validate_{self.security_scheme.model.name}"
return validate_api_key

View File

@@ -1,5 +1,3 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -15,8 +13,8 @@ class RateLimitSettings(BaseSettings):
default="6379", description="Redis port", validation_alias="REDIS_PORT"
)
redis_password: Optional[str] = Field(
default=None,
redis_password: str = Field(
default="password",
description="Redis password",
validation_alias="REDIS_PASSWORD",
)

View File

@@ -11,7 +11,7 @@ class RateLimiter:
self,
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
redis_password: str = RATE_LIMIT_SETTINGS.redis_password,
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
):
self.redis = Redis(

View File

@@ -1,68 +1,90 @@
import asyncio
import inspect
import logging
import threading
import time
from functools import wraps
from typing import (
Any,
Awaitable,
Callable,
ParamSpec,
Protocol,
Tuple,
TypeVar,
cast,
overload,
runtime_checkable,
)
P = ParamSpec("P")
R = TypeVar("R")
R_co = TypeVar("R_co", covariant=True)
logger = logging.getLogger(__name__)
def _make_hashable_key(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[Any, ...]:
"""
Convert args and kwargs into a hashable cache key.
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
pass
Handles unhashable types like dict, list, set by converting them to
their sorted string representations.
"""
def make_hashable(obj: Any) -> Any:
"""Recursively convert an object to a hashable representation."""
if isinstance(obj, dict):
# Sort dict items to ensure consistent ordering
return (
"__dict__",
tuple(sorted((k, make_hashable(v)) for k, v in obj.items())),
)
elif isinstance(obj, (list, tuple)):
return ("__list__", tuple(make_hashable(item) for item in obj))
elif isinstance(obj, set):
return ("__set__", tuple(sorted(make_hashable(item) for item in obj)))
elif hasattr(obj, "__dict__"):
# Handle objects with __dict__ attribute
return ("__obj__", obj.__class__.__name__, make_hashable(obj.__dict__))
else:
# For basic hashable types (str, int, bool, None, etc.)
try:
hash(obj)
return obj
except TypeError:
# Fallback: convert to string representation
return ("__str__", str(obj))
@overload
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
pass
hashable_args = tuple(make_hashable(arg) for arg in args)
hashable_kwargs = tuple(sorted((k, make_hashable(v)) for k, v in kwargs.items()))
return (hashable_args, hashable_kwargs)
def thread_cached(
func: Callable[P, R] | Callable[P, Awaitable[R]],
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
thread_local = threading.local()
def _clear():
if hasattr(thread_local, "cache"):
del thread_local.cache
if inspect.iscoroutinefunction(func):
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
*args, **kwargs
)
return cache[key]
setattr(async_wrapper, "clear_cache", _clear)
return async_wrapper
else:
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
setattr(sync_wrapper, "clear_cache", _clear)
return sync_wrapper
def clear_thread_cache(func: Callable) -> None:
if clear := getattr(func, "clear_cache", None):
clear()
FuncT = TypeVar("FuncT")
R_co = TypeVar("R_co", covariant=True)
@runtime_checkable
class CachedFunction(Protocol[P, R_co]):
"""Protocol for cached functions with cache management methods."""
class AsyncCachedFunction(Protocol[P, R_co]):
"""Protocol for async functions with cache management methods."""
def cache_clear(self) -> None:
"""Clear all cached entries."""
@@ -72,169 +94,101 @@ class CachedFunction(Protocol[P, R_co]):
"""Get cache statistics."""
return {}
def cache_delete(self, *args: P.args, **kwargs: P.kwargs) -> bool:
"""Delete a specific cache entry by its arguments. Returns True if entry existed."""
return False
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
"""Call the cached function."""
return None # type: ignore
def cached(
*,
maxsize: int = 128,
ttl_seconds: int | None = None,
) -> Callable[[Callable], CachedFunction]:
def async_ttl_cache(
maxsize: int = 128, ttl_seconds: int | None = None
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
"""
Thundering herd safe cache decorator for both sync and async functions.
TTL (Time To Live) cache decorator for async functions.
Uses double-checked locking to prevent multiple threads/coroutines from
executing the expensive operation simultaneously during cache misses.
Similar to functools.lru_cache but works with async functions and includes optional TTL.
Args:
func: The function to cache (when used without parentheses)
maxsize: Maximum number of cached entries
ttl_seconds: Time to live in seconds. If None, entries never expire
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
Returns:
Decorated function or decorator
Decorator function
Example:
@cache() # Default: maxsize=128, no TTL
def expensive_sync_operation(param: str) -> dict:
# With TTL
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
async def api_call(param: str) -> dict:
return {"result": param}
@cache() # Works with async too
async def expensive_async_operation(param: str) -> dict:
return {"result": param}
@cache(maxsize=1000, ttl_seconds=300) # Custom maxsize and TTL
def another_operation(param: str) -> dict:
# Without TTL (permanent cache like lru_cache)
@async_ttl_cache(maxsize=1000)
async def expensive_computation(param: str) -> dict:
return {"result": param}
"""
def decorator(target_func):
# Cache storage and locks
cache_storage = {}
def decorator(
async_func: Callable[P, Awaitable[R]],
) -> AsyncCachedFunction[P, R]:
# Cache storage - use union type to handle both cases
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
if inspect.iscoroutinefunction(target_func):
# Async function with asyncio.Lock
cache_lock = asyncio.Lock()
@wraps(async_func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# Create cache key from arguments
key = (args, tuple(sorted(kwargs.items())))
current_time = time.time()
@wraps(target_func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
key = _make_hashable_key(args, kwargs)
current_time = time.time()
# Fast path: check cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.debug(f"Cache hit for {target_func.__name__}")
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(f"Cache hit for {target_func.__name__}")
return result
# Slow path: acquire lock for cache miss/expiry
async with cache_lock:
# Double-check: another coroutine might have populated cache
if key in cache_storage:
if ttl_seconds is None:
return cache_storage[key]
# Check if we have a valid cached entry
if key in cache_storage:
if ttl_seconds is None:
# No TTL - return cached result directly
logger.debug(
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
)
return cast(R, cache_storage[key])
else:
# With TTL - check expiration
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
)
return cast(R, result)
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
return result
# Expired entry
del cache_storage[key]
logger.debug(
f"Cache entry expired for {async_func.__name__}"
)
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
result = await target_func(*args, **kwargs)
# Cache miss or expired - fetch fresh data
logger.debug(
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
)
result = await async_func(*args, **kwargs)
# Store result
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Store in cache
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
# Simple cleanup when cache gets too large
if len(cache_storage) > maxsize:
# Remove oldest entries (simple FIFO cleanup)
cutoff = maxsize // 2
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
logger.debug(
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
)
return result
return result
wrapper = async_wrapper
else:
# Sync function with threading.Lock
cache_lock = threading.Lock()
@wraps(target_func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
key = _make_hashable_key(args, kwargs)
current_time = time.time()
# Fast path: check cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.debug(f"Cache hit for {target_func.__name__}")
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(f"Cache hit for {target_func.__name__}")
return result
# Slow path: acquire lock for cache miss/expiry
with cache_lock:
# Double-check: another thread might have populated cache
if key in cache_storage:
if ttl_seconds is None:
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
result = target_func(*args, **kwargs)
# Store result
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
return result
wrapper = sync_wrapper
# Add cache management methods
# Add cache management methods (similar to functools.lru_cache)
def cache_clear() -> None:
cache_storage.clear()
@@ -245,84 +199,68 @@ def cached(
"ttl_seconds": ttl_seconds,
}
def cache_delete(*args, **kwargs) -> bool:
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if key in cache_storage:
del cache_storage[key]
return True
return False
# Attach methods to wrapper
setattr(wrapper, "cache_clear", cache_clear)
setattr(wrapper, "cache_info", cache_info)
setattr(wrapper, "cache_delete", cache_delete)
return cast(CachedFunction, wrapper)
return cast(AsyncCachedFunction[P, R], wrapper)
return decorator
def thread_cached(func):
"""
Thread-local cache decorator for both sync and async functions.
@overload
def async_cache(
func: Callable[P, Awaitable[R]],
) -> AsyncCachedFunction[P, R]:
pass
Each thread gets its own cache, which is useful for request-scoped caching
in web applications where you want to cache within a single request but
not across requests.
@overload
def async_cache(
func: None = None,
*,
maxsize: int = 128,
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
pass
def async_cache(
func: Callable[P, Awaitable[R]] | None = None,
*,
maxsize: int = 128,
) -> (
AsyncCachedFunction[P, R]
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
):
"""
Process-level cache decorator for async functions (no TTL).
Similar to functools.lru_cache but works with async functions.
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
Args:
func: The function to cache
func: The async function to cache (when used without parentheses)
maxsize: Maximum number of cached entries
Returns:
Decorated function with thread-local caching
Decorated function or decorator
Example:
@thread_cached
def expensive_operation(param: str) -> dict:
# Without parentheses (uses default maxsize=128)
@async_cache
async def get_data(param: str) -> dict:
return {"result": param}
@thread_cached # Works with async too
async def expensive_async_operation(param: str) -> dict:
# With parentheses and custom maxsize
@async_cache(maxsize=1000)
async def expensive_computation(param: str) -> dict:
# Expensive computation here
return {"result": param}
"""
thread_local = threading.local()
def _clear():
if hasattr(thread_local, "cache"):
del thread_local.cache
if inspect.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = _make_hashable_key(args, kwargs)
if key not in cache:
cache[key] = await func(*args, **kwargs)
return cache[key]
setattr(async_wrapper, "clear_cache", _clear)
return async_wrapper
if func is None:
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = _make_hashable_key(args, kwargs)
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
setattr(sync_wrapper, "clear_cache", _clear)
return sync_wrapper
def clear_thread_cache(func: Callable) -> None:
"""Clear thread-local cache for a function."""
if clear := getattr(func, "clear_cache", None):
clear()
# Called without parentheses @async_cache
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
return decorator(func)

View File

@@ -16,7 +16,12 @@ from unittest.mock import Mock
import pytest
from autogpt_libs.utils.cache import cached, clear_thread_cache, thread_cached
from autogpt_libs.utils.cache import (
async_cache,
async_ttl_cache,
clear_thread_cache,
thread_cached,
)
class TestThreadCached:
@@ -325,202 +330,102 @@ class TestThreadCached:
assert mock.call_count == 2
class TestCache:
"""Tests for the unified @cache decorator (works for both sync and async)."""
def test_basic_sync_caching(self):
"""Test basic sync caching functionality."""
call_count = 0
@cached()
def expensive_sync_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
return x + y
# First call
result1 = expensive_sync_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call with same args - should use cache
result2 = expensive_sync_function(1, 2)
assert result2 == 3
assert call_count == 1
# Different args - should call function again
result3 = expensive_sync_function(2, 3)
assert result3 == 5
assert call_count == 2
class TestAsyncTTLCache:
"""Tests for the @async_ttl_cache decorator."""
@pytest.mark.asyncio
async def test_basic_async_caching(self):
"""Test basic async caching functionality."""
async def test_basic_caching(self):
"""Test basic caching functionality."""
call_count = 0
@cached()
async def expensive_async_function(x: int, y: int = 0) -> int:
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def cached_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01) # Simulate async work
return x + y
# First call
result1 = await expensive_async_function(1, 2)
result1 = await cached_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call with same args - should use cache
result2 = await expensive_async_function(1, 2)
result2 = await cached_function(1, 2)
assert result2 == 3
assert call_count == 1
assert call_count == 1 # No additional call
# Different args - should call function again
result3 = await expensive_async_function(2, 3)
result3 = await cached_function(2, 3)
assert result3 == 5
assert call_count == 2
def test_sync_thundering_herd_protection(self):
"""Test that concurrent sync calls don't cause thundering herd."""
call_count = 0
results = []
@cached()
def slow_function(x: int) -> int:
nonlocal call_count
call_count += 1
time.sleep(0.1) # Simulate expensive operation
return x * x
def worker():
result = slow_function(5)
results.append(result)
# Launch multiple concurrent threads
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(worker) for _ in range(5)]
for future in futures:
future.result()
# All results should be the same
assert all(result == 25 for result in results)
# Only one thread should have executed the expensive operation
assert call_count == 1
@pytest.mark.asyncio
async def test_async_thundering_herd_protection(self):
"""Test that concurrent async calls don't cause thundering herd."""
async def test_ttl_expiration(self):
"""Test that cache entries expire after TTL."""
call_count = 0
@cached()
async def slow_async_function(x: int) -> int:
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
async def short_lived_cache(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.1) # Simulate expensive operation
return x * x
# Launch concurrent coroutines
tasks = [slow_async_function(7) for _ in range(5)]
results = await asyncio.gather(*tasks)
# All results should be the same
assert all(result == 49 for result in results)
# Only one coroutine should have executed the expensive operation
assert call_count == 1
def test_ttl_functionality(self):
"""Test TTL functionality with sync function."""
call_count = 0
@cached(maxsize=10, ttl_seconds=1) # Short TTL
def ttl_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 3
return x * 2
# First call
result1 = ttl_function(3)
assert result1 == 9
result1 = await short_lived_cache(5)
assert result1 == 10
assert call_count == 1
# Second call immediately - should use cache
result2 = ttl_function(3)
assert result2 == 9
assert call_count == 1
# Wait for TTL to expire
time.sleep(1.1)
# Third call after expiration - should call function again
result3 = ttl_function(3)
assert result3 == 9
assert call_count == 2
@pytest.mark.asyncio
async def test_async_ttl_functionality(self):
"""Test TTL functionality with async function."""
call_count = 0
@cached(maxsize=10, ttl_seconds=1) # Short TTL
async def async_ttl_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x * 4
# First call
result1 = await async_ttl_function(3)
assert result1 == 12
assert call_count == 1
# Second call immediately - should use cache
result2 = await async_ttl_function(3)
assert result2 == 12
result2 = await short_lived_cache(5)
assert result2 == 10
assert call_count == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# Third call after expiration - should call function again
result3 = await async_ttl_function(3)
assert result3 == 12
result3 = await short_lived_cache(5)
assert result3 == 10
assert call_count == 2
def test_cache_info(self):
@pytest.mark.asyncio
async def test_cache_info(self):
"""Test cache info functionality."""
@cached(maxsize=10, ttl_seconds=60)
def info_test_function(x: int) -> int:
@async_ttl_cache(maxsize=5, ttl_seconds=300)
async def info_test_function(x: int) -> int:
return x * 3
# Check initial cache info
info = info_test_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] == 10
assert info["ttl_seconds"] == 60
assert info["maxsize"] == 5
assert info["ttl_seconds"] == 300
# Add an entry
info_test_function(1)
await info_test_function(1)
info = info_test_function.cache_info()
assert info["size"] == 1
def test_cache_clear(self):
@pytest.mark.asyncio
async def test_cache_clear(self):
"""Test cache clearing functionality."""
call_count = 0
@cached()
def clearable_function(x: int) -> int:
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 4
# First call
result1 = clearable_function(2)
result1 = await clearable_function(2)
assert result1 == 8
assert call_count == 1
# Second call - should use cache
result2 = clearable_function(2)
result2 = await clearable_function(2)
assert result2 == 8
assert call_count == 1
@@ -528,149 +433,273 @@ class TestCache:
clearable_function.cache_clear()
# Third call after clear - should call function again
result3 = clearable_function(2)
result3 = await clearable_function(2)
assert result3 == 8
assert call_count == 2
@pytest.mark.asyncio
async def test_async_cache_clear(self):
"""Test cache clearing functionality with async function."""
async def test_maxsize_cleanup(self):
"""Test that cache cleans up when maxsize is exceeded."""
call_count = 0
@cached()
async def async_clearable_function(x: int) -> int:
@async_ttl_cache(maxsize=3, ttl_seconds=60)
async def size_limited_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x * 5
return x**2
# First call
result1 = await async_clearable_function(2)
# Fill cache to maxsize
await size_limited_function(1) # call_count: 1
await size_limited_function(2) # call_count: 2
await size_limited_function(3) # call_count: 3
info = size_limited_function.cache_info()
assert info["size"] == 3
# Add one more entry - should trigger cleanup
await size_limited_function(4) # call_count: 4
# Cache size should be reduced (cleanup removes oldest entries)
info = size_limited_function.cache_info()
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
@pytest.mark.asyncio
async def test_argument_variations(self):
"""Test caching with different argument patterns."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
nonlocal call_count
call_count += 1
return f"{a}-{b}-{c}"
# Different ways to call with same logical arguments
result1 = await arg_test_function(1, "test", c=200)
assert call_count == 1
# Same arguments, same order - should use cache
result2 = await arg_test_function(1, "test", c=200)
assert call_count == 1
assert result1 == result2
# Different arguments - should call function
result3 = await arg_test_function(2, "test", c=200)
assert call_count == 2
assert result1 != result3
@pytest.mark.asyncio
async def test_exception_handling(self):
"""Test that exceptions are not cached."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def exception_function(x: int) -> int:
nonlocal call_count
call_count += 1
if x < 0:
raise ValueError("Negative value not allowed")
return x * 2
# Successful call - should be cached
result1 = await exception_function(5)
assert result1 == 10
assert call_count == 1
# Second call - should use cache
result2 = await async_clearable_function(2)
# Same successful call - should use cache
result2 = await exception_function(5)
assert result2 == 10
assert call_count == 1
# Clear cache
async_clearable_function.cache_clear()
# Third call after clear - should call function again
result3 = await async_clearable_function(2)
assert result3 == 10
# Exception call - should not be cached
with pytest.raises(ValueError):
await exception_function(-1)
assert call_count == 2
# Same exception call - should call again (not cached)
with pytest.raises(ValueError):
await exception_function(-1)
assert call_count == 3
@pytest.mark.asyncio
async def test_async_function_returns_results_not_coroutines(self):
"""Test that cached async functions return actual results, not coroutines."""
async def test_concurrent_calls(self):
"""Test caching behavior with concurrent calls."""
call_count = 0
@cached()
async def async_result_function(x: int) -> str:
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def concurrent_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return f"result_{x}"
await asyncio.sleep(0.05) # Simulate work
return x * x
# Launch concurrent calls with same arguments
tasks = [concurrent_function(3) for _ in range(5)]
results = await asyncio.gather(*tasks)
# All results should be the same
assert all(result == 9 for result in results)
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
# This tests that the cache doesn't break under concurrent access
assert 1 <= call_count <= 5
class TestAsyncCache:
"""Tests for the @async_cache decorator (no TTL)."""
@pytest.mark.asyncio
async def test_basic_caching_no_ttl(self):
"""Test basic caching functionality without TTL."""
call_count = 0
@async_cache(maxsize=10)
async def cached_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01) # Simulate async work
return x + y
# First call
result1 = await async_result_function(1)
assert result1 == "result_1"
assert isinstance(result1, str) # Should be string, not coroutine
result1 = await cached_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call - should return cached result (string), not coroutine
result2 = await async_result_function(1)
assert result2 == "result_1"
assert isinstance(result2, str) # Should be string, not coroutine
assert call_count == 1 # Function should not be called again
# Second call with same args - should use cache
result2 = await cached_function(1, 2)
assert result2 == 3
assert call_count == 1 # No additional call
# Verify results are identical
assert result1 is result2 # Should be same cached object
# Third call after some time - should still use cache (no TTL)
await asyncio.sleep(0.05)
result3 = await cached_function(1, 2)
assert result3 == 3
assert call_count == 1 # Still no additional call
def test_cache_delete(self):
"""Test selective cache deletion functionality."""
call_count = 0
@cached()
def deletable_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 6
# First call for x=1
result1 = deletable_function(1)
assert result1 == 6
assert call_count == 1
# First call for x=2
result2 = deletable_function(2)
assert result2 == 12
# Different args - should call function again
result4 = await cached_function(2, 3)
assert result4 == 5
assert call_count == 2
# Second calls - should use cache
assert deletable_function(1) == 6
assert deletable_function(2) == 12
assert call_count == 2
# Delete specific entry for x=1
was_deleted = deletable_function.cache_delete(1)
assert was_deleted is True
# Call with x=1 should execute function again
result3 = deletable_function(1)
assert result3 == 6
assert call_count == 3
# Call with x=2 should still use cache
assert deletable_function(2) == 12
assert call_count == 3
# Try to delete non-existent entry
was_deleted = deletable_function.cache_delete(99)
assert was_deleted is False
@pytest.mark.asyncio
async def test_async_cache_delete(self):
"""Test selective cache deletion functionality with async function."""
async def test_no_ttl_vs_ttl_behavior(self):
"""Test the difference between TTL and no-TTL caching."""
ttl_call_count = 0
no_ttl_call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
async def ttl_function(x: int) -> int:
nonlocal ttl_call_count
ttl_call_count += 1
return x * 2
@async_cache(maxsize=10) # No TTL
async def no_ttl_function(x: int) -> int:
nonlocal no_ttl_call_count
no_ttl_call_count += 1
return x * 2
# First calls
await ttl_function(5)
await no_ttl_function(5)
assert ttl_call_count == 1
assert no_ttl_call_count == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# Second calls after TTL expiry
await ttl_function(5) # Should call function again (TTL expired)
await no_ttl_function(5) # Should use cache (no TTL)
assert ttl_call_count == 2 # TTL function called again
assert no_ttl_call_count == 1 # No-TTL function still cached
@pytest.mark.asyncio
async def test_async_cache_info(self):
"""Test cache info for no-TTL cache."""
@async_cache(maxsize=5)
async def info_test_function(x: int) -> int:
return x * 3
# Check initial cache info
info = info_test_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] == 5
assert info["ttl_seconds"] is None # No TTL
# Add an entry
await info_test_function(1)
info = info_test_function.cache_info()
assert info["size"] == 1
class TestTTLOptional:
"""Tests for optional TTL functionality."""
@pytest.mark.asyncio
async def test_ttl_none_behavior(self):
"""Test that ttl_seconds=None works like no TTL."""
call_count = 0
@cached()
async def async_deletable_function(x: int) -> int:
@async_ttl_cache(maxsize=10, ttl_seconds=None)
async def no_ttl_via_none(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x * 7
return x**2
# First call for x=1
result1 = await async_deletable_function(1)
assert result1 == 7
# First call
result1 = await no_ttl_via_none(3)
assert result1 == 9
assert call_count == 1
# First call for x=2
result2 = await async_deletable_function(2)
assert result2 == 14
assert call_count == 2
# Wait (would expire if there was TTL)
await asyncio.sleep(0.1)
# Second calls - should use cache
assert await async_deletable_function(1) == 7
assert await async_deletable_function(2) == 14
assert call_count == 2
# Second call - should still use cache
result2 = await no_ttl_via_none(3)
assert result2 == 9
assert call_count == 1 # No additional call
# Delete specific entry for x=1
was_deleted = async_deletable_function.cache_delete(1)
assert was_deleted is True
# Check cache info
info = no_ttl_via_none.cache_info()
assert info["ttl_seconds"] is None
# Call with x=1 should execute function again
result3 = await async_deletable_function(1)
assert result3 == 7
assert call_count == 3
@pytest.mark.asyncio
async def test_cache_options_comparison(self):
"""Test different cache options work as expected."""
ttl_calls = 0
no_ttl_calls = 0
# Call with x=2 should still use cache
assert await async_deletable_function(2) == 14
assert call_count == 3
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
async def ttl_function(x: int) -> int:
nonlocal ttl_calls
ttl_calls += 1
return x * 10
# Try to delete non-existent entry
was_deleted = async_deletable_function.cache_delete(99)
assert was_deleted is False
@async_cache(maxsize=10) # Process-level cache (no TTL)
async def process_function(x: int) -> int:
nonlocal no_ttl_calls
no_ttl_calls += 1
return x * 10
# Both should cache initially
await ttl_function(3)
await process_function(3)
assert ttl_calls == 1
assert no_ttl_calls == 1
# Immediate second calls - both should use cache
await ttl_function(3)
await process_function(3)
assert ttl_calls == 1
assert no_ttl_calls == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# After TTL expiry
await ttl_function(3) # Should call function again
await process_function(3) # Should still use cache
assert ttl_calls == 2 # TTL cache expired, called again
assert no_ttl_calls == 1 # Process cache never expires

View File

@@ -54,7 +54,7 @@ version = "1.2.0"
description = "Backport of asyncio.Runner, a context manager that controls event loop life cycle."
optional = false
python-versions = "<3.11,>=3.8"
groups = ["dev"]
groups = ["main"]
markers = "python_version < \"3.11\""
files = [
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
@@ -85,87 +85,6 @@ files = [
{file = "certifi-2025.7.14.tar.gz", hash = "sha256:8ea99dbdfaaf2ba2f9bac77b9249ef62ec5218e7c2b2e903378ed5fccf765995"},
]
[[package]]
name = "cffi"
version = "1.17.1"
description = "Foreign Function Interface for Python calling C code."
optional = false
python-versions = ">=3.8"
groups = ["main"]
markers = "platform_python_implementation != \"PyPy\""
files = [
{file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"},
{file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"},
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382"},
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702"},
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3"},
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6"},
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17"},
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8"},
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e"},
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be"},
{file = "cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c"},
{file = "cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15"},
{file = "cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401"},
{file = "cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf"},
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4"},
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41"},
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1"},
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6"},
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d"},
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6"},
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f"},
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b"},
{file = "cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655"},
{file = "cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0"},
{file = "cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4"},
{file = "cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c"},
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36"},
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5"},
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff"},
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99"},
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93"},
{file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3"},
{file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8"},
{file = "cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65"},
{file = "cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903"},
{file = "cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e"},
{file = "cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2"},
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3"},
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683"},
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5"},
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4"},
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd"},
{file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed"},
{file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9"},
{file = "cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d"},
{file = "cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a"},
{file = "cffi-1.17.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b"},
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964"},
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9"},
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc"},
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c"},
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1"},
{file = "cffi-1.17.1-cp38-cp38-win32.whl", hash = "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8"},
{file = "cffi-1.17.1-cp38-cp38-win_amd64.whl", hash = "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1"},
{file = "cffi-1.17.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16"},
{file = "cffi-1.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36"},
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8"},
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576"},
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87"},
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0"},
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3"},
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595"},
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a"},
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e"},
{file = "cffi-1.17.1-cp39-cp39-win32.whl", hash = "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7"},
{file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"},
{file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"},
]
[package.dependencies]
pycparser = "*"
[[package]]
name = "charset-normalizer"
version = "3.4.2"
@@ -289,176 +208,12 @@ version = "0.4.6"
description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
groups = ["main", "dev"]
groups = ["main"]
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "coverage"
version = "7.10.5"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "coverage-7.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c6a5c3414bfc7451b879141ce772c546985163cf553f08e0f135f0699a911801"},
{file = "coverage-7.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bc8e4d99ce82f1710cc3c125adc30fd1487d3cf6c2cd4994d78d68a47b16989a"},
{file = "coverage-7.10.5-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:02252dc1216e512a9311f596b3169fad54abcb13827a8d76d5630c798a50a754"},
{file = "coverage-7.10.5-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:73269df37883e02d460bee0cc16be90509faea1e3bd105d77360b512d5bb9c33"},
{file = "coverage-7.10.5-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f8a81b0614642f91c9effd53eec284f965577591f51f547a1cbeb32035b4c2f"},
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6a29f8e0adb7f8c2b95fa2d4566a1d6e6722e0a637634c6563cb1ab844427dd9"},
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fcf6ab569436b4a647d4e91accba12509ad9f2554bc93d3aee23cc596e7f99c3"},
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:90dc3d6fb222b194a5de60af8d190bedeeddcbc7add317e4a3cd333ee6b7c879"},
{file = "coverage-7.10.5-cp310-cp310-win32.whl", hash = "sha256:414a568cd545f9dc75f0686a0049393de8098414b58ea071e03395505b73d7a8"},
{file = "coverage-7.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:e551f9d03347196271935fd3c0c165f0e8c049220280c1120de0084d65e9c7ff"},
{file = "coverage-7.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c177e6ffe2ebc7c410785307758ee21258aa8e8092b44d09a2da767834f075f2"},
{file = "coverage-7.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:14d6071c51ad0f703d6440827eaa46386169b5fdced42631d5a5ac419616046f"},
{file = "coverage-7.10.5-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:61f78c7c3bc272a410c5ae3fde7792b4ffb4acc03d35a7df73ca8978826bb7ab"},
{file = "coverage-7.10.5-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f39071caa126f69d63f99b324fb08c7b1da2ec28cbb1fe7b5b1799926492f65c"},
{file = "coverage-7.10.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343a023193f04d46edc46b2616cdbee68c94dd10208ecd3adc56fcc54ef2baa1"},
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:585ffe93ae5894d1ebdee69fc0b0d4b7c75d8007983692fb300ac98eed146f78"},
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b0ef4e66f006ed181df29b59921bd8fc7ed7cd6a9289295cd8b2824b49b570df"},
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:eb7b0bbf7cc1d0453b843eca7b5fa017874735bef9bfdfa4121373d2cc885ed6"},
{file = "coverage-7.10.5-cp311-cp311-win32.whl", hash = "sha256:1d043a8a06987cc0c98516e57c4d3fc2c1591364831e9deb59c9e1b4937e8caf"},
{file = "coverage-7.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:fefafcca09c3ac56372ef64a40f5fe17c5592fab906e0fdffd09543f3012ba50"},
{file = "coverage-7.10.5-cp311-cp311-win_arm64.whl", hash = "sha256:7e78b767da8b5fc5b2faa69bb001edafcd6f3995b42a331c53ef9572c55ceb82"},
{file = "coverage-7.10.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c2d05c7e73c60a4cecc7d9b60dbfd603b4ebc0adafaef371445b47d0f805c8a9"},
{file = "coverage-7.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:32ddaa3b2c509778ed5373b177eb2bf5662405493baeff52278a0b4f9415188b"},
{file = "coverage-7.10.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:dd382410039fe062097aa0292ab6335a3f1e7af7bba2ef8d27dcda484918f20c"},
{file = "coverage-7.10.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7fa22800f3908df31cea6fb230f20ac49e343515d968cc3a42b30d5c3ebf9b5a"},
{file = "coverage-7.10.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f366a57ac81f5e12797136552f5b7502fa053c861a009b91b80ed51f2ce651c6"},
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5f1dc8f1980a272ad4a6c84cba7981792344dad33bf5869361576b7aef42733a"},
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:2285c04ee8676f7938b02b4936d9b9b672064daab3187c20f73a55f3d70e6b4a"},
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c2492e4dd9daab63f5f56286f8a04c51323d237631eb98505d87e4c4ff19ec34"},
{file = "coverage-7.10.5-cp312-cp312-win32.whl", hash = "sha256:38a9109c4ee8135d5df5505384fc2f20287a47ccbe0b3f04c53c9a1989c2bbaf"},
{file = "coverage-7.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:6b87f1ad60b30bc3c43c66afa7db6b22a3109902e28c5094957626a0143a001f"},
{file = "coverage-7.10.5-cp312-cp312-win_arm64.whl", hash = "sha256:672a6c1da5aea6c629819a0e1461e89d244f78d7b60c424ecf4f1f2556c041d8"},
{file = "coverage-7.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ef3b83594d933020f54cf65ea1f4405d1f4e41a009c46df629dd964fcb6e907c"},
{file = "coverage-7.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2b96bfdf7c0ea9faebce088a3ecb2382819da4fbc05c7b80040dbc428df6af44"},
{file = "coverage-7.10.5-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:63df1fdaffa42d914d5c4d293e838937638bf75c794cf20bee12978fc8c4e3bc"},
{file = "coverage-7.10.5-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8002dc6a049aac0e81ecec97abfb08c01ef0c1fbf962d0c98da3950ace89b869"},
{file = "coverage-7.10.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:63d4bb2966d6f5f705a6b0c6784c8969c468dbc4bcf9d9ded8bff1c7e092451f"},
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1f672efc0731a6846b157389b6e6d5d5e9e59d1d1a23a5c66a99fd58339914d5"},
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:3f39cef43d08049e8afc1fde4a5da8510fc6be843f8dea350ee46e2a26b2f54c"},
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2968647e3ed5a6c019a419264386b013979ff1fb67dd11f5c9886c43d6a31fc2"},
{file = "coverage-7.10.5-cp313-cp313-win32.whl", hash = "sha256:0d511dda38595b2b6934c2b730a1fd57a3635c6aa2a04cb74714cdfdd53846f4"},
{file = "coverage-7.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:9a86281794a393513cf117177fd39c796b3f8e3759bb2764259a2abba5cce54b"},
{file = "coverage-7.10.5-cp313-cp313-win_arm64.whl", hash = "sha256:cebd8e906eb98bb09c10d1feed16096700b1198d482267f8bf0474e63a7b8d84"},
{file = "coverage-7.10.5-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0520dff502da5e09d0d20781df74d8189ab334a1e40d5bafe2efaa4158e2d9e7"},
{file = "coverage-7.10.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d9cd64aca68f503ed3f1f18c7c9174cbb797baba02ca8ab5112f9d1c0328cd4b"},
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0913dd1613a33b13c4f84aa6e3f4198c1a21ee28ccb4f674985c1f22109f0aae"},
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1b7181c0feeb06ed8a02da02792f42f829a7b29990fef52eff257fef0885d760"},
{file = "coverage-7.10.5-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36d42b7396b605f774d4372dd9c49bed71cbabce4ae1ccd074d155709dd8f235"},
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b4fdc777e05c4940b297bf47bf7eedd56a39a61dc23ba798e4b830d585486ca5"},
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:42144e8e346de44a6f1dbd0a56575dd8ab8dfa7e9007da02ea5b1c30ab33a7db"},
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:66c644cbd7aed8fe266d5917e2c9f65458a51cfe5eeff9c05f15b335f697066e"},
{file = "coverage-7.10.5-cp313-cp313t-win32.whl", hash = "sha256:2d1b73023854068c44b0c554578a4e1ef1b050ed07cf8b431549e624a29a66ee"},
{file = "coverage-7.10.5-cp313-cp313t-win_amd64.whl", hash = "sha256:54a1532c8a642d8cc0bd5a9a51f5a9dcc440294fd06e9dda55e743c5ec1a8f14"},
{file = "coverage-7.10.5-cp313-cp313t-win_arm64.whl", hash = "sha256:74d5b63fe3f5f5d372253a4ef92492c11a4305f3550631beaa432fc9df16fcff"},
{file = "coverage-7.10.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:68c5e0bc5f44f68053369fa0d94459c84548a77660a5f2561c5e5f1e3bed7031"},
{file = "coverage-7.10.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:cf33134ffae93865e32e1e37df043bef15a5e857d8caebc0099d225c579b0fa3"},
{file = "coverage-7.10.5-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ad8fa9d5193bafcf668231294241302b5e683a0518bf1e33a9a0dfb142ec3031"},
{file = "coverage-7.10.5-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:146fa1531973d38ab4b689bc764592fe6c2f913e7e80a39e7eeafd11f0ef6db2"},
{file = "coverage-7.10.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6013a37b8a4854c478d3219ee8bc2392dea51602dd0803a12d6f6182a0061762"},
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:eb90fe20db9c3d930fa2ad7a308207ab5b86bf6a76f54ab6a40be4012d88fcae"},
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:384b34482272e960c438703cafe63316dfbea124ac62006a455c8410bf2a2262"},
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:467dc74bd0a1a7de2bedf8deaf6811f43602cb532bd34d81ffd6038d6d8abe99"},
{file = "coverage-7.10.5-cp314-cp314-win32.whl", hash = "sha256:556d23d4e6393ca898b2e63a5bca91e9ac2d5fb13299ec286cd69a09a7187fde"},
{file = "coverage-7.10.5-cp314-cp314-win_amd64.whl", hash = "sha256:f4446a9547681533c8fa3e3c6cf62121eeee616e6a92bd9201c6edd91beffe13"},
{file = "coverage-7.10.5-cp314-cp314-win_arm64.whl", hash = "sha256:5e78bd9cf65da4c303bf663de0d73bf69f81e878bf72a94e9af67137c69b9fe9"},
{file = "coverage-7.10.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:5661bf987d91ec756a47c7e5df4fbcb949f39e32f9334ccd3f43233bbb65e508"},
{file = "coverage-7.10.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a46473129244db42a720439a26984f8c6f834762fc4573616c1f37f13994b357"},
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:1f64b8d3415d60f24b058b58d859e9512624bdfa57a2d1f8aff93c1ec45c429b"},
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:44d43de99a9d90b20e0163f9770542357f58860a26e24dc1d924643bd6aa7cb4"},
{file = "coverage-7.10.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a931a87e5ddb6b6404e65443b742cb1c14959622777f2a4efd81fba84f5d91ba"},
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f9559b906a100029274448f4c8b8b0a127daa4dade5661dfd821b8c188058842"},
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:b08801e25e3b4526ef9ced1aa29344131a8f5213c60c03c18fe4c6170ffa2874"},
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ed9749bb8eda35f8b636fb7632f1c62f735a236a5d4edadd8bbcc5ea0542e732"},
{file = "coverage-7.10.5-cp314-cp314t-win32.whl", hash = "sha256:609b60d123fc2cc63ccee6d17e4676699075db72d14ac3c107cc4976d516f2df"},
{file = "coverage-7.10.5-cp314-cp314t-win_amd64.whl", hash = "sha256:0666cf3d2c1626b5a3463fd5b05f5e21f99e6aec40a3192eee4d07a15970b07f"},
{file = "coverage-7.10.5-cp314-cp314t-win_arm64.whl", hash = "sha256:bc85eb2d35e760120540afddd3044a5bf69118a91a296a8b3940dfc4fdcfe1e2"},
{file = "coverage-7.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:62835c1b00c4a4ace24c1a88561a5a59b612fbb83a525d1c70ff5720c97c0610"},
{file = "coverage-7.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5255b3bbcc1d32a4069d6403820ac8e6dbcc1d68cb28a60a1ebf17e47028e898"},
{file = "coverage-7.10.5-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3876385722e335d6e991c430302c24251ef9c2a9701b2b390f5473199b1b8ebf"},
{file = "coverage-7.10.5-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8048ce4b149c93447a55d279078c8ae98b08a6951a3c4d2d7e87f4efc7bfe100"},
{file = "coverage-7.10.5-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4028e7558e268dd8bcf4d9484aad393cafa654c24b4885f6f9474bf53183a82a"},
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03f47dc870eec0367fcdd603ca6a01517d2504e83dc18dbfafae37faec66129a"},
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2d488d7d42b6ded7ea0704884f89dcabd2619505457de8fc9a6011c62106f6e5"},
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b3dcf2ead47fa8be14224ee817dfc1df98043af568fe120a22f81c0eb3c34ad2"},
{file = "coverage-7.10.5-cp39-cp39-win32.whl", hash = "sha256:02650a11324b80057b8c9c29487020073d5e98a498f1857f37e3f9b6ea1b2426"},
{file = "coverage-7.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:b45264dd450a10f9e03237b41a9a24e85cbb1e278e5a32adb1a303f58f0017f3"},
{file = "coverage-7.10.5-py3-none-any.whl", hash = "sha256:0be24d35e4db1d23d0db5c0f6a74a962e2ec83c426b5cac09f4234aadef38e4a"},
{file = "coverage-7.10.5.tar.gz", hash = "sha256:f2e57716a78bc3ae80b2207be0709a3b2b63b9f2dcf9740ee6ac03588a2015b6"},
]
[package.dependencies]
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
[package.extras]
toml = ["tomli ; python_full_version <= \"3.11.0a6\""]
[[package]]
name = "cryptography"
version = "45.0.6"
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
optional = false
python-versions = "!=3.9.0,!=3.9.1,>=3.7"
groups = ["main"]
files = [
{file = "cryptography-45.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:048e7ad9e08cf4c0ab07ff7f36cc3115924e22e2266e034450a890d9e312dd74"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:44647c5d796f5fc042bbc6d61307d04bf29bccb74d188f18051b635f20a9c75f"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e40b80ecf35ec265c452eea0ba94c9587ca763e739b8e559c128d23bff7ebbbf"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:00e8724bdad672d75e6f069b27970883179bd472cd24a63f6e620ca7e41cc0c5"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7a3085d1b319d35296176af31c90338eeb2ddac8104661df79f80e1d9787b8b2"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1b7fa6a1c1188c7ee32e47590d16a5a0646270921f8020efc9a511648e1b2e08"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:275ba5cc0d9e320cd70f8e7b96d9e59903c815ca579ab96c1e37278d231fc402"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:f4028f29a9f38a2025abedb2e409973709c660d44319c61762202206ed577c42"},
{file = "cryptography-45.0.6-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ee411a1b977f40bd075392c80c10b58025ee5c6b47a822a33c1198598a7a5f05"},
{file = "cryptography-45.0.6-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e2a21a8eda2d86bb604934b6b37691585bd095c1f788530c1fcefc53a82b3453"},
{file = "cryptography-45.0.6-cp311-abi3-win32.whl", hash = "sha256:d063341378d7ee9c91f9d23b431a3502fc8bfacd54ef0a27baa72a0843b29159"},
{file = "cryptography-45.0.6-cp311-abi3-win_amd64.whl", hash = "sha256:833dc32dfc1e39b7376a87b9a6a4288a10aae234631268486558920029b086ec"},
{file = "cryptography-45.0.6-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:3436128a60a5e5490603ab2adbabc8763613f638513ffa7d311c900a8349a2a0"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0d9ef57b6768d9fa58e92f4947cea96ade1233c0e236db22ba44748ffedca394"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea3c42f2016a5bbf71825537c2ad753f2870191134933196bee408aac397b3d9"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:20ae4906a13716139d6d762ceb3e0e7e110f7955f3bc3876e3a07f5daadec5f3"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2dac5ec199038b8e131365e2324c03d20e97fe214af051d20c49db129844e8b3"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:18f878a34b90d688982e43f4b700408b478102dd58b3e39de21b5ebf6509c301"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:5bd6020c80c5b2b2242d6c48487d7b85700f5e0038e67b29d706f98440d66eb5"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:eccddbd986e43014263eda489abbddfbc287af5cddfd690477993dbb31e31016"},
{file = "cryptography-45.0.6-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:550ae02148206beb722cfe4ef0933f9352bab26b087af00e48fdfb9ade35c5b3"},
{file = "cryptography-45.0.6-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5b64e668fc3528e77efa51ca70fadcd6610e8ab231e3e06ae2bab3b31c2b8ed9"},
{file = "cryptography-45.0.6-cp37-abi3-win32.whl", hash = "sha256:780c40fb751c7d2b0c6786ceee6b6f871e86e8718a8ff4bc35073ac353c7cd02"},
{file = "cryptography-45.0.6-cp37-abi3-win_amd64.whl", hash = "sha256:20d15aed3ee522faac1a39fbfdfee25d17b1284bafd808e1640a74846d7c4d1b"},
{file = "cryptography-45.0.6-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:705bb7c7ecc3d79a50f236adda12ca331c8e7ecfbea51edd931ce5a7a7c4f012"},
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:826b46dae41a1155a0c0e66fafba43d0ede1dc16570b95e40c4d83bfcf0a451d"},
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:cc4d66f5dc4dc37b89cfef1bd5044387f7a1f6f0abb490815628501909332d5d"},
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:f68f833a9d445cc49f01097d95c83a850795921b3f7cc6488731e69bde3288da"},
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:3b5bf5267e98661b9b888a9250d05b063220dfa917a8203744454573c7eb79db"},
{file = "cryptography-45.0.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2384f2ab18d9be88a6e4f8972923405e2dbb8d3e16c6b43f15ca491d7831bd18"},
{file = "cryptography-45.0.6-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fc022c1fa5acff6def2fc6d7819bbbd31ccddfe67d075331a65d9cfb28a20983"},
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3de77e4df42ac8d4e4d6cdb342d989803ad37707cf8f3fbf7b088c9cbdd46427"},
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:599c8d7df950aa68baa7e98f7b73f4f414c9f02d0e8104a30c0182a07732638b"},
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:31a2b9a10530a1cb04ffd6aa1cd4d3be9ed49f7d77a4dafe198f3b382f41545c"},
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:e5b3dda1b00fb41da3af4c5ef3f922a200e33ee5ba0f0bc9ecf0b0c173958385"},
{file = "cryptography-45.0.6-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:629127cfdcdc6806dfe234734d7cb8ac54edaf572148274fa377a7d3405b0043"},
{file = "cryptography-45.0.6.tar.gz", hash = "sha256:5c966c732cf6e4a276ce83b6e4c729edda2df6929083a952cc7da973c539c719"},
]
[package.dependencies]
cffi = {version = ">=1.14", markers = "platform_python_implementation != \"PyPy\""}
[package.extras]
docs = ["sphinx (>=5.3.0)", "sphinx-inline-tabs ; python_full_version >= \"3.8.0\"", "sphinx-rtd-theme (>=3.0.0) ; python_full_version >= \"3.8.0\""]
docstest = ["pyenchant (>=3)", "readme-renderer (>=30.0)", "sphinxcontrib-spelling (>=7.3.1)"]
nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2) ; python_full_version >= \"3.8.0\""]
pep8test = ["check-sdist ; python_full_version >= \"3.8.0\"", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"]
sdist = ["build (>=1.0.0)"]
ssh = ["bcrypt (>=3.1.5)"]
test = ["certifi (>=2024)", "cryptography-vectors (==45.0.6)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
test-randomorder = ["pytest-randomly"]
[[package]]
name = "deprecation"
version = "2.1.0"
@@ -480,7 +235,7 @@ version = "1.3.0"
description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
groups = ["main", "dev"]
groups = ["main"]
markers = "python_version < \"3.11\""
files = [
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
@@ -955,7 +710,7 @@ version = "2.1.0"
description = "brain-dead simple config-ini parsing"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
groups = ["main"]
files = [
{file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"},
{file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"},
@@ -1002,18 +757,6 @@ dynamodb = ["boto3 (>=1.9.71)"]
redis = ["redis (>=2.10.5)"]
test-filesource = ["pyyaml (>=5.3.1)", "watchdog (>=3.0.0)"]
[[package]]
name = "nodeenv"
version = "1.9.1"
description = "Node.js virtual environment builder"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
groups = ["dev"]
files = [
{file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"},
{file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"},
]
[[package]]
name = "opentelemetry-api"
version = "1.35.0"
@@ -1036,7 +779,7 @@ version = "25.0"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
groups = ["main", "dev"]
groups = ["main"]
files = [
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
@@ -1048,7 +791,7 @@ version = "1.6.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
groups = ["main"]
files = [
{file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"},
{file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"},
@@ -1140,19 +883,6 @@ files = [
[package.dependencies]
pyasn1 = ">=0.6.1,<0.7.0"
[[package]]
name = "pycparser"
version = "2.22"
description = "C parser in Python"
optional = false
python-versions = ">=3.8"
groups = ["main"]
markers = "platform_python_implementation != \"PyPy\""
files = [
{file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"},
{file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"},
]
[[package]]
name = "pydantic"
version = "2.11.7"
@@ -1317,7 +1047,7 @@ version = "2.19.2"
description = "Pygments is a syntax highlighting package written in Python."
optional = false
python-versions = ">=3.8"
groups = ["dev"]
groups = ["main"]
files = [
{file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"},
{file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"},
@@ -1338,9 +1068,6 @@ files = [
{file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"},
]
[package.dependencies]
cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"crypto\""}
[package.extras]
crypto = ["cryptography (>=3.4.0)"]
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"]
@@ -1359,34 +1086,13 @@ files = [
{file = "pyrfc3339-2.0.1.tar.gz", hash = "sha256:e47843379ea35c1296c3b6c67a948a1a490ae0584edfcbdea0eaffb5dd29960b"},
]
[[package]]
name = "pyright"
version = "1.1.404"
description = "Command line wrapper for pyright"
optional = false
python-versions = ">=3.7"
groups = ["dev"]
files = [
{file = "pyright-1.1.404-py3-none-any.whl", hash = "sha256:c7b7ff1fdb7219c643079e4c3e7d4125f0dafcc19d253b47e898d130ea426419"},
{file = "pyright-1.1.404.tar.gz", hash = "sha256:455e881a558ca6be9ecca0b30ce08aa78343ecc031d37a198ffa9a7a1abeb63e"},
]
[package.dependencies]
nodeenv = ">=1.6.0"
typing-extensions = ">=4.1"
[package.extras]
all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"]
dev = ["twine (>=3.4.1)"]
nodejs = ["nodejs-wheel-binaries"]
[[package]]
name = "pytest"
version = "8.4.1"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
groups = ["main"]
files = [
{file = "pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7"},
{file = "pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c"},
@@ -1410,7 +1116,7 @@ version = "1.1.0"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
groups = ["main"]
files = [
{file = "pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf"},
{file = "pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea"},
@@ -1424,33 +1130,13 @@ pytest = ">=8.2,<9"
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
name = "pytest-cov"
version = "6.2.1"
description = "Pytest plugin for measuring coverage."
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5"},
{file = "pytest_cov-6.2.1.tar.gz", hash = "sha256:25cc6cc0a5358204b8108ecedc51a9b57b34cc6b8c967cc2c01a4e00d8a67da2"},
]
[package.dependencies]
coverage = {version = ">=7.5", extras = ["toml"]}
pluggy = ">=1.2"
pytest = ">=6.2.5"
[package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
[[package]]
name = "pytest-mock"
version = "3.14.1"
description = "Thin-wrapper around the mock package for easier use with pytest"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
groups = ["main"]
files = [
{file = "pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0"},
{file = "pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e"},
@@ -1567,31 +1253,31 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.12.11"
version = "0.12.9"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
groups = ["dev"]
files = [
{file = "ruff-0.12.11-py3-none-linux_armv6l.whl", hash = "sha256:93fce71e1cac3a8bf9200e63a38ac5c078f3b6baebffb74ba5274fb2ab276065"},
{file = "ruff-0.12.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8e33ac7b28c772440afa80cebb972ffd823621ded90404f29e5ab6d1e2d4b93"},
{file = "ruff-0.12.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d69fb9d4937aa19adb2e9f058bc4fbfe986c2040acb1a4a9747734834eaa0bfd"},
{file = "ruff-0.12.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:411954eca8464595077a93e580e2918d0a01a19317af0a72132283e28ae21bee"},
{file = "ruff-0.12.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6a2c0a2e1a450f387bf2c6237c727dd22191ae8c00e448e0672d624b2bbd7fb0"},
{file = "ruff-0.12.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ca4c3a7f937725fd2413c0e884b5248a19369ab9bdd850b5781348ba283f644"},
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4d1df0098124006f6a66ecf3581a7f7e754c4df7644b2e6704cd7ca80ff95211"},
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a8dd5f230efc99a24ace3b77e3555d3fbc0343aeed3fc84c8d89e75ab2ff793"},
{file = "ruff-0.12.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dc75533039d0ed04cd33fb8ca9ac9620b99672fe7ff1533b6402206901c34ee"},
{file = "ruff-0.12.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fc58f9266d62c6eccc75261a665f26b4ef64840887fc6cbc552ce5b29f96cc8"},
{file = "ruff-0.12.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:5a0113bd6eafd545146440225fe60b4e9489f59eb5f5f107acd715ba5f0b3d2f"},
{file = "ruff-0.12.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0d737b4059d66295c3ea5720e6efc152623bb83fde5444209b69cd33a53e2000"},
{file = "ruff-0.12.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:916fc5defee32dbc1fc1650b576a8fed68f5e8256e2180d4d9855aea43d6aab2"},
{file = "ruff-0.12.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c984f07d7adb42d3ded5be894fb4007f30f82c87559438b4879fe7aa08c62b39"},
{file = "ruff-0.12.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e07fbb89f2e9249f219d88331c833860489b49cdf4b032b8e4432e9b13e8a4b9"},
{file = "ruff-0.12.11-py3-none-win32.whl", hash = "sha256:c792e8f597c9c756e9bcd4d87cf407a00b60af77078c96f7b6366ea2ce9ba9d3"},
{file = "ruff-0.12.11-py3-none-win_amd64.whl", hash = "sha256:a3283325960307915b6deb3576b96919ee89432ebd9c48771ca12ee8afe4a0fd"},
{file = "ruff-0.12.11-py3-none-win_arm64.whl", hash = "sha256:bae4d6e6a2676f8fb0f98b74594a048bae1b944aab17e9f5d504062303c6dbea"},
{file = "ruff-0.12.11.tar.gz", hash = "sha256:c6b09ae8426a65bbee5425b9d0b82796dbb07cb1af045743c79bfb163001165d"},
{file = "ruff-0.12.9-py3-none-linux_armv6l.whl", hash = "sha256:fcebc6c79fcae3f220d05585229463621f5dbf24d79fdc4936d9302e177cfa3e"},
{file = "ruff-0.12.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aed9d15f8c5755c0e74467731a007fcad41f19bcce41cd75f768bbd687f8535f"},
{file = "ruff-0.12.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5b15ea354c6ff0d7423814ba6d44be2807644d0c05e9ed60caca87e963e93f70"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d596c2d0393c2502eaabfef723bd74ca35348a8dac4267d18a94910087807c53"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1b15599931a1a7a03c388b9c5df1bfa62be7ede6eb7ef753b272381f39c3d0ff"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d02faa2977fb6f3f32ddb7828e212b7dd499c59eb896ae6c03ea5c303575756"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:17d5b6b0b3a25259b69ebcba87908496e6830e03acfb929ef9fd4c58675fa2ea"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72db7521860e246adbb43f6ef464dd2a532ef2ef1f5dd0d470455b8d9f1773e0"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a03242c1522b4e0885af63320ad754d53983c9599157ee33e77d748363c561ce"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fc83e4e9751e6c13b5046d7162f205d0a7bac5840183c5beebf824b08a27340"},
{file = "ruff-0.12.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:881465ed56ba4dd26a691954650de6ad389a2d1fdb130fe51ff18a25639fe4bb"},
{file = "ruff-0.12.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:43f07a3ccfc62cdb4d3a3348bf0588358a66da756aa113e071b8ca8c3b9826af"},
{file = "ruff-0.12.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:07adb221c54b6bba24387911e5734357f042e5669fa5718920ee728aba3cbadc"},
{file = "ruff-0.12.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f5cd34fabfdea3933ab85d72359f118035882a01bff15bd1d2b15261d85d5f66"},
{file = "ruff-0.12.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:f6be1d2ca0686c54564da8e7ee9e25f93bdd6868263805f8c0b8fc6a449db6d7"},
{file = "ruff-0.12.9-py3-none-win32.whl", hash = "sha256:cc7a37bd2509974379d0115cc5608a1a4a6c4bff1b452ea69db83c8855d53f93"},
{file = "ruff-0.12.9-py3-none-win_amd64.whl", hash = "sha256:6fb15b1977309741d7d098c8a3cb7a30bc112760a00fb6efb7abc85f00ba5908"},
{file = "ruff-0.12.9-py3-none-win_arm64.whl", hash = "sha256:63c8c819739d86b96d500cce885956a1a48ab056bbcbc61b747ad494b2485089"},
{file = "ruff-0.12.9.tar.gz", hash = "sha256:fbd94b2e3c623f659962934e52c2bea6fc6da11f667a427a368adaf3af2c866a"},
]
[[package]]
@@ -1725,7 +1411,7 @@ version = "2.2.1"
description = "A lil' TOML parser"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
groups = ["main"]
markers = "python_version < \"3.11\""
files = [
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
@@ -1768,7 +1454,7 @@ version = "4.14.1"
description = "Backported and Experimental Type Hints for Python 3.9+"
optional = false
python-versions = ">=3.9"
groups = ["main", "dev"]
groups = ["main"]
files = [
{file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"},
{file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"},
@@ -1929,4 +1615,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
content-hash = "4cc687aabe5865665fb8c4ccc0ea7e0af80b41e401ca37919f57efa6e0b5be00"

View File

@@ -9,25 +9,21 @@ packages = [{ include = "autogpt_libs" }]
[tool.poetry.dependencies]
python = ">=3.10,<4.0"
colorama = "^0.4.6"
cryptography = "^45.0"
expiringdict = "^1.2.2"
fastapi = "^0.116.1"
google-cloud-logging = "^3.12.1"
launchdarkly-server-sdk = "^9.12.0"
pydantic = "^2.11.7"
pydantic-settings = "^2.10.1"
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
pyjwt = "^2.10.1"
pytest-asyncio = "^1.1.0"
pytest-mock = "^3.14.1"
redis = "^6.2.0"
supabase = "^2.16.0"
uvicorn = "^0.35.0"
[tool.poetry.group.dev.dependencies]
pyright = "^1.1.404"
pytest = "^8.4.1"
pytest-asyncio = "^1.1.0"
pytest-mock = "^3.14.1"
pytest-cov = "^6.2.1"
ruff = "^0.12.11"
ruff = "^0.12.9"
[build-system]
requires = ["poetry-core"]

View File

@@ -16,12 +16,13 @@ DB_SCHEMA=platform
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
PRISMA_SCHEMA="postgres/schema.prisma"
ENABLE_AUTH=true
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
# Redis Configuration
REDIS_HOST=localhost
REDIS_PORT=6379
# REDIS_PASSWORD=
REDIS_PASSWORD=password
# RabbitMQ Credentials
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
@@ -30,7 +31,7 @@ RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
# Supabase Authentication
SUPABASE_URL=http://localhost:8000
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
## ===== REQUIRED SECURITY KEYS ===== ##
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
@@ -66,11 +67,6 @@ NVIDIA_API_KEY=
GITHUB_CLIENT_ID=
GITHUB_CLIENT_SECRET=
# Notion OAuth App server credentials - https://developers.notion.com/docs/authorization
# Configure a public integration
NOTION_CLIENT_ID=
NOTION_CLIENT_SECRET=
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
# You'll need to add/enable the following scopes (minimum):
@@ -179,4 +175,4 @@ SMARTLEAD_API_KEY=
ZEROBOUNCE_API_KEY=
# Other Services
AUTOMOD_API_KEY=
AUTOMOD_API_KEY=

View File

@@ -9,12 +9,4 @@ secrets/*
!secrets/.gitkeep
*.ignore.*
*.ign.*
# Load test results and reports
load-tests/*_RESULTS.md
load-tests/*_REPORT.md
load-tests/results/
load-tests/*.json
load-tests/*.log
load-tests/node_modules/*
*.ign.*

View File

@@ -9,15 +9,8 @@ WORKDIR /app
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
# Install Node.js repository key and setup
# Update package list and install Python and build dependencies
RUN apt-get update --allow-releaseinfo-change --fix-missing \
&& apt-get install -y curl ca-certificates gnupg \
&& mkdir -p /etc/apt/keyrings \
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list
# Update package list and install Python, Node.js, and build dependencies
RUN apt-get update \
&& apt-get install -y \
python3.13 \
python3.13-dev \
@@ -27,9 +20,7 @@ RUN apt-get update \
libpq5 \
libz-dev \
libssl-dev \
postgresql-client \
nodejs \
&& rm -rf /var/lib/apt/lists/*
postgresql-client
ENV POETRY_HOME=/opt/poetry
ENV POETRY_NO_INTERACTION=1
@@ -63,18 +54,13 @@ ENV PATH=/opt/poetry/bin:$PATH
# Install Python without upgrading system-managed packages
RUN apt-get update && apt-get install -y \
python3.13 \
python3-pip \
&& rm -rf /var/lib/apt/lists/*
python3-pip
# Copy only necessary files from builder
COPY --from=builder /app /app
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
# Copy Node.js installation for Prisma
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
COPY --from=builder /usr/bin/npm /usr/bin/npm
COPY --from=builder /usr/bin/npx /usr/bin/npx
# Copy Prisma binaries
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"

View File

@@ -132,58 +132,17 @@ def test_endpoint_success(snapshot: Snapshot):
### Testing with Authentication
For the main API routes that use JWT authentication, auth is provided by the `autogpt_libs.auth` module. If the test actually uses the `user_id`, the recommended approach for testing is to mock the `get_jwt_payload` function, which underpins all higher-level auth functions used in the API (`requires_user`, `requires_admin_user`, `get_user_id`).
If the test doesn't need the `user_id` specifically, mocking is not necessary as during tests auth is disabled anyway (see `conftest.py`).
#### Using Global Auth Fixtures
Two global auth fixtures are provided by `backend/server/conftest.py`:
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
These provide the easiest way to set up authentication mocking in test modules:
```python
import fastapi
import fastapi.testclient
import pytest
from backend.server.v2.myroute import router
def override_auth_middleware():
return {"sub": "test-user-id"}
app = fastapi.FastAPI()
app.include_router(router)
client = fastapi.testclient.TestClient(app)
def override_get_user_id():
return "test-user-id"
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user['get_jwt_payload']
yield
app.dependency_overrides.clear()
app.dependency_overrides[auth_middleware] = override_auth_middleware
app.dependency_overrides[get_user_id] = override_get_user_id
```
For admin-only endpoints, use `mock_jwt_admin` instead:
```python
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_admin):
"""Setup auth overrides for admin tests"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin['get_jwt_payload']
yield
app.dependency_overrides.clear()
```
The IDs are also available separately as fixtures:
- `test_user_id`
- `admin_user_id`
- `target_user_id` (for admin <-> user operations)
### Mocking External Services
```python
@@ -194,10 +153,10 @@ def test_external_api_call(mocker, snapshot):
"backend.services.external_api.call",
return_value=mock_response
)
response = client.post("/api/process")
assert response.status_code == 200
snapshot.snapshot_dir = "snapshots"
snapshot.assert_match(
json.dumps(response.json(), indent=2, sort_keys=True),
@@ -228,17 +187,6 @@ def test_external_api_call(mocker, snapshot):
- Use `async def` with `@pytest.mark.asyncio` for testing async functions directly
### 5. Fixtures
#### Global Fixtures (conftest.py)
Authentication fixtures are available globally from `conftest.py`:
- `mock_jwt_user` - Standard user authentication
- `mock_jwt_admin` - Admin user authentication
- `configured_snapshot` - Pre-configured snapshot fixture
#### Custom Fixtures
Create reusable fixtures for common test data:
```python
@@ -254,18 +202,9 @@ def test_create_user(sample_user, snapshot):
# ... test implementation
```
#### Test Isolation
All tests must use fixtures that ensure proper isolation:
- Authentication overrides are automatically cleaned up after each test
- Database connections are properly managed with cleanup
- Mock objects are reset between tests
## CI/CD Integration
The GitHub Actions workflow automatically runs tests on:
- Pull requests
- Pushes to main branch
@@ -277,19 +216,16 @@ Snapshot tests work in CI by:
## Troubleshooting
### Snapshot Mismatches
- Review the diff carefully
- If changes are expected: `poetry run pytest --snapshot-update`
- If changes are unexpected: Fix the code causing the difference
### Async Test Issues
- Ensure async functions use `@pytest.mark.asyncio`
- Use `AsyncMock` for mocking async functions
- FastAPI TestClient handles async automatically
### Import Errors
- Check that all dependencies are in `pyproject.toml`
- Run `poetry install` to ensure dependencies are installed
- Verify import paths are correct
@@ -298,4 +234,4 @@ Snapshot tests work in CI by:
Snapshot testing provides a powerful way to ensure API responses remain consistent. Combined with traditional assertions, it creates a robust test suite that catches regressions while remaining maintainable.
Remember: Good tests are as important as good code!
Remember: Good tests are as important as good code!

View File

@@ -1,3 +1,4 @@
import functools
import importlib
import logging
import os
@@ -5,8 +6,6 @@ import re
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar
from autogpt_libs.utils.cache import cached
logger = logging.getLogger(__name__)
@@ -16,7 +15,7 @@ if TYPE_CHECKING:
T = TypeVar("T")
@cached()
@functools.cache
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
from backend.util.settings import Config

View File

@@ -1,6 +1,8 @@
import logging
from typing import Any, Optional
from pydantic import JsonValue
from backend.data.block import (
Block,
BlockCategory,
@@ -10,7 +12,7 @@ from backend.data.block import (
BlockType,
get_block,
)
from backend.data.execution import ExecutionStatus, NodesInputMasks
from backend.data.execution import ExecutionStatus
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util.json import validate_with_jsonschema
from backend.util.retry import func_retry
@@ -31,7 +33,7 @@ class AgentExecutorBlock(Block):
input_schema: dict = SchemaField(description="Input schema for the graph")
output_schema: dict = SchemaField(description="Output schema for the graph")
nodes_input_masks: Optional[NodesInputMasks] = SchemaField(
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
default=None, hidden=True
)

View File

@@ -1,154 +0,0 @@
from enum import Enum
from typing import Literal
from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import MediaFileType
class GeminiImageModel(str, Enum):
NANO_BANANA = "google/nano-banana"
class OutputFormat(str, Enum):
JPG = "jpg"
PNG = "png"
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="replicate",
api_key=SecretStr("mock-replicate-api-key"),
title="Mock Replicate API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
class AIImageCustomizerBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
description="Replicate API key with permissions for Google Gemini image models",
)
prompt: str = SchemaField(
description="A text description of the image you want to generate",
title="Prompt",
)
model: GeminiImageModel = SchemaField(
description="The AI model to use for image generation and editing",
default=GeminiImageModel.NANO_BANANA,
title="Model",
)
images: list[MediaFileType] = SchemaField(
description="Optional list of input images to reference or modify",
default=[],
title="Input Images",
)
output_format: OutputFormat = SchemaField(
description="Format of the output image",
default=OutputFormat.PNG,
title="Output Format",
)
class Output(BlockSchema):
image_url: MediaFileType = SchemaField(description="URL of the generated image")
error: str = SchemaField(description="Error message if generation failed")
def __init__(self):
super().__init__(
id="d76bbe4c-930e-4894-8469-b66775511f71",
description=(
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
"Provide a prompt and optional reference images to create or modify images."
),
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
input_schema=AIImageCustomizerBlock.Input,
output_schema=AIImageCustomizerBlock.Output,
test_input={
"prompt": "Make the scene more vibrant and colorful",
"model": GeminiImageModel.NANO_BANANA,
"images": [],
"output_format": OutputFormat.JPG,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
("image_url", "https://replicate.delivery/generated-image.jpg"),
],
test_mock={
"run_model": lambda *args, **kwargs: MediaFileType(
"https://replicate.delivery/generated-image.jpg"
),
},
test_credentials=TEST_CREDENTIALS,
)
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
try:
result = await self.run_model(
api_key=credentials.api_key,
model_name=input_data.model.value,
prompt=input_data.prompt,
images=input_data.images,
output_format=input_data.output_format.value,
)
yield "image_url", result
except Exception as e:
yield "error", str(e)
async def run_model(
self,
api_key: SecretStr,
model_name: str,
prompt: str,
images: list[MediaFileType],
output_format: str,
) -> MediaFileType:
client = ReplicateClient(api_token=api_key.get_secret_value())
input_params: dict = {
"prompt": prompt,
"output_format": output_format,
}
# Add images to input if provided (API expects "image_input" parameter)
if images:
input_params["image_input"] = [str(img) for img in images]
output: FileOutput | str = await client.async_run( # type: ignore
model_name,
input=input_params,
wait=False,
)
if isinstance(output, FileOutput):
return MediaFileType(output.url)
if isinstance(output, str):
return MediaFileType(output)
raise ValueError("No output received from the model")

View File

@@ -661,167 +661,6 @@ async def update_field(
#################################################################
async def get_table_schema(
credentials: Credentials,
base_id: str,
table_id_or_name: str,
) -> dict:
"""
Get the schema for a specific table, including all field definitions.
Args:
credentials: Airtable API credentials
base_id: The base ID
table_id_or_name: The table ID or name
Returns:
Dict containing table schema with fields information
"""
# First get all tables to find the right one
response = await Requests().get(
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
headers={"Authorization": credentials.auth_header()},
)
data = response.json()
tables = data.get("tables", [])
# Find the matching table
for table in tables:
if table.get("id") == table_id_or_name or table.get("name") == table_id_or_name:
return table
raise ValueError(f"Table '{table_id_or_name}' not found in base '{base_id}'")
def get_empty_value_for_field(field_type: str) -> Any:
"""
Return the appropriate empty value for a given Airtable field type.
Args:
field_type: The Airtable field type
Returns:
The appropriate empty value for that field type
"""
# Fields that should be false when empty
if field_type == "checkbox":
return False
# Fields that should be empty arrays
if field_type in [
"multipleSelects",
"multipleRecordLinks",
"multipleAttachments",
"multipleLookupValues",
"multipleCollaborators",
]:
return []
# Fields that should be 0 when empty (numeric types)
if field_type in [
"number",
"percent",
"currency",
"rating",
"duration",
"count",
"autoNumber",
]:
return 0
# Fields that should be empty strings
if field_type in [
"singleLineText",
"multilineText",
"email",
"url",
"phoneNumber",
"richText",
"barcode",
]:
return ""
# Everything else gets null (dates, single selects, formulas, etc.)
return None
async def normalize_records(
records: list[dict],
table_schema: dict,
include_field_metadata: bool = False,
) -> dict:
"""
Normalize Airtable records to include all fields with proper empty values.
Args:
records: List of record objects from Airtable API
table_schema: Table schema containing field definitions
include_field_metadata: Whether to include field metadata in response
Returns:
Dict with normalized records and optionally field metadata
"""
fields = table_schema.get("fields", [])
# Normalize each record
normalized_records = []
for record in records:
normalized = {
"id": record.get("id"),
"createdTime": record.get("createdTime"),
"fields": {},
}
# Add existing fields
existing_fields = record.get("fields", {})
# Add all fields from schema, using empty values for missing ones
for field in fields:
field_name = field["name"]
field_type = field["type"]
if field_name in existing_fields:
# Field exists, use its value
normalized["fields"][field_name] = existing_fields[field_name]
else:
# Field is missing, add appropriate empty value
normalized["fields"][field_name] = get_empty_value_for_field(field_type)
normalized_records.append(normalized)
# Build result dictionary
if include_field_metadata:
field_metadata = {}
for field in fields:
metadata = {"type": field["type"], "id": field["id"]}
# Add type-specific metadata
options = field.get("options", {})
if field["type"] == "currency" and "symbol" in options:
metadata["symbol"] = options["symbol"]
metadata["precision"] = options.get("precision", 2)
elif field["type"] == "duration" and "durationFormat" in options:
metadata["format"] = options["durationFormat"]
elif field["type"] == "percent" and "precision" in options:
metadata["precision"] = options["precision"]
elif (
field["type"] in ["singleSelect", "multipleSelects"]
and "choices" in options
):
metadata["choices"] = [choice["name"] for choice in options["choices"]]
elif field["type"] == "rating" and "max" in options:
metadata["max"] = options["max"]
metadata["icon"] = options.get("icon", "star")
metadata["color"] = options.get("color", "yellowBright")
field_metadata[field["name"]] = metadata
return {"records": normalized_records, "field_metadata": field_metadata}
else:
return {"records": normalized_records}
async def list_records(
credentials: Credentials,
base_id: str,
@@ -1410,26 +1249,3 @@ async def list_bases(
)
return response.json()
async def get_base_tables(
credentials: Credentials,
base_id: str,
) -> list[dict]:
"""
Get all tables for a specific base.
Args:
credentials: Airtable API credentials
base_id: The ID of the base
Returns:
list[dict]: List of table objects with their schemas
"""
response = await Requests().get(
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
headers={"Authorization": credentials.auth_header()},
)
data = response.json()
return data.get("tables", [])

View File

@@ -14,13 +14,13 @@ from backend.sdk import (
SchemaField,
)
from ._api import create_base, get_base_tables, list_bases
from ._api import create_base, list_bases
from ._config import airtable
class AirtableCreateBaseBlock(Block):
"""
Creates a new base in an Airtable workspace, or returns existing base if one with the same name exists.
Creates a new base in an Airtable workspace.
"""
class Input(BlockSchema):
@@ -31,10 +31,6 @@ class AirtableCreateBaseBlock(Block):
description="The workspace ID where the base will be created"
)
name: str = SchemaField(description="The name of the new base")
find_existing: bool = SchemaField(
description="If true, return existing base with same name instead of creating duplicate",
default=True,
)
tables: list[dict] = SchemaField(
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
default=[
@@ -54,18 +50,14 @@ class AirtableCreateBaseBlock(Block):
)
class Output(BlockSchema):
base_id: str = SchemaField(description="The ID of the created or found base")
base_id: str = SchemaField(description="The ID of the created base")
tables: list[dict] = SchemaField(description="Array of table objects")
table: dict = SchemaField(description="A single table object")
was_created: bool = SchemaField(
description="True if a new base was created, False if existing was found",
default=True,
)
def __init__(self):
super().__init__(
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
description="Create or find a base in Airtable",
description="Create a new base in Airtable",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
@@ -74,31 +66,6 @@ class AirtableCreateBaseBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# If find_existing is true, check if a base with this name already exists
if input_data.find_existing:
# List all bases to check for existing one with same name
# Note: Airtable API doesn't have a direct search, so we need to list and filter
existing_bases = await list_bases(credentials)
for base in existing_bases.get("bases", []):
if base.get("name") == input_data.name:
# Base already exists, return it
base_id = base.get("id")
yield "base_id", base_id
yield "was_created", False
# Get the tables for this base
try:
tables = await get_base_tables(credentials, base_id)
yield "tables", tables
for table in tables:
yield "table", table
except Exception:
# If we can't get tables, return empty list
yield "tables", []
return
# No existing base found or find_existing is false, create new one
data = await create_base(
credentials,
input_data.workspace_id,
@@ -107,7 +74,6 @@ class AirtableCreateBaseBlock(Block):
)
yield "base_id", data.get("id", None)
yield "was_created", True
yield "tables", data.get("tables", [])
for table in data.get("tables", []):
yield "table", table

View File

@@ -2,7 +2,7 @@
Airtable record operation blocks.
"""
from typing import Optional, cast
from typing import Optional
from backend.sdk import (
APIKeyCredentials,
@@ -18,9 +18,7 @@ from ._api import (
create_record,
delete_multiple_records,
get_record,
get_table_schema,
list_records,
normalize_records,
update_multiple_records,
)
from ._config import airtable
@@ -56,24 +54,12 @@ class AirtableListRecordsBlock(Block):
return_fields: list[str] = SchemaField(
description="Specific fields to return (comma-separated)", default=[]
)
normalize_output: bool = SchemaField(
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
default=True,
)
include_field_metadata: bool = SchemaField(
description="Include field type and configuration metadata (requires normalize_output=true)",
default=False,
)
class Output(BlockSchema):
records: list[dict] = SchemaField(description="Array of record objects")
offset: Optional[str] = SchemaField(
description="Offset for next page (null if no more records)", default=None
)
field_metadata: Optional[dict] = SchemaField(
description="Field type and configuration metadata (only when include_field_metadata=true)",
default=None,
)
def __init__(self):
super().__init__(
@@ -87,7 +73,6 @@ class AirtableListRecordsBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
data = await list_records(
credentials,
input_data.base_id,
@@ -103,33 +88,8 @@ class AirtableListRecordsBlock(Block):
fields=input_data.return_fields if input_data.return_fields else None,
)
records = data.get("records", [])
# Normalize output if requested
if input_data.normalize_output:
# Fetch table schema
table_schema = await get_table_schema(
credentials, input_data.base_id, input_data.table_id_or_name
)
# Normalize the records
normalized_data = await normalize_records(
records,
table_schema,
include_field_metadata=input_data.include_field_metadata,
)
yield "records", normalized_data["records"]
yield "offset", data.get("offset", None)
if (
input_data.include_field_metadata
and "field_metadata" in normalized_data
):
yield "field_metadata", normalized_data["field_metadata"]
else:
yield "records", records
yield "offset", data.get("offset", None)
yield "records", data.get("records", [])
yield "offset", data.get("offset", None)
class AirtableGetRecordBlock(Block):
@@ -144,23 +104,11 @@ class AirtableGetRecordBlock(Block):
base_id: str = SchemaField(description="The Airtable base ID")
table_id_or_name: str = SchemaField(description="Table ID or name")
record_id: str = SchemaField(description="The record ID to retrieve")
normalize_output: bool = SchemaField(
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
default=True,
)
include_field_metadata: bool = SchemaField(
description="Include field type and configuration metadata (requires normalize_output=true)",
default=False,
)
class Output(BlockSchema):
id: str = SchemaField(description="The record ID")
fields: dict = SchemaField(description="The record fields")
created_time: str = SchemaField(description="The record created time")
field_metadata: Optional[dict] = SchemaField(
description="Field type and configuration metadata (only when include_field_metadata=true)",
default=None,
)
def __init__(self):
super().__init__(
@@ -174,7 +122,6 @@ class AirtableGetRecordBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
record = await get_record(
credentials,
input_data.base_id,
@@ -182,34 +129,9 @@ class AirtableGetRecordBlock(Block):
input_data.record_id,
)
# Normalize output if requested
if input_data.normalize_output:
# Fetch table schema
table_schema = await get_table_schema(
credentials, input_data.base_id, input_data.table_id_or_name
)
# Normalize the single record (wrap in list and unwrap result)
normalized_data = await normalize_records(
[record],
table_schema,
include_field_metadata=input_data.include_field_metadata,
)
normalized_record = normalized_data["records"][0]
yield "id", normalized_record.get("id", None)
yield "fields", normalized_record.get("fields", None)
yield "created_time", normalized_record.get("createdTime", None)
if (
input_data.include_field_metadata
and "field_metadata" in normalized_data
):
yield "field_metadata", normalized_data["field_metadata"]
else:
yield "id", record.get("id", None)
yield "fields", record.get("fields", None)
yield "created_time", record.get("createdTime", None)
yield "id", record.get("id", None)
yield "fields", record.get("fields", None)
yield "created_time", record.get("createdTime", None)
class AirtableCreateRecordsBlock(Block):
@@ -226,10 +148,6 @@ class AirtableCreateRecordsBlock(Block):
records: list[dict] = SchemaField(
description="Array of records to create (each with 'fields' object)"
)
skip_normalization: bool = SchemaField(
description="Skip output normalization to get raw Airtable response (faster but may have missing fields)",
default=False,
)
typecast: bool = SchemaField(
description="Automatically convert string values to appropriate types",
default=False,
@@ -255,7 +173,7 @@ class AirtableCreateRecordsBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# The create_record API expects records in a specific format
data = await create_record(
credentials,
input_data.base_id,
@@ -264,22 +182,8 @@ class AirtableCreateRecordsBlock(Block):
typecast=input_data.typecast if input_data.typecast else None,
return_fields_by_field_id=input_data.return_fields_by_field_id,
)
result_records = cast(list[dict], data.get("records", []))
# Normalize output unless explicitly disabled
if not input_data.skip_normalization and result_records:
# Fetch table schema
table_schema = await get_table_schema(
credentials, input_data.base_id, input_data.table_id_or_name
)
# Normalize the records
normalized_data = await normalize_records(
result_records, table_schema, include_field_metadata=False
)
result_records = normalized_data["records"]
yield "records", result_records
yield "records", data.get("records", [])
details = data.get("details", None)
if details:
yield "details", details

View File

@@ -1,3 +0,0 @@
from .text_overlay import BannerbearTextOverlayBlock
__all__ = ["BannerbearTextOverlayBlock"]

View File

@@ -1,8 +0,0 @@
from backend.sdk import BlockCostType, ProviderBuilder
bannerbear = (
ProviderBuilder("bannerbear")
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -1,239 +0,0 @@
import uuid
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
pass
from pydantic import SecretStr
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
from ._config import bannerbear
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="bannerbear",
api_key=SecretStr("mock-bannerbear-api-key"),
title="Mock Bannerbear API Key",
)
class TextModification(BlockSchema):
name: str = SchemaField(
description="The name of the layer to modify in the template"
)
text: str = SchemaField(description="The text content to add to this layer")
color: str = SchemaField(
description="Hex color code for the text (e.g., '#FF0000')",
default="",
advanced=True,
)
font_family: str = SchemaField(
description="Font family to use for the text",
default="",
advanced=True,
)
font_size: int = SchemaField(
description="Font size in pixels",
default=0,
advanced=True,
)
font_weight: str = SchemaField(
description="Font weight (e.g., 'bold', 'normal')",
default="",
advanced=True,
)
text_align: str = SchemaField(
description="Text alignment (left, center, right)",
default="",
advanced=True,
)
class BannerbearTextOverlayBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = bannerbear.credentials_field(
description="API credentials for Bannerbear"
)
template_id: str = SchemaField(
description="The unique ID of your Bannerbear template"
)
project_id: str = SchemaField(
description="Optional: Project ID (required when using Master API Key)",
default="",
advanced=True,
)
text_modifications: List[TextModification] = SchemaField(
description="List of text layers to modify in the template"
)
image_url: str = SchemaField(
description="Optional: URL of an image to use in the template",
default="",
advanced=True,
)
image_layer_name: str = SchemaField(
description="Optional: Name of the image layer in the template",
default="photo",
advanced=True,
)
webhook_url: str = SchemaField(
description="Optional: URL to receive webhook notification when image is ready",
default="",
advanced=True,
)
metadata: str = SchemaField(
description="Optional: Custom metadata to attach to the image",
default="",
advanced=True,
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the image generation was successfully initiated"
)
image_url: str = SchemaField(
description="URL of the generated image (if synchronous) or placeholder"
)
uid: str = SchemaField(description="Unique identifier for the generated image")
status: str = SchemaField(description="Status of the image generation")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="c7d3a5c2-05fc-450e-8dce-3b0e04626009",
description="Add text overlay to images using Bannerbear templates. Perfect for creating social media graphics, marketing materials, and dynamic image content.",
categories={BlockCategory.PRODUCTIVITY, BlockCategory.AI},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"template_id": "jJWBKNELpQPvbX5R93Gk",
"text_modifications": [
{
"name": "headline",
"text": "Amazing Product Launch!",
"color": "#FF0000",
},
{
"name": "subtitle",
"text": "50% OFF Today Only",
},
],
"credentials": {
"provider": "bannerbear",
"id": str(uuid.uuid4()),
"type": "api_key",
},
},
test_output=[
("success", True),
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
("uid", "test-uid-123"),
("status", "completed"),
],
test_mock={
"_make_api_request": lambda *args, **kwargs: {
"uid": "test-uid-123",
"status": "completed",
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
}
},
test_credentials=TEST_CREDENTIALS,
)
async def _make_api_request(self, payload: dict, api_key: str) -> dict:
"""Make the actual API request to Bannerbear. This is separated for easy mocking in tests."""
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
response = await Requests().post(
"https://sync.api.bannerbear.com/v2/images",
headers=headers,
json=payload,
)
if response.status in [200, 201, 202]:
return response.json()
else:
error_msg = f"API request failed with status {response.status}"
if response.text:
try:
error_data = response.json()
error_msg = (
f"{error_msg}: {error_data.get('message', response.text)}"
)
except Exception:
error_msg = f"{error_msg}: {response.text}"
raise Exception(error_msg)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# Build the modifications array
modifications = []
# Add text modifications
for text_mod in input_data.text_modifications:
mod_data: Dict[str, Any] = {
"name": text_mod.name,
"text": text_mod.text,
}
# Add optional text styling parameters only if they have values
if text_mod.color and text_mod.color.strip():
mod_data["color"] = text_mod.color
if text_mod.font_family and text_mod.font_family.strip():
mod_data["font_family"] = text_mod.font_family
if text_mod.font_size and text_mod.font_size > 0:
mod_data["font_size"] = text_mod.font_size
if text_mod.font_weight and text_mod.font_weight.strip():
mod_data["font_weight"] = text_mod.font_weight
if text_mod.text_align and text_mod.text_align.strip():
mod_data["text_align"] = text_mod.text_align
modifications.append(mod_data)
# Add image modification if provided and not empty
if input_data.image_url and input_data.image_url.strip():
modifications.append(
{
"name": input_data.image_layer_name,
"image_url": input_data.image_url,
}
)
# Build the request payload - only include non-empty optional fields
payload = {
"template": input_data.template_id,
"modifications": modifications,
}
# Add project_id if provided (required for Master API keys)
if input_data.project_id and input_data.project_id.strip():
payload["project_id"] = input_data.project_id
if input_data.webhook_url and input_data.webhook_url.strip():
payload["webhook_url"] = input_data.webhook_url
if input_data.metadata and input_data.metadata.strip():
payload["metadata"] = input_data.metadata
# Make the API request using the private method
data = await self._make_api_request(
payload, credentials.api_key.get_secret_value()
)
# Synchronous request - image should be ready
yield "success", True
yield "image_url", data.get("image_url", "")
yield "uid", data.get("uid", "")
yield "status", data.get("status", "completed")

View File

@@ -113,7 +113,6 @@ class DataForSeoClient:
include_serp_info: bool = False,
include_clickstream_data: bool = False,
limit: int = 100,
depth: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
Get related keywords from DataForSEO Labs.
@@ -126,7 +125,6 @@ class DataForSeoClient:
include_serp_info: Include SERP data
include_clickstream_data: Include clickstream metrics
limit: Maximum number of results (up to 3000)
depth: Keyword search depth (0-4), controls number of returned keywords
Returns:
API response with related keywords
@@ -150,8 +148,6 @@ class DataForSeoClient:
task_data["include_clickstream_data"] = include_clickstream_data
if limit is not None:
task_data["limit"] = limit
if depth is not None:
task_data["depth"] = depth
payload = [task_data]

View File

@@ -78,12 +78,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
ge=1,
le=3000,
)
depth: int = SchemaField(
description="Keyword search depth (0-4). Controls the number of returned keywords: 0=1 keyword, 1=~8 keywords, 2=~72 keywords, 3=~584 keywords, 4=~4680 keywords",
default=1,
ge=0,
le=4,
)
class Output(BlockSchema):
related_keywords: List[RelatedKeyword] = SchemaField(
@@ -160,7 +154,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
include_serp_info=input_data.include_serp_info,
include_clickstream_data=input_data.include_clickstream_data,
limit=input_data.limit,
depth=input_data.depth,
)
async def run(

View File

@@ -93,11 +93,11 @@ class Webset(BaseModel):
"""
Set of key-value pairs you want to associate with this object.
"""
created_at: Annotated[datetime | None, Field(alias="createdAt")] = None
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
"""
The date and time the webset was created
"""
updated_at: Annotated[datetime | None, Field(alias="updatedAt")] = None
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
"""
The date and time the webset was last updated
"""

View File

@@ -39,6 +39,18 @@ def serialize_email_recipients(recipients: list[str]) -> str:
return ", ".join(recipients)
def deduplicate_email_addresses(addresses: list[str]) -> list[str]:
"""Deduplicate email addresses while preserving order.
Args:
addresses: List of email addresses that may contain duplicates or None values
Returns:
List of unique email addresses with None values filtered out
"""
return list(dict.fromkeys(filter(None, addresses)))
def _make_mime_text(
body: str,
content_type: Optional[Literal["auto", "plain", "html"]] = None,
@@ -1094,117 +1106,6 @@ class GmailGetThreadBlock(GmailBase):
return thread
async def _build_reply_message(
service, input_data, graph_exec_id: str, user_id: str
) -> tuple[str, str]:
"""
Builds a reply MIME message for Gmail threads.
Returns:
tuple: (base64-encoded raw message, threadId)
"""
# Get parent message for reply context
parent = await asyncio.to_thread(
lambda: service.users()
.messages()
.get(
userId="me",
id=input_data.parentMessageId,
format="metadata",
metadataHeaders=[
"Subject",
"References",
"Message-ID",
"From",
"To",
"Cc",
"Reply-To",
],
)
.execute()
)
# Build headers dictionary, preserving all values for duplicate headers
headers = {}
for h in parent.get("payload", {}).get("headers", []):
name = h["name"].lower()
value = h["value"]
if name in headers:
# For duplicate headers, keep the first occurrence (most relevant for reply context)
continue
headers[name] = value
# Determine recipients if not specified
if not (input_data.to or input_data.cc or input_data.bcc):
if input_data.replyAll:
recipients = [parseaddr(headers.get("from", ""))[1]]
recipients += [addr for _, addr in getaddresses([headers.get("to", "")])]
recipients += [addr for _, addr in getaddresses([headers.get("cc", "")])]
# Use dict.fromkeys() for O(n) deduplication while preserving order
input_data.to = list(dict.fromkeys(filter(None, recipients)))
else:
# Check Reply-To header first, fall back to From header
reply_to = headers.get("reply-to", "")
from_addr = headers.get("from", "")
sender = parseaddr(reply_to if reply_to else from_addr)[1]
input_data.to = [sender] if sender else []
# Set subject with Re: prefix if not already present
if input_data.subject:
subject = input_data.subject
else:
parent_subject = headers.get("subject", "").strip()
# Only add "Re:" if not already present (case-insensitive check)
if parent_subject.lower().startswith("re:"):
subject = parent_subject
else:
subject = f"Re: {parent_subject}" if parent_subject else "Re:"
# Build references header for proper threading
references = headers.get("references", "").split()
if headers.get("message-id"):
references.append(headers["message-id"])
# Create MIME message
msg = MIMEMultipart()
if input_data.to:
msg["To"] = ", ".join(input_data.to)
if input_data.cc:
msg["Cc"] = ", ".join(input_data.cc)
if input_data.bcc:
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Subject"] = subject
if headers.get("message-id"):
msg["In-Reply-To"] = headers["message-id"]
if references:
msg["References"] = " ".join(references)
# Use the helper function for consistent content type handling
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
# Handle attachments
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
encoders.encode_base64(part)
part.add_header(
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
)
msg.attach(part)
# Encode message
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
return raw, input_data.threadId
class GmailReplyBlock(GmailBase):
"""
Replies to Gmail threads with intelligent content type detection.
@@ -1341,31 +1242,102 @@ class GmailReplyBlock(GmailBase):
async def _reply(
self, service, input_data: Input, graph_exec_id: str, user_id: str
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, graph_exec_id, user_id
parent = await asyncio.to_thread(
lambda: service.users()
.messages()
.get(
userId="me",
id=input_data.parentMessageId,
format="metadata",
metadataHeaders=[
"Subject",
"References",
"Message-ID",
"From",
"To",
"Cc",
"Reply-To",
],
)
.execute()
)
# Send the message
headers = {
h["name"].lower(): h["value"]
for h in parent.get("payload", {}).get("headers", [])
}
if not (input_data.to or input_data.cc or input_data.bcc):
if input_data.replyAll:
recipients = [parseaddr(headers.get("from", ""))[1]]
recipients += [
addr for _, addr in getaddresses([headers.get("to", "")])
]
recipients += [
addr for _, addr in getaddresses([headers.get("cc", "")])
]
# Deduplicate recipients while preserving order
input_data.to = deduplicate_email_addresses(recipients)
else:
sender = parseaddr(headers.get("reply-to", headers.get("from", "")))[1]
input_data.to = [sender] if sender else []
subject = input_data.subject or (f"Re: {headers.get('subject', '')}".strip())
references = headers.get("references", "").split()
if headers.get("message-id"):
references.append(headers["message-id"])
msg = MIMEMultipart()
if input_data.to:
msg["To"] = ", ".join(input_data.to)
if input_data.cc:
msg["Cc"] = ", ".join(input_data.cc)
if input_data.bcc:
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Subject"] = subject
if headers.get("message-id"):
msg["In-Reply-To"] = headers["message-id"]
if references:
msg["References"] = " ".join(references)
# Use the new helper function for consistent content type handling
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
encoders.encode_base64(part)
part.add_header(
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
)
msg.attach(part)
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
return await asyncio.to_thread(
lambda: service.users()
.messages()
.send(userId="me", body={"threadId": thread_id, "raw": raw})
.send(userId="me", body={"threadId": input_data.threadId, "raw": raw})
.execute()
)
class GmailDraftReplyBlock(GmailBase):
class GmailCreateDraftReplyBlock(GmailBase):
"""
Creates draft replies to Gmail threads with intelligent content type detection.
Features:
- Automatic HTML detection: Draft replies containing HTML tags are formatted as text/html
- No hard-wrap for plain text: Plain text draft replies preserve natural line flow
- No hard-wrap for plain text: Plain text drafts preserve natural line flow
- Manual content type override: Use content_type parameter to force specific format
- Reply-all functionality: Option to reply to all original recipients
- Reply-all functionality: Option to draft reply to all original recipients
- Thread preservation: Maintains proper email threading with headers
- Full Unicode/emoji support with UTF-8 encoding
- Attachment support for multiple files
"""
class Input(BlockSchema):
@@ -1405,31 +1377,31 @@ class GmailDraftReplyBlock(GmailBase):
def __init__(self):
super().__init__(
id="d7a9f3e2-8b4c-4d6f-9e1a-3c5b7f8d2a6e",
description="Create draft replies to Gmail threads with automatic HTML detection and proper text formatting. Plain text draft replies maintain natural paragraph flow without 78-character line wrapping. HTML content is automatically detected and formatted correctly.",
id="8f2e9d3c-4b1a-4c7e-9a2f-1d3e5f7a9b1c",
description="Create draft replies to Gmail threads with automatic HTML detection and proper text formatting. Drafts maintain proper email threading and can be edited before sending.",
categories={BlockCategory.COMMUNICATION},
input_schema=GmailDraftReplyBlock.Input,
output_schema=GmailDraftReplyBlock.Output,
input_schema=GmailCreateDraftReplyBlock.Input,
output_schema=GmailCreateDraftReplyBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
test_input={
"threadId": "t1",
"parentMessageId": "m1",
"body": "Thanks for your message. I'll review and get back to you.",
"body": "Thanks for your message. I'll draft a response.",
"replyAll": False,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("draftId", "draft1"),
("messageId", "m2"),
("messageId", "msg1"),
("threadId", "t1"),
("status", "draft_created"),
("status", "draft_reply_created"),
],
test_mock={
"_create_draft_reply": lambda *args, **kwargs: {
"id": "draft1",
"message": {"id": "m2", "threadId": "t1"},
}
"message": {"id": "msg1", "threadId": "t1"},
},
},
)
@@ -1443,26 +1415,117 @@ class GmailDraftReplyBlock(GmailBase):
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
draft = await self._create_draft_reply(
result = await self._create_draft_reply(
service,
input_data,
graph_exec_id,
user_id,
)
yield "draftId", draft["id"]
yield "messageId", draft["message"]["id"]
yield "threadId", draft["message"].get("threadId", input_data.threadId)
yield "status", "draft_created"
yield "draftId", result["id"]
yield "messageId", result["message"]["id"]
yield "threadId", result["message"].get("threadId", input_data.threadId)
yield "status", "draft_reply_created"
async def _create_draft_reply(
self, service, input_data: Input, graph_exec_id: str, user_id: str
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, graph_exec_id, user_id
# Fetch parent message metadata
parent = await asyncio.to_thread(
lambda: service.users()
.messages()
.get(
userId="me",
id=input_data.parentMessageId,
format="metadata",
metadataHeaders=[
"Subject",
"References",
"Message-ID",
"From",
"To",
"Cc",
"Reply-To",
],
)
.execute()
)
# Create draft with proper thread association
headers = {
h["name"].lower(): h["value"]
for h in parent.get("payload", {}).get("headers", [])
}
# Auto-populate recipients if not provided
if not (input_data.to or input_data.cc or input_data.bcc):
if input_data.replyAll:
# Reply all - include all original recipients
recipients = [parseaddr(headers.get("from", ""))[1]]
recipients += [
addr for _, addr in getaddresses([headers.get("to", "")])
]
recipients += [
addr for _, addr in getaddresses([headers.get("cc", "")])
]
# Deduplicate recipients
dedup: list[str] = []
for r in recipients:
if r and r not in dedup:
dedup.append(r)
input_data.to = dedup
else:
# Reply to sender only
sender = parseaddr(headers.get("reply-to", headers.get("from", "")))[1]
input_data.to = [sender] if sender else []
# Generate subject with Re: prefix if needed
subject = input_data.subject or (f"Re: {headers.get('subject', '')}".strip())
# Build References header chain
references = headers.get("references", "").split()
if headers.get("message-id"):
references.append(headers["message-id"])
# Create MIME message with threading headers
msg = MIMEMultipart()
if input_data.to:
msg["To"] = ", ".join(input_data.to)
if input_data.cc:
msg["Cc"] = ", ".join(input_data.cc)
if input_data.bcc:
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Subject"] = subject
# Set threading headers for proper conversation grouping
if headers.get("message-id"):
msg["In-Reply-To"] = headers["message-id"]
if references:
msg["References"] = " ".join(references)
# Add body with proper content type handling
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
# Handle attachments if any
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
encoders.encode_base64(part)
part.add_header(
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
)
msg.attach(part)
# Encode message for Gmail API
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
# Create draft with threadId to ensure it appears as a reply
draft = await asyncio.to_thread(
lambda: service.users()
.drafts()
@@ -1470,7 +1533,7 @@ class GmailDraftReplyBlock(GmailBase):
userId="me",
body={
"message": {
"threadId": thread_id,
"threadId": input_data.threadId,
"raw": raw,
}
},

View File

@@ -30,7 +30,6 @@ TEST_CREDENTIALS_INPUT = {
class IdeogramModelName(str, Enum):
V3 = "V_3"
V2 = "V_2"
V1 = "V_1"
V1_TURBO = "V_1_TURBO"
@@ -96,8 +95,8 @@ class IdeogramModelBlock(Block):
title="Prompt",
)
ideogram_model_name: IdeogramModelName = SchemaField(
description="The name of the Image Generation Model, e.g., V_3",
default=IdeogramModelName.V3,
description="The name of the Image Generation Model, e.g., V_2",
default=IdeogramModelName.V2,
title="Image Generation Model",
advanced=False,
)
@@ -237,111 +236,6 @@ class IdeogramModelBlock(Block):
negative_prompt: Optional[str],
color_palette_name: str,
custom_colors: Optional[list[str]],
):
# Use V3 endpoint for V3 model, legacy endpoint for others
if model_name == "V_3":
return await self._run_model_v3(
api_key,
prompt,
seed,
aspect_ratio,
magic_prompt_option,
style_type,
negative_prompt,
color_palette_name,
custom_colors,
)
else:
return await self._run_model_legacy(
api_key,
model_name,
prompt,
seed,
aspect_ratio,
magic_prompt_option,
style_type,
negative_prompt,
color_palette_name,
custom_colors,
)
async def _run_model_v3(
self,
api_key: SecretStr,
prompt: str,
seed: Optional[int],
aspect_ratio: str,
magic_prompt_option: str,
style_type: str,
negative_prompt: Optional[str],
color_palette_name: str,
custom_colors: Optional[list[str]],
):
url = "https://api.ideogram.ai/v1/ideogram-v3/generate"
headers = {
"Api-Key": api_key.get_secret_value(),
"Content-Type": "application/json",
}
# Map legacy aspect ratio values to V3 format
aspect_ratio_map = {
"ASPECT_10_16": "10x16",
"ASPECT_16_10": "16x10",
"ASPECT_9_16": "9x16",
"ASPECT_16_9": "16x9",
"ASPECT_3_2": "3x2",
"ASPECT_2_3": "2x3",
"ASPECT_4_3": "4x3",
"ASPECT_3_4": "3x4",
"ASPECT_1_1": "1x1",
"ASPECT_1_3": "1x3",
"ASPECT_3_1": "3x1",
# Additional V3 supported ratios
"ASPECT_1_2": "1x2",
"ASPECT_2_1": "2x1",
"ASPECT_4_5": "4x5",
"ASPECT_5_4": "5x4",
}
v3_aspect_ratio = aspect_ratio_map.get(
aspect_ratio, "1x1"
) # Default to 1x1 if not found
# Use JSON for V3 endpoint (simpler than multipart/form-data)
data: Dict[str, Any] = {
"prompt": prompt,
"aspect_ratio": v3_aspect_ratio,
"magic_prompt": magic_prompt_option,
"style_type": style_type,
}
if seed is not None:
data["seed"] = seed
if negative_prompt:
data["negative_prompt"] = negative_prompt
# Note: V3 endpoint may have different color palette support
# For now, we'll omit color palettes for V3 to avoid errors
try:
response = await Requests().post(url, headers=headers, json=data)
return response.json()["data"][0]["url"]
except RequestException as e:
raise Exception(f"Failed to fetch image with V3 endpoint: {str(e)}")
async def _run_model_legacy(
self,
api_key: SecretStr,
model_name: str,
prompt: str,
seed: Optional[int],
aspect_ratio: str,
magic_prompt_option: str,
style_type: str,
negative_prompt: Optional[str],
color_palette_name: str,
custom_colors: Optional[list[str]],
):
url = "https://api.ideogram.ai/generate"
headers = {
@@ -355,33 +249,28 @@ class IdeogramModelBlock(Block):
"model": model_name,
"aspect_ratio": aspect_ratio,
"magic_prompt_option": magic_prompt_option,
"style_type": style_type,
}
}
# Only add style_type for V2, V2_TURBO, and V3 models (V1 models don't support it)
if model_name in ["V_2", "V_2_TURBO", "V_3"]:
data["image_request"]["style_type"] = style_type
if seed is not None:
data["image_request"]["seed"] = seed
if negative_prompt:
data["image_request"]["negative_prompt"] = negative_prompt
# Only add color palette for V2 and V2_TURBO models (V1 models don't support it)
if model_name in ["V_2", "V_2_TURBO"]:
if color_palette_name != "NONE":
data["color_palette"] = {"name": color_palette_name}
elif custom_colors:
data["color_palette"] = {
"members": [{"color_hex": color} for color in custom_colors]
}
if color_palette_name != "NONE":
data["color_palette"] = {"name": color_palette_name}
elif custom_colors:
data["color_palette"] = {
"members": [{"color_hex": color} for color in custom_colors]
}
try:
response = await Requests().post(url, headers=headers, json=data)
return response.json()["data"][0]["url"]
except RequestException as e:
raise Exception(f"Failed to fetch image with legacy endpoint: {str(e)}")
raise Exception(f"Failed to fetch image: {str(e)}")
async def upscale_image(self, api_key: SecretStr, image_url: str):
url = "https://api.ideogram.ai/upscale"

View File

@@ -10,6 +10,7 @@ from backend.util.settings import Config
from backend.util.text import TextFormatter
from backend.util.type import LongTextType, MediaFileType, ShortTextType
formatter = TextFormatter()
config = Config()
@@ -131,11 +132,6 @@ class AgentOutputBlock(Block):
default="",
advanced=True,
)
escape_html: bool = SchemaField(
default=False,
advanced=True,
description="Whether to escape special characters in the inserted values to be HTML-safe. Enable for HTML output, disable for plain text.",
)
advanced: bool = SchemaField(
description="Whether to treat the output as advanced.",
default=False,
@@ -197,7 +193,6 @@ class AgentOutputBlock(Block):
"""
if input_data.format:
try:
formatter = TextFormatter(autoescape=input_data.escape_html)
yield "output", formatter.format_string(
input_data.format, {input_data.name: input_data.value}
)

View File

@@ -1,9 +1,5 @@
# This file contains a lot of prompt block strings that would trigger "line too long"
# flake8: noqa: E501
import ast
import logging
import re
import secrets
from abc import ABC
from enum import Enum, EnumMeta
from json import JSONDecodeError
@@ -31,7 +27,7 @@ from backend.util.prompt import compress_prompt, estimate_token_count
from backend.util.text import TextFormatter
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
fmt = TextFormatter()
LLMProviderName = Literal[
ProviderName.AIML_API,
@@ -208,13 +204,13 @@ MODEL_METADATA = {
"anthropic", 200000, 32000
), # claude-opus-4-1-20250805
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
"anthropic", 200000, 32000
"anthropic", 200000, 8192
), # claude-4-opus-20250514
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
"anthropic", 200000, 64000
"anthropic", 200000, 8192
), # claude-4-sonnet-20250514
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 64000
"anthropic", 200000, 8192
), # claude-3-7-sonnet-20250219
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
"anthropic", 200000, 8192
@@ -386,9 +382,7 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
return None
def get_parallel_tool_calls_param(
llm_model: LlmModel, parallel_tool_calls: bool | None
):
def get_parallel_tool_calls_param(llm_model: LlmModel, parallel_tool_calls):
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
if llm_model.startswith("o") or parallel_tool_calls is None:
return openai.NOT_GIVEN
@@ -399,8 +393,8 @@ async def llm_call(
credentials: APIKeyCredentials,
llm_model: LlmModel,
prompt: list[dict],
json_format: bool,
max_tokens: int | None,
force_json_output: bool = False,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
parallel_tool_calls=None,
@@ -413,7 +407,7 @@ async def llm_call(
credentials: The API key credentials to use.
llm_model: The LLM model to use.
prompt: The prompt to send to the LLM.
force_json_output: Whether the response should be in JSON format.
json_format: Whether the response should be in JSON format.
max_tokens: The maximum number of tokens to generate in the chat completion.
tools: The tools to use in the chat completion.
ollama_host: The host for ollama to use.
@@ -452,7 +446,7 @@ async def llm_call(
llm_model, parallel_tool_calls
)
if force_json_output:
if json_format:
response_format = {"type": "json_object"}
response = await oai_client.chat.completions.create(
@@ -565,7 +559,7 @@ async def llm_call(
raise ValueError("Groq does not support tools.")
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if force_json_output else None
response_format = {"type": "json_object"} if json_format else None
response = await client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
@@ -723,7 +717,7 @@ async def llm_call(
)
response_format = None
if force_json_output:
if json_format:
response_format = {"type": "json_object"}
parallel_tool_calls_param = get_parallel_tool_calls_param(
@@ -786,17 +780,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
description="The language model to use for answering the prompt.",
advanced=False,
)
force_json_output: bool = SchemaField(
title="Restrict LLM to pure JSON output",
default=False,
description=(
"Whether to force the LLM to produce a JSON-only response. "
"This can increase the block's reliability, "
"but may also reduce the quality of the response "
"because it prohibits the LLM from reasoning "
"before providing its JSON response."
),
)
credentials: AICredentials = AICredentialsField()
sys_prompt: str = SchemaField(
title="System Prompt",
@@ -865,18 +848,17 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
"llm_call": lambda *args, **kwargs: LLMResponse(
raw_response="",
prompt=[""],
response=(
'<json_output id="test123456">{\n'
' "key1": "key1Value",\n'
' "key2": "key2Value"\n'
"}</json_output>"
response=json.dumps(
{
"key1": "key1Value",
"key2": "key2Value",
}
),
tool_calls=None,
prompt_tokens=0,
completion_tokens=0,
reasoning=None,
),
"get_collision_proof_output_tag_id": lambda *args: "test123456",
)
},
)
@@ -885,9 +867,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
credentials: APIKeyCredentials,
llm_model: LlmModel,
prompt: list[dict],
json_format: bool,
compress_prompt_to_fit: bool,
max_tokens: int | None,
force_json_output: bool = False,
compress_prompt_to_fit: bool = True,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
) -> LLMResponse:
@@ -900,8 +882,8 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
json_format=json_format,
max_tokens=max_tokens,
force_json_output=force_json_output,
tools=tools,
ollama_host=ollama_host,
compress_prompt_to_fit=compress_prompt_to_fit,
@@ -913,6 +895,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
logger.debug(f"Calling LLM with input data: {input_data}")
prompt = [json.to_dict(p) for p in input_data.conversation_history]
def trim_prompt(s: str) -> str:
lines = s.strip().split("\n")
return "\n".join([line.strip().lstrip("|") for line in lines])
values = input_data.prompt_values
if values:
input_data.prompt = fmt.format_string(input_data.prompt, values)
@@ -921,15 +907,27 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
if input_data.sys_prompt:
prompt.append({"role": "system", "content": input_data.sys_prompt})
# Use a one-time unique tag to prevent collisions with user/LLM content
output_tag_id = self.get_collision_proof_output_tag_id()
output_tag_start = f'<json_output id="{output_tag_id}">'
if input_data.expected_format:
sys_prompt = self.response_format_instructions(
input_data.expected_format,
list_mode=input_data.list_result,
pure_json_mode=input_data.force_json_output,
output_tag_start=output_tag_start,
expected_format = [
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
]
if input_data.list_result:
format_prompt = (
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
)
else:
format_prompt = "\n ".join(expected_format)
sys_prompt = trim_prompt(
f"""
|Reply strictly only in the following JSON format:
|{{
| {format_prompt}
|}}
|
|Ensure the response is valid JSON. Do not include any additional text outside of the JSON.
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
"""
)
prompt.append({"role": "system", "content": sys_prompt})
@@ -947,21 +945,18 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
except JSONDecodeError as e:
return f"JSON decode error: {e}"
error_feedback_message = ""
logger.debug(f"LLM request: {prompt}")
retry_prompt = ""
llm_model = input_data.model
for retry_count in range(input_data.retry):
logger.debug(f"LLM request: {prompt}")
try:
llm_response = await self.llm_call(
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
compress_prompt_to_fit=input_data.compress_prompt_to_fit,
force_json_output=(
input_data.force_json_output
and bool(input_data.expected_format)
),
json_format=bool(input_data.expected_format),
ollama_host=input_data.ollama_host,
max_tokens=input_data.max_tokens,
)
@@ -975,55 +970,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
if input_data.expected_format:
try:
response_obj = self.get_json_from_response(
response_text,
pure_json_mode=input_data.force_json_output,
output_tag_start=output_tag_start,
)
except (ValueError, JSONDecodeError) as parse_error:
censored_response = re.sub(r"[A-Za-z0-9]", "*", response_text)
response_snippet = (
f"{censored_response[:50]}...{censored_response[-30:]}"
)
logger.warning(
f"Error getting JSON from LLM response: {parse_error}\n\n"
f"Response start+end: `{response_snippet}`"
)
prompt.append({"role": "assistant", "content": response_text})
error_feedback_message = self.invalid_response_feedback(
parse_error,
was_parseable=False,
list_mode=input_data.list_result,
pure_json_mode=input_data.force_json_output,
output_tag_start=output_tag_start,
)
prompt.append(
{"role": "user", "content": error_feedback_message}
)
continue
response_obj = json.loads(response_text)
# Handle object response for `force_json_output`+`list_result`
if input_data.list_result and isinstance(response_obj, dict):
if "results" in response_obj and isinstance(
response_obj["results"], list
):
response_obj = response_obj["results"]
else:
error_feedback_message = (
"Expected an array of objects in the 'results' key, "
f"but got: {response_obj}"
)
prompt.append(
{"role": "assistant", "content": response_text}
)
prompt.append(
{"role": "user", "content": error_feedback_message}
)
continue
if "results" in response_obj:
response_obj = response_obj.get("results", [])
elif len(response_obj) == 1:
response_obj = list(response_obj.values())
validation_errors = "\n".join(
response_error = "\n".join(
[
validation_error
for response_item in (
@@ -1035,7 +991,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
]
)
if not validation_errors:
if not response_error:
self.merge_stats(
NodeExecutionStats(
llm_call_count=retry_count + 1,
@@ -1045,16 +1001,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
yield "response", response_obj
yield "prompt", self.prompt
return
prompt.append({"role": "assistant", "content": response_text})
error_feedback_message = self.invalid_response_feedback(
validation_errors,
was_parseable=True,
list_mode=input_data.list_result,
pure_json_mode=input_data.force_json_output,
output_tag_start=output_tag_start,
)
prompt.append({"role": "user", "content": error_feedback_message})
else:
self.merge_stats(
NodeExecutionStats(
@@ -1065,6 +1011,21 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
yield "response", {"response": response_text}
yield "prompt", self.prompt
return
retry_prompt = trim_prompt(
f"""
|This is your previous error response:
|--
|{response_text}
|--
|
|And this is the error:
|--
|{response_error}
|--
"""
)
prompt.append({"role": "user", "content": retry_prompt})
except Exception as e:
logger.exception(f"Error calling LLM: {e}")
if (
@@ -1077,133 +1038,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
logger.debug(
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
)
# Don't add retry prompt for token limit errors,
# just retry with lower maximum output tokens
retry_prompt = f"Error calling LLM: {e}"
error_feedback_message = f"Error calling LLM: {e}"
raise RuntimeError(error_feedback_message)
def response_format_instructions(
self,
expected_object_format: dict[str, str],
*,
list_mode: bool,
pure_json_mode: bool,
output_tag_start: str,
) -> str:
expected_output_format = json.dumps(expected_object_format, indent=2)
output_type = "object" if not list_mode else "array"
outer_output_type = "object" if pure_json_mode else output_type
if output_type == "array":
indented_obj_format = expected_output_format.replace("\n", "\n ")
expected_output_format = f"[\n {indented_obj_format},\n ...\n]"
if pure_json_mode:
indented_list_format = expected_output_format.replace("\n", "\n ")
expected_output_format = (
"{\n"
' "reasoning": "... (optional)",\n' # for better performance
f' "results": {indented_list_format}\n'
"}"
)
# Preserve indentation in prompt
expected_output_format = expected_output_format.replace("\n", "\n|")
# Prepare prompt
if not pure_json_mode:
expected_output_format = (
f"{output_tag_start}\n{expected_output_format}\n</json_output>"
)
instructions = f"""
|In your response you MUST include a valid JSON {outer_output_type} strictly following this format:
|{expected_output_format}
|
|If you cannot provide all the keys, you MUST provide an empty string for the values you cannot answer.
""".strip()
if not pure_json_mode:
instructions += f"""
|
|You MUST enclose your final JSON answer in {output_tag_start}...</json_output> tags, even if the user specifies a different tag.
|There MUST be exactly ONE {output_tag_start}...</json_output> block in your response, which MUST ONLY contain the JSON {outer_output_type} and nothing else. Other text outside this block is allowed.
""".strip()
return trim_prompt(instructions)
def invalid_response_feedback(
self,
error,
*,
was_parseable: bool,
list_mode: bool,
pure_json_mode: bool,
output_tag_start: str,
) -> str:
outer_output_type = "object" if not list_mode or pure_json_mode else "array"
if was_parseable:
complaint = f"Your previous response did not match the expected {outer_output_type} format."
else:
complaint = f"Your previous response did not contain a parseable JSON {outer_output_type}."
indented_parse_error = str(error).replace("\n", "\n|")
instruction = (
f"Please provide a {output_tag_start}...</json_output> block containing a"
if not pure_json_mode
else "Please provide a"
) + f" valid JSON {outer_output_type} that matches the expected format."
return trim_prompt(
f"""
|{complaint}
|
|{indented_parse_error}
|
|{instruction}
"""
)
def get_json_from_response(
self, response_text: str, *, pure_json_mode: bool, output_tag_start: str
) -> dict[str, Any] | list[dict[str, Any]]:
if pure_json_mode:
# Handle pure JSON responses
try:
return json.loads(response_text)
except JSONDecodeError as first_parse_error:
# If that didn't work, try finding the { and } to deal with possible ```json fences etc.
json_start = response_text.find("{")
json_end = response_text.rfind("}")
try:
return json.loads(response_text[json_start : json_end + 1])
except JSONDecodeError:
# Raise the original error, as it's more likely to be relevant
raise first_parse_error from None
if output_tag_start not in response_text:
raise ValueError(
"Response does not contain the expected "
f"{output_tag_start}...</json_output> block."
)
json_output = (
response_text.split(output_tag_start, 1)[1]
.rsplit("</json_output>", 1)[0]
.strip()
)
return json.loads(json_output)
def get_collision_proof_output_tag_id(self) -> str:
return secrets.token_hex(8)
def trim_prompt(s: str) -> str:
"""Removes indentation up to and including `|` from a multi-line prompt."""
lines = s.strip().split("\n")
return "\n".join([line.strip().lstrip("|") for line in lines])
raise RuntimeError(retry_prompt)
class AITextGeneratorBlock(AIBlockBase):

View File

@@ -1,536 +0,0 @@
"""
Notion API helper functions and client for making authenticated requests.
"""
from typing import Any, Dict, List, Optional
from backend.data.model import OAuth2Credentials
from backend.util.request import Requests
NOTION_VERSION = "2022-06-28"
class NotionAPIException(Exception):
"""Exception raised for Notion API errors."""
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code
class NotionClient:
"""Client for interacting with the Notion API."""
def __init__(self, credentials: OAuth2Credentials):
self.credentials = credentials
self.headers = {
"Authorization": credentials.auth_header(),
"Notion-Version": NOTION_VERSION,
"Content-Type": "application/json",
}
self.requests = Requests()
async def get_page(self, page_id: str) -> dict:
"""
Fetch a page by ID.
Args:
page_id: The ID of the page to fetch.
Returns:
The page object from Notion API.
"""
url = f"https://api.notion.com/v1/pages/{page_id}"
response = await self.requests.get(url, headers=self.headers)
if not response.ok:
raise NotionAPIException(
f"Failed to fetch page: {response.status} - {response.text()}",
response.status,
)
return response.json()
async def get_blocks(self, block_id: str, recursive: bool = True) -> List[dict]:
"""
Fetch all blocks from a page or block.
Args:
block_id: The ID of the page or block to fetch children from.
recursive: Whether to fetch nested blocks recursively.
Returns:
List of block objects.
"""
blocks = []
cursor = None
while True:
url = f"https://api.notion.com/v1/blocks/{block_id}/children"
params = {"page_size": 100}
if cursor:
params["start_cursor"] = cursor
response = await self.requests.get(url, headers=self.headers, params=params)
if not response.ok:
raise NotionAPIException(
f"Failed to fetch blocks: {response.status} - {response.text()}",
response.status,
)
data = response.json()
current_blocks = data.get("results", [])
# If recursive, fetch children for blocks that have them
if recursive:
for block in current_blocks:
if block.get("has_children"):
block["children"] = await self.get_blocks(
block["id"], recursive=True
)
blocks.extend(current_blocks)
if not data.get("has_more"):
break
cursor = data.get("next_cursor")
return blocks
async def query_database(
self,
database_id: str,
filter_obj: Optional[dict] = None,
sorts: Optional[List[dict]] = None,
page_size: int = 100,
) -> dict:
"""
Query a database with optional filters and sorts.
Args:
database_id: The ID of the database to query.
filter_obj: Optional filter object for the query.
sorts: Optional list of sort objects.
page_size: Number of results per page.
Returns:
Query results including pages and pagination info.
"""
url = f"https://api.notion.com/v1/databases/{database_id}/query"
payload: Dict[str, Any] = {"page_size": page_size}
if filter_obj:
payload["filter"] = filter_obj
if sorts:
payload["sorts"] = sorts
response = await self.requests.post(url, headers=self.headers, json=payload)
if not response.ok:
raise NotionAPIException(
f"Failed to query database: {response.status} - {response.text()}",
response.status,
)
return response.json()
async def create_page(
self,
parent: dict,
properties: dict,
children: Optional[List[dict]] = None,
icon: Optional[dict] = None,
cover: Optional[dict] = None,
) -> dict:
"""
Create a new page.
Args:
parent: Parent object (page_id or database_id).
properties: Page properties.
children: Optional list of block children.
icon: Optional icon object.
cover: Optional cover object.
Returns:
The created page object.
"""
url = "https://api.notion.com/v1/pages"
payload: Dict[str, Any] = {"parent": parent, "properties": properties}
if children:
payload["children"] = children
if icon:
payload["icon"] = icon
if cover:
payload["cover"] = cover
response = await self.requests.post(url, headers=self.headers, json=payload)
if not response.ok:
raise NotionAPIException(
f"Failed to create page: {response.status} - {response.text()}",
response.status,
)
return response.json()
async def update_page(self, page_id: str, properties: dict) -> dict:
"""
Update a page's properties.
Args:
page_id: The ID of the page to update.
properties: Properties to update.
Returns:
The updated page object.
"""
url = f"https://api.notion.com/v1/pages/{page_id}"
response = await self.requests.patch(
url, headers=self.headers, json={"properties": properties}
)
if not response.ok:
raise NotionAPIException(
f"Failed to update page: {response.status} - {response.text()}",
response.status,
)
return response.json()
async def append_blocks(self, block_id: str, children: List[dict]) -> dict:
"""
Append blocks to a page or block.
Args:
block_id: The ID of the page or block to append to.
children: List of block objects to append.
Returns:
Response with the created blocks.
"""
url = f"https://api.notion.com/v1/blocks/{block_id}/children"
response = await self.requests.patch(
url, headers=self.headers, json={"children": children}
)
if not response.ok:
raise NotionAPIException(
f"Failed to append blocks: {response.status} - {response.text()}",
response.status,
)
return response.json()
async def search(
self,
query: str = "",
filter_obj: Optional[dict] = None,
sort: Optional[dict] = None,
page_size: int = 100,
) -> dict:
"""
Search for pages and databases.
Args:
query: Search query text.
filter_obj: Optional filter object.
sort: Optional sort object.
page_size: Number of results per page.
Returns:
Search results.
"""
url = "https://api.notion.com/v1/search"
payload: Dict[str, Any] = {"page_size": page_size}
if query:
payload["query"] = query
if filter_obj:
payload["filter"] = filter_obj
if sort:
payload["sort"] = sort
response = await self.requests.post(url, headers=self.headers, json=payload)
if not response.ok:
raise NotionAPIException(
f"Search failed: {response.status} - {response.text()}", response.status
)
return response.json()
# Conversion helper functions
def parse_rich_text(rich_text_array: List[dict]) -> str:
"""
Extract plain text from a Notion rich text array.
Args:
rich_text_array: Array of rich text objects from Notion.
Returns:
Plain text string.
"""
if not rich_text_array:
return ""
text_parts = []
for text_obj in rich_text_array:
if "plain_text" in text_obj:
text_parts.append(text_obj["plain_text"])
return "".join(text_parts)
def rich_text_to_markdown(rich_text_array: List[dict]) -> str:
"""
Convert Notion rich text array to markdown with formatting.
Args:
rich_text_array: Array of rich text objects from Notion.
Returns:
Markdown formatted string.
"""
if not rich_text_array:
return ""
markdown_parts = []
for text_obj in rich_text_array:
text = text_obj.get("plain_text", "")
annotations = text_obj.get("annotations", {})
# Apply formatting based on annotations
if annotations.get("code"):
text = f"`{text}`"
else:
if annotations.get("bold"):
text = f"**{text}**"
if annotations.get("italic"):
text = f"*{text}*"
if annotations.get("strikethrough"):
text = f"~~{text}~~"
if annotations.get("underline"):
text = f"<u>{text}</u>"
# Handle links
if text_obj.get("href"):
text = f"[{text}]({text_obj['href']})"
markdown_parts.append(text)
return "".join(markdown_parts)
def block_to_markdown(block: dict, indent_level: int = 0) -> str:
"""
Convert a single Notion block to markdown.
Args:
block: Block object from Notion API.
indent_level: Current indentation level for nested blocks.
Returns:
Markdown string representation of the block.
"""
block_type = block.get("type")
indent = " " * indent_level
markdown_lines = []
# Handle different block types
if block_type == "paragraph":
text = rich_text_to_markdown(block["paragraph"].get("rich_text", []))
if text:
markdown_lines.append(f"{indent}{text}")
elif block_type == "heading_1":
text = parse_rich_text(block["heading_1"].get("rich_text", []))
markdown_lines.append(f"{indent}# {text}")
elif block_type == "heading_2":
text = parse_rich_text(block["heading_2"].get("rich_text", []))
markdown_lines.append(f"{indent}## {text}")
elif block_type == "heading_3":
text = parse_rich_text(block["heading_3"].get("rich_text", []))
markdown_lines.append(f"{indent}### {text}")
elif block_type == "bulleted_list_item":
text = rich_text_to_markdown(block["bulleted_list_item"].get("rich_text", []))
markdown_lines.append(f"{indent}- {text}")
elif block_type == "numbered_list_item":
text = rich_text_to_markdown(block["numbered_list_item"].get("rich_text", []))
# Note: This is simplified - proper numbering would need context
markdown_lines.append(f"{indent}1. {text}")
elif block_type == "to_do":
text = rich_text_to_markdown(block["to_do"].get("rich_text", []))
checked = "x" if block["to_do"].get("checked") else " "
markdown_lines.append(f"{indent}- [{checked}] {text}")
elif block_type == "toggle":
text = rich_text_to_markdown(block["toggle"].get("rich_text", []))
markdown_lines.append(f"{indent}<details>")
markdown_lines.append(f"{indent}<summary>{text}</summary>")
markdown_lines.append(f"{indent}")
# Process children if they exist
if block.get("children"):
for child in block["children"]:
child_markdown = block_to_markdown(child, indent_level + 1)
if child_markdown:
markdown_lines.append(child_markdown)
markdown_lines.append(f"{indent}</details>")
elif block_type == "code":
code = parse_rich_text(block["code"].get("rich_text", []))
language = block["code"].get("language", "")
markdown_lines.append(f"{indent}```{language}")
markdown_lines.append(f"{indent}{code}")
markdown_lines.append(f"{indent}```")
elif block_type == "quote":
text = rich_text_to_markdown(block["quote"].get("rich_text", []))
markdown_lines.append(f"{indent}> {text}")
elif block_type == "divider":
markdown_lines.append(f"{indent}---")
elif block_type == "image":
image = block["image"]
url = image.get("external", {}).get("url") or image.get("file", {}).get(
"url", ""
)
caption = parse_rich_text(image.get("caption", []))
alt_text = caption if caption else "Image"
markdown_lines.append(f"{indent}![{alt_text}]({url})")
if caption:
markdown_lines.append(f"{indent}*{caption}*")
elif block_type == "video":
video = block["video"]
url = video.get("external", {}).get("url") or video.get("file", {}).get(
"url", ""
)
caption = parse_rich_text(video.get("caption", []))
markdown_lines.append(f"{indent}[Video]({url})")
if caption:
markdown_lines.append(f"{indent}*{caption}*")
elif block_type == "file":
file = block["file"]
url = file.get("external", {}).get("url") or file.get("file", {}).get("url", "")
caption = parse_rich_text(file.get("caption", []))
name = caption if caption else "File"
markdown_lines.append(f"{indent}[{name}]({url})")
elif block_type == "bookmark":
url = block["bookmark"].get("url", "")
caption = parse_rich_text(block["bookmark"].get("caption", []))
markdown_lines.append(f"{indent}[{caption if caption else url}]({url})")
elif block_type == "equation":
expression = block["equation"].get("expression", "")
markdown_lines.append(f"{indent}$${expression}$$")
elif block_type == "callout":
text = rich_text_to_markdown(block["callout"].get("rich_text", []))
icon = block["callout"].get("icon", {})
if icon.get("emoji"):
markdown_lines.append(f"{indent}> {icon['emoji']} {text}")
else:
markdown_lines.append(f"{indent}> {text}")
elif block_type == "child_page":
title = block["child_page"].get("title", "Untitled")
markdown_lines.append(f"{indent}📄 [{title}](notion://page/{block['id']})")
elif block_type == "child_database":
title = block["child_database"].get("title", "Untitled Database")
markdown_lines.append(f"{indent}🗂️ [{title}](notion://database/{block['id']})")
elif block_type == "table":
# Tables are complex - for now just indicate there's a table
markdown_lines.append(
f"{indent}[Table with {block['table'].get('table_width', 0)} columns]"
)
elif block_type == "column_list":
# Process columns
if block.get("children"):
markdown_lines.append(f"{indent}<div style='display: flex'>")
for column in block["children"]:
markdown_lines.append(f"{indent}<div style='flex: 1'>")
if column.get("children"):
for child in column["children"]:
child_markdown = block_to_markdown(child, indent_level + 1)
if child_markdown:
markdown_lines.append(child_markdown)
markdown_lines.append(f"{indent}</div>")
markdown_lines.append(f"{indent}</div>")
# Handle children for blocks that haven't been processed yet
elif block.get("children") and block_type not in ["toggle", "column_list"]:
for child in block["children"]:
child_markdown = block_to_markdown(child, indent_level)
if child_markdown:
markdown_lines.append(child_markdown)
return "\n".join(markdown_lines) if markdown_lines else ""
def blocks_to_markdown(blocks: List[dict]) -> str:
"""
Convert a list of Notion blocks to a markdown document.
Args:
blocks: List of block objects from Notion API.
Returns:
Complete markdown document as a string.
"""
markdown_parts = []
for i, block in enumerate(blocks):
markdown = block_to_markdown(block)
if markdown:
markdown_parts.append(markdown)
# Add spacing between top-level blocks (except lists)
if i < len(blocks) - 1:
next_type = blocks[i + 1].get("type", "")
current_type = block.get("type", "")
# Don't add extra spacing between list items
list_types = {"bulleted_list_item", "numbered_list_item", "to_do"}
if not (current_type in list_types and next_type in list_types):
markdown_parts.append("")
return "\n".join(markdown_parts)
def extract_page_title(page: dict) -> str:
"""
Extract the title from a Notion page object.
Args:
page: Page object from Notion API.
Returns:
Page title as a string.
"""
properties = page.get("properties", {})
# Find the title property (it has type "title")
for prop_name, prop_value in properties.items():
if prop_value.get("type") == "title":
return parse_rich_text(prop_value.get("title", []))
return "Untitled"

View File

@@ -1,42 +0,0 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import CredentialsField, CredentialsMetaInput, OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
secrets = Secrets()
NOTION_OAUTH_IS_CONFIGURED = bool(
secrets.notion_client_id and secrets.notion_client_secret
)
NotionCredentials = OAuth2Credentials
NotionCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.NOTION], Literal["oauth2"]
]
def NotionCredentialsField() -> NotionCredentialsInput:
"""Creates a Notion OAuth2 credentials field."""
return CredentialsField(
description="Connect your Notion account. Ensure the pages/databases are shared with the integration."
)
# Test credentials for Notion OAuth2
TEST_CREDENTIALS = OAuth2Credentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="notion",
access_token=SecretStr("test_access_token"),
title="Mock Notion OAuth",
scopes=["read_content", "insert_content", "update_content"],
username="testuser",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}

View File

@@ -1,360 +0,0 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import model_validator
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import OAuth2Credentials, SchemaField
from ._api import NotionClient
from ._auth import (
NOTION_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
NotionCredentialsField,
NotionCredentialsInput,
)
class NotionCreatePageBlock(Block):
"""Create a new page in Notion with content."""
class Input(BlockSchema):
credentials: NotionCredentialsInput = NotionCredentialsField()
parent_page_id: Optional[str] = SchemaField(
description="Parent page ID to create the page under. Either this OR parent_database_id is required.",
default=None,
)
parent_database_id: Optional[str] = SchemaField(
description="Parent database ID to create the page in. Either this OR parent_page_id is required.",
default=None,
)
title: str = SchemaField(
description="Title of the new page",
)
content: Optional[str] = SchemaField(
description="Content for the page. Can be plain text or markdown - will be converted to Notion blocks.",
default=None,
)
properties: Optional[Dict[str, Any]] = SchemaField(
description="Additional properties for database pages (e.g., {'Status': 'In Progress', 'Priority': 'High'})",
default=None,
)
icon_emoji: Optional[str] = SchemaField(
description="Emoji to use as the page icon (e.g., '📄', '🚀')", default=None
)
@model_validator(mode="after")
def validate_parent(self):
"""Ensure either parent_page_id or parent_database_id is provided."""
if not self.parent_page_id and not self.parent_database_id:
raise ValueError(
"Either parent_page_id or parent_database_id must be provided"
)
if self.parent_page_id and self.parent_database_id:
raise ValueError(
"Only one of parent_page_id or parent_database_id should be provided, not both"
)
return self
class Output(BlockSchema):
page_id: str = SchemaField(description="ID of the created page.")
page_url: str = SchemaField(description="URL of the created page.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="c15febe0-66ce-4c6f-aebd-5ab351653804",
description="Create a new page in Notion. Requires EITHER a parent_page_id OR parent_database_id. Supports markdown content.",
categories={BlockCategory.PRODUCTIVITY},
input_schema=NotionCreatePageBlock.Input,
output_schema=NotionCreatePageBlock.Output,
disabled=not NOTION_OAUTH_IS_CONFIGURED,
test_input={
"parent_page_id": "00000000-0000-0000-0000-000000000000",
"title": "Test Page",
"content": "This is test content.",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
("page_id", "12345678-1234-1234-1234-123456789012"),
(
"page_url",
"https://notion.so/Test-Page-12345678123412341234123456789012",
),
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"create_page": lambda *args, **kwargs: (
"12345678-1234-1234-1234-123456789012",
"https://notion.so/Test-Page-12345678123412341234123456789012",
)
},
)
@staticmethod
def _markdown_to_blocks(content: str) -> List[dict]:
"""Convert markdown content to Notion block objects."""
if not content:
return []
blocks = []
lines = content.split("\n")
i = 0
while i < len(lines):
line = lines[i]
# Skip empty lines
if not line.strip():
i += 1
continue
# Headings
if line.startswith("### "):
blocks.append(
{
"type": "heading_3",
"heading_3": {
"rich_text": [
{"type": "text", "text": {"content": line[4:].strip()}}
]
},
}
)
elif line.startswith("## "):
blocks.append(
{
"type": "heading_2",
"heading_2": {
"rich_text": [
{"type": "text", "text": {"content": line[3:].strip()}}
]
},
}
)
elif line.startswith("# "):
blocks.append(
{
"type": "heading_1",
"heading_1": {
"rich_text": [
{"type": "text", "text": {"content": line[2:].strip()}}
]
},
}
)
# Bullet points
elif line.strip().startswith("- "):
blocks.append(
{
"type": "bulleted_list_item",
"bulleted_list_item": {
"rich_text": [
{
"type": "text",
"text": {"content": line.strip()[2:].strip()},
}
]
},
}
)
# Numbered list
elif line.strip() and line.strip()[0].isdigit() and ". " in line:
content_start = line.find(". ") + 2
blocks.append(
{
"type": "numbered_list_item",
"numbered_list_item": {
"rich_text": [
{
"type": "text",
"text": {"content": line[content_start:].strip()},
}
]
},
}
)
# Code block
elif line.strip().startswith("```"):
code_lines = []
language = line[3:].strip() or "plain text"
i += 1
while i < len(lines) and not lines[i].strip().startswith("```"):
code_lines.append(lines[i])
i += 1
blocks.append(
{
"type": "code",
"code": {
"rich_text": [
{
"type": "text",
"text": {"content": "\n".join(code_lines)},
}
],
"language": language,
},
}
)
# Quote
elif line.strip().startswith("> "):
blocks.append(
{
"type": "quote",
"quote": {
"rich_text": [
{
"type": "text",
"text": {"content": line.strip()[2:].strip()},
}
]
},
}
)
# Horizontal rule
elif line.strip() in ["---", "***", "___"]:
blocks.append({"type": "divider", "divider": {}})
# Regular paragraph
else:
# Parse for basic markdown formatting
text_content = line.strip()
rich_text = []
# Simple bold/italic parsing (this is simplified)
if "**" in text_content or "*" in text_content:
# For now, just pass as plain text
# A full implementation would parse and create proper annotations
rich_text = [{"type": "text", "text": {"content": text_content}}]
else:
rich_text = [{"type": "text", "text": {"content": text_content}}]
blocks.append(
{"type": "paragraph", "paragraph": {"rich_text": rich_text}}
)
i += 1
return blocks
@staticmethod
def _build_properties(
title: str, additional_properties: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Build properties object for page creation."""
properties: Dict[str, Any] = {
"title": {"title": [{"type": "text", "text": {"content": title}}]}
}
if additional_properties:
for key, value in additional_properties.items():
if key.lower() == "title":
continue # Skip title as we already have it
# Try to intelligently map property types
if isinstance(value, bool):
properties[key] = {"checkbox": value}
elif isinstance(value, (int, float)):
properties[key] = {"number": value}
elif isinstance(value, list):
# Assume multi-select
properties[key] = {
"multi_select": [{"name": str(item)} for item in value]
}
elif isinstance(value, str):
# Could be select, rich_text, or other types
# For simplicity, try common patterns
if key.lower() in ["status", "priority", "type", "category"]:
properties[key] = {"select": {"name": value}}
elif key.lower() in ["url", "link"]:
properties[key] = {"url": value}
elif key.lower() in ["email"]:
properties[key] = {"email": value}
else:
properties[key] = {
"rich_text": [{"type": "text", "text": {"content": value}}]
}
return properties
@staticmethod
async def create_page(
credentials: OAuth2Credentials,
title: str,
parent_page_id: Optional[str] = None,
parent_database_id: Optional[str] = None,
content: Optional[str] = None,
properties: Optional[Dict[str, Any]] = None,
icon_emoji: Optional[str] = None,
) -> tuple[str, str]:
"""
Create a new Notion page.
Returns:
Tuple of (page_id, page_url)
"""
if not parent_page_id and not parent_database_id:
raise ValueError(
"Either parent_page_id or parent_database_id must be provided"
)
if parent_page_id and parent_database_id:
raise ValueError(
"Only one of parent_page_id or parent_database_id should be provided, not both"
)
client = NotionClient(credentials)
# Build parent object
if parent_page_id:
parent = {"type": "page_id", "page_id": parent_page_id}
else:
parent = {"type": "database_id", "database_id": parent_database_id}
# Build properties
page_properties = NotionCreatePageBlock._build_properties(title, properties)
# Convert content to blocks if provided
children = None
if content:
children = NotionCreatePageBlock._markdown_to_blocks(content)
# Build icon if provided
icon = None
if icon_emoji:
icon = {"type": "emoji", "emoji": icon_emoji}
# Create the page
result = await client.create_page(
parent=parent, properties=page_properties, children=children, icon=icon
)
page_id = result.get("id", "")
page_url = result.get("url", "")
if not page_id or not page_url:
raise ValueError("Failed to get page ID or URL from Notion response")
return page_id, page_url
async def run(
self,
input_data: Input,
*,
credentials: OAuth2Credentials,
**kwargs,
) -> BlockOutput:
try:
page_id, page_url = await self.create_page(
credentials,
input_data.title,
input_data.parent_page_id,
input_data.parent_database_id,
input_data.content,
input_data.properties,
input_data.icon_emoji,
)
yield "page_id", page_id
yield "page_url", page_url
except Exception as e:
yield "error", str(e) if str(e) else "Unknown error"

View File

@@ -1,285 +0,0 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import OAuth2Credentials, SchemaField
from ._api import NotionClient, parse_rich_text
from ._auth import (
NOTION_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
NotionCredentialsField,
NotionCredentialsInput,
)
class NotionReadDatabaseBlock(Block):
"""Query a Notion database and retrieve entries with their properties."""
class Input(BlockSchema):
credentials: NotionCredentialsInput = NotionCredentialsField()
database_id: str = SchemaField(
description="Notion database ID. Must be accessible by the connected integration.",
)
filter_property: Optional[str] = SchemaField(
description="Property name to filter by (e.g., 'Status', 'Priority')",
default=None,
)
filter_value: Optional[str] = SchemaField(
description="Value to filter for in the specified property", default=None
)
sort_property: Optional[str] = SchemaField(
description="Property name to sort by", default=None
)
sort_direction: Optional[str] = SchemaField(
description="Sort direction: 'ascending' or 'descending'",
default="ascending",
)
limit: int = SchemaField(
description="Maximum number of entries to retrieve",
default=100,
ge=1,
le=100,
)
class Output(BlockSchema):
entries: List[Dict[str, Any]] = SchemaField(
description="List of database entries with their properties."
)
entry: Dict[str, Any] = SchemaField(
description="Individual database entry (yields one per entry found)."
)
entry_ids: List[str] = SchemaField(
description="List of entry IDs for batch operations."
)
entry_id: str = SchemaField(
description="Individual entry ID (yields one per entry found)."
)
count: int = SchemaField(description="Number of entries retrieved.")
database_title: str = SchemaField(description="Title of the database.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="fcd53135-88c9-4ba3-be50-cc6936286e6c",
description="Query a Notion database with optional filtering and sorting, returning structured entries.",
categories={BlockCategory.PRODUCTIVITY},
input_schema=NotionReadDatabaseBlock.Input,
output_schema=NotionReadDatabaseBlock.Output,
disabled=not NOTION_OAUTH_IS_CONFIGURED,
test_input={
"database_id": "00000000-0000-0000-0000-000000000000",
"limit": 10,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"entries",
[{"Name": "Test Entry", "Status": "Active", "_id": "test-123"}],
),
("entry_ids", ["test-123"]),
(
"entry",
{"Name": "Test Entry", "Status": "Active", "_id": "test-123"},
),
("entry_id", "test-123"),
("count", 1),
("database_title", "Test Database"),
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"query_database": lambda *args, **kwargs: (
[{"Name": "Test Entry", "Status": "Active", "_id": "test-123"}],
1,
"Test Database",
)
},
)
@staticmethod
def _parse_property_value(prop: dict) -> Any:
"""Parse a Notion property value into a simple Python type."""
prop_type = prop.get("type")
if prop_type == "title":
return parse_rich_text(prop.get("title", []))
elif prop_type == "rich_text":
return parse_rich_text(prop.get("rich_text", []))
elif prop_type == "number":
return prop.get("number")
elif prop_type == "select":
select = prop.get("select")
return select.get("name") if select else None
elif prop_type == "multi_select":
return [item.get("name") for item in prop.get("multi_select", [])]
elif prop_type == "date":
date = prop.get("date")
if date:
return date.get("start")
return None
elif prop_type == "checkbox":
return prop.get("checkbox", False)
elif prop_type == "url":
return prop.get("url")
elif prop_type == "email":
return prop.get("email")
elif prop_type == "phone_number":
return prop.get("phone_number")
elif prop_type == "people":
return [
person.get("name", person.get("id"))
for person in prop.get("people", [])
]
elif prop_type == "files":
files = prop.get("files", [])
return [
f.get(
"name",
f.get("external", {}).get("url", f.get("file", {}).get("url")),
)
for f in files
]
elif prop_type == "relation":
return [rel.get("id") for rel in prop.get("relation", [])]
elif prop_type == "formula":
formula = prop.get("formula", {})
return formula.get(formula.get("type"))
elif prop_type == "rollup":
rollup = prop.get("rollup", {})
return rollup.get(rollup.get("type"))
elif prop_type == "created_time":
return prop.get("created_time")
elif prop_type == "created_by":
return prop.get("created_by", {}).get(
"name", prop.get("created_by", {}).get("id")
)
elif prop_type == "last_edited_time":
return prop.get("last_edited_time")
elif prop_type == "last_edited_by":
return prop.get("last_edited_by", {}).get(
"name", prop.get("last_edited_by", {}).get("id")
)
else:
# Return the raw value for unknown types
return prop
@staticmethod
def _build_filter(property_name: str, value: str) -> dict:
"""Build a simple filter object for a property."""
# This is a simplified filter - in reality, you'd need to know the property type
# For now, we'll try common filter types
return {
"or": [
{"property": property_name, "rich_text": {"contains": value}},
{"property": property_name, "title": {"contains": value}},
{"property": property_name, "select": {"equals": value}},
{"property": property_name, "multi_select": {"contains": value}},
]
}
@staticmethod
async def query_database(
credentials: OAuth2Credentials,
database_id: str,
filter_property: Optional[str] = None,
filter_value: Optional[str] = None,
sort_property: Optional[str] = None,
sort_direction: str = "ascending",
limit: int = 100,
) -> tuple[List[Dict[str, Any]], int, str]:
"""
Query a Notion database and parse the results.
Returns:
Tuple of (entries_list, count, database_title)
"""
client = NotionClient(credentials)
# Build filter if specified
filter_obj = None
if filter_property and filter_value:
filter_obj = NotionReadDatabaseBlock._build_filter(
filter_property, filter_value
)
# Build sorts if specified
sorts = None
if sort_property:
sorts = [{"property": sort_property, "direction": sort_direction}]
# Query the database
result = await client.query_database(
database_id, filter_obj=filter_obj, sorts=sorts, page_size=limit
)
# Parse the entries
entries = []
for page in result.get("results", []):
entry = {}
properties = page.get("properties", {})
for prop_name, prop_value in properties.items():
entry[prop_name] = NotionReadDatabaseBlock._parse_property_value(
prop_value
)
# Add metadata
entry["_id"] = page.get("id")
entry["_url"] = page.get("url")
entry["_created_time"] = page.get("created_time")
entry["_last_edited_time"] = page.get("last_edited_time")
entries.append(entry)
# Get database title (we need to make a separate call for this)
try:
database_url = f"https://api.notion.com/v1/databases/{database_id}"
db_response = await client.requests.get(
database_url, headers=client.headers
)
if db_response.ok:
db_data = db_response.json()
db_title = parse_rich_text(db_data.get("title", []))
else:
db_title = "Unknown Database"
except Exception:
db_title = "Unknown Database"
return entries, len(entries), db_title
async def run(
self,
input_data: Input,
*,
credentials: OAuth2Credentials,
**kwargs,
) -> BlockOutput:
try:
entries, count, db_title = await self.query_database(
credentials,
input_data.database_id,
input_data.filter_property,
input_data.filter_value,
input_data.sort_property,
input_data.sort_direction or "ascending",
input_data.limit,
)
# Yield the complete list for batch operations
yield "entries", entries
# Extract and yield IDs as a list for batch operations
entry_ids = [entry["_id"] for entry in entries if "_id" in entry]
yield "entry_ids", entry_ids
# Yield each individual entry and its ID for single connections
for entry in entries:
yield "entry", entry
if "_id" in entry:
yield "entry_id", entry["_id"]
yield "count", count
yield "database_title", db_title
except Exception as e:
yield "error", str(e) if str(e) else "Unknown error"

View File

@@ -1,64 +0,0 @@
from __future__ import annotations
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import OAuth2Credentials, SchemaField
from ._api import NotionClient
from ._auth import (
NOTION_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
NotionCredentialsField,
NotionCredentialsInput,
)
class NotionReadPageBlock(Block):
"""Read a Notion page by ID and return its raw JSON."""
class Input(BlockSchema):
credentials: NotionCredentialsInput = NotionCredentialsField()
page_id: str = SchemaField(
description="Notion page ID. Must be accessible by the connected integration. You can get this from the page URL notion.so/A-Page-586edd711467478da59fe3ce29a1ffab would be 586edd711467478da59fe35e29a1ffab",
)
class Output(BlockSchema):
page: dict = SchemaField(description="Raw Notion page JSON.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="5246cc1d-34b7-452b-8fc5-3fb25fd8f542",
description="Read a Notion page by its ID and return its raw JSON.",
categories={BlockCategory.PRODUCTIVITY},
input_schema=NotionReadPageBlock.Input,
output_schema=NotionReadPageBlock.Output,
disabled=not NOTION_OAUTH_IS_CONFIGURED,
test_input={
"page_id": "00000000-0000-0000-0000-000000000000",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[("page", dict)],
test_credentials=TEST_CREDENTIALS,
test_mock={
"get_page": lambda *args, **kwargs: {"object": "page", "id": "mocked"}
},
)
@staticmethod
async def get_page(credentials: OAuth2Credentials, page_id: str) -> dict:
client = NotionClient(credentials)
return await client.get_page(page_id)
async def run(
self,
input_data: Input,
*,
credentials: OAuth2Credentials,
**kwargs,
) -> BlockOutput:
try:
page = await self.get_page(credentials, input_data.page_id)
yield "page", page
except Exception as e:
yield "error", str(e) if str(e) else "Unknown error"

View File

@@ -1,109 +0,0 @@
from __future__ import annotations
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import OAuth2Credentials, SchemaField
from ._api import NotionClient, blocks_to_markdown, extract_page_title
from ._auth import (
NOTION_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
NotionCredentialsField,
NotionCredentialsInput,
)
class NotionReadPageMarkdownBlock(Block):
"""Read a Notion page and convert it to clean Markdown format."""
class Input(BlockSchema):
credentials: NotionCredentialsInput = NotionCredentialsField()
page_id: str = SchemaField(
description="Notion page ID. Must be accessible by the connected integration. You can get this from the page URL notion.so/A-Page-586edd711467478da59fe35e29a1ffab would be 586edd711467478da59fe35e29a1ffab",
)
include_title: bool = SchemaField(
description="Whether to include the page title as a header in the markdown",
default=True,
)
class Output(BlockSchema):
markdown: str = SchemaField(description="Page content in Markdown format.")
title: str = SchemaField(description="Page title.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="d1312c4d-fae2-4e70-893d-f4d07cce1d4e",
description="Read a Notion page and convert it to Markdown format with proper formatting for headings, lists, links, and rich text.",
categories={BlockCategory.PRODUCTIVITY},
input_schema=NotionReadPageMarkdownBlock.Input,
output_schema=NotionReadPageMarkdownBlock.Output,
disabled=not NOTION_OAUTH_IS_CONFIGURED,
test_input={
"page_id": "00000000-0000-0000-0000-000000000000",
"include_title": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
("markdown", "# Test Page\n\nThis is test content."),
("title", "Test Page"),
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"get_page_markdown": lambda *args, **kwargs: (
"# Test Page\n\nThis is test content.",
"Test Page",
)
},
)
@staticmethod
async def get_page_markdown(
credentials: OAuth2Credentials, page_id: str, include_title: bool = True
) -> tuple[str, str]:
"""
Get a Notion page and convert it to markdown.
Args:
credentials: OAuth2 credentials for Notion.
page_id: The ID of the page to fetch.
include_title: Whether to include the page title in the markdown.
Returns:
Tuple of (markdown_content, title)
"""
client = NotionClient(credentials)
# Get page metadata
page = await client.get_page(page_id)
title = extract_page_title(page)
# Get all blocks from the page
blocks = await client.get_blocks(page_id, recursive=True)
# Convert blocks to markdown
content_markdown = blocks_to_markdown(blocks)
# Combine title and content if requested
if include_title and title:
full_markdown = f"# {title}\n\n{content_markdown}"
else:
full_markdown = content_markdown
return full_markdown, title
async def run(
self,
input_data: Input,
*,
credentials: OAuth2Credentials,
**kwargs,
) -> BlockOutput:
try:
markdown, title = await self.get_page_markdown(
credentials, input_data.page_id, input_data.include_title
)
yield "markdown", markdown
yield "title", title
except Exception as e:
yield "error", str(e) if str(e) else "Unknown error"

View File

@@ -1,225 +0,0 @@
from __future__ import annotations
from typing import List, Optional
from pydantic import BaseModel
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import OAuth2Credentials, SchemaField
from ._api import NotionClient, extract_page_title, parse_rich_text
from ._auth import (
NOTION_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
NotionCredentialsField,
NotionCredentialsInput,
)
class NotionSearchResult(BaseModel):
"""Typed model for Notion search results."""
id: str
type: str # 'page' or 'database'
title: str
url: str
created_time: Optional[str] = None
last_edited_time: Optional[str] = None
parent_type: Optional[str] = None # 'page', 'database', or 'workspace'
parent_id: Optional[str] = None
icon: Optional[str] = None # emoji icon if present
is_inline: Optional[bool] = None # for databases only
class NotionSearchBlock(Block):
"""Search across your Notion workspace for pages and databases."""
class Input(BlockSchema):
credentials: NotionCredentialsInput = NotionCredentialsField()
query: str = SchemaField(
description="Search query text. Leave empty to get all accessible pages/databases.",
default="",
)
filter_type: Optional[str] = SchemaField(
description="Filter results by type: 'page' or 'database'. Leave empty for both.",
default=None,
)
limit: int = SchemaField(
description="Maximum number of results to return", default=20, ge=1, le=100
)
class Output(BlockSchema):
results: List[NotionSearchResult] = SchemaField(
description="List of search results with title, type, URL, and metadata."
)
result: NotionSearchResult = SchemaField(
description="Individual search result (yields one per result found)."
)
result_ids: List[str] = SchemaField(
description="List of IDs from search results for batch operations."
)
count: int = SchemaField(description="Number of results found.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="313515dd-9848-46ea-9cd6-3c627c892c56",
description="Search your Notion workspace for pages and databases by text query.",
categories={BlockCategory.PRODUCTIVITY, BlockCategory.SEARCH},
input_schema=NotionSearchBlock.Input,
output_schema=NotionSearchBlock.Output,
disabled=not NOTION_OAUTH_IS_CONFIGURED,
test_input={
"query": "project",
"limit": 5,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"results",
[
NotionSearchResult(
id="123",
type="page",
title="Project Plan",
url="https://notion.so/Project-Plan-123",
)
],
),
("result_ids", ["123"]),
(
"result",
NotionSearchResult(
id="123",
type="page",
title="Project Plan",
url="https://notion.so/Project-Plan-123",
),
),
("count", 1),
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"search_workspace": lambda *args, **kwargs: (
[
NotionSearchResult(
id="123",
type="page",
title="Project Plan",
url="https://notion.so/Project-Plan-123",
)
],
1,
)
},
)
@staticmethod
async def search_workspace(
credentials: OAuth2Credentials,
query: str = "",
filter_type: Optional[str] = None,
limit: int = 20,
) -> tuple[List[NotionSearchResult], int]:
"""
Search the Notion workspace.
Returns:
Tuple of (results_list, count)
"""
client = NotionClient(credentials)
# Build filter if type is specified
filter_obj = None
if filter_type:
filter_obj = {"property": "object", "value": filter_type}
# Execute search
response = await client.search(
query=query, filter_obj=filter_obj, page_size=limit
)
# Parse results
results = []
for item in response.get("results", []):
result_data = {
"id": item.get("id", ""),
"type": item.get("object", ""),
"url": item.get("url", ""),
"created_time": item.get("created_time"),
"last_edited_time": item.get("last_edited_time"),
"title": "", # Will be set below
}
# Extract title based on type
if item.get("object") == "page":
# For pages, get the title from properties
result_data["title"] = extract_page_title(item)
# Add parent info
parent = item.get("parent", {})
if parent.get("type") == "page_id":
result_data["parent_type"] = "page"
result_data["parent_id"] = parent.get("page_id")
elif parent.get("type") == "database_id":
result_data["parent_type"] = "database"
result_data["parent_id"] = parent.get("database_id")
elif parent.get("type") == "workspace":
result_data["parent_type"] = "workspace"
# Add icon if present
icon = item.get("icon")
if icon and icon.get("type") == "emoji":
result_data["icon"] = icon.get("emoji")
elif item.get("object") == "database":
# For databases, get title from the title array
result_data["title"] = parse_rich_text(item.get("title", []))
# Add database-specific metadata
result_data["is_inline"] = item.get("is_inline", False)
# Add parent info
parent = item.get("parent", {})
if parent.get("type") == "page_id":
result_data["parent_type"] = "page"
result_data["parent_id"] = parent.get("page_id")
elif parent.get("type") == "workspace":
result_data["parent_type"] = "workspace"
# Add icon if present
icon = item.get("icon")
if icon and icon.get("type") == "emoji":
result_data["icon"] = icon.get("emoji")
results.append(NotionSearchResult(**result_data))
return results, len(results)
async def run(
self,
input_data: Input,
*,
credentials: OAuth2Credentials,
**kwargs,
) -> BlockOutput:
try:
results, count = await self.search_workspace(
credentials, input_data.query, input_data.filter_type, input_data.limit
)
# Yield the complete list for batch operations
yield "results", results
# Extract and yield IDs as a list for batch operations
result_ids = [r.id for r in results]
yield "result_ids", result_ids
# Yield each individual result for single connections
for result in results:
yield "result", result
yield "count", count
except Exception as e:
yield "error", str(e) if str(e) else "Unknown error"

View File

@@ -14,7 +14,6 @@ from backend.data.block import (
BlockType,
)
from backend.data.model import NodeExecutionStats, SchemaField
from backend.data.optional_block import get_optional_config
from backend.util import json
from backend.util.clients import get_database_manager_async_client
@@ -389,9 +388,7 @@ class SmartDecisionMakerBlock(Block):
return {"type": "function", "function": tool_function}
@staticmethod
async def _create_function_signature(
node_id: str, user_id: str | None = None, check_optional: bool = True
) -> list[dict[str, Any]]:
async def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
"""
Creates function signatures for tools linked to a specified node within a graph.
@@ -401,8 +398,6 @@ class SmartDecisionMakerBlock(Block):
Args:
node_id: The node_id for which to create function signatures.
user_id: The ID of the user, used for checking credential-based optional blocks.
check_optional: Whether to check and skip optional blocks based on their conditions.
Returns:
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
@@ -434,41 +429,6 @@ class SmartDecisionMakerBlock(Block):
if not sink_node:
raise ValueError(f"Sink node not found: {links[0].sink_id}")
# todo: use the renamed value of metadata when available
# Check if this node is marked as optional and should be skipped
optional_config = get_optional_config(sink_node.metadata)
if optional_config and optional_config.enabled and check_optional:
# Check conditions to determine if block should be skipped
skip_block = False
# Check credential availability if configured
if optional_config.conditions.on_missing_credentials and user_id:
# Get credential fields from the block
sink_block = sink_node.block
if hasattr(sink_block, "input_schema"):
creds_fields = sink_block.input_schema.get_credentials_fields()
if creds_fields:
# For Smart Decision Maker, we simplify by assuming
# credentials might be missing if optional is enabled
# Full check happens at execution time
logger.debug(
f"Optional block {sink_node.id} may be skipped based on credentials"
)
# Continue to exclude from available tools
continue
# If other conditions exist but can't be checked now, be conservative
if (
optional_config.conditions.input_flag
or optional_config.conditions.kv_flag
):
# These runtime flags can't be checked here, so exclude the block
logger.debug(
f"Optional block {sink_node.id} excluded due to runtime conditions"
)
continue
if sink_node.block_id == AgentExecutorBlock().id:
return_tool_functions.append(
await SmartDecisionMakerBlock._create_agent_function_signature(
@@ -496,9 +456,7 @@ class SmartDecisionMakerBlock(Block):
user_id: str,
**kwargs,
) -> BlockOutput:
tool_functions = await self._create_function_signature(
node_id, user_id=user_id, check_optional=True
)
tool_functions = await self._create_function_signature(node_id)
yield "tool_functions", json.dumps(tool_functions)
input_data.conversation_history = input_data.conversation_history or []
@@ -565,6 +523,7 @@ class SmartDecisionMakerBlock(Block):
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
json_format=False,
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,

View File

@@ -1,8 +0,0 @@
from backend.sdk import BlockCostType, ProviderBuilder
stagehand = (
ProviderBuilder("stagehand")
.with_api_key("STAGEHAND_API_KEY", "Stagehand API Key")
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -1,393 +0,0 @@
import logging
import signal
import threading
from contextlib import contextmanager
from enum import Enum
# Monkey patch Stagehands to prevent signal handling in worker threads
import stagehand.main
from stagehand import Stagehand
from backend.blocks.llm import (
MODEL_METADATA,
AICredentials,
AICredentialsField,
LlmModel,
ModelMetadata,
)
from backend.blocks.stagehand._config import stagehand as stagehand_provider
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
)
# Store the original method
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
def safe_register_signal_handlers(self):
"""Only register signal handlers in the main thread"""
if threading.current_thread() is threading.main_thread():
original_register_signal_handlers(self)
else:
# Skip signal handling in worker threads
pass
# Replace the method
stagehand.main.Stagehand._register_signal_handlers = safe_register_signal_handlers
@contextmanager
def disable_signal_handling():
"""Context manager to temporarily disable signal handling"""
if threading.current_thread() is not threading.main_thread():
# In worker threads, temporarily replace signal.signal with a no-op
original_signal = signal.signal
def noop_signal(*args, **kwargs):
pass
signal.signal = noop_signal
try:
yield
finally:
signal.signal = original_signal
else:
# In main thread, don't modify anything
yield
logger = logging.getLogger(__name__)
class StagehandRecommendedLlmModel(str, Enum):
"""
This is subset of LLModel from autogpt_platform/backend/backend/blocks/llm.py
It contains only the models recommended by Stagehand
"""
# OpenAI
GPT41 = "gpt-4.1-2025-04-14"
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
# Anthropic
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
@property
def provider_name(self) -> str:
"""
Returns the provider name for the model in the required format for Stagehand:
provider/model_name
"""
model_metadata = MODEL_METADATA[LlmModel(self.value)]
model_name = self.value
if len(model_name.split("/")) == 1 and not self.value.startswith(
model_metadata.provider
):
assert (
model_metadata.provider != "open_router"
), "Logic failed and open_router provider attempted to be prepended to model name! in stagehand/_config.py"
model_name = f"{model_metadata.provider}/{model_name}"
logger.error(f"Model name: {model_name}")
return model_name
@property
def provider(self) -> str:
return MODEL_METADATA[LlmModel(self.value)].provider
@property
def metadata(self) -> ModelMetadata:
return MODEL_METADATA[LlmModel(self.value)]
@property
def context_window(self) -> int:
return MODEL_METADATA[LlmModel(self.value)].context_window
@property
def max_output_tokens(self) -> int | None:
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
class StagehandObserveBlock(Block):
class Input(BlockSchema):
# Browserbase credentials (Stagehand provider) or raw API key
stagehand_credentials: CredentialsMetaInput = (
stagehand_provider.credentials_field(
description="Stagehand/Browserbase API key"
)
)
browserbase_project_id: str = SchemaField(
description="Browserbase project ID (required if using Browserbase)",
)
# Model selection and credentials (provider-discriminated like llm.py)
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
url: str = SchemaField(
description="URL to navigate to.",
)
instruction: str = SchemaField(
description="Natural language description of elements or actions to discover.",
)
iframes: bool = SchemaField(
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
default=True,
)
domSettleTimeoutMs: int = SchemaField(
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
default=45000,
)
class Output(BlockSchema):
selector: str = SchemaField(description="XPath selector to locate element.")
description: str = SchemaField(description="Human-readable description")
method: str | None = SchemaField(description="Suggested action method")
arguments: list[str] | None = SchemaField(
description="Additional action parameters"
)
def __init__(self):
super().__init__(
id="d3863944-0eaf-45c4-a0c9-63e0fe1ee8b9",
description="Find suggested actions for your workflows",
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
input_schema=StagehandObserveBlock.Input,
output_schema=StagehandObserveBlock.Output,
)
async def run(
self,
input_data: Input,
*,
stagehand_credentials: APIKeyCredentials,
model_credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(
api_key=stagehand_credentials.api_key.get_secret_value(),
project_id=input_data.browserbase_project_id,
model_name=input_data.model.provider_name,
model_api_key=model_credentials.api_key.get_secret_value(),
)
await stagehand.init()
page = stagehand.page
assert page is not None, "Stagehand page is not initialized"
await page.goto(input_data.url)
observe_results = await page.observe(
input_data.instruction,
iframes=input_data.iframes,
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
)
for result in observe_results:
yield "selector", result.selector
yield "description", result.description
yield "method", result.method
yield "arguments", result.arguments
class StagehandActBlock(Block):
class Input(BlockSchema):
# Browserbase credentials (Stagehand provider) or raw API key
stagehand_credentials: CredentialsMetaInput = (
stagehand_provider.credentials_field(
description="Stagehand/Browserbase API key"
)
)
browserbase_project_id: str = SchemaField(
description="Browserbase project ID (required if using Browserbase)",
)
# Model selection and credentials (provider-discriminated like llm.py)
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
url: str = SchemaField(
description="URL to navigate to.",
)
action: list[str] = SchemaField(
description="Action to perform. Suggested actions are: click, fill, type, press, scroll, select from dropdown. For multi-step actions, add an entry for each step.",
)
variables: dict[str, str] = SchemaField(
description="Variables to use in the action. Variables contains data you want the action to use.",
default_factory=dict,
)
iframes: bool = SchemaField(
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
default=True,
)
domSettleTimeoutMs: int = SchemaField(
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
default=45000,
)
timeoutMs: int = SchemaField(
description="Timeout in milliseconds for DOM ready. Extended timeout for slow-loading forms",
default=60000,
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the action was completed successfully"
)
message: str = SchemaField(description="Details about the actions execution.")
action: str = SchemaField(description="Action performed")
def __init__(self):
super().__init__(
id="86eba68b-9549-4c0b-a0db-47d85a56cc27",
description="Interact with a web page by performing actions on a web page. Use it to build self-healing and deterministic automations that adapt to website chang.",
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
input_schema=StagehandActBlock.Input,
output_schema=StagehandActBlock.Output,
)
async def run(
self,
input_data: Input,
*,
stagehand_credentials: APIKeyCredentials,
model_credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(
api_key=stagehand_credentials.api_key.get_secret_value(),
project_id=input_data.browserbase_project_id,
model_name=input_data.model.provider_name,
model_api_key=model_credentials.api_key.get_secret_value(),
)
await stagehand.init()
page = stagehand.page
assert page is not None, "Stagehand page is not initialized"
await page.goto(input_data.url)
for action in input_data.action:
action_results = await page.act(
action,
variables=input_data.variables,
iframes=input_data.iframes,
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
timeoutMs=input_data.timeoutMs,
)
yield "success", action_results.success
yield "message", action_results.message
yield "action", action_results.action
class StagehandExtractBlock(Block):
class Input(BlockSchema):
# Browserbase credentials (Stagehand provider) or raw API key
stagehand_credentials: CredentialsMetaInput = (
stagehand_provider.credentials_field(
description="Stagehand/Browserbase API key"
)
)
browserbase_project_id: str = SchemaField(
description="Browserbase project ID (required if using Browserbase)",
)
# Model selection and credentials (provider-discriminated like llm.py)
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
url: str = SchemaField(
description="URL to navigate to.",
)
instruction: str = SchemaField(
description="Natural language description of elements or actions to discover.",
)
iframes: bool = SchemaField(
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
default=True,
)
domSettleTimeoutMs: int = SchemaField(
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
default=45000,
)
class Output(BlockSchema):
extraction: str = SchemaField(description="Extracted data from the page.")
def __init__(self):
super().__init__(
id="fd3c0b18-2ba6-46ae-9339-fcb40537ad98",
description="Extract structured data from a webpage.",
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
input_schema=StagehandExtractBlock.Input,
output_schema=StagehandExtractBlock.Output,
)
async def run(
self,
input_data: Input,
*,
stagehand_credentials: APIKeyCredentials,
model_credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(
api_key=stagehand_credentials.api_key.get_secret_value(),
project_id=input_data.browserbase_project_id,
model_name=input_data.model.provider_name,
model_api_key=model_credentials.api_key.get_secret_value(),
)
await stagehand.init()
page = stagehand.page
assert page is not None, "Stagehand page is not initialized"
await page.goto(input_data.url)
extraction = await page.extract(
input_data.instruction,
iframes=input_data.iframes,
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
)
yield "extraction", str(extraction.model_dump()["extraction"])

View File

@@ -30,6 +30,7 @@ class TestLLMStatsTracking:
credentials=llm.TEST_CREDENTIALS,
llm_model=llm.LlmModel.GPT4O,
prompt=[{"role": "user", "content": "Hello"}],
json_format=False,
max_tokens=100,
)
@@ -41,8 +42,6 @@ class TestLLMStatsTracking:
@pytest.mark.asyncio
async def test_ai_structured_response_block_tracks_stats(self):
"""Test that AIStructuredResponseGeneratorBlock correctly tracks stats."""
from unittest.mock import patch
import backend.blocks.llm as llm
block = llm.AIStructuredResponseGeneratorBlock()
@@ -52,7 +51,7 @@ class TestLLMStatsTracking:
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
response='{"key1": "value1", "key2": "value2"}',
tool_calls=None,
prompt_tokens=15,
completion_tokens=25,
@@ -70,12 +69,10 @@ class TestLLMStatsTracking:
)
outputs = {}
# Mock secrets.token_hex to return consistent ID
with patch("secrets.token_hex", return_value="test123456"):
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
# Check stats
assert block.execution_stats.input_token_count == 15
@@ -146,7 +143,7 @@ class TestLLMStatsTracking:
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"wrong": "format"}</json_output>',
response='{"wrong": "format"}',
tool_calls=None,
prompt_tokens=10,
completion_tokens=15,
@@ -157,7 +154,7 @@ class TestLLMStatsTracking:
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
response='{"key1": "value1", "key2": "value2"}',
tool_calls=None,
prompt_tokens=20,
completion_tokens=25,
@@ -176,12 +173,10 @@ class TestLLMStatsTracking:
)
outputs = {}
# Mock secrets.token_hex to return consistent ID
with patch("secrets.token_hex", return_value="test123456"):
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
# Check stats - should accumulate both calls
# For 2 attempts: attempt 1 (failed) + attempt 2 (success) = 2 total
@@ -274,8 +269,7 @@ class TestLLMStatsTracking:
mock_response.choices = [
MagicMock(
message=MagicMock(
content='<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>',
tool_calls=None,
content='{"summary": "Test chunk summary"}', tool_calls=None
)
)
]
@@ -283,7 +277,7 @@ class TestLLMStatsTracking:
mock_response.choices = [
MagicMock(
message=MagicMock(
content='<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>',
content='{"final_summary": "Test final summary"}',
tool_calls=None,
)
)
@@ -304,13 +298,11 @@ class TestLLMStatsTracking:
max_tokens=1000, # Large enough to avoid chunking
)
# Mock secrets.token_hex to return consistent ID
with patch("secrets.token_hex", return_value="test123456"):
outputs = {}
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
outputs = {}
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
print(f"Actual calls made: {call_count}")
print(f"Block stats: {block.execution_stats}")
@@ -465,7 +457,7 @@ class TestLLMStatsTracking:
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"result": "test"}</json_output>',
response='{"result": "test"}',
tool_calls=None,
prompt_tokens=10,
completion_tokens=20,
@@ -484,12 +476,10 @@ class TestLLMStatsTracking:
# Run the block
outputs = {}
# Mock secrets.token_hex to return consistent ID
with patch("secrets.token_hex", return_value="test123456"):
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
# Block finished - now grab and assert stats
assert block.execution_stats is not None

View File

@@ -35,19 +35,20 @@ async def execute_graph(
logger.info("Input data: %s", input_data)
# --- Test adding new executions --- #
graph_exec = await agent_server.test_execute_graph(
response = await agent_server.test_execute_graph(
user_id=test_user.id,
graph_id=test_graph.id,
graph_version=test_graph.version,
node_input=input_data,
)
logger.info("Created execution with ID: %s", graph_exec.id)
graph_exec_id = response.graph_exec_id
logger.info("Created execution with ID: %s", graph_exec_id)
# Execution queue should be empty
logger.info("Waiting for execution to complete...")
result = await wait_execution(test_user.id, graph_exec.id, 30)
result = await wait_execution(test_user.id, graph_exec_id, 30)
logger.info("Execution completed with %d results", len(result))
return graph_exec.id
return graph_exec_id
@pytest.mark.asyncio(loop_scope="session")

View File

@@ -172,11 +172,6 @@ class FillTextTemplateBlock(Block):
format: str = SchemaField(
description="Template to format the text using `values`. Use Jinja2 syntax."
)
escape_html: bool = SchemaField(
default=False,
advanced=True,
description="Whether to escape special characters in the inserted values to be HTML-safe. Enable for HTML output, disable for plain text.",
)
class Output(BlockSchema):
output: str = SchemaField(description="Formatted text")
@@ -210,7 +205,6 @@ class FillTextTemplateBlock(Block):
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
formatter = text.TextFormatter(autoescape=input_data.escape_html)
yield "output", formatter.format_string(input_data.format, input_data.values)

View File

@@ -4,8 +4,6 @@ from logging import getLogger
from typing import Any, Dict, List, Union
from urllib.parse import urlencode
from pydantic import field_serializer
from backend.sdk import BaseModel, Credentials, Requests
logger = getLogger(__name__)
@@ -384,9 +382,8 @@ class CreatePostRequest(BaseModel):
# Advanced
metadata: List[Dict[str, Any]] | None = None
@field_serializer("date")
def serialize_date(self, value: datetime | None) -> str | None:
return value.isoformat() if value else None
class Config:
json_encoders = {datetime: lambda v: v.isoformat()}
class PostAuthor(BaseModel):

View File

@@ -6,6 +6,8 @@ from dotenv import load_dotenv
from backend.util.logging import configure_logging
os.environ["ENABLE_AUTH"] = "false"
load_dotenv()
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs

View File

@@ -1,8 +1,5 @@
from backend.server.v2.library.model import LibraryAgentPreset
from .graph import NodeModel
from .integrations import Webhook # noqa: F401
# Resolve Webhook forward references
# Resolve Webhook <- NodeModel forward reference
NodeModel.model_rebuild()
LibraryAgentPreset.model_rebuild()

View File

@@ -1,31 +1,57 @@
import logging
import uuid
from datetime import datetime, timezone
from typing import Optional
from typing import List, Optional
from autogpt_libs.api_key.keysmith import APIKeySmith
from autogpt_libs.api_key.key_manager import APIKeyManager
from prisma.enums import APIKeyPermission, APIKeyStatus
from prisma.errors import PrismaError
from prisma.models import APIKey as PrismaAPIKey
from prisma.types import APIKeyWhereUniqueInput
from pydantic import BaseModel, Field
from prisma.types import (
APIKeyCreateInput,
APIKeyUpdateInput,
APIKeyWhereInput,
APIKeyWhereUniqueInput,
)
from pydantic import BaseModel
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.data.db import BaseDbModel
logger = logging.getLogger(__name__)
keysmith = APIKeySmith()
class APIKeyInfo(BaseModel):
id: str
# Some basic exceptions
class APIKeyError(Exception):
"""Base exception for API key operations"""
pass
class APIKeyNotFoundError(APIKeyError):
"""Raised when an API key is not found"""
pass
class APIKeyPermissionError(APIKeyError):
"""Raised when there are permission issues with API key operations"""
pass
class APIKeyValidationError(APIKeyError):
"""Raised when API key validation fails"""
pass
class APIKey(BaseDbModel):
name: str
head: str = Field(
description=f"The first {APIKeySmith.HEAD_LENGTH} characters of the key"
)
tail: str = Field(
description=f"The last {APIKeySmith.TAIL_LENGTH} characters of the key"
)
status: APIKeyStatus
permissions: list[APIKeyPermission]
prefix: str
key: str
status: APIKeyStatus = APIKeyStatus.ACTIVE
permissions: List[APIKeyPermission]
postfix: str
created_at: datetime
last_used_at: Optional[datetime] = None
revoked_at: Optional[datetime] = None
@@ -34,211 +60,266 @@ class APIKeyInfo(BaseModel):
@staticmethod
def from_db(api_key: PrismaAPIKey):
return APIKeyInfo(
id=api_key.id,
name=api_key.name,
head=api_key.head,
tail=api_key.tail,
status=APIKeyStatus(api_key.status),
permissions=[APIKeyPermission(p) for p in api_key.permissions],
created_at=api_key.createdAt,
last_used_at=api_key.lastUsedAt,
revoked_at=api_key.revokedAt,
description=api_key.description,
user_id=api_key.userId,
)
try:
return APIKey(
id=api_key.id,
name=api_key.name,
prefix=api_key.prefix,
postfix=api_key.postfix,
key=api_key.key,
status=APIKeyStatus(api_key.status),
permissions=[APIKeyPermission(p) for p in api_key.permissions],
created_at=api_key.createdAt,
last_used_at=api_key.lastUsedAt,
revoked_at=api_key.revokedAt,
description=api_key.description,
user_id=api_key.userId,
)
except Exception as e:
logger.error(f"Error creating APIKey from db: {str(e)}")
raise APIKeyError(f"Failed to create API key object: {str(e)}")
class APIKeyInfoWithHash(APIKeyInfo):
hash: str
salt: str | None = None # None for legacy keys
def match(self, plaintext_key: str) -> bool:
"""Returns whether the given key matches this API key object."""
return keysmith.verify_key(plaintext_key, self.hash, self.salt)
class APIKeyWithoutHash(BaseModel):
id: str
name: str
prefix: str
postfix: str
status: APIKeyStatus
permissions: List[APIKeyPermission]
created_at: datetime
last_used_at: Optional[datetime]
revoked_at: Optional[datetime]
description: Optional[str]
user_id: str
@staticmethod
def from_db(api_key: PrismaAPIKey):
return APIKeyInfoWithHash(
**APIKeyInfo.from_db(api_key).model_dump(),
hash=api_key.hash,
salt=api_key.salt,
)
def without_hash(self) -> APIKeyInfo:
return APIKeyInfo(**self.model_dump(exclude={"hash", "salt"}))
try:
return APIKeyWithoutHash(
id=api_key.id,
name=api_key.name,
prefix=api_key.prefix,
postfix=api_key.postfix,
status=APIKeyStatus(api_key.status),
permissions=[APIKeyPermission(p) for p in api_key.permissions],
created_at=api_key.createdAt,
last_used_at=api_key.lastUsedAt,
revoked_at=api_key.revokedAt,
description=api_key.description,
user_id=api_key.userId,
)
except Exception as e:
logger.error(f"Error creating APIKeyWithoutHash from db: {str(e)}")
raise APIKeyError(f"Failed to create API key object: {str(e)}")
async def create_api_key(
async def generate_api_key(
name: str,
user_id: str,
permissions: list[APIKeyPermission],
permissions: List[APIKeyPermission],
description: Optional[str] = None,
) -> tuple[APIKeyInfo, str]:
) -> tuple[APIKeyWithoutHash, str]:
"""
Generate a new API key and store it in the database.
Returns the API key object (without hash) and the plain text key.
"""
generated_key = keysmith.generate_key()
try:
api_manager = APIKeyManager()
key = api_manager.generate_api_key()
saved_key_obj = await PrismaAPIKey.prisma().create(
data={
"id": str(uuid.uuid4()),
"name": name,
"head": generated_key.head,
"tail": generated_key.tail,
"hash": generated_key.hash,
"salt": generated_key.salt,
"permissions": [p for p in permissions],
"description": description,
"userId": user_id,
}
)
api_key = await PrismaAPIKey.prisma().create(
data=APIKeyCreateInput(
id=str(uuid.uuid4()),
name=name,
prefix=key.prefix,
postfix=key.postfix,
key=key.hash,
permissions=[p for p in permissions],
description=description,
userId=user_id,
)
)
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
api_key_without_hash = APIKeyWithoutHash.from_db(api_key)
return api_key_without_hash, key.raw
except PrismaError as e:
logger.error(f"Database error while generating API key: {str(e)}")
raise APIKeyError(f"Failed to generate API key: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while generating API key: {str(e)}")
raise APIKeyError(f"Failed to generate API key: {str(e)}")
async def get_active_api_keys_by_head(head: str) -> list[APIKeyInfoWithHash]:
results = await PrismaAPIKey.prisma().find_many(
where={"head": head, "status": APIKeyStatus.ACTIVE}
)
return [APIKeyInfoWithHash.from_db(key) for key in results]
async def validate_api_key(plaintext_key: str) -> Optional[APIKeyInfo]:
async def validate_api_key(plain_text_key: str) -> Optional[APIKey]:
"""
Validate an API key and return the API key object if valid and active.
Validate an API key and return the API key object if valid.
"""
try:
if not plaintext_key.startswith(APIKeySmith.PREFIX):
if not plain_text_key.startswith(APIKeyManager.PREFIX):
logger.warning("Invalid API key format")
return None
head = plaintext_key[: APIKeySmith.HEAD_LENGTH]
potential_matches = await get_active_api_keys_by_head(head)
prefix = plain_text_key[: APIKeyManager.PREFIX_LENGTH]
api_manager = APIKeyManager()
matched_api_key = next(
(pm for pm in potential_matches if pm.match(plaintext_key)),
None,
api_key = await PrismaAPIKey.prisma().find_first(
where=APIKeyWhereInput(prefix=prefix, status=(APIKeyStatus.ACTIVE))
)
if not matched_api_key:
# API key not found or invalid
if not api_key:
logger.warning(f"No active API key found with prefix {prefix}")
return None
# Migrate legacy keys to secure format on successful validation
if matched_api_key.salt is None:
matched_api_key = await _migrate_key_to_secure_hash(
plaintext_key, matched_api_key
is_valid = api_manager.verify_api_key(plain_text_key, api_key.key)
if not is_valid:
logger.warning("API key verification failed")
return None
return APIKey.from_db(api_key)
except Exception as e:
logger.error(f"Error validating API key: {str(e)}")
raise APIKeyValidationError(f"Failed to validate API key: {str(e)}")
async def revoke_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
try:
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
if not api_key:
raise APIKeyNotFoundError(f"API key with id {key_id} not found")
if api_key.userId != user_id:
raise APIKeyPermissionError(
"You do not have permission to revoke this API key."
)
return matched_api_key.without_hash()
except Exception as e:
logger.error(f"Error while validating API key: {e}")
raise RuntimeError("Failed to validate API key") from e
async def _migrate_key_to_secure_hash(
plaintext_key: str, key_obj: APIKeyInfoWithHash
) -> APIKeyInfoWithHash:
"""Replace the SHA256 hash of a legacy API key with a salted Scrypt hash."""
try:
new_hash, new_salt = keysmith.hash_key(plaintext_key)
await PrismaAPIKey.prisma().update(
where={"id": key_obj.id}, data={"hash": new_hash, "salt": new_salt}
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
updated_api_key = await PrismaAPIKey.prisma().update(
where=where_clause,
data=APIKeyUpdateInput(
status=APIKeyStatus.REVOKED, revokedAt=datetime.now(timezone.utc)
),
)
logger.info(f"Migrated legacy API key #{key_obj.id} to secure format")
# Update the API key object with new values for return
key_obj.hash = new_hash
key_obj.salt = new_salt
except Exception as e:
logger.error(f"Failed to migrate legacy API key #{key_obj.id}: {e}")
return key_obj
async def revoke_api_key(key_id: str, user_id: str) -> APIKeyInfo:
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
if not api_key:
raise NotFoundError(f"API key with id {key_id} not found")
if api_key.userId != user_id:
raise NotAuthorizedError("You do not have permission to revoke this API key.")
updated_api_key = await PrismaAPIKey.prisma().update(
where={"id": key_id},
data={
"status": APIKeyStatus.REVOKED,
"revokedAt": datetime.now(timezone.utc),
},
)
if not updated_api_key:
raise NotFoundError(f"API key #{key_id} vanished while trying to revoke.")
return APIKeyInfo.from_db(updated_api_key)
async def list_user_api_keys(user_id: str) -> list[APIKeyInfo]:
api_keys = await PrismaAPIKey.prisma().find_many(
where={"userId": user_id}, order={"createdAt": "desc"}
)
return [APIKeyInfo.from_db(key) for key in api_keys]
async def suspend_api_key(key_id: str, user_id: str) -> APIKeyInfo:
selector: APIKeyWhereUniqueInput = {"id": key_id}
api_key = await PrismaAPIKey.prisma().find_unique(where=selector)
if not api_key:
raise NotFoundError(f"API key with id {key_id} not found")
if api_key.userId != user_id:
raise NotAuthorizedError("You do not have permission to suspend this API key.")
updated_api_key = await PrismaAPIKey.prisma().update(
where=selector, data={"status": APIKeyStatus.SUSPENDED}
)
if not updated_api_key:
raise NotFoundError(f"API key #{key_id} vanished while trying to suspend.")
return APIKeyInfo.from_db(updated_api_key)
def has_permission(api_key: APIKeyInfo, required_permission: APIKeyPermission) -> bool:
return required_permission in api_key.permissions
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyInfo]:
api_key = await PrismaAPIKey.prisma().find_first(
where={"id": key_id, "userId": user_id}
)
if not api_key:
if updated_api_key:
return APIKeyWithoutHash.from_db(updated_api_key)
return None
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
raise e
except PrismaError as e:
logger.error(f"Database error while revoking API key: {str(e)}")
raise APIKeyError(f"Failed to revoke API key: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while revoking API key: {str(e)}")
raise APIKeyError(f"Failed to revoke API key: {str(e)}")
return APIKeyInfo.from_db(api_key)
async def list_user_api_keys(user_id: str) -> List[APIKeyWithoutHash]:
try:
where_clause: APIKeyWhereInput = {"userId": user_id}
api_keys = await PrismaAPIKey.prisma().find_many(
where=where_clause, order={"createdAt": "desc"}
)
return [APIKeyWithoutHash.from_db(key) for key in api_keys]
except PrismaError as e:
logger.error(f"Database error while listing API keys: {str(e)}")
raise APIKeyError(f"Failed to list API keys: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while listing API keys: {str(e)}")
raise APIKeyError(f"Failed to list API keys: {str(e)}")
async def suspend_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
try:
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
if not api_key:
raise APIKeyNotFoundError(f"API key with id {key_id} not found")
if api_key.userId != user_id:
raise APIKeyPermissionError(
"You do not have permission to suspend this API key."
)
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
updated_api_key = await PrismaAPIKey.prisma().update(
where=where_clause,
data=APIKeyUpdateInput(status=APIKeyStatus.SUSPENDED),
)
if updated_api_key:
return APIKeyWithoutHash.from_db(updated_api_key)
return None
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
raise e
except PrismaError as e:
logger.error(f"Database error while suspending API key: {str(e)}")
raise APIKeyError(f"Failed to suspend API key: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while suspending API key: {str(e)}")
raise APIKeyError(f"Failed to suspend API key: {str(e)}")
def has_permission(api_key: APIKey, required_permission: APIKeyPermission) -> bool:
try:
return required_permission in api_key.permissions
except Exception as e:
logger.error(f"Error checking API key permissions: {str(e)}")
return False
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
try:
api_key = await PrismaAPIKey.prisma().find_first(
where=APIKeyWhereInput(id=key_id, userId=user_id)
)
if not api_key:
return None
return APIKeyWithoutHash.from_db(api_key)
except PrismaError as e:
logger.error(f"Database error while getting API key: {str(e)}")
raise APIKeyError(f"Failed to get API key: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while getting API key: {str(e)}")
raise APIKeyError(f"Failed to get API key: {str(e)}")
async def update_api_key_permissions(
key_id: str, user_id: str, permissions: list[APIKeyPermission]
) -> APIKeyInfo:
key_id: str, user_id: str, permissions: List[APIKeyPermission]
) -> Optional[APIKeyWithoutHash]:
"""
Update the permissions of an API key.
"""
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
try:
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
if api_key is None:
raise NotFoundError("No such API key found.")
if api_key is None:
raise APIKeyNotFoundError("No such API key found.")
if api_key.userId != user_id:
raise NotAuthorizedError("You do not have permission to update this API key.")
if api_key.userId != user_id:
raise APIKeyPermissionError(
"You do not have permission to update this API key."
)
updated_api_key = await PrismaAPIKey.prisma().update(
where={"id": key_id},
data={"permissions": permissions},
)
if not updated_api_key:
raise NotFoundError(f"API key #{key_id} vanished while trying to update.")
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
updated_api_key = await PrismaAPIKey.prisma().update(
where=where_clause,
data=APIKeyUpdateInput(permissions=permissions),
)
return APIKeyInfo.from_db(updated_api_key)
if updated_api_key:
return APIKeyWithoutHash.from_db(updated_api_key)
return None
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
raise e
except PrismaError as e:
logger.error(f"Database error while updating API key permissions: {str(e)}")
raise APIKeyError(f"Failed to update API key permissions: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while updating API key permissions: {str(e)}")
raise APIKeyError(f"Failed to update API key permissions: {str(e)}")

View File

@@ -1,3 +1,4 @@
import functools
import inspect
import logging
import os
@@ -7,7 +8,6 @@ from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Optional,
@@ -20,7 +20,6 @@ from typing import (
import jsonref
import jsonschema
from autogpt_libs.utils.cache import cached
from prisma.models import AgentBlock
from prisma.types import AgentBlockCreateInput
from pydantic import BaseModel
@@ -45,10 +44,9 @@ if TYPE_CHECKING:
app_config = Config()
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
BlockInput = dict[str, Any] # Input: 1 input pin consumes 1 data.
BlockOutputEntry = tuple[str, Any] # Output data should be a tuple of (name, value).
BlockOutput = AsyncGen[BlockOutputEntry, None] # Output: 1 output pin produces n data.
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
BlockOutput = AsyncGen[BlockData, None] # Output: 1 output pin produces n data.
CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a dict.
@@ -91,45 +89,6 @@ class BlockCategory(Enum):
return {"category": self.name, "description": self.value}
class BlockCostType(str, Enum):
RUN = "run" # cost X credits per run
BYTE = "byte" # cost X credits per byte
SECOND = "second" # cost X credits per second
class BlockCost(BaseModel):
cost_amount: int
cost_filter: BlockInput
cost_type: BlockCostType
def __init__(
self,
cost_amount: int,
cost_type: BlockCostType = BlockCostType.RUN,
cost_filter: Optional[BlockInput] = None,
**data: Any,
) -> None:
super().__init__(
cost_amount=cost_amount,
cost_filter=cost_filter or {},
cost_type=cost_type,
**data,
)
class BlockInfo(BaseModel):
id: str
name: str
inputSchema: dict[str, Any]
outputSchema: dict[str, Any]
costs: list[BlockCost]
description: str
categories: list[dict[str, str]]
contributors: list[dict[str, Any]]
staticOutput: bool
uiType: str
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any]]
@@ -347,7 +306,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
input_schema: Type[BlockSchemaInputType] = EmptySchema,
output_schema: Type[BlockSchemaOutputType] = EmptySchema,
test_input: BlockInput | list[BlockInput] | None = None,
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
test_output: BlockData | list[BlockData] | None = None,
test_mock: dict[str, Any] | None = None,
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
disabled: bool = False,
@@ -493,24 +452,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
"uiType": self.block_type.value,
}
def get_info(self) -> BlockInfo:
from backend.data.credit import get_block_cost
return BlockInfo(
id=self.id,
name=self.name,
inputSchema=self.input_schema.jsonschema(),
outputSchema=self.output_schema.jsonschema(),
costs=get_block_cost(self),
description=self.description,
categories=[category.dict() for category in self.categories],
contributors=[
contributor.model_dump() for contributor in self.contributors
],
staticOutput=self.static_output,
uiType=self.block_type.value,
)
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
if error := self.input_schema.validate_data(input_data):
raise ValueError(
@@ -722,7 +663,7 @@ def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
return cls() if cls else None
@cached()
@functools.cache
def get_webhook_block_ids() -> Sequence[str]:
return [
id
@@ -731,7 +672,7 @@ def get_webhook_block_ids() -> Sequence[str]:
]
@cached()
@functools.cache
def get_io_block_ids() -> Sequence[str]:
return [
id

View File

@@ -29,7 +29,8 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
from backend.data.block import Block, BlockCost, BlockCostType
from backend.data.block import Block
from backend.data.cost import BlockCost, BlockCostType
from backend.integrations.credentials_store import (
aiml_api_credentials,
anthropic_credentials,
@@ -306,18 +307,7 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
"type": ideogram_credentials.type,
}
},
),
BlockCost(
cost_amount=18,
cost_filter={
"ideogram_model_name": "V_3",
"credentials": {
"id": ideogram_credentials.id,
"provider": ideogram_credentials.provider,
"type": ideogram_credentials.type,
},
},
),
)
],
AIShortformVideoCreatorBlock: [
BlockCost(

View File

@@ -0,0 +1,32 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from backend.data.block import BlockInput
class BlockCostType(str, Enum):
RUN = "run" # cost X credits per run
BYTE = "byte" # cost X credits per byte
SECOND = "second" # cost X credits per second
class BlockCost(BaseModel):
cost_amount: int
cost_filter: BlockInput
cost_type: BlockCostType
def __init__(
self,
cost_amount: int,
cost_type: BlockCostType = BlockCostType.RUN,
cost_filter: Optional[BlockInput] = None,
**data: Any,
) -> None:
super().__init__(
cost_amount=cost_amount,
cost_filter=cost_filter or {},
cost_type=cost_type,
**data,
)

View File

@@ -2,7 +2,7 @@ import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, cast
from typing import Any, cast
import stripe
from prisma import Json
@@ -23,6 +23,7 @@ from pydantic import BaseModel
from backend.data import db
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCost
from backend.data.model import (
AutoTopUpConfig,
RefundRequest,
@@ -40,9 +41,6 @@ from backend.util.models import Pagination
from backend.util.retry import func_retry
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.data.block import Block, BlockCost
settings = Settings()
stripe.api_key = settings.secrets.stripe_api_key
logger = logging.getLogger(__name__)
@@ -999,14 +997,10 @@ def get_user_credit_model() -> UserCreditBase:
return UserCredit()
def get_block_costs() -> dict[str, list["BlockCost"]]:
def get_block_costs() -> dict[str, list[BlockCost]]:
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
def get_block_cost(block: "Block") -> list["BlockCost"]:
return BLOCK_COSTS.get(block.__class__, [])
async def get_stripe_customer_id(user_id: str) -> str:
user = await get_user_by_id(user_id)

View File

@@ -83,7 +83,7 @@ async def disconnect():
# Transaction timeout constant (in milliseconds)
TRANSACTION_TIMEOUT = 30000 # 30 seconds - Increased from 15s to prevent timeout errors during graph creation under load
TRANSACTION_TIMEOUT = 15000 # 15 seconds - Increased from 5s to prevent timeout errors
@asynccontextmanager

View File

@@ -11,14 +11,11 @@ from typing import (
Generator,
Generic,
Literal,
Mapping,
Optional,
TypeVar,
cast,
overload,
)
from prisma import Json
from prisma.enums import AgentExecutionStatus
from prisma.models import (
AgentGraphExecution,
@@ -27,6 +24,7 @@ from prisma.models import (
AgentNodeExecutionKeyValueData,
)
from prisma.types import (
AgentGraphExecutionCreateInput,
AgentGraphExecutionUpdateManyMutationInput,
AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput,
@@ -62,7 +60,7 @@ from .includes import (
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
graph_execution_include,
)
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
from .model import GraphExecutionStats, NodeExecutionStats
T = TypeVar("T")
@@ -89,37 +87,6 @@ class BlockErrorStats(BaseModel):
ExecutionStatus = AgentExecutionStatus
NodeInputMask = Mapping[str, JsonValue]
NodesInputMasks = Mapping[str, NodeInputMask]
# dest: source
VALID_STATUS_TRANSITIONS = {
ExecutionStatus.QUEUED: [
ExecutionStatus.INCOMPLETE,
],
ExecutionStatus.RUNNING: [
ExecutionStatus.INCOMPLETE,
ExecutionStatus.QUEUED,
ExecutionStatus.TERMINATED, # For resuming halted execution
],
ExecutionStatus.COMPLETED: [
ExecutionStatus.RUNNING,
],
ExecutionStatus.FAILED: [
ExecutionStatus.INCOMPLETE,
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
],
ExecutionStatus.TERMINATED: [
ExecutionStatus.INCOMPLETE,
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
],
ExecutionStatus.SKIPPED: [
ExecutionStatus.INCOMPLETE,
ExecutionStatus.QUEUED,
],
}
class GraphExecutionMeta(BaseDbModel):
@@ -127,15 +94,10 @@ class GraphExecutionMeta(BaseDbModel):
user_id: str
graph_id: str
graph_version: int
inputs: Optional[BlockInput] # no default -> required in the OpenAPI spec
credential_inputs: Optional[dict[str, CredentialsMetaInput]]
nodes_input_masks: Optional[dict[str, BlockInput]]
preset_id: Optional[str]
preset_id: Optional[str] = None
status: ExecutionStatus
started_at: datetime
ended_at: datetime
is_shared: bool = False
share_token: Optional[str] = None
class Stats(BaseModel):
model_config = ConfigDict(
@@ -217,18 +179,6 @@ class GraphExecutionMeta(BaseDbModel):
user_id=_graph_exec.userId,
graph_id=_graph_exec.agentGraphId,
graph_version=_graph_exec.agentGraphVersion,
inputs=cast(BlockInput | None, _graph_exec.inputs),
credential_inputs=(
{
name: CredentialsMetaInput.model_validate(cmi)
for name, cmi in cast(dict, _graph_exec.credentialInputs).items()
}
if _graph_exec.credentialInputs
else None
),
nodes_input_masks=cast(
dict[str, BlockInput] | None, _graph_exec.nodesInputMasks
),
preset_id=_graph_exec.agentPresetId,
status=ExecutionStatus(_graph_exec.executionStatus),
started_at=start_time,
@@ -252,13 +202,11 @@ class GraphExecutionMeta(BaseDbModel):
if stats
else None
),
is_shared=_graph_exec.isShared,
share_token=_graph_exec.shareToken,
)
class GraphExecution(GraphExecutionMeta):
inputs: BlockInput # type: ignore - incompatible override is intentional
inputs: BlockInput
outputs: CompletedBlockOutput
@staticmethod
@@ -278,18 +226,15 @@ class GraphExecution(GraphExecutionMeta):
)
inputs = {
**(
graph_exec.inputs
or {
# fallback: extract inputs from Agent Input Blocks
exec.input_data["name"]: exec.input_data.get("value")
for exec in complete_node_executions
if (
(block := get_block(exec.block_id))
and block.block_type == BlockType.INPUT
)
}
),
**{
# inputs from Agent Input Blocks
exec.input_data["name"]: exec.input_data.get("value")
for exec in complete_node_executions
if (
(block := get_block(exec.block_id))
and block.block_type == BlockType.INPUT
)
},
**{
# input from webhook-triggered block
"payload": exec.input_data["payload"]
@@ -307,13 +252,14 @@ class GraphExecution(GraphExecutionMeta):
if (
block := get_block(exec.block_id)
) and block.block_type == BlockType.OUTPUT:
outputs[exec.input_data["name"]].append(exec.input_data.get("value"))
outputs[exec.input_data["name"]].append(
exec.input_data.get("value", None)
)
return GraphExecution(
**{
field_name: getattr(graph_exec, field_name)
for field_name in GraphExecutionMeta.model_fields
if field_name != "inputs"
},
inputs=inputs,
outputs=outputs,
@@ -346,17 +292,13 @@ class GraphExecutionWithNodes(GraphExecution):
node_executions=node_executions,
)
def to_graph_execution_entry(
self,
user_context: "UserContext",
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
):
def to_graph_execution_entry(self, user_context: "UserContext"):
return GraphExecutionEntry(
user_id=self.user_id,
graph_id=self.graph_id,
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
nodes_input_masks=compiled_nodes_input_masks,
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
user_context=user_context,
)
@@ -393,7 +335,7 @@ class NodeExecutionResult(BaseModel):
else:
input_data: BlockInput = defaultdict()
for data in _node_exec.Input or []:
input_data[data.name] = type_utils.convert(data.data, JsonValue)
input_data[data.name] = type_utils.convert(data.data, type[Any])
output_data: CompletedBlockOutput = defaultdict(list)
@@ -402,7 +344,7 @@ class NodeExecutionResult(BaseModel):
output_data[name].extend(messages)
else:
for data in _node_exec.Output or []:
output_data[data.name].append(type_utils.convert(data.data, JsonValue))
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
if graph_execution:
@@ -597,12 +539,9 @@ async def get_graph_execution(
async def create_graph_execution(
graph_id: str,
graph_version: int,
starting_nodes_input: list[tuple[str, BlockInput]], # list[(node_id, BlockInput)]
inputs: Mapping[str, JsonValue],
starting_nodes_input: list[tuple[str, BlockInput]],
user_id: str,
preset_id: Optional[str] = None,
credential_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
preset_id: str | None = None,
) -> GraphExecutionWithNodes:
"""
Create a new AgentGraphExecution record.
@@ -610,18 +549,11 @@ async def create_graph_execution(
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
"""
result = await AgentGraphExecution.prisma().create(
data={
"agentGraphId": graph_id,
"agentGraphVersion": graph_version,
"executionStatus": ExecutionStatus.INCOMPLETE,
"inputs": SafeJson(inputs),
"credentialInputs": (
SafeJson(credential_inputs) if credential_inputs else Json({})
),
"nodesInputMasks": (
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
),
"NodeExecutions": {
data=AgentGraphExecutionCreateInput(
agentGraphId=graph_id,
agentGraphVersion=graph_version,
executionStatus=ExecutionStatus.QUEUED,
NodeExecutions={
"create": [
AgentNodeExecutionCreateInput(
agentNodeId=node_id,
@@ -637,9 +569,9 @@ async def create_graph_execution(
for node_id, node_input in starting_nodes_input
]
},
"userId": user_id,
"agentPresetId": preset_id,
},
userId=user_id,
agentPresetId=preset_id,
),
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
@@ -650,7 +582,7 @@ async def upsert_execution_input(
node_id: str,
graph_exec_id: str,
input_name: str,
input_data: JsonValue,
input_data: Any,
node_exec_id: str | None = None,
) -> tuple[str, BlockInput]:
"""
@@ -699,7 +631,7 @@ async def upsert_execution_input(
)
return existing_execution.id, {
**{
input_data.name: type_utils.convert(input_data.data, JsonValue)
input_data.name: type_utils.convert(input_data.data, type[Any])
for input_data in existing_execution.Input or []
},
input_name: input_data,
@@ -760,11 +692,6 @@ async def update_graph_execution_stats(
status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None,
) -> GraphExecution | None:
if not status and not stats:
raise ValueError(
f"Must provide either status or stats to update for execution {graph_exec_id}"
)
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
if stats:
@@ -776,25 +703,20 @@ async def update_graph_execution_stats(
if status:
update_data["executionStatus"] = status
where_clause: AgentGraphExecutionWhereInput = {"id": graph_exec_id}
if status:
if allowed_from := VALID_STATUS_TRANSITIONS.get(status, []):
# Add OR clause to check if current status is one of the allowed source statuses
where_clause["AND"] = [
{"id": graph_exec_id},
{"OR": [{"executionStatus": s} for s in allowed_from]},
]
else:
raise ValueError(
f"Status {status} cannot be set via update for execution {graph_exec_id}. "
f"This status can only be set at creation or is not a valid target status."
)
await AgentGraphExecution.prisma().update_many(
where=where_clause,
updated_count = await AgentGraphExecution.prisma().update_many(
where={
"id": graph_exec_id,
"OR": [
{"executionStatus": ExecutionStatus.RUNNING},
{"executionStatus": ExecutionStatus.QUEUED},
# Terminated graph can be resumed.
{"executionStatus": ExecutionStatus.TERMINATED},
],
},
data=update_data,
)
if updated_count == 0:
return None
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
where={"id": graph_exec_id},
@@ -802,7 +724,6 @@ async def update_graph_execution_stats(
[*get_io_block_ids(), *get_webhook_block_ids()]
),
)
return GraphExecution.from_db(graph_exec)
@@ -967,7 +888,7 @@ class GraphExecutionEntry(BaseModel):
graph_exec_id: str
graph_id: str
graph_version: int
nodes_input_masks: Optional[NodesInputMasks] = None
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
user_context: UserContext
@@ -1029,18 +950,6 @@ class NodeExecutionEvent(NodeExecutionResult):
)
class SharedExecutionResponse(BaseModel):
"""Public-safe response for shared executions"""
id: str
graph_name: str
graph_description: Optional[str]
status: ExecutionStatus
created_at: datetime
outputs: CompletedBlockOutput # Only the final outputs, no intermediate data
# Deliberately exclude: user_id, inputs, credentials, node details
ExecutionEvent = Annotated[
GraphExecutionEvent | NodeExecutionEvent, Field(discriminator="event_type")
]
@@ -1218,98 +1127,3 @@ async def get_block_error_stats(
)
for row in result
]
async def update_graph_execution_share_status(
execution_id: str,
user_id: str,
is_shared: bool,
share_token: str | None,
shared_at: datetime | None,
) -> None:
"""Update the sharing status of a graph execution."""
await AgentGraphExecution.prisma().update(
where={"id": execution_id},
data={
"isShared": is_shared,
"shareToken": share_token,
"sharedAt": shared_at,
},
)
async def get_graph_execution_by_share_token(
share_token: str,
) -> SharedExecutionResponse | None:
"""Get a shared execution with limited public-safe data."""
execution = await AgentGraphExecution.prisma().find_first(
where={
"shareToken": share_token,
"isShared": True,
"isDeleted": False,
},
include={
"AgentGraph": True,
"NodeExecutions": {
"include": {
"Output": True,
"Node": {
"include": {
"AgentBlock": True,
}
},
},
},
},
)
if not execution:
return None
# Extract outputs from OUTPUT blocks only (consistent with GraphExecution.from_db)
outputs: CompletedBlockOutput = defaultdict(list)
if execution.NodeExecutions:
for node_exec in execution.NodeExecutions:
if node_exec.Node and node_exec.Node.agentBlockId:
# Get the block definition to check its type
block = get_block(node_exec.Node.agentBlockId)
if block and block.block_type == BlockType.OUTPUT:
# For OUTPUT blocks, the data is stored in executionData or Input
# The executionData contains the structured input with 'name' and 'value' fields
if hasattr(node_exec, "executionData") and node_exec.executionData:
exec_data = type_utils.convert(
node_exec.executionData, dict[str, Any]
)
if "name" in exec_data:
name = exec_data["name"]
value = exec_data.get("value")
outputs[name].append(value)
elif node_exec.Input:
# Build input_data from Input relation
input_data = {}
for data in node_exec.Input:
if data.name and data.data is not None:
input_data[data.name] = type_utils.convert(
data.data, JsonValue
)
if "name" in input_data:
name = input_data["name"]
value = input_data.get("value")
outputs[name].append(value)
return SharedExecutionResponse(
id=execution.id,
graph_name=(
execution.AgentGraph.name
if (execution.AgentGraph and execution.AgentGraph.name)
else "Untitled Agent"
),
graph_description=(
execution.AgentGraph.description if execution.AgentGraph else None
),
status=ExecutionStatus(execution.executionStatus),
created_at=execution.createdAt,
outputs=outputs,
)

View File

@@ -1,7 +1,6 @@
import logging
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from prisma.enums import SubmissionStatus
@@ -13,7 +12,7 @@ from prisma.types import (
AgentNodeLinkCreateInput,
StoreListingVersionWhereInput,
)
from pydantic import BaseModel, Field, create_model
from pydantic import Field, JsonValue, create_model
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
@@ -29,14 +28,12 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName
from backend.util import type as type_utils
from backend.util.json import SafeJson
from backend.util.models import Pagination
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
from .db import BaseDbModel, query_raw_with_schema, transaction
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
if TYPE_CHECKING:
from .execution import NodesInputMasks
from .integrations import Webhook
logger = logging.getLogger(__name__)
@@ -162,8 +159,6 @@ class BaseGraph(BaseDbModel):
is_active: bool = True
name: str
description: str
instructions: str | None = None
recommended_schedule_cron: str | None = None
nodes: list[Node] = []
links: list[Link] = []
forked_from_id: str | None = None
@@ -210,35 +205,6 @@ class BaseGraph(BaseDbModel):
None,
)
@computed_field
@property
def trigger_setup_info(self) -> "GraphTriggerInfo | None":
if not (
self.webhook_input_node
and (trigger_block := self.webhook_input_node.block).webhook_config
):
return None
return GraphTriggerInfo(
provider=trigger_block.webhook_config.provider,
config_schema={
**(json_schema := trigger_block.input_schema.jsonschema()),
"properties": {
pn: sub_schema
for pn, sub_schema in json_schema["properties"].items()
if not is_credentials_field_name(pn)
},
"required": [
pn
for pn in json_schema.get("required", [])
if not is_credentials_field_name(pn)
],
},
credentials_input_name=next(
iter(trigger_block.input_schema.get_credentials_fields()), None
),
)
@staticmethod
def _generate_schema(
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
@@ -272,14 +238,6 @@ class BaseGraph(BaseDbModel):
}
class GraphTriggerInfo(BaseModel):
provider: ProviderName
config_schema: dict[str, Any] = Field(
description="Input schema for the trigger block"
)
credentials_input_name: Optional[str]
class Graph(BaseGraph):
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
@@ -384,8 +342,6 @@ class GraphModel(Graph):
user_id: str
nodes: list[NodeModel] = [] # type: ignore
created_at: datetime
@property
def starting_nodes(self) -> list[NodeModel]:
outbound_nodes = {link.sink_id for link in self.links}
@@ -398,10 +354,6 @@ class GraphModel(Graph):
if node.id not in outbound_nodes or node.id in input_nodes
]
@property
def webhook_input_node(self) -> NodeModel | None: # type: ignore
return cast(NodeModel, super().webhook_input_node)
def meta(self) -> "GraphMeta":
"""
Returns a GraphMeta object with metadata about the graph.
@@ -462,7 +414,7 @@ class GraphModel(Graph):
def validate_graph(
self,
for_run: bool = False,
nodes_input_masks: Optional["NodesInputMasks"] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
):
"""
Validate graph structure and raise `ValueError` on issues.
@@ -476,7 +428,7 @@ class GraphModel(Graph):
def _validate_graph(
graph: BaseGraph,
for_run: bool = False,
nodes_input_masks: Optional["NodesInputMasks"] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> None:
errors = GraphModel._validate_graph_get_errors(
graph, for_run, nodes_input_masks
@@ -490,7 +442,7 @@ class GraphModel(Graph):
def validate_graph_get_errors(
self,
for_run: bool = False,
nodes_input_masks: Optional["NodesInputMasks"] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> dict[str, dict[str, str]]:
"""
Validate graph and return structured errors per node.
@@ -512,7 +464,7 @@ class GraphModel(Graph):
def _validate_graph_get_errors(
graph: BaseGraph,
for_run: bool = False,
nodes_input_masks: Optional["NodesInputMasks"] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> dict[str, dict[str, str]]:
"""
Validate graph and return structured errors per node.
@@ -703,12 +655,9 @@ class GraphModel(Graph):
version=graph.version,
forked_from_id=graph.forkedFromId,
forked_from_version=graph.forkedFromVersion,
created_at=graph.createdAt,
is_active=graph.isActive,
name=graph.name or "",
description=graph.description or "",
instructions=graph.instructions,
recommended_schedule_cron=graph.recommendedScheduleCron,
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
links=list(
{
@@ -747,13 +696,6 @@ class GraphMeta(Graph):
return GraphMeta(**graph.model_dump())
class GraphsPaginated(BaseModel):
"""Response schema for paginated graphs."""
graphs: list[GraphMeta]
pagination: Pagination
# --------------------- CRUD functions --------------------- #
@@ -782,42 +724,31 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
return NodeModel.from_db(node)
async def list_graphs_paginated(
async def list_graphs(
user_id: str,
page: int = 1,
page_size: int = 25,
filter_by: Literal["active"] | None = "active",
) -> GraphsPaginated:
) -> list[GraphMeta]:
"""
Retrieves paginated graph metadata objects.
Retrieves graph metadata objects.
Default behaviour is to get all currently active graphs.
Args:
user_id: The ID of the user that owns the graphs.
page: Page number (1-based).
page_size: Number of graphs per page.
filter_by: An optional filter to either select graphs.
user_id: The ID of the user that owns the graph.
Returns:
GraphsPaginated: Paginated list of graph metadata.
list[GraphMeta]: A list of objects representing the retrieved graphs.
"""
where_clause: AgentGraphWhereInput = {"userId": user_id}
if filter_by == "active":
where_clause["isActive"] = True
# Get total count
total_count = await AgentGraph.prisma().count(where=where_clause)
total_pages = (total_count + page_size - 1) // page_size
# Get paginated results
offset = (page - 1) * page_size
graphs = await AgentGraph.prisma().find_many(
where=where_clause,
distinct=["id"],
order={"version": "desc"},
include=AGENT_GRAPH_INCLUDE,
skip=offset,
take=page_size,
)
graph_models: list[GraphMeta] = []
@@ -831,15 +762,7 @@ async def list_graphs_paginated(
logger.error(f"Error processing graph {graph.id}: {e}")
continue
return GraphsPaginated(
graphs=graph_models,
pagination=Pagination(
total_items=total_count,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)
return graph_models
async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph | None:
@@ -1122,7 +1045,6 @@ async def __create_graph(tx, graph: Graph, user_id: str):
version=graph.version,
name=graph.name,
description=graph.description,
recommendedScheduleCron=graph.recommended_schedule_cron,
isActive=graph.is_active,
userId=user_id,
forkedFromId=graph.forked_from_id,
@@ -1181,7 +1103,6 @@ def make_graph_model(creatable_graph: Graph, user_id: str) -> GraphModel:
return GraphModel(
**creatable_graph.model_dump(exclude={"nodes"}),
user_id=user_id,
created_at=datetime.now(tz=timezone.utc),
nodes=[
NodeModel(
**creatable_node.model_dump(),

View File

@@ -2,6 +2,7 @@ import json
from typing import Any
from uuid import UUID
import autogpt_libs.auth.models
import fastapi.exceptions
import pytest
from pytest_snapshot.plugin import Snapshot
@@ -316,7 +317,12 @@ async def test_access_store_listing_graph(server: SpinTestServer):
is_approved=True,
comments="Test comments",
),
user_id=admin_user.id,
autogpt_libs.auth.models.User(
user_id=admin_user.id,
role="admin",
email=admin_user.email,
phone_number="1234567890",
),
)
# Now we check the graph can be accessed by a user that does not own the graph

View File

@@ -59,15 +59,9 @@ def graph_execution_include(
}
AGENT_PRESET_INCLUDE: prisma.types.AgentPresetInclude = {
"InputPresets": True,
"Webhook": True,
}
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE},
"AgentPresets": {"include": AGENT_PRESET_INCLUDE},
"AgentPresets": {"include": {"InputPresets": True}},
}

View File

@@ -1,59 +0,0 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class ConditionOperator(str, Enum):
AND = "and"
OR = "or"
class OptionalBlockConditions(BaseModel):
"""Conditions that determine when a block should be skipped"""
on_missing_credentials: bool = Field(
default=False,
description="Skip block if any required credentials are missing",
)
check_skip_input: bool = Field(
default=True,
description="Check the standard 'skip' input to control skip behavior",
)
kv_flag: Optional[str] = Field(
default=None,
description="Key-value store flag name that controls skip behavior",
)
operator: ConditionOperator = Field(
default=ConditionOperator.OR,
description="Logical operator for combining conditions (AND/OR)",
)
class OptionalBlockConfig(BaseModel):
"""Configuration for making a block optional/skippable"""
enabled: bool = Field(
default=False,
description="Whether this block can be optionally skipped",
)
conditions: OptionalBlockConditions = Field(
default_factory=OptionalBlockConditions,
description="Conditions that trigger skipping",
)
skip_message: Optional[str] = Field(
default=None,
description="Custom message to log when block is skipped",
)
def get_optional_config(node_metadata: dict) -> Optional[OptionalBlockConfig]:
"""Extract optional block configuration from node metadata"""
if "optional" not in node_metadata:
return None
optional_data = node_metadata.get("optional", {})
if not optional_data:
return None
return OptionalBlockConfig(**optional_data)

View File

@@ -1,7 +1,8 @@
import logging
import os
from functools import cache
from autogpt_libs.utils.cache import cached, thread_cached
from autogpt_libs.utils.cache import thread_cached
from dotenv import load_dotenv
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
@@ -12,7 +13,7 @@ load_dotenv()
HOST = os.getenv("REDIS_HOST", "localhost")
PORT = int(os.getenv("REDIS_PORT", "6379"))
PASSWORD = os.getenv("REDIS_PASSWORD", None)
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
logger = logging.getLogger(__name__)
@@ -34,7 +35,7 @@ def disconnect():
get_redis().close()
@cached()
@cache
def get_redis() -> Redis:
return connect()

View File

@@ -7,7 +7,6 @@ from typing import Optional, cast
from urllib.parse import quote_plus
from autogpt_libs.auth.models import DEFAULT_USER_ID
from autogpt_libs.utils.cache import cached
from fastapi import HTTPException
from prisma.enums import NotificationType
from prisma.models import User as PrismaUser
@@ -24,11 +23,7 @@ from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
# Cache decorator alias for consistent user lookup caching
cache_user_lookup = cached(maxsize=1000, ttl_seconds=300)
@cache_user_lookup
async def get_or_create_user(user_data: dict) -> User:
try:
user_id = user_data.get("sub")
@@ -54,7 +49,6 @@ async def get_or_create_user(user_data: dict) -> User:
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
@cache_user_lookup
async def get_user_by_id(user_id: str) -> User:
user = await prisma.user.find_unique(where={"id": user_id})
if not user:
@@ -70,7 +64,6 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
raise DatabaseError(f"Failed to get user email for user {user_id}: {e}") from e
@cache_user_lookup
async def get_user_by_email(email: str) -> Optional[User]:
try:
user = await prisma.user.find_unique(where={"email": email})
@@ -81,17 +74,7 @@ async def get_user_by_email(email: str) -> Optional[User]:
async def update_user_email(user_id: str, email: str):
try:
# Get old email first for cache invalidation
old_user = await prisma.user.find_unique(where={"id": user_id})
old_email = old_user.email if old_user else None
await prisma.user.update(where={"id": user_id}, data={"email": email})
# Selectively invalidate only the specific user entries
get_user_by_id.cache_delete(user_id)
if old_email:
get_user_by_email.cache_delete(old_email)
get_user_by_email.cache_delete(email)
except Exception as e:
raise DatabaseError(
f"Failed to update user email for user {user_id}: {e}"
@@ -131,8 +114,6 @@ async def update_user_integrations(user_id: str, data: UserIntegrations):
where={"id": user_id},
data={"integrations": encrypted_data},
)
# Invalidate cache for this user
get_user_by_id.cache_delete(user_id)
async def migrate_and_encrypt_user_integrations():
@@ -304,10 +285,6 @@ async def update_user_notification_preference(
)
if not user:
raise ValueError(f"User not found with ID: {user_id}")
# Invalidate cache for this user since notification preferences are part of user data
get_user_by_id.cache_delete(user_id)
preferences: dict[NotificationType, bool] = {
NotificationType.AGENT_RUN: user.notifyOnAgentRun or True,
NotificationType.ZERO_BALANCE: user.notifyOnZeroBalance or True,
@@ -346,8 +323,6 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
where={"id": user_id},
data={"emailVerified": verified},
)
# Invalidate cache for this user
get_user_by_id.cache_delete(user_id)
except Exception as e:
raise DatabaseError(
f"Failed to set email verification status for user {user_id}: {e}"
@@ -432,10 +407,6 @@ async def update_user_timezone(user_id: str, timezone: str) -> User:
)
if not user:
raise ValueError(f"User not found with ID: {user_id}")
# Invalidate cache for this user
get_user_by_id.cache_delete(user_id)
return User.from_db(user)
except Exception as e:
raise DatabaseError(f"Failed to update timezone for user {user_id}: {e}") from e

View File

@@ -107,7 +107,7 @@ async def generate_activity_status_for_execution(
# Check if we have OpenAI API key
try:
settings = Settings()
if not settings.secrets.openai_internal_api_key:
if not settings.secrets.openai_api_key:
logger.debug(
"OpenAI API key not configured, skipping activity status generation"
)
@@ -187,7 +187,7 @@ async def generate_activity_status_for_execution(
credentials = APIKeyCredentials(
id="openai",
provider="openai",
api_key=SecretStr(settings.secrets.openai_internal_api_key),
api_key=SecretStr(settings.secrets.openai_api_key),
title="System OpenAI",
)
@@ -423,6 +423,7 @@ async def _call_llm_direct(
credentials=credentials,
llm_model=LlmModel.GPT4O_MINI,
prompt=prompt,
json_format=False,
max_tokens=150,
compress_prompt_to_fit=True,
)

View File

@@ -468,7 +468,7 @@ class TestGenerateActivityStatusForExecution:
):
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
mock_settings.return_value.secrets.openai_api_key = "test_key"
mock_llm.return_value = (
"I analyzed your data and provided the requested insights."
)
@@ -520,7 +520,7 @@ class TestGenerateActivityStatusForExecution:
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
):
mock_settings.return_value.secrets.openai_internal_api_key = ""
mock_settings.return_value.secrets.openai_api_key = ""
result = await generate_activity_status_for_execution(
graph_exec_id="test_exec",
@@ -546,7 +546,7 @@ class TestGenerateActivityStatusForExecution:
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
):
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
mock_settings.return_value.secrets.openai_api_key = "test_key"
result = await generate_activity_status_for_execution(
graph_exec_id="test_exec",
@@ -581,7 +581,7 @@ class TestGenerateActivityStatusForExecution:
):
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
mock_settings.return_value.secrets.openai_api_key = "test_key"
mock_llm.return_value = "Agent completed execution."
result = await generate_activity_status_for_execution(
@@ -633,7 +633,7 @@ class TestIntegration:
):
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
mock_settings.return_value.secrets.openai_api_key = "test_key"
mock_response = LLMResponse(
raw_response={},

View File

@@ -85,16 +85,6 @@ class DatabaseManager(AppService):
async def health_check(self) -> str:
if not db.is_connected():
raise UnhealthyServiceError("Database is not connected")
try:
# Test actual database connectivity by executing a simple query
# This will fail if Prisma query engine is not responding
result = await db.query_raw_with_schema("SELECT 1 as health_check")
if not result or result[0].get("health_check") != 1:
raise UnhealthyServiceError("Database query test failed")
except Exception as e:
raise UnhealthyServiceError(f"Database health check failed: {e}")
return await super().health_check()
@classmethod

View File

@@ -10,10 +10,10 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties
from pydantic import JsonValue
from redis.asyncio.lock import Lock as RedisLock
from backend.blocks.io import AgentOutputBlock
from backend.blocks.persistence import get_storage_key
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
@@ -28,7 +28,6 @@ from backend.executor.activity_status_generator import (
)
from backend.executor.utils import LogMetadata
from backend.notifications.notifications import queue_notification
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import InsufficientBalanceError, ModerationError
if TYPE_CHECKING:
@@ -39,9 +38,9 @@ from prometheus_client import Gauge, start_http_server
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis_client as redis
from backend.data.block import (
BlockData,
BlockInput,
BlockOutput,
BlockOutputEntry,
BlockSchema,
get_block,
)
@@ -53,11 +52,9 @@ from backend.data.execution import (
GraphExecutionEntry,
NodeExecutionEntry,
NodeExecutionResult,
NodesInputMasks,
UserContext,
)
from backend.data.graph import Link, Node
from backend.data.optional_block import get_optional_config
from backend.executor.utils import (
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
@@ -76,6 +73,7 @@ from backend.server.v2.AutoMod.manager import automod_manager
from backend.util import json
from backend.util.clients import (
get_async_execution_event_bus,
get_database_manager_async_client,
get_database_manager_client,
get_execution_event_bus,
get_notification_manager_client,
@@ -128,92 +126,12 @@ def execute_graph(
T = TypeVar("T")
async def should_skip_node(
node: Node,
creds_manager: IntegrationCredentialsManager,
user_id: str,
user_context: UserContext,
input_data: BlockInput,
graph_id: str,
) -> tuple[bool, str]:
"""
Check if a node should be skipped based on optional configuration.
Returns:
Tuple of (should_skip, skip_reason)
"""
optional_config = get_optional_config(node.metadata)
if not optional_config or not optional_config.enabled:
return False, ""
conditions = optional_config.conditions
skip_reasons = []
conditions_met = []
# Check credential availability
if conditions.on_missing_credentials:
node_block = node.block
input_model = cast(type[BlockSchema], node_block.input_schema)
for field_name, input_type in input_model.get_credentials_fields().items():
if field_name in input_data:
credentials_meta = input_type(**input_data[field_name])
# Check if credentials exist without acquiring lock
if not await creds_manager.exists(user_id, credentials_meta.id):
skip_reasons.append(f"Missing credentials: {field_name}")
conditions_met.append(True)
break
else:
# All credentials exist
if conditions.on_missing_credentials:
conditions_met.append(False)
# Check standard skip_run_block input (automatically added for optional blocks)
if conditions.check_skip_input and "skip_run_block" in input_data:
skip_value = input_data.get("skip_run_block", False)
if skip_value is True: # Skip if input is True
skip_reasons.append("Skip input is true")
conditions_met.append(True)
else:
conditions_met.append(False)
# Check key-value flag
if conditions.kv_flag:
# Determine storage key (assume within_agent scope for now)
storage_key = get_storage_key(conditions.kv_flag, "within_agent", graph_id)
# Retrieve the value from KV store
kv_value = await get_database_manager_async_client().get_execution_kv_data(
user_id=user_id,
key=storage_key,
)
# Skip if flag is True (treat missing/None as False)
if kv_value is True:
skip_reasons.append(f"KV flag '{conditions.kv_flag}' is true")
conditions_met.append(True)
else:
conditions_met.append(False)
# Apply logical operator
if not conditions_met:
return False, ""
if conditions.operator == "and":
should_skip = all(conditions_met)
else: # OR
should_skip = any(conditions_met)
skip_message = optional_config.skip_message or "; ".join(skip_reasons)
return should_skip, skip_message
async def execute_node(
node: Node,
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry,
execution_stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> BlockOutput:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -319,12 +237,12 @@ async def execute_node(
async def _enqueue_next_nodes(
db_client: "DatabaseManagerAsyncClient",
node: Node,
output: BlockOutputEntry,
output: BlockData,
user_id: str,
graph_exec_id: str,
graph_id: str,
log_metadata: LogMetadata,
nodes_input_masks: Optional[NodesInputMasks],
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
user_context: UserContext,
) -> list[NodeExecutionEntry]:
async def add_enqueued_execution(
@@ -501,7 +419,7 @@ class ExecutionProcessor:
self,
node_exec: NodeExecutionEntry,
node_exec_progress: NodeExecutionProgress,
nodes_input_masks: Optional[NodesInputMasks],
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
) -> NodeExecutionStats:
log_metadata = LogMetadata(
@@ -569,7 +487,7 @@ class ExecutionProcessor:
stats: NodeExecutionStats,
db_client: "DatabaseManagerAsyncClient",
log_metadata: LogMetadata,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> ExecutionStatus:
status = ExecutionStatus.RUNNING
@@ -593,28 +511,6 @@ class ExecutionProcessor:
)
try:
# Check if node should be skipped
should_skip, skip_reason = await should_skip_node(
node=node,
creds_manager=self.creds_manager,
user_id=node_exec.user_id,
user_context=node_exec.user_context,
input_data=node_exec.inputs,
graph_id=node_exec.graph_id,
)
if should_skip:
log_metadata.info(
f"Skipping node execution {node_exec.node_exec_id}: {skip_reason}"
)
await async_update_node_execution_status(
db_client=db_client,
exec_id=node_exec.node_exec_id,
status=ExecutionStatus.SKIPPED,
stats={"skip_reason": skip_reason},
)
return ExecutionStatus.SKIPPED
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
await async_update_node_execution_status(
db_client=db_client,
@@ -709,7 +605,7 @@ class ExecutionProcessor:
)
return
if exec_meta.status in [ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE]:
if exec_meta.status == ExecutionStatus.QUEUED:
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
exec_meta.status = ExecutionStatus.RUNNING
send_execution_update(
@@ -1157,7 +1053,7 @@ class ExecutionProcessor:
node_id: str,
graph_exec: GraphExecutionEntry,
log_metadata: LogMetadata,
nodes_input_masks: Optional[NodesInputMasks],
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
execution_queue: ExecutionQueue[NodeExecutionEntry],
) -> None:
"""Process a node's output, update its status, and enqueue next nodes.
@@ -1498,14 +1394,14 @@ class ExecutionManager(AppProcess):
delivery_tag = method.delivery_tag
@func_retry
def _ack_message(reject: bool, requeue: bool):
def _ack_message(reject: bool = False):
"""Acknowledge or reject the message based on execution status."""
# Connection can be lost, so always get a fresh channel
channel = self.run_client.get_channel()
if reject:
channel.connection.add_callback_threadsafe(
lambda: channel.basic_nack(delivery_tag, requeue=requeue)
lambda: channel.basic_nack(delivery_tag, requeue=True)
)
else:
channel.connection.add_callback_threadsafe(
@@ -1517,13 +1413,13 @@ class ExecutionManager(AppProcess):
logger.info(
f"[{self.service_name}] Rejecting new execution during shutdown"
)
_ack_message(reject=True, requeue=True)
_ack_message(reject=True)
return
# Check if we can accept more runs
self._cleanup_completed_runs()
if len(self.active_graph_runs) >= self.pool_size:
_ack_message(reject=True, requeue=True)
_ack_message(reject=True)
return
try:
@@ -1532,7 +1428,7 @@ class ExecutionManager(AppProcess):
logger.error(
f"[{self.service_name}] Could not parse run message: {e}, body={body}"
)
_ack_message(reject=True, requeue=False)
_ack_message(reject=True)
return
graph_exec_id = graph_exec_entry.graph_exec_id
@@ -1544,7 +1440,7 @@ class ExecutionManager(AppProcess):
logger.error(
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
)
_ack_message(reject=True, requeue=False)
_ack_message(reject=True)
return
cancel_event = threading.Event()
@@ -1560,9 +1456,9 @@ class ExecutionManager(AppProcess):
logger.error(
f"[{self.service_name}] Execution for {graph_exec_id} failed: {type(exec_error)} {exec_error}"
)
_ack_message(reject=True, requeue=True)
_ack_message(reject=True)
else:
_ack_message(reject=False, requeue=False)
_ack_message(reject=False)
except BaseException as e:
logger.exception(
f"[{self.service_name}] Error in run completion callback: {e}"

View File

@@ -1,5 +1,6 @@
import logging
import autogpt_libs.auth.models
import fastapi.responses
import pytest
@@ -35,20 +36,21 @@ async def execute_graph(
logger.info(f"Input data: {input_data}")
# --- Test adding new executions --- #
graph_exec = await agent_server.test_execute_graph(
response = await agent_server.test_execute_graph(
user_id=test_user.id,
graph_id=test_graph.id,
graph_version=test_graph.version,
node_input=input_data,
)
logger.info(f"Created execution with ID: {graph_exec.id}")
graph_exec_id = response.graph_exec_id
logger.info(f"Created execution with ID: {graph_exec_id}")
# Execution queue should be empty
logger.info("Waiting for execution to complete...")
result = await wait_execution(test_user.id, graph_exec.id, 30)
result = await wait_execution(test_user.id, graph_exec_id, 30)
logger.info(f"Execution completed with {len(result)} results")
assert len(result) == num_execs
return graph_exec.id
return graph_exec_id
async def assert_sample_graph_executions(
@@ -378,7 +380,7 @@ async def test_execute_preset(server: SpinTestServer):
# Verify execution
assert result is not None
graph_exec_id = result.id
graph_exec_id = result["id"]
# Wait for execution to complete
executions = await wait_execution(test_user.id, graph_exec_id)
@@ -467,7 +469,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
# Verify execution
assert result is not None, "Result must not be None"
graph_exec_id = result.id
graph_exec_id = result["id"]
# Wait for execution to complete
executions = await wait_execution(test_user.id, graph_exec_id)
@@ -519,7 +521,12 @@ async def test_store_listing_graph(server: SpinTestServer):
is_approved=True,
comments="Test comments",
),
user_id=admin_user.id,
autogpt_libs.auth.models.User(
user_id=admin_user.id,
role="admin",
email=admin_user.email,
phone_number="1234567890",
),
)
alt_test_user = admin_user

View File

@@ -191,22 +191,15 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
id: str
name: str
next_run_time: str
timezone: str = Field(default="UTC", description="Timezone used for scheduling")
@staticmethod
def from_db(
job_args: GraphExecutionJobArgs, job_obj: JobObj
) -> "GraphExecutionJobInfo":
# Extract timezone from the trigger if it's a CronTrigger
timezone_str = "UTC"
if hasattr(job_obj.trigger, "timezone"):
timezone_str = str(job_obj.trigger.timezone)
return GraphExecutionJobInfo(
id=job_obj.id,
name=job_obj.name,
next_run_time=job_obj.next_run_time.isoformat(),
timezone=timezone_str,
**job_args.model_dump(),
)
@@ -402,7 +395,6 @@ class Scheduler(AppService):
input_data: BlockInput,
input_credentials: dict[str, CredentialsMetaInput],
name: Optional[str] = None,
user_timezone: str | None = None,
) -> GraphExecutionJobInfo:
# Validate the graph before scheduling to prevent runtime failures
# We don't need the return value, just want the validation to run
@@ -416,18 +408,7 @@ class Scheduler(AppService):
)
)
# Use provided timezone or default to UTC
# Note: Timezone should be passed from the client to avoid database lookups
if not user_timezone:
user_timezone = "UTC"
logger.warning(
f"No timezone provided for user {user_id}, using UTC for scheduling. "
f"Client should pass user's timezone for correct scheduling."
)
logger.info(
f"Scheduling job for user {user_id} with timezone {user_timezone} (cron: {cron})"
)
logger.info(f"Scheduling job for user {user_id} in UTC (cron: {cron})")
job_args = GraphExecutionJobArgs(
user_id=user_id,
@@ -441,12 +422,12 @@ class Scheduler(AppService):
execute_graph,
kwargs=job_args.model_dump(),
name=name,
trigger=CronTrigger.from_crontab(cron, timezone=user_timezone),
trigger=CronTrigger.from_crontab(cron, timezone="UTC"),
jobstore=Jobstores.EXECUTION.value,
replace_existing=True,
)
logger.info(
f"Added job {job.id} with cron schedule '{cron}' in timezone {user_timezone}, input data: {input_data}"
f"Added job {job.id} with cron schedule '{cron}' in UTC, input data: {input_data}"
)
return GraphExecutionJobInfo.from_db(job_args, job)

View File

@@ -1,7 +1,6 @@
from typing import cast
import pytest
from pytest_mock import MockerFixture
from backend.executor.utils import merge_execution_input, parse_execution_output
from backend.util.mock import MockObject
@@ -277,147 +276,3 @@ def test_merge_execution_input():
result = merge_execution_input(data)
assert "mixed" in result
assert result["mixed"].attr[0]["key"] == "value3"
@pytest.mark.asyncio
async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
"""
Verify that calling the function with its own output creates the same execution again.
"""
from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput
from backend.executor.utils import add_graph_execution
from backend.integrations.providers import ProviderName
# Mock data
graph_id = "test-graph-id"
user_id = "test-user-id"
inputs = {"test_input": "test_value"}
preset_id = "test-preset-id"
graph_version = 1
graph_credentials_inputs = {
"cred_key": CredentialsMetaInput(
id="cred-id", provider=ProviderName("test_provider"), type="oauth2"
)
}
nodes_input_masks = {"node1": {"input1": "masked_value"}}
# Mock the graph object returned by validate_and_construct_node_execution_input
mock_graph = mocker.MagicMock()
mock_graph.version = graph_version
# Mock the starting nodes input and compiled nodes input masks
starting_nodes_input = [
("node1", {"input1": "value1"}),
("node2", {"input1": "value2"}),
]
compiled_nodes_input_masks = {"node1": {"input1": "compiled_mask"}}
# Mock the graph execution object
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
# Mock user context
mock_user_context = {"user_id": user_id, "context": "test_context"}
# Mock the queue and event bus
mock_queue = mocker.AsyncMock()
mock_event_bus = mocker.MagicMock()
mock_event_bus.publish = mocker.AsyncMock()
# Setup mocks
mock_validate = mocker.patch(
"backend.executor.utils.validate_and_construct_node_execution_input"
)
mock_edb = mocker.patch("backend.executor.utils.execution_db")
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_get_user_context = mocker.patch("backend.executor.utils.get_user_context")
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
# Setup mock returns
mock_validate.return_value = (
mock_graph,
starting_nodes_input,
compiled_nodes_input_masks,
)
mock_prisma.is_connected.return_value = True
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
return_value=mock_graph_exec
)
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
mock_get_user_context.return_value = mock_user_context
mock_get_queue.return_value = mock_queue
mock_get_event_bus.return_value = mock_event_bus
# Call the function - first execution
result1 = await add_graph_execution(
graph_id=graph_id,
user_id=user_id,
inputs=inputs,
preset_id=preset_id,
graph_version=graph_version,
graph_credentials_inputs=graph_credentials_inputs,
nodes_input_masks=nodes_input_masks,
)
# Store the parameters used in the first call to create_graph_execution
first_call_kwargs = mock_edb.create_graph_execution.call_args[1]
# Verify the create_graph_execution was called with correct parameters
mock_edb.create_graph_execution.assert_called_once_with(
user_id=user_id,
graph_id=graph_id,
graph_version=mock_graph.version,
inputs=inputs,
credential_inputs=graph_credentials_inputs,
nodes_input_masks=nodes_input_masks,
starting_nodes_input=starting_nodes_input,
preset_id=preset_id,
)
# Set up the graph execution mock to have properties we can extract
mock_graph_exec.graph_id = graph_id
mock_graph_exec.user_id = user_id
mock_graph_exec.graph_version = graph_version
mock_graph_exec.inputs = inputs
mock_graph_exec.credential_inputs = graph_credentials_inputs
mock_graph_exec.nodes_input_masks = nodes_input_masks
mock_graph_exec.preset_id = preset_id
# Create a second mock execution for the sanity check
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec_2.id = "execution-id-456"
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
# Reset mocks and set up for second call
mock_edb.create_graph_execution.reset_mock()
mock_edb.create_graph_execution.return_value = mock_graph_exec_2
mock_validate.reset_mock()
# Sanity check: call add_graph_execution with properties from first result
# This should create the same execution parameters
result2 = await add_graph_execution(
graph_id=mock_graph_exec.graph_id,
user_id=mock_graph_exec.user_id,
inputs=mock_graph_exec.inputs,
preset_id=mock_graph_exec.preset_id,
graph_version=mock_graph_exec.graph_version,
graph_credentials_inputs=mock_graph_exec.credential_inputs,
nodes_input_masks=mock_graph_exec.nodes_input_masks,
)
# Verify that create_graph_execution was called with identical parameters
second_call_kwargs = mock_edb.create_graph_execution.call_args[1]
# The sanity check: both calls should use identical parameters
assert first_call_kwargs == second_call_kwargs
# Both executions should succeed (though they create different objects)
assert result1 == mock_graph_exec
assert result2 == mock_graph_exec_2

View File

@@ -4,27 +4,20 @@ import threading
import time
from collections import defaultdict
from concurrent.futures import Future
from typing import Any, Mapping, Optional, cast
from typing import Any, Optional
from pydantic import BaseModel, JsonValue, ValidationError
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.block import (
Block,
BlockCostType,
BlockInput,
BlockOutputEntry,
BlockType,
get_block,
)
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCostType
from backend.data.db import prisma
from backend.data.execution import (
ExecutionStatus,
GraphExecutionStats,
GraphExecutionWithNodes,
NodesInputMasks,
UserContext,
)
from backend.data.graph import GraphModel, Node
@@ -246,7 +239,7 @@ def _tokenise(path: str) -> list[tuple[str, str]] | None:
# --------------------------------------------------------------------------- #
def parse_execution_output(output: BlockOutputEntry, name: str) -> JsonValue | None:
def parse_execution_output(output: BlockData, name: str) -> Any | None:
"""
Retrieve a nested value out of `output` using the flattened *name*.
@@ -270,7 +263,7 @@ def parse_execution_output(output: BlockOutputEntry, name: str) -> JsonValue | N
if tokens is None:
return None
cur: JsonValue = data
cur: Any = data
for delim, ident in tokens:
if delim == LIST_SPLIT:
# list[index]
@@ -435,7 +428,7 @@ def validate_exec(
async def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> dict[str, dict[str, str]]:
"""
Checks all credentials for all nodes of the graph and returns structured errors.
@@ -515,8 +508,8 @@ async def _validate_node_input_credentials(
def make_node_credentials_input_map(
graph: GraphModel,
graph_credentials_input: Mapping[str, CredentialsMetaInput],
) -> NodesInputMasks:
graph_credentials_input: dict[str, CredentialsMetaInput],
) -> dict[str, dict[str, JsonValue]]:
"""
Maps credentials for an execution to the correct nodes.
@@ -551,8 +544,8 @@ def make_node_credentials_input_map(
async def validate_graph_with_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> Mapping[str, Mapping[str, str]]:
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> dict[str, dict[str, str]]:
"""
Validate graph including credentials and return structured errors per node.
@@ -582,7 +575,7 @@ async def _construct_starting_node_execution_input(
graph: GraphModel,
user_id: str,
graph_inputs: BlockInput,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> list[tuple[str, BlockInput]]:
"""
Validates and prepares the input data for executing a graph.
@@ -623,7 +616,7 @@ async def _construct_starting_node_execution_input(
# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
input_name = cast(str | None, node.input_default.get("name"))
input_name = node.input_default.get("name")
if input_name and input_name in graph_inputs:
input_data = {"value": graph_inputs[input_name]}
@@ -650,9 +643,9 @@ async def validate_and_construct_node_execution_input(
user_id: str,
graph_inputs: BlockInput,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> tuple[GraphModel, list[tuple[str, BlockInput]], dict[str, dict[str, JsonValue]]]:
"""
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
This centralizes the logic used by both scheduler validation and actual execution.
@@ -666,9 +659,7 @@ async def validate_and_construct_node_execution_input(
nodes_input_masks: Node inputs to use.
Returns:
GraphModel: Full graph object for the given `graph_id`.
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
dict[str, BlockInput]: Node input masks including all passed-in credentials.
tuple[GraphModel, list[tuple[str, BlockInput]]]: Graph model and list of tuples for node execution input.
Raises:
NotFoundError: If the graph is not found.
@@ -709,11 +700,11 @@ async def validate_and_construct_node_execution_input(
def _merge_nodes_input_masks(
overrides_map_1: NodesInputMasks,
overrides_map_2: NodesInputMasks,
) -> NodesInputMasks:
overrides_map_1: dict[str, dict[str, JsonValue]],
overrides_map_2: dict[str, dict[str, JsonValue]],
) -> dict[str, dict[str, JsonValue]]:
"""Perform a per-node merge of input overrides"""
result = dict(overrides_map_1).copy()
result = overrides_map_1.copy()
for node_id, overrides2 in overrides_map_2.items():
if node_id in result:
result[node_id] = {**result[node_id], **overrides2}
@@ -863,8 +854,8 @@ async def add_graph_execution(
inputs: Optional[BlockInput] = None,
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
@@ -888,7 +879,7 @@ async def add_graph_execution(
else:
edb = get_database_manager_async_client()
graph, starting_nodes_input, compiled_nodes_input_masks = (
graph, starting_nodes_input, nodes_input_masks = (
await validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
@@ -901,43 +892,37 @@ async def add_graph_execution(
graph_exec = None
try:
# Sanity check: running add_graph_execution with the properties of
# the graph_exec created here should create the same execution again.
graph_exec = await edb.create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
inputs=inputs or {},
credential_inputs=graph_credentials_inputs,
nodes_input_masks=nodes_input_masks,
starting_nodes_input=starting_nodes_input,
preset_id=preset_id,
)
graph_exec_entry = graph_exec.to_graph_execution_entry(
user_context=await get_user_context(user_id),
compiled_nodes_input_masks=compiled_nodes_input_masks,
)
# Fetch user context for the graph execution
user_context = await get_user_context(user_id)
queue = await get_async_execution_queue()
graph_exec_entry = graph_exec.to_graph_execution_entry(user_context)
if nodes_input_masks:
graph_exec_entry.nodes_input_masks = nodes_input_masks
logger.info(
f"Created graph execution #{graph_exec.id} for graph "
f"#{graph_id} with {len(starting_nodes_input)} starting nodes. "
f"Now publishing to execution queue."
)
exec_queue = await get_async_execution_queue()
await exec_queue.publish_message(
await queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
graph_exec.status = ExecutionStatus.QUEUED
await edb.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=graph_exec.status,
)
await get_async_execution_event_bus().publish(graph_exec)
bus = get_async_execution_event_bus()
await bus.publish(graph_exec)
return graph_exec
except BaseException as e:
@@ -967,7 +952,7 @@ async def add_graph_execution(
class ExecutionOutputEntry(BaseModel):
node: Node
node_exec_id: str
data: BlockOutputEntry
data: BlockData
class NodeExecutionProgress:

View File

@@ -1,14 +1,13 @@
import functools
from typing import TYPE_CHECKING
from autogpt_libs.utils.cache import cached
if TYPE_CHECKING:
from ..providers import ProviderName
from ._base import BaseWebhooksManager
# --8<-- [start:load_webhook_managers]
@cached()
@functools.cache
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
webhook_managers = {}

View File

@@ -7,9 +7,10 @@ from backend.data.graph import set_node_webhook
from backend.integrations.creds_manager import IntegrationCredentialsManager
from . import get_webhook_manager, supports_webhooks
from .utils import setup_webhook_for_block
if TYPE_CHECKING:
from backend.data.graph import BaseGraph, GraphModel, NodeModel
from backend.data.graph import BaseGraph, GraphModel, Node, NodeModel
from backend.data.model import Credentials
from ._base import BaseWebhooksManager
@@ -42,19 +43,32 @@ async def _on_graph_activate(graph: "BaseGraph", user_id: str) -> "BaseGraph": .
async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
get_credentials = credentials_manager.cached_getter(user_id)
updated_nodes = []
for new_node in graph.nodes:
block_input_schema = cast(BlockSchema, new_node.block.input_schema)
for creds_field_name in block_input_schema.get_credentials_fields().keys():
# Prevent saving graph with non-existent credentials
if (
creds_meta := new_node.input_default.get(creds_field_name)
) and not await get_credentials(creds_meta["id"]):
raise ValueError(
f"Node #{new_node.id} input '{creds_field_name}' updated with "
f"non-existent credentials #{creds_meta['id']}"
node_credentials = None
if (
# Webhook-triggered blocks are only allowed to have 1 credentials input
(
creds_field_name := next(
iter(block_input_schema.get_credentials_fields()), None
)
)
and (creds_meta := new_node.input_default.get(creds_field_name))
and not (node_credentials := await get_credentials(creds_meta["id"]))
):
raise ValueError(
f"Node #{new_node.id} input '{creds_field_name}' updated with "
f"non-existent credentials #{creds_meta['id']}"
)
updated_node = await on_node_activate(
user_id, graph.id, new_node, credentials=node_credentials
)
updated_nodes.append(updated_node)
graph.nodes = updated_nodes
return graph
@@ -71,14 +85,20 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
block_input_schema = cast(BlockSchema, node.block.input_schema)
node_credentials = None
for creds_field_name in block_input_schema.get_credentials_fields().keys():
if (creds_meta := node.input_default.get(creds_field_name)) and not (
node_credentials := await get_credentials(creds_meta["id"])
):
logger.warning(
f"Node #{node.id} input '{creds_field_name}' referenced "
f"non-existent credentials #{creds_meta['id']}"
if (
# Webhook-triggered blocks are only allowed to have 1 credentials input
(
creds_field_name := next(
iter(block_input_schema.get_credentials_fields()), None
)
)
and (creds_meta := node.input_default.get(creds_field_name))
and not (node_credentials := await get_credentials(creds_meta["id"]))
):
logger.error(
f"Node #{node.id} input '{creds_field_name}' referenced non-existent "
f"credentials #{creds_meta['id']}"
)
updated_node = await on_node_deactivate(
user_id, node, credentials=node_credentials
@@ -89,6 +109,32 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
return graph
async def on_node_activate(
user_id: str,
graph_id: str,
node: "Node",
*,
credentials: Optional["Credentials"] = None,
) -> "Node":
"""Hook to be called when the node is activated/created"""
if node.block.webhook_config:
new_webhook, feedback = await setup_webhook_for_block(
user_id=user_id,
trigger_block=node.block,
trigger_config=node.input_default,
for_graph_id=graph_id,
)
if new_webhook:
node = await set_node_webhook(node.id, new_webhook.id)
else:
logger.debug(
f"Node #{node.id} does not have everything for a webhook: {feedback}"
)
return node
async def on_node_deactivate(
user_id: str,
node: "NodeModel",

View File

@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional, cast
from pydantic import JsonValue
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.settings import Config
from . import get_webhook_manager, supports_webhooks
@@ -12,7 +13,6 @@ if TYPE_CHECKING:
from backend.data.block import Block, BlockSchema
from backend.data.integrations import Webhook
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
logger = logging.getLogger(__name__)
app_config = Config()
@@ -20,7 +20,7 @@ credentials_manager = IntegrationCredentialsManager()
# TODO: add test to assert this matches the actual API route
def webhook_ingress_url(provider_name: "ProviderName", webhook_id: str) -> str:
def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
return (
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
f"/webhooks/{webhook_id}/ingress"
@@ -144,69 +144,3 @@ async def setup_webhook_for_block(
)
logger.debug(f"Acquired webhook: {webhook}")
return webhook, None
async def migrate_legacy_triggered_graphs():
from prisma.models import AgentGraph
from backend.data.graph import AGENT_GRAPH_INCLUDE, GraphModel, set_node_webhook
from backend.data.model import is_credentials_field_name
from backend.server.v2.library.db import create_preset
from backend.server.v2.library.model import LibraryAgentPresetCreatable
triggered_graphs = [
GraphModel.from_db(_graph)
for _graph in await AgentGraph.prisma().find_many(
where={
"isActive": True,
"Nodes": {"some": {"NOT": [{"webhookId": None}]}},
},
include=AGENT_GRAPH_INCLUDE,
)
]
n_migrated_webhooks = 0
for graph in triggered_graphs:
try:
if not (
(trigger_node := graph.webhook_input_node) and trigger_node.webhook_id
):
continue
# Use trigger node's inputs for the preset
preset_credentials = {
field_name: creds_meta
for field_name, creds_meta in trigger_node.input_default.items()
if is_credentials_field_name(field_name)
}
preset_inputs = {
field_name: value
for field_name, value in trigger_node.input_default.items()
if not is_credentials_field_name(field_name)
}
# Create a triggered preset for the graph
await create_preset(
graph.user_id,
LibraryAgentPresetCreatable(
graph_id=graph.id,
graph_version=graph.version,
inputs=preset_inputs,
credentials=preset_credentials,
name=graph.name,
description=graph.description,
webhook_id=trigger_node.webhook_id,
is_active=True,
),
)
# Detach webhook from the graph node
await set_node_webhook(trigger_node.id, None)
n_migrated_webhooks += 1
except Exception as e:
logger.error(f"Failed to migrate graph #{graph.id} trigger to preset: {e}")
continue
logger.info(f"Migrated {n_migrated_webhooks} node triggers to triggered presets")

View File

@@ -1,287 +0,0 @@
"""
Prometheus instrumentation for FastAPI services.
This module provides centralized metrics collection and instrumentation
for all FastAPI services in the AutoGPT platform.
"""
import logging
from typing import Optional
from fastapi import FastAPI
from prometheus_client import Counter, Gauge, Histogram, Info
from prometheus_fastapi_instrumentator import Instrumentator, metrics
logger = logging.getLogger(__name__)
# Custom business metrics with controlled cardinality
GRAPH_EXECUTIONS = Counter(
"autogpt_graph_executions_total",
"Total number of graph executions",
labelnames=[
"status"
], # Removed graph_id and user_id to prevent cardinality explosion
)
GRAPH_EXECUTIONS_BY_USER = Counter(
"autogpt_graph_executions_by_user_total",
"Total number of graph executions by user (sampled)",
labelnames=["status"], # Only status, user_id tracked separately when needed
)
BLOCK_EXECUTIONS = Counter(
"autogpt_block_executions_total",
"Total number of block executions",
labelnames=["block_type", "status"], # block_type is bounded
)
BLOCK_DURATION = Histogram(
"autogpt_block_duration_seconds",
"Duration of block executions in seconds",
labelnames=["block_type"],
buckets=[0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60],
)
WEBSOCKET_CONNECTIONS = Gauge(
"autogpt_websocket_connections_total",
"Total number of active WebSocket connections",
# Removed user_id label - track total only to prevent cardinality explosion
)
SCHEDULER_JOBS = Gauge(
"autogpt_scheduler_jobs",
"Current number of scheduled jobs",
labelnames=["job_type", "status"],
)
DATABASE_QUERIES = Histogram(
"autogpt_database_query_duration_seconds",
"Duration of database queries in seconds",
labelnames=["operation", "table"],
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5],
)
RABBITMQ_MESSAGES = Counter(
"autogpt_rabbitmq_messages_total",
"Total number of RabbitMQ messages",
labelnames=["queue", "status"],
)
AUTHENTICATION_ATTEMPTS = Counter(
"autogpt_auth_attempts_total",
"Total number of authentication attempts",
labelnames=["method", "status"],
)
API_KEY_USAGE = Counter(
"autogpt_api_key_usage_total",
"API key usage by provider",
labelnames=["provider", "block_type", "status"],
)
# Function/operation level metrics with controlled cardinality
GRAPH_OPERATIONS = Counter(
"autogpt_graph_operations_total",
"Graph operations by type",
labelnames=["operation", "status"], # create, update, delete, execute, etc.
)
USER_OPERATIONS = Counter(
"autogpt_user_operations_total",
"User operations by type",
labelnames=["operation", "status"], # login, register, update_profile, etc.
)
RATE_LIMIT_HITS = Counter(
"autogpt_rate_limit_hits_total",
"Number of rate limit hits",
labelnames=["endpoint"], # Removed user_id to prevent cardinality explosion
)
SERVICE_INFO = Info(
"autogpt_service",
"Service information",
)
def instrument_fastapi(
app: FastAPI,
service_name: str,
expose_endpoint: bool = True,
endpoint: str = "/metrics",
include_in_schema: bool = False,
excluded_handlers: Optional[list] = None,
) -> Instrumentator:
"""
Instrument a FastAPI application with Prometheus metrics.
Args:
app: FastAPI application instance
service_name: Name of the service for metrics labeling
expose_endpoint: Whether to expose /metrics endpoint
endpoint: Path for metrics endpoint
include_in_schema: Whether to include metrics endpoint in OpenAPI schema
excluded_handlers: List of paths to exclude from metrics
Returns:
Configured Instrumentator instance
"""
# Set service info
try:
from importlib.metadata import version
service_version = version("autogpt-platform-backend")
except Exception:
service_version = "unknown"
SERVICE_INFO.info(
{
"service": service_name,
"version": service_version,
}
)
# Create instrumentator with default metrics
instrumentator = Instrumentator(
should_group_status_codes=True,
should_ignore_untemplated=True,
should_respect_env_var=True,
should_instrument_requests_inprogress=True,
excluded_handlers=excluded_handlers or ["/health", "/readiness"],
env_var_name="ENABLE_METRICS",
inprogress_name="autogpt_http_requests_inprogress",
inprogress_labels=True,
)
# Add default HTTP metrics
instrumentator.add(
metrics.default(
metric_namespace="autogpt",
metric_subsystem=service_name.replace("-", "_"),
)
)
# Add request size metrics
instrumentator.add(
metrics.request_size(
metric_namespace="autogpt",
metric_subsystem=service_name.replace("-", "_"),
)
)
# Add response size metrics
instrumentator.add(
metrics.response_size(
metric_namespace="autogpt",
metric_subsystem=service_name.replace("-", "_"),
)
)
# Add latency metrics with custom buckets for better granularity
instrumentator.add(
metrics.latency(
metric_namespace="autogpt",
metric_subsystem=service_name.replace("-", "_"),
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60],
)
)
# Add combined metrics (requests by method and status)
instrumentator.add(
metrics.combined_size(
metric_namespace="autogpt",
metric_subsystem=service_name.replace("-", "_"),
)
)
# Instrument the app
instrumentator.instrument(app)
# Expose metrics endpoint if requested
if expose_endpoint:
instrumentator.expose(
app,
endpoint=endpoint,
include_in_schema=include_in_schema,
tags=["monitoring"] if include_in_schema else None,
)
logger.info(f"Metrics endpoint exposed at {endpoint} for {service_name}")
return instrumentator
def record_graph_execution(graph_id: str, status: str, user_id: str):
"""Record a graph execution event.
Args:
graph_id: Graph identifier (kept for future sampling/debugging)
status: Execution status (success/error/validation_error)
user_id: User identifier (kept for future sampling/debugging)
"""
# Track overall executions without high-cardinality labels
GRAPH_EXECUTIONS.labels(status=status).inc()
# Optionally track per-user executions (implement sampling if needed)
# For now, just track status to avoid cardinality explosion
GRAPH_EXECUTIONS_BY_USER.labels(status=status).inc()
def record_block_execution(block_type: str, status: str, duration: float):
"""Record a block execution event with duration."""
BLOCK_EXECUTIONS.labels(block_type=block_type, status=status).inc()
BLOCK_DURATION.labels(block_type=block_type).observe(duration)
def update_websocket_connections(user_id: str, delta: int):
"""Update the number of active WebSocket connections.
Args:
user_id: User identifier (kept for future sampling/debugging)
delta: Change in connection count (+1 for connect, -1 for disconnect)
"""
# Track total connections without user_id to prevent cardinality explosion
if delta > 0:
WEBSOCKET_CONNECTIONS.inc(delta)
else:
WEBSOCKET_CONNECTIONS.dec(abs(delta))
def record_database_query(operation: str, table: str, duration: float):
"""Record a database query with duration."""
DATABASE_QUERIES.labels(operation=operation, table=table).observe(duration)
def record_rabbitmq_message(queue: str, status: str):
"""Record a RabbitMQ message event."""
RABBITMQ_MESSAGES.labels(queue=queue, status=status).inc()
def record_authentication_attempt(method: str, status: str):
"""Record an authentication attempt."""
AUTHENTICATION_ATTEMPTS.labels(method=method, status=status).inc()
def record_api_key_usage(provider: str, block_type: str, status: str):
"""Record API key usage by provider and block."""
API_KEY_USAGE.labels(provider=provider, block_type=block_type, status=status).inc()
def record_rate_limit_hit(endpoint: str, user_id: str):
"""Record a rate limit hit.
Args:
endpoint: API endpoint that was rate limited
user_id: User identifier (kept for future sampling/debugging)
"""
RATE_LIMIT_HITS.labels(endpoint=endpoint).inc()
def record_graph_operation(operation: str, status: str):
"""Record a graph operation (create, update, delete, execute, etc.)."""
GRAPH_OPERATIONS.labels(operation=operation, status=status).inc()
def record_user_operation(operation: str, status: str):
"""Record a user operation (login, register, etc.)."""
USER_OPERATIONS.labels(operation=operation, status=status).inc()

View File

@@ -63,7 +63,7 @@ except ImportError:
# Cost System
try:
from backend.data.block import BlockCost, BlockCostType
from backend.data.cost import BlockCost, BlockCostType
except ImportError:
from backend.data.block_cost_config import BlockCost, BlockCostType

View File

@@ -8,7 +8,7 @@ from typing import Callable, List, Optional, Type
from pydantic import SecretStr
from backend.data.block import BlockCost, BlockCostType
from backend.data.cost import BlockCost, BlockCostType
from backend.data.model import (
APIKeyCredentials,
Credentials,

View File

@@ -8,8 +8,9 @@ BLOCK_COSTS configuration used by the execution system.
import logging
from typing import List, Type
from backend.data.block import Block, BlockCost
from backend.data.block import Block
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCost
from backend.sdk.registry import AutoRegistry
logger = logging.getLogger(__name__)

View File

@@ -7,7 +7,7 @@ from typing import Any, Callable, List, Optional, Set, Type
from pydantic import BaseModel, SecretStr
from backend.data.block import BlockCost
from backend.data.cost import BlockCost
from backend.data.model import (
APIKeyCredentials,
Credentials,

View File

@@ -6,10 +6,10 @@ import logging
import threading
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from pydantic import BaseModel
from pydantic import BaseModel, SecretStr
from backend.blocks.basic import Block
from backend.data.model import Credentials
from backend.data.model import APIKeyCredentials, Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks._base import BaseWebhooksManager
@@ -17,8 +17,6 @@ from backend.integrations.webhooks._base import BaseWebhooksManager
if TYPE_CHECKING:
from backend.sdk.provider import Provider
logger = logging.getLogger(__name__)
class SDKOAuthCredentials(BaseModel):
"""OAuth credentials configuration for SDK providers."""
@@ -104,8 +102,21 @@ class AutoRegistry:
"""Register an environment variable as an API key for a provider."""
with cls._lock:
cls._api_key_mappings[provider] = env_var_name
# Note: The credential itself is created by ProviderBuilder.with_api_key()
# We only store the mapping here to avoid duplication
# Dynamically check if the env var exists and create credential
import os
api_key = os.getenv(env_var_name)
if api_key:
credential = APIKeyCredentials(
id=f"{provider}-default",
provider=provider,
api_key=SecretStr(api_key),
title=f"Default {provider} credentials",
)
# Check if credential already exists to avoid duplicates
if not any(c.id == credential.id for c in cls._default_credentials):
cls._default_credentials.append(credential)
@classmethod
def get_all_credentials(cls) -> List[Credentials]:
@@ -199,43 +210,3 @@ class AutoRegistry:
webhooks.load_webhook_managers = patched_load
except Exception as e:
logging.warning(f"Failed to patch webhook managers: {e}")
# Patch credentials store to include SDK-registered credentials
try:
import sys
from typing import Any
# Get the module from sys.modules to respect mocking
if "backend.integrations.credentials_store" in sys.modules:
creds_store: Any = sys.modules["backend.integrations.credentials_store"]
else:
import backend.integrations.credentials_store
creds_store: Any = backend.integrations.credentials_store
if hasattr(creds_store, "IntegrationCredentialsStore"):
store_class = creds_store.IntegrationCredentialsStore
if hasattr(store_class, "get_all_creds"):
original_get_all_creds = store_class.get_all_creds
async def patched_get_all_creds(self, user_id: str):
# Get original credentials
original_creds = await original_get_all_creds(self, user_id)
# Add SDK-registered credentials
sdk_creds = cls.get_all_credentials()
# Combine credentials, avoiding duplicates by ID
existing_ids = {c.id for c in original_creds}
for cred in sdk_creds:
if cred.id not in existing_ids:
original_creds.append(cred)
return original_creds
store_class.get_all_creds = patched_get_all_creds
logger.info(
"Successfully patched IntegrationCredentialsStore.get_all_creds"
)
except Exception as e:
logging.warning(f"Failed to patch credentials store: {e}")

View File

@@ -11,45 +11,7 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
return snapshot
@pytest.fixture
def test_user_id() -> str:
"""Test user ID fixture."""
return "test-user-id"
@pytest.fixture
def admin_user_id() -> str:
"""Admin user ID fixture."""
return "admin-user-id"
@pytest.fixture
def target_user_id() -> str:
"""Target user ID fixture."""
return "target-user-id"
@pytest.fixture
def mock_jwt_user(test_user_id):
"""Provide mock JWT payload for regular user testing."""
import fastapi
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
return {"sub": test_user_id, "role": "user", "email": "test@example.com"}
return {"get_jwt_payload": override_get_jwt_payload, "user_id": test_user_id}
@pytest.fixture
def mock_jwt_admin(admin_user_id):
"""Provide mock JWT payload for admin user testing."""
import fastapi
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
return {
"sub": admin_user_id,
"role": "admin",
"email": "test-admin@example.com",
}
return {"get_jwt_payload": override_get_jwt_payload, "user_id": admin_user_id}
# Test ID constants
TEST_USER_ID = "test-user-id"
ADMIN_USER_ID = "admin-user-id"
TARGET_USER_ID = "target-user-id"

View File

@@ -88,7 +88,6 @@ async def test_send_graph_execution_result(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
preset_id=None,
status=ExecutionStatus.COMPLETED,
started_at=datetime.now(tz=timezone.utc),
ended_at=datetime.now(tz=timezone.utc),
@@ -102,8 +101,6 @@ async def test_send_graph_execution_result(
"input_1": "some input value :)",
"input_2": "some *other* input value",
},
credential_inputs=None,
nodes_input_masks=None,
outputs={
"the_output": ["some output value"],
"other_output": ["sike there was another output"],

View File

@@ -1,6 +1,5 @@
from fastapi import FastAPI
from backend.monitoring.instrumentation import instrument_fastapi
from backend.server.middleware.security import SecurityHeadersMiddleware
from .routes.v1 import v1_router
@@ -14,12 +13,3 @@ external_app = FastAPI(
external_app.add_middleware(SecurityHeadersMiddleware)
external_app.include_router(v1_router, prefix="/v1")
# Add Prometheus instrumentation
instrument_fastapi(
external_app,
service_name="external-api",
expose_endpoint=True,
endpoint="/metrics",
include_in_schema=True,
)

View File

@@ -1,14 +1,16 @@
from fastapi import HTTPException, Security
from fastapi import Depends, HTTPException, Request
from fastapi.security import APIKeyHeader
from prisma.enums import APIKeyPermission
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
from backend.data.api_key import has_permission, validate_api_key
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
api_key_header = APIKeyHeader(name="X-API-Key")
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
async def require_api_key(request: Request):
"""Base middleware for API key authentication"""
api_key = await api_key_header(request)
if api_key is None:
raise HTTPException(status_code=401, detail="Missing API key")
@@ -17,19 +19,18 @@ async def require_api_key(api_key: str | None = Security(api_key_header)) -> API
if not api_key_obj:
raise HTTPException(status_code=401, detail="Invalid API key")
request.state.api_key = api_key_obj
return api_key_obj
def require_permission(permission: APIKeyPermission):
"""Dependency function for checking specific permissions"""
async def check_permission(
api_key: APIKeyInfo = Security(require_api_key),
) -> APIKeyInfo:
async def check_permission(api_key=Depends(require_api_key)):
if not has_permission(api_key, permission):
raise HTTPException(
status_code=403,
detail=f"API key lacks the required permission '{permission}'",
detail=f"API key missing required permission: {permission}",
)
return api_key

Some files were not shown because too many files have changed in this diff Show More