mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 00:28:31 -05:00
Compare commits
5 Commits
copilot/fi
...
feat/execu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f652cb978 | ||
|
|
279552a2a3 | ||
|
|
fb6ac1d6ca | ||
|
|
9db15bff02 | ||
|
|
db4b94e0dc |
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
@@ -1,97 +0,0 @@
|
||||
name: Auto Fix CI Failures
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["CI"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
actions: read
|
||||
issues: write
|
||||
id-token: write # Required for OIDC token exchange
|
||||
|
||||
jobs:
|
||||
auto-fix:
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'failure' &&
|
||||
github.event.workflow_run.pull_requests[0] &&
|
||||
!startsWith(github.event.workflow_run.head_branch, 'claude-auto-fix-ci-')
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.workflow_run.head_branch }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Setup git identity
|
||||
run: |
|
||||
git config --global user.email "claude[bot]@users.noreply.github.com"
|
||||
git config --global user.name "claude[bot]"
|
||||
|
||||
- name: Create fix branch
|
||||
id: branch
|
||||
run: |
|
||||
BRANCH_NAME="claude-auto-fix-ci-${{ github.event.workflow_run.head_branch }}-${{ github.run_id }}"
|
||||
git checkout -b "$BRANCH_NAME"
|
||||
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const run = await github.rest.actions.getWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const jobs = await github.rest.actions.listJobsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const failedJobs = jobs.data.jobs.filter(job => job.conclusion === 'failure');
|
||||
|
||||
let errorLogs = [];
|
||||
for (const job of failedJobs) {
|
||||
const logs = await github.rest.actions.downloadJobLogsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
job_id: job.id
|
||||
});
|
||||
errorLogs.push({
|
||||
jobName: job.name,
|
||||
logs: logs.data
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
runUrl: run.data.html_url,
|
||||
failedJobs: failedJobs.map(j => j.name),
|
||||
errorLogs: errorLogs
|
||||
};
|
||||
|
||||
- name: Fix CI failures with Claude
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
prompt: |
|
||||
/fix-ci
|
||||
Failed CI Run: ${{ fromJSON(steps.failure_details.outputs.result).runUrl }}
|
||||
Failed Jobs: ${{ join(fromJSON(steps.failure_details.outputs.result).failedJobs, ', ') }}
|
||||
PR Number: ${{ github.event.workflow_run.pull_requests[0].number }}
|
||||
Branch Name: ${{ steps.branch.outputs.branch_name }}
|
||||
Base Branch: ${{ github.event.workflow_run.head_branch }}
|
||||
Repository: ${{ github.repository }}
|
||||
|
||||
Error logs:
|
||||
${{ toJSON(fromJSON(steps.failure_details.outputs.result).errorLogs) }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: "--allowedTools 'Edit,MultiEdit,Write,Read,Glob,Grep,LS,Bash(git:*),Bash(bun:*),Bash(npm:*),Bash(npx:*),Bash(gh:*)'"
|
||||
379
.github/workflows/claude-dependabot.yml
vendored
379
.github/workflows/claude-dependabot.yml
vendored
@@ -1,379 +0,0 @@
|
||||
# Claude Dependabot PR Review Workflow
|
||||
#
|
||||
# This workflow automatically runs Claude analysis on Dependabot PRs to:
|
||||
# - Identify dependency changes and their versions
|
||||
# - Look up changelogs for updated packages
|
||||
# - Assess breaking changes and security impacts
|
||||
# - Provide actionable recommendations for the development team
|
||||
#
|
||||
# Triggered on: Dependabot PRs (opened, synchronize)
|
||||
# Requirements: ANTHROPIC_API_KEY secret must be configured
|
||||
|
||||
name: Claude Dependabot PR Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
|
||||
jobs:
|
||||
dependabot-review:
|
||||
# Only run on Dependabot PRs
|
||||
if: github.actor == 'dependabot[bot]'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
|
||||
- name: Run Claude Dependabot Analysis
|
||||
id: claude_review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
You are Claude, an AI assistant specialized in reviewing Dependabot dependency update PRs.
|
||||
|
||||
Your primary tasks are:
|
||||
1. **Analyze the dependency changes** in this Dependabot PR
|
||||
2. **Look up changelogs** for all updated dependencies to understand what changed
|
||||
3. **Identify breaking changes** and assess potential impact on the AutoGPT codebase
|
||||
4. **Provide actionable recommendations** for the development team
|
||||
|
||||
## Analysis Process:
|
||||
|
||||
1. **Identify Changed Dependencies**:
|
||||
- Use git diff to see what dependencies were updated
|
||||
- Parse package.json, poetry.lock, requirements files, etc.
|
||||
- List all package versions: old → new
|
||||
|
||||
2. **Changelog Research**:
|
||||
- For each updated dependency, look up its changelog/release notes
|
||||
- Use WebFetch to access GitHub releases, NPM package pages, PyPI project pages. The pr should also have some details
|
||||
- Focus on versions between the old and new versions
|
||||
- Identify: breaking changes, deprecations, security fixes, new features
|
||||
|
||||
3. **Breaking Change Assessment**:
|
||||
- Categorize changes: BREAKING, MAJOR, MINOR, PATCH, SECURITY
|
||||
- Assess impact on AutoGPT's usage patterns
|
||||
- Check if AutoGPT uses affected APIs/features
|
||||
- Look for migration guides or upgrade instructions
|
||||
|
||||
4. **Codebase Impact Analysis**:
|
||||
- Search the AutoGPT codebase for usage of changed APIs
|
||||
- Identify files that might be affected by breaking changes
|
||||
- Check test files for deprecated usage patterns
|
||||
- Look for configuration changes needed
|
||||
|
||||
## Output Format:
|
||||
|
||||
Provide a comprehensive review comment with:
|
||||
|
||||
### 🔍 Dependency Analysis Summary
|
||||
- List of updated packages with version changes
|
||||
- Overall risk assessment (LOW/MEDIUM/HIGH)
|
||||
|
||||
### 📋 Detailed Changelog Review
|
||||
For each updated dependency:
|
||||
- **Package**: name (old_version → new_version)
|
||||
- **Changes**: Summary of key changes
|
||||
- **Breaking Changes**: List any breaking changes
|
||||
- **Security Fixes**: Note security improvements
|
||||
- **Migration Notes**: Any upgrade steps needed
|
||||
|
||||
### ⚠️ Impact Assessment
|
||||
- **Breaking Changes Found**: Yes/No with details
|
||||
- **Affected Files**: List AutoGPT files that may need updates
|
||||
- **Test Impact**: Any tests that may need updating
|
||||
- **Configuration Changes**: Required config updates
|
||||
|
||||
### 🛠️ Recommendations
|
||||
- **Action Required**: What the team should do
|
||||
- **Testing Focus**: Areas to test thoroughly
|
||||
- **Follow-up Tasks**: Any additional work needed
|
||||
- **Merge Recommendation**: APPROVE/REVIEW_NEEDED/HOLD
|
||||
|
||||
### 📚 Useful Links
|
||||
- Links to relevant changelogs, migration guides, documentation
|
||||
|
||||
Be thorough but concise. Focus on actionable insights that help the development team make informed decisions about the dependency updates.
|
||||
284
.github/workflows/claude.yml
vendored
284
.github/workflows/claude.yml
vendored
@@ -30,296 +30,18 @@ jobs:
|
||||
github.event.issue.author_association == 'COLLABORATOR'
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
contents: read
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
uses: anthropics/claude-code-action@beta
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*), Bash(gh pr edit:*)"
|
||||
--model opus
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
||||
@@ -3,7 +3,6 @@ name: AutoGPT Platform - Deploy Prod Environment
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -18,8 +17,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name || 'master' }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -39,7 +36,7 @@ jobs:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
runs-on: ubuntu-latest
|
||||
@@ -50,5 +47,4 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_prod
|
||||
client-payload: |
|
||||
{"ref": "${{ github.ref_name || 'master' }}", "repository": "${{ github.repository }}"}
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
@@ -5,13 +5,6 @@ on:
|
||||
branches: [ dev ]
|
||||
paths:
|
||||
- 'autogpt_platform/**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_ref:
|
||||
description: 'Git ref (branch/tag) of AutoGPT to deploy'
|
||||
required: true
|
||||
default: 'master'
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -26,8 +19,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -57,4 +48,4 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_dev
|
||||
client-payload: '{"ref": "${{ github.event.inputs.git_ref || github.ref }}", "repository": "${{ github.repository }}"}'
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
|
||||
113
.github/workflows/platform-container-publish.yml
vendored
113
.github/workflows/platform-container-publish.yml
vendored
@@ -1,113 +0,0 @@
|
||||
name: Platform - Container Publishing
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
no_cache:
|
||||
type: boolean
|
||||
description: 'Build from scratch, without using cached layers'
|
||||
default: false
|
||||
registry:
|
||||
type: choice
|
||||
description: 'Container registry to publish to'
|
||||
options:
|
||||
- 'both'
|
||||
- 'ghcr'
|
||||
- 'dockerhub'
|
||||
default: 'both'
|
||||
|
||||
env:
|
||||
GHCR_REGISTRY: ghcr.io
|
||||
GHCR_IMAGE_BASE: ${{ github.repository_owner }}/autogpt-platform
|
||||
DOCKERHUB_IMAGE_BASE: ${{ secrets.DOCKER_USER }}/autogpt-platform
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
build-and-publish:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
component: [backend, frontend]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
if: inputs.registry == 'both' || inputs.registry == 'ghcr' || github.event_name == 'release'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.GHCR_REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
if: (inputs.registry == 'both' || inputs.registry == 'dockerhub' || github.event_name == 'release') && secrets.DOCKER_USER
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_BASE }}-${{ matrix.component }}
|
||||
${{ secrets.DOCKER_USER && format('{0}-{1}', env.DOCKERHUB_IMAGE_BASE, matrix.component) || '' }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
|
||||
- name: Set build context and dockerfile for backend
|
||||
if: matrix.component == 'backend'
|
||||
run: |
|
||||
echo "BUILD_CONTEXT=." >> $GITHUB_ENV
|
||||
echo "DOCKERFILE=autogpt_platform/backend/Dockerfile" >> $GITHUB_ENV
|
||||
echo "BUILD_TARGET=server" >> $GITHUB_ENV
|
||||
|
||||
- name: Set build context and dockerfile for frontend
|
||||
if: matrix.component == 'frontend'
|
||||
run: |
|
||||
echo "BUILD_CONTEXT=." >> $GITHUB_ENV
|
||||
echo "DOCKERFILE=autogpt_platform/frontend/Dockerfile" >> $GITHUB_ENV
|
||||
echo "BUILD_TARGET=prod" >> $GITHUB_ENV
|
||||
|
||||
- name: Build and push container image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ${{ env.BUILD_CONTEXT }}
|
||||
file: ${{ env.DOCKERFILE }}
|
||||
target: ${{ env.BUILD_TARGET }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: ${{ !inputs.no_cache && 'type=gha' || '' }},scope=platform-${{ matrix.component }}
|
||||
cache-to: type=gha,scope=platform-${{ matrix.component }},mode=max
|
||||
|
||||
- name: Generate build summary
|
||||
run: |
|
||||
echo "## 🐳 Container Build Summary" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Component:** ${{ matrix.component }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Registry:** ${{ inputs.registry || 'both' }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Tags:** ${{ steps.meta.outputs.tags }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "### Images Published:" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
echo "${{ steps.meta.outputs.tags }}" | sed 's/,/\n/g' >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
2
.github/workflows/platform-frontend-ci.yml
vendored
2
.github/workflows/platform-frontend-ci.yml
vendored
@@ -160,7 +160,7 @@ jobs:
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||
docker compose -f ../docker-compose.yml up -d
|
||||
env:
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||
|
||||
@@ -61,27 +61,24 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && pnpm i
|
||||
cd frontend && npm install
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
npm run dev
|
||||
|
||||
# Run E2E tests
|
||||
pnpm test
|
||||
npm run test
|
||||
|
||||
# Run Storybook for component development
|
||||
pnpm storybook
|
||||
npm run storybook
|
||||
|
||||
# Build production
|
||||
pnpm build
|
||||
npm run build
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
npm run types
|
||||
```
|
||||
|
||||
We have a components library in autogpt_platform/frontend/src/components/atoms that should be used when adding new pages and components.
|
||||
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
|
||||
@@ -1,389 +0,0 @@
|
||||
# AutoGPT Platform Container Publishing
|
||||
|
||||
This document describes the container publishing infrastructure and deployment options for the AutoGPT Platform.
|
||||
|
||||
## Published Container Images
|
||||
|
||||
### GitHub Container Registry (GHCR) - Recommended
|
||||
|
||||
- **Backend**: `ghcr.io/significant-gravitas/autogpt-platform-backend`
|
||||
- **Frontend**: `ghcr.io/significant-gravitas/autogpt-platform-frontend`
|
||||
|
||||
### Docker Hub
|
||||
|
||||
- **Backend**: `significantgravitas/autogpt-platform-backend`
|
||||
- **Frontend**: `significantgravitas/autogpt-platform-frontend`
|
||||
|
||||
## Available Tags
|
||||
|
||||
- `latest` - Latest stable release from master branch
|
||||
- `v1.0.0`, `v1.1.0`, etc. - Specific version releases
|
||||
- `main` - Latest development build (use with caution)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Using Docker Compose (Recommended)
|
||||
|
||||
```bash
|
||||
# Clone the repository (or just download the compose file)
|
||||
git clone https://github.com/Significant-Gravitas/AutoGPT.git
|
||||
cd AutoGPT/autogpt_platform
|
||||
|
||||
# Deploy with published images
|
||||
./deploy.sh deploy
|
||||
```
|
||||
|
||||
### Manual Docker Run
|
||||
|
||||
```bash
|
||||
# Start dependencies first
|
||||
docker network create autogpt
|
||||
|
||||
# PostgreSQL
|
||||
docker run -d --name postgres --network autogpt \
|
||||
-e POSTGRES_DB=autogpt \
|
||||
-e POSTGRES_USER=autogpt \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-v postgres_data:/var/lib/postgresql/data \
|
||||
postgres:15
|
||||
|
||||
# Redis
|
||||
docker run -d --name redis --network autogpt \
|
||||
-v redis_data:/data \
|
||||
redis:7-alpine redis-server --requirepass password
|
||||
|
||||
# RabbitMQ
|
||||
docker run -d --name rabbitmq --network autogpt \
|
||||
-e RABBITMQ_DEFAULT_USER=autogpt \
|
||||
-e RABBITMQ_DEFAULT_PASS=password \
|
||||
-p 15672:15672 \
|
||||
rabbitmq:3-management
|
||||
|
||||
# Backend
|
||||
docker run -d --name backend --network autogpt \
|
||||
-p 8000:8000 \
|
||||
-e DATABASE_URL=postgresql://autogpt:password@postgres:5432/autogpt \
|
||||
-e REDIS_HOST=redis \
|
||||
-e RABBITMQ_HOST=rabbitmq \
|
||||
ghcr.io/significant-gravitas/autogpt-platform-backend:latest
|
||||
|
||||
# Frontend
|
||||
docker run -d --name frontend --network autogpt \
|
||||
-p 3000:3000 \
|
||||
-e AGPT_SERVER_URL=http://localhost:8000/api \
|
||||
ghcr.io/significant-gravitas/autogpt-platform-frontend:latest
|
||||
```
|
||||
|
||||
## Deployment Scripts
|
||||
|
||||
### Deploy Script
|
||||
|
||||
The included `deploy.sh` script provides a complete deployment solution:
|
||||
|
||||
```bash
|
||||
# Basic deployment
|
||||
./deploy.sh deploy
|
||||
|
||||
# Deploy specific version
|
||||
./deploy.sh -v v1.0.0 deploy
|
||||
|
||||
# Deploy from Docker Hub
|
||||
./deploy.sh -r docker.io deploy
|
||||
|
||||
# Production deployment
|
||||
./deploy.sh -p production deploy
|
||||
|
||||
# Other operations
|
||||
./deploy.sh start # Start services
|
||||
./deploy.sh stop # Stop services
|
||||
./deploy.sh restart # Restart services
|
||||
./deploy.sh update # Update to latest
|
||||
./deploy.sh backup # Create backup
|
||||
./deploy.sh status # Show status
|
||||
./deploy.sh logs # Show logs
|
||||
./deploy.sh cleanup # Remove everything
|
||||
```
|
||||
|
||||
## Platform-Specific Deployment Guides
|
||||
|
||||
### Unraid
|
||||
|
||||
See [Unraid Deployment Guide](../docs/content/platform/deployment/unraid.md)
|
||||
|
||||
Key features:
|
||||
- Community Applications template
|
||||
- Web UI management
|
||||
- Automatic updates
|
||||
- Built-in backup system
|
||||
|
||||
### Home Assistant Add-on
|
||||
|
||||
See [Home Assistant Add-on Guide](../docs/content/platform/deployment/home-assistant.md)
|
||||
|
||||
Key features:
|
||||
- Native Home Assistant integration
|
||||
- Automation services
|
||||
- Entity monitoring
|
||||
- Backup integration
|
||||
|
||||
### Kubernetes
|
||||
|
||||
See [Kubernetes Deployment Guide](../docs/content/platform/deployment/kubernetes.md)
|
||||
|
||||
Key features:
|
||||
- Helm charts
|
||||
- Horizontal scaling
|
||||
- Health checks
|
||||
- Persistent volumes
|
||||
|
||||
## Container Architecture
|
||||
|
||||
### Backend Container
|
||||
|
||||
- **Base Image**: `debian:13-slim`
|
||||
- **Runtime**: Python 3.13 with Poetry
|
||||
- **Services**: REST API, WebSocket, Executor, Scheduler, Database Manager, Notification
|
||||
- **Ports**: 8000-8007 (depending on service)
|
||||
- **Health Check**: `GET /health`
|
||||
|
||||
### Frontend Container
|
||||
|
||||
- **Base Image**: `node:21-alpine`
|
||||
- **Runtime**: Next.js production build
|
||||
- **Port**: 3000
|
||||
- **Health Check**: HTTP 200 on root path
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
### Required Environment Variables
|
||||
|
||||
#### Backend
|
||||
```env
|
||||
DATABASE_URL=postgresql://user:pass@host:5432/db
|
||||
REDIS_HOST=redis
|
||||
RABBITMQ_HOST=rabbitmq
|
||||
JWT_SECRET=your-secret-key
|
||||
```
|
||||
|
||||
#### Frontend
|
||||
```env
|
||||
AGPT_SERVER_URL=http://backend:8000/api
|
||||
SUPABASE_URL=http://auth:8000
|
||||
```
|
||||
|
||||
### Optional Configuration
|
||||
|
||||
```env
|
||||
# Logging
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_DEBUG=false
|
||||
|
||||
# Performance
|
||||
REDIS_PASSWORD=your-redis-password
|
||||
RABBITMQ_PASSWORD=your-rabbitmq-password
|
||||
|
||||
# Security
|
||||
CORS_ORIGINS=http://localhost:3000
|
||||
```
|
||||
|
||||
## CI/CD Pipeline
|
||||
|
||||
### GitHub Actions Workflow
|
||||
|
||||
The publishing workflow (`.github/workflows/platform-container-publish.yml`) automatically:
|
||||
|
||||
1. **Triggers** on releases and manual dispatch
|
||||
2. **Builds** both backend and frontend containers
|
||||
3. **Tests** container functionality
|
||||
4. **Publishes** to both GHCR and Docker Hub
|
||||
5. **Tags** with version and latest
|
||||
|
||||
### Manual Publishing
|
||||
|
||||
```bash
|
||||
# Build and tag locally
|
||||
docker build -t ghcr.io/significant-gravitas/autogpt-platform-backend:latest \
|
||||
-f autogpt_platform/backend/Dockerfile \
|
||||
--target server .
|
||||
|
||||
docker build -t ghcr.io/significant-gravitas/autogpt-platform-frontend:latest \
|
||||
-f autogpt_platform/frontend/Dockerfile \
|
||||
--target prod .
|
||||
|
||||
# Push to registry
|
||||
docker push ghcr.io/significant-gravitas/autogpt-platform-backend:latest
|
||||
docker push ghcr.io/significant-gravitas/autogpt-platform-frontend:latest
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Container Security
|
||||
|
||||
1. **Non-root users** - Containers run as non-root
|
||||
2. **Minimal base images** - Using slim/alpine images
|
||||
3. **No secrets in images** - All secrets via environment variables
|
||||
4. **Read-only filesystem** - Where possible
|
||||
5. **Resource limits** - CPU and memory limits set
|
||||
|
||||
### Deployment Security
|
||||
|
||||
1. **Network isolation** - Use dedicated networks
|
||||
2. **TLS encryption** - Enable HTTPS in production
|
||||
3. **Secret management** - Use Docker secrets or external secret stores
|
||||
4. **Regular updates** - Keep images updated
|
||||
5. **Vulnerability scanning** - Regular security scans
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Health Checks
|
||||
|
||||
All containers include health checks:
|
||||
|
||||
```bash
|
||||
# Check container health
|
||||
docker inspect --format='{{.State.Health.Status}}' container_name
|
||||
|
||||
# Manual health check
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
### Metrics
|
||||
|
||||
The backend exposes Prometheus metrics at `/metrics`:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/metrics
|
||||
```
|
||||
|
||||
### Logging
|
||||
|
||||
Containers log to stdout/stderr for easy aggregation:
|
||||
|
||||
```bash
|
||||
# View logs
|
||||
docker logs container_name
|
||||
|
||||
# Follow logs
|
||||
docker logs -f container_name
|
||||
|
||||
# Aggregate logs
|
||||
docker compose logs -f
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Container won't start**
|
||||
```bash
|
||||
# Check logs
|
||||
docker logs container_name
|
||||
|
||||
# Check environment
|
||||
docker exec container_name env
|
||||
```
|
||||
|
||||
2. **Database connection failed**
|
||||
```bash
|
||||
# Test connectivity
|
||||
docker exec backend ping postgres
|
||||
|
||||
# Check database status
|
||||
docker exec postgres pg_isready
|
||||
```
|
||||
|
||||
3. **Port conflicts**
|
||||
```bash
|
||||
# Check port usage
|
||||
ss -tuln | grep :3000
|
||||
|
||||
# Use different ports
|
||||
docker run -p 3001:3000 ...
|
||||
```
|
||||
|
||||
### Debug Mode
|
||||
|
||||
Enable debug mode for detailed logging:
|
||||
|
||||
```env
|
||||
LOG_LEVEL=DEBUG
|
||||
ENABLE_DEBUG=true
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Resource Limits
|
||||
|
||||
```yaml
|
||||
# Docker Compose
|
||||
services:
|
||||
backend:
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 2G
|
||||
cpus: '1.0'
|
||||
reservations:
|
||||
memory: 1G
|
||||
cpus: '0.5'
|
||||
```
|
||||
|
||||
### Scaling
|
||||
|
||||
```bash
|
||||
# Scale backend services
|
||||
docker compose up -d --scale backend=3
|
||||
|
||||
# Or use Docker Swarm
|
||||
docker service scale backend=3
|
||||
```
|
||||
|
||||
## Backup and Recovery
|
||||
|
||||
### Data Backup
|
||||
|
||||
```bash
|
||||
# Database backup
|
||||
docker exec postgres pg_dump -U autogpt autogpt > backup.sql
|
||||
|
||||
# Volume backup
|
||||
docker run --rm -v postgres_data:/data -v $(pwd):/backup \
|
||||
alpine tar czf /backup/postgres_backup.tar.gz /data
|
||||
```
|
||||
|
||||
### Restore
|
||||
|
||||
```bash
|
||||
# Database restore
|
||||
docker exec -i postgres psql -U autogpt autogpt < backup.sql
|
||||
|
||||
# Volume restore
|
||||
docker run --rm -v postgres_data:/data -v $(pwd):/backup \
|
||||
alpine tar xzf /backup/postgres_backup.tar.gz -C /
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
- **Documentation**: [Platform Docs](../docs/content/platform/)
|
||||
- **Issues**: [GitHub Issues](https://github.com/Significant-Gravitas/AutoGPT/issues)
|
||||
- **Discord**: [AutoGPT Community](https://discord.gg/autogpt)
|
||||
- **Docker Hub**: [Container Registry](https://hub.docker.com/r/significantgravitas/)
|
||||
|
||||
## Contributing
|
||||
|
||||
To contribute to the container infrastructure:
|
||||
|
||||
1. **Test locally** with `docker build` and `docker run`
|
||||
2. **Update documentation** if making changes
|
||||
3. **Test deployment scripts** on your platform
|
||||
4. **Submit PR** with clear description of changes
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [ ] ARM64 support for Apple Silicon
|
||||
- [ ] Helm charts for Kubernetes
|
||||
- [ ] Official Unraid template
|
||||
- [ ] Home Assistant Add-on store submission
|
||||
- [ ] Multi-stage builds optimization
|
||||
- [ ] Security scanning integration
|
||||
- [ ] Performance benchmarking
|
||||
@@ -2,38 +2,16 @@
|
||||
|
||||
Welcome to the AutoGPT Platform - a powerful system for creating and running AI agents to solve business problems. This platform enables you to harness the power of artificial intelligence to automate tasks, analyze data, and generate insights for your organization.
|
||||
|
||||
## Deployment Options
|
||||
|
||||
### Quick Deploy with Published Containers (Recommended)
|
||||
|
||||
The fastest way to get started is using our pre-built containers:
|
||||
|
||||
```bash
|
||||
# Download and run with published images
|
||||
curl -fsSL https://raw.githubusercontent.com/Significant-Gravitas/AutoGPT/master/autogpt_platform/deploy.sh -o deploy.sh
|
||||
chmod +x deploy.sh
|
||||
./deploy.sh deploy
|
||||
```
|
||||
|
||||
Access the platform at http://localhost:3000 after deployment completes.
|
||||
|
||||
### Platform-Specific Deployments
|
||||
|
||||
- **Unraid**: [Deployment Guide](../docs/content/platform/deployment/unraid.md)
|
||||
- **Home Assistant**: [Add-on Guide](../docs/content/platform/deployment/home-assistant.md)
|
||||
- **Kubernetes**: [K8s Deployment](../docs/content/platform/deployment/kubernetes.md)
|
||||
- **General Containers**: [Container Guide](../docs/content/platform/container-deployment.md)
|
||||
|
||||
## Development Setup
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Docker
|
||||
- Docker Compose V2 (comes with Docker Desktop, or can be installed separately)
|
||||
|
||||
### Running from Source
|
||||
### Running the System
|
||||
|
||||
To run the AutoGPT Platform from source for development:
|
||||
To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
1. Clone this repository to your local machine and navigate to the `autogpt_platform` directory within the repository:
|
||||
|
||||
@@ -179,28 +157,3 @@ If you need to update the API client after making changes to the backend API:
|
||||
```
|
||||
|
||||
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.
|
||||
|
||||
## Container Deployment
|
||||
|
||||
For production deployments and specific platforms, see our container deployment guides:
|
||||
|
||||
- **[Container Deployment Overview](CONTAINERS.md)** - Complete guide to using published containers
|
||||
- **[Deployment Script](deploy.sh)** - Automated deployment and management tool
|
||||
- **[Published Images](docker-compose.published.yml)** - Docker Compose for published containers
|
||||
|
||||
### Published Container Images
|
||||
|
||||
- **Backend**: `ghcr.io/significant-gravitas/autogpt-platform-backend:latest`
|
||||
- **Frontend**: `ghcr.io/significant-gravitas/autogpt-platform-frontend:latest`
|
||||
|
||||
### Quick Production Deployment
|
||||
|
||||
```bash
|
||||
# Deploy with published containers
|
||||
./deploy.sh deploy
|
||||
|
||||
# Or use the published compose file directly
|
||||
docker compose -f docker-compose.published.yml up -d
|
||||
```
|
||||
|
||||
For detailed deployment instructions, troubleshooting, and platform-specific guides, see the [Container Documentation](CONTAINERS.md).
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
raw: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
hash: str
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
PREFIX: str = "agpt_"
|
||||
PREFIX_LENGTH: int = 8
|
||||
POSTFIX_LENGTH: int = 8
|
||||
|
||||
def generate_api_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with all its parts."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
return APIKeyContainer(
|
||||
raw=raw_key,
|
||||
prefix=raw_key[: self.PREFIX_LENGTH],
|
||||
postfix=raw_key[-self.POSTFIX_LENGTH :],
|
||||
hash=hashlib.sha256(raw_key.encode()).hexdigest(),
|
||||
)
|
||||
|
||||
def verify_api_key(self, provided_key: str, stored_hash: str) -> bool:
|
||||
"""Verify if a provided API key matches the stored hash."""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(provided_hash, stored_hash)
|
||||
@@ -1,78 +0,0 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
key: str
|
||||
head: str
|
||||
tail: str
|
||||
hash: str
|
||||
salt: str
|
||||
|
||||
|
||||
class APIKeySmith:
|
||||
PREFIX: str = "agpt_"
|
||||
HEAD_LENGTH: int = 8
|
||||
TAIL_LENGTH: int = 8
|
||||
|
||||
def generate_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with secure hashing."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
hash, salt = self.hash_key(raw_key)
|
||||
|
||||
return APIKeyContainer(
|
||||
key=raw_key,
|
||||
head=raw_key[: self.HEAD_LENGTH],
|
||||
tail=raw_key[-self.TAIL_LENGTH :],
|
||||
hash=hash,
|
||||
salt=salt,
|
||||
)
|
||||
|
||||
def verify_key(
|
||||
self, provided_key: str, known_hash: str, known_salt: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Verify an API key against a known hash (+ salt).
|
||||
Supports verifying both legacy SHA256 and secure Scrypt hashes.
|
||||
"""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
|
||||
# Handle legacy SHA256 hashes (migration support)
|
||||
if known_salt is None:
|
||||
legacy_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(legacy_hash, known_hash)
|
||||
|
||||
try:
|
||||
salt_bytes = bytes.fromhex(known_salt)
|
||||
provided_hash = self._hash_key_with_salt(provided_key, salt_bytes)
|
||||
return secrets.compare_digest(provided_hash, known_hash)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
def _generate_salt(self) -> bytes:
|
||||
"""Generate a random salt for hashing."""
|
||||
return secrets.token_bytes(32)
|
||||
|
||||
def _hash_key_with_salt(self, raw_key: str, salt: bytes) -> str:
|
||||
"""Hash API key using Scrypt with salt."""
|
||||
kdf = Scrypt(
|
||||
length=32,
|
||||
salt=salt,
|
||||
n=2**14, # CPU/memory cost parameter
|
||||
r=8, # Block size parameter
|
||||
p=1, # Parallelization parameter
|
||||
)
|
||||
key_hash = kdf.derive(raw_key.encode())
|
||||
return key_hash.hex()
|
||||
@@ -1,79 +0,0 @@
|
||||
import hashlib
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
|
||||
|
||||
def test_generate_api_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
assert key.key.startswith(keysmith.PREFIX)
|
||||
assert key.head == key.key[: keysmith.HEAD_LENGTH]
|
||||
assert key.tail == key.key[-keysmith.TAIL_LENGTH :]
|
||||
assert len(key.hash) == 64 # 32 bytes hex encoded
|
||||
assert len(key.salt) == 64 # 32 bytes hex encoded
|
||||
|
||||
|
||||
def test_verify_new_secure_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test correct key validates
|
||||
assert keysmith.verify_key(key.key, key.hash, key.salt) is True
|
||||
|
||||
# Test wrong key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey123"
|
||||
assert keysmith.verify_key(wrong_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_verify_legacy_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}legacykey123"
|
||||
legacy_hash = hashlib.sha256(legacy_key.encode()).hexdigest()
|
||||
|
||||
# Test legacy key validates without salt
|
||||
assert keysmith.verify_key(legacy_key, legacy_hash) is True
|
||||
|
||||
# Test wrong legacy key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wronglegacy"
|
||||
assert keysmith.verify_key(wrong_key, legacy_hash) is False
|
||||
|
||||
|
||||
def test_rehash_existing_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}migratekey123"
|
||||
|
||||
# Migrate the legacy key
|
||||
new_hash, new_salt = keysmith.hash_key(legacy_key)
|
||||
|
||||
# Verify migrated key works
|
||||
assert keysmith.verify_key(legacy_key, new_hash, new_salt) is True
|
||||
|
||||
# Verify different key fails with migrated hash
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey"
|
||||
assert keysmith.verify_key(wrong_key, new_hash, new_salt) is False
|
||||
|
||||
|
||||
def test_invalid_key_prefix():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test key without proper prefix fails
|
||||
invalid_key = "invalid_prefix_key"
|
||||
assert keysmith.verify_key(invalid_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_secure_hash_requires_salt():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Secure hash without salt should fail
|
||||
assert keysmith.verify_key(key.key, key.hash) is False
|
||||
|
||||
|
||||
def test_invalid_salt_format():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Invalid salt format should fail gracefully
|
||||
assert keysmith.verify_key(key.key, key.hash, "invalid_hex") is False
|
||||
76
autogpt_platform/autogpt_libs/poetry.lock
generated
76
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1002,18 +1002,6 @@ dynamodb = ["boto3 (>=1.9.71)"]
|
||||
redis = ["redis (>=2.10.5)"]
|
||||
test-filesource = ["pyyaml (>=5.3.1)", "watchdog (>=3.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.9.1"
|
||||
description = "Node.js virtual environment builder"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"},
|
||||
{file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.35.0"
|
||||
@@ -1359,27 +1347,6 @@ files = [
|
||||
{file = "pyrfc3339-2.0.1.tar.gz", hash = "sha256:e47843379ea35c1296c3b6c67a948a1a490ae0584edfcbdea0eaffb5dd29960b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyright"
|
||||
version = "1.1.404"
|
||||
description = "Command line wrapper for pyright"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pyright-1.1.404-py3-none-any.whl", hash = "sha256:c7b7ff1fdb7219c643079e4c3e7d4125f0dafcc19d253b47e898d130ea426419"},
|
||||
{file = "pyright-1.1.404.tar.gz", hash = "sha256:455e881a558ca6be9ecca0b30ce08aa78343ecc031d37a198ffa9a7a1abeb63e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nodeenv = ">=1.6.0"
|
||||
typing-extensions = ">=4.1"
|
||||
|
||||
[package.extras]
|
||||
all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"]
|
||||
dev = ["twine (>=3.4.1)"]
|
||||
nodejs = ["nodejs-wheel-binaries"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.4.1"
|
||||
@@ -1567,31 +1534,31 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.12.11"
|
||||
version = "0.12.9"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.12.11-py3-none-linux_armv6l.whl", hash = "sha256:93fce71e1cac3a8bf9200e63a38ac5c078f3b6baebffb74ba5274fb2ab276065"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8e33ac7b28c772440afa80cebb972ffd823621ded90404f29e5ab6d1e2d4b93"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d69fb9d4937aa19adb2e9f058bc4fbfe986c2040acb1a4a9747734834eaa0bfd"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:411954eca8464595077a93e580e2918d0a01a19317af0a72132283e28ae21bee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6a2c0a2e1a450f387bf2c6237c727dd22191ae8c00e448e0672d624b2bbd7fb0"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ca4c3a7f937725fd2413c0e884b5248a19369ab9bdd850b5781348ba283f644"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4d1df0098124006f6a66ecf3581a7f7e754c4df7644b2e6704cd7ca80ff95211"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a8dd5f230efc99a24ace3b77e3555d3fbc0343aeed3fc84c8d89e75ab2ff793"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dc75533039d0ed04cd33fb8ca9ac9620b99672fe7ff1533b6402206901c34ee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fc58f9266d62c6eccc75261a665f26b4ef64840887fc6cbc552ce5b29f96cc8"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:5a0113bd6eafd545146440225fe60b4e9489f59eb5f5f107acd715ba5f0b3d2f"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0d737b4059d66295c3ea5720e6efc152623bb83fde5444209b69cd33a53e2000"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:916fc5defee32dbc1fc1650b576a8fed68f5e8256e2180d4d9855aea43d6aab2"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c984f07d7adb42d3ded5be894fb4007f30f82c87559438b4879fe7aa08c62b39"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e07fbb89f2e9249f219d88331c833860489b49cdf4b032b8e4432e9b13e8a4b9"},
|
||||
{file = "ruff-0.12.11-py3-none-win32.whl", hash = "sha256:c792e8f597c9c756e9bcd4d87cf407a00b60af77078c96f7b6366ea2ce9ba9d3"},
|
||||
{file = "ruff-0.12.11-py3-none-win_amd64.whl", hash = "sha256:a3283325960307915b6deb3576b96919ee89432ebd9c48771ca12ee8afe4a0fd"},
|
||||
{file = "ruff-0.12.11-py3-none-win_arm64.whl", hash = "sha256:bae4d6e6a2676f8fb0f98b74594a048bae1b944aab17e9f5d504062303c6dbea"},
|
||||
{file = "ruff-0.12.11.tar.gz", hash = "sha256:c6b09ae8426a65bbee5425b9d0b82796dbb07cb1af045743c79bfb163001165d"},
|
||||
{file = "ruff-0.12.9-py3-none-linux_armv6l.whl", hash = "sha256:fcebc6c79fcae3f220d05585229463621f5dbf24d79fdc4936d9302e177cfa3e"},
|
||||
{file = "ruff-0.12.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aed9d15f8c5755c0e74467731a007fcad41f19bcce41cd75f768bbd687f8535f"},
|
||||
{file = "ruff-0.12.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5b15ea354c6ff0d7423814ba6d44be2807644d0c05e9ed60caca87e963e93f70"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d596c2d0393c2502eaabfef723bd74ca35348a8dac4267d18a94910087807c53"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1b15599931a1a7a03c388b9c5df1bfa62be7ede6eb7ef753b272381f39c3d0ff"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d02faa2977fb6f3f32ddb7828e212b7dd499c59eb896ae6c03ea5c303575756"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:17d5b6b0b3a25259b69ebcba87908496e6830e03acfb929ef9fd4c58675fa2ea"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72db7521860e246adbb43f6ef464dd2a532ef2ef1f5dd0d470455b8d9f1773e0"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a03242c1522b4e0885af63320ad754d53983c9599157ee33e77d748363c561ce"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fc83e4e9751e6c13b5046d7162f205d0a7bac5840183c5beebf824b08a27340"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:881465ed56ba4dd26a691954650de6ad389a2d1fdb130fe51ff18a25639fe4bb"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:43f07a3ccfc62cdb4d3a3348bf0588358a66da756aa113e071b8ca8c3b9826af"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:07adb221c54b6bba24387911e5734357f042e5669fa5718920ee728aba3cbadc"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f5cd34fabfdea3933ab85d72359f118035882a01bff15bd1d2b15261d85d5f66"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:f6be1d2ca0686c54564da8e7ee9e25f93bdd6868263805f8c0b8fc6a449db6d7"},
|
||||
{file = "ruff-0.12.9-py3-none-win32.whl", hash = "sha256:cc7a37bd2509974379d0115cc5608a1a4a6c4bff1b452ea69db83c8855d53f93"},
|
||||
{file = "ruff-0.12.9-py3-none-win_amd64.whl", hash = "sha256:6fb15b1977309741d7d098c8a3cb7a30bc112760a00fb6efb7abc85f00ba5908"},
|
||||
{file = "ruff-0.12.9-py3-none-win_arm64.whl", hash = "sha256:63c8c819739d86b96d500cce885956a1a48ab056bbcbc61b747ad494b2485089"},
|
||||
{file = "ruff-0.12.9.tar.gz", hash = "sha256:fbd94b2e3c623f659962934e52c2bea6fc6da11f667a427a368adaf3af2c866a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1773,6 +1740,7 @@ files = [
|
||||
{file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"},
|
||||
{file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"},
|
||||
]
|
||||
markers = {dev = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "typing-inspection"
|
||||
@@ -1929,4 +1897,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
|
||||
content-hash = "ef7818fba061cea2841c6d7ca4852acde83e4f73b32fca1315e58660002bb0d0"
|
||||
|
||||
@@ -9,7 +9,6 @@ packages = [{ include = "autogpt_libs" }]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
@@ -22,12 +21,11 @@ supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pyright = "^1.1.404"
|
||||
ruff = "^0.12.9"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
ruff = "^0.12.11"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -10,7 +12,7 @@ from backend.data.block import (
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus, NodesInputMasks
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
from backend.util.retry import func_retry
|
||||
@@ -31,7 +33,7 @@ class AgentExecutorBlock(Block):
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
nodes_input_masks: Optional[NodesInputMasks] = SchemaField(
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
|
||||
default=None, hidden=True
|
||||
)
|
||||
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import MediaFileType
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
JPG = "jpg"
|
||||
PNG = "png"
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="replicate",
|
||||
api_key=SecretStr("mock-replicate-api-key"),
|
||||
title="Mock Replicate API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class AIImageCustomizerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Replicate API key with permissions for Google Gemini image models",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="A text description of the image you want to generate",
|
||||
title="Prompt",
|
||||
)
|
||||
model: GeminiImageModel = SchemaField(
|
||||
description="The AI model to use for image generation and editing",
|
||||
default=GeminiImageModel.NANO_BANANA,
|
||||
title="Model",
|
||||
)
|
||||
images: list[MediaFileType] = SchemaField(
|
||||
description="Optional list of input images to reference or modify",
|
||||
default=[],
|
||||
title="Input Images",
|
||||
)
|
||||
output_format: OutputFormat = SchemaField(
|
||||
description="Format of the output image",
|
||||
default=OutputFormat.PNG,
|
||||
title="Output Format",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
image_url: MediaFileType = SchemaField(description="URL of the generated image")
|
||||
error: str = SchemaField(description="Error message if generation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d76bbe4c-930e-4894-8469-b66775511f71",
|
||||
description=(
|
||||
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
|
||||
"Provide a prompt and optional reference images to create or modify images."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
input_schema=AIImageCustomizerBlock.Input,
|
||||
output_schema=AIImageCustomizerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Make the scene more vibrant and colorful",
|
||||
"model": GeminiImageModel.NANO_BANANA,
|
||||
"images": [],
|
||||
"output_format": OutputFormat.JPG,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||
"https://replicate.delivery/generated-image.jpg"
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.value,
|
||||
prompt=input_data.prompt,
|
||||
images=input_data.images,
|
||||
output_format=input_data.output_format.value,
|
||||
)
|
||||
yield "image_url", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
images: list[MediaFileType],
|
||||
output_format: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
|
||||
input_params: dict = {
|
||||
"prompt": prompt,
|
||||
"output_format": output_format,
|
||||
}
|
||||
|
||||
# Add images to input if provided (API expects "image_input" parameter)
|
||||
if images:
|
||||
input_params["image_input"] = [str(img) for img in images]
|
||||
|
||||
output: FileOutput | str = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
if isinstance(output, FileOutput):
|
||||
return MediaFileType(output.url)
|
||||
if isinstance(output, str):
|
||||
return MediaFileType(output)
|
||||
|
||||
raise ValueError("No output received from the model")
|
||||
@@ -661,167 +661,6 @@ async def update_field(
|
||||
#################################################################
|
||||
|
||||
|
||||
async def get_table_schema(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
table_id_or_name: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the schema for a specific table, including all field definitions.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The base ID
|
||||
table_id_or_name: The table ID or name
|
||||
|
||||
Returns:
|
||||
Dict containing table schema with fields information
|
||||
"""
|
||||
# First get all tables to find the right one
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
tables = data.get("tables", [])
|
||||
|
||||
# Find the matching table
|
||||
for table in tables:
|
||||
if table.get("id") == table_id_or_name or table.get("name") == table_id_or_name:
|
||||
return table
|
||||
|
||||
raise ValueError(f"Table '{table_id_or_name}' not found in base '{base_id}'")
|
||||
|
||||
|
||||
def get_empty_value_for_field(field_type: str) -> Any:
|
||||
"""
|
||||
Return the appropriate empty value for a given Airtable field type.
|
||||
|
||||
Args:
|
||||
field_type: The Airtable field type
|
||||
|
||||
Returns:
|
||||
The appropriate empty value for that field type
|
||||
"""
|
||||
# Fields that should be false when empty
|
||||
if field_type == "checkbox":
|
||||
return False
|
||||
|
||||
# Fields that should be empty arrays
|
||||
if field_type in [
|
||||
"multipleSelects",
|
||||
"multipleRecordLinks",
|
||||
"multipleAttachments",
|
||||
"multipleLookupValues",
|
||||
"multipleCollaborators",
|
||||
]:
|
||||
return []
|
||||
|
||||
# Fields that should be 0 when empty (numeric types)
|
||||
if field_type in [
|
||||
"number",
|
||||
"percent",
|
||||
"currency",
|
||||
"rating",
|
||||
"duration",
|
||||
"count",
|
||||
"autoNumber",
|
||||
]:
|
||||
return 0
|
||||
|
||||
# Fields that should be empty strings
|
||||
if field_type in [
|
||||
"singleLineText",
|
||||
"multilineText",
|
||||
"email",
|
||||
"url",
|
||||
"phoneNumber",
|
||||
"richText",
|
||||
"barcode",
|
||||
]:
|
||||
return ""
|
||||
|
||||
# Everything else gets null (dates, single selects, formulas, etc.)
|
||||
return None
|
||||
|
||||
|
||||
async def normalize_records(
|
||||
records: list[dict],
|
||||
table_schema: dict,
|
||||
include_field_metadata: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Normalize Airtable records to include all fields with proper empty values.
|
||||
|
||||
Args:
|
||||
records: List of record objects from Airtable API
|
||||
table_schema: Table schema containing field definitions
|
||||
include_field_metadata: Whether to include field metadata in response
|
||||
|
||||
Returns:
|
||||
Dict with normalized records and optionally field metadata
|
||||
"""
|
||||
fields = table_schema.get("fields", [])
|
||||
|
||||
# Normalize each record
|
||||
normalized_records = []
|
||||
for record in records:
|
||||
normalized = {
|
||||
"id": record.get("id"),
|
||||
"createdTime": record.get("createdTime"),
|
||||
"fields": {},
|
||||
}
|
||||
|
||||
# Add existing fields
|
||||
existing_fields = record.get("fields", {})
|
||||
|
||||
# Add all fields from schema, using empty values for missing ones
|
||||
for field in fields:
|
||||
field_name = field["name"]
|
||||
field_type = field["type"]
|
||||
|
||||
if field_name in existing_fields:
|
||||
# Field exists, use its value
|
||||
normalized["fields"][field_name] = existing_fields[field_name]
|
||||
else:
|
||||
# Field is missing, add appropriate empty value
|
||||
normalized["fields"][field_name] = get_empty_value_for_field(field_type)
|
||||
|
||||
normalized_records.append(normalized)
|
||||
|
||||
# Build result dictionary
|
||||
if include_field_metadata:
|
||||
field_metadata = {}
|
||||
for field in fields:
|
||||
metadata = {"type": field["type"], "id": field["id"]}
|
||||
|
||||
# Add type-specific metadata
|
||||
options = field.get("options", {})
|
||||
if field["type"] == "currency" and "symbol" in options:
|
||||
metadata["symbol"] = options["symbol"]
|
||||
metadata["precision"] = options.get("precision", 2)
|
||||
elif field["type"] == "duration" and "durationFormat" in options:
|
||||
metadata["format"] = options["durationFormat"]
|
||||
elif field["type"] == "percent" and "precision" in options:
|
||||
metadata["precision"] = options["precision"]
|
||||
elif (
|
||||
field["type"] in ["singleSelect", "multipleSelects"]
|
||||
and "choices" in options
|
||||
):
|
||||
metadata["choices"] = [choice["name"] for choice in options["choices"]]
|
||||
elif field["type"] == "rating" and "max" in options:
|
||||
metadata["max"] = options["max"]
|
||||
metadata["icon"] = options.get("icon", "star")
|
||||
metadata["color"] = options.get("color", "yellowBright")
|
||||
|
||||
field_metadata[field["name"]] = metadata
|
||||
|
||||
return {"records": normalized_records, "field_metadata": field_metadata}
|
||||
else:
|
||||
return {"records": normalized_records}
|
||||
|
||||
|
||||
async def list_records(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
@@ -1410,26 +1249,3 @@ async def list_bases(
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
async def get_base_tables(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get all tables for a specific base.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The ID of the base
|
||||
|
||||
Returns:
|
||||
list[dict]: List of table objects with their schemas
|
||||
"""
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
return data.get("tables", [])
|
||||
|
||||
@@ -14,13 +14,13 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import create_base, get_base_tables, list_bases
|
||||
from ._api import create_base, list_bases
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableCreateBaseBlock(Block):
|
||||
"""
|
||||
Creates a new base in an Airtable workspace, or returns existing base if one with the same name exists.
|
||||
Creates a new base in an Airtable workspace.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
@@ -31,10 +31,6 @@ class AirtableCreateBaseBlock(Block):
|
||||
description="The workspace ID where the base will be created"
|
||||
)
|
||||
name: str = SchemaField(description="The name of the new base")
|
||||
find_existing: bool = SchemaField(
|
||||
description="If true, return existing base with same name instead of creating duplicate",
|
||||
default=True,
|
||||
)
|
||||
tables: list[dict] = SchemaField(
|
||||
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
|
||||
default=[
|
||||
@@ -54,18 +50,14 @@ class AirtableCreateBaseBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
base_id: str = SchemaField(description="The ID of the created or found base")
|
||||
base_id: str = SchemaField(description="The ID of the created base")
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
table: dict = SchemaField(description="A single table object")
|
||||
was_created: bool = SchemaField(
|
||||
description="True if a new base was created, False if existing was found",
|
||||
default=True,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
|
||||
description="Create or find a base in Airtable",
|
||||
description="Create a new base in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
@@ -74,31 +66,6 @@ class AirtableCreateBaseBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# If find_existing is true, check if a base with this name already exists
|
||||
if input_data.find_existing:
|
||||
# List all bases to check for existing one with same name
|
||||
# Note: Airtable API doesn't have a direct search, so we need to list and filter
|
||||
existing_bases = await list_bases(credentials)
|
||||
|
||||
for base in existing_bases.get("bases", []):
|
||||
if base.get("name") == input_data.name:
|
||||
# Base already exists, return it
|
||||
base_id = base.get("id")
|
||||
yield "base_id", base_id
|
||||
yield "was_created", False
|
||||
|
||||
# Get the tables for this base
|
||||
try:
|
||||
tables = await get_base_tables(credentials, base_id)
|
||||
yield "tables", tables
|
||||
for table in tables:
|
||||
yield "table", table
|
||||
except Exception:
|
||||
# If we can't get tables, return empty list
|
||||
yield "tables", []
|
||||
return
|
||||
|
||||
# No existing base found or find_existing is false, create new one
|
||||
data = await create_base(
|
||||
credentials,
|
||||
input_data.workspace_id,
|
||||
@@ -107,7 +74,6 @@ class AirtableCreateBaseBlock(Block):
|
||||
)
|
||||
|
||||
yield "base_id", data.get("id", None)
|
||||
yield "was_created", True
|
||||
yield "tables", data.get("tables", [])
|
||||
for table in data.get("tables", []):
|
||||
yield "table", table
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Airtable record operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
@@ -18,9 +18,7 @@ from ._api import (
|
||||
create_record,
|
||||
delete_multiple_records,
|
||||
get_record,
|
||||
get_table_schema,
|
||||
list_records,
|
||||
normalize_records,
|
||||
update_multiple_records,
|
||||
)
|
||||
from ._config import airtable
|
||||
@@ -56,24 +54,12 @@ class AirtableListRecordsBlock(Block):
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return (comma-separated)", default=[]
|
||||
)
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of record objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more records)", default=None
|
||||
)
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -87,7 +73,6 @@ class AirtableListRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
data = await list_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -103,33 +88,8 @@ class AirtableListRecordsBlock(Block):
|
||||
fields=input_data.return_fields if input_data.return_fields else None,
|
||||
)
|
||||
|
||||
records = data.get("records", [])
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
records,
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
yield "records", normalized_data["records"]
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "records", records
|
||||
yield "offset", data.get("offset", None)
|
||||
yield "records", data.get("records", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
|
||||
class AirtableGetRecordBlock(Block):
|
||||
@@ -144,23 +104,11 @@ class AirtableGetRecordBlock(Block):
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
record_id: str = SchemaField(description="The record ID to retrieve")
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="The record ID")
|
||||
fields: dict = SchemaField(description="The record fields")
|
||||
created_time: str = SchemaField(description="The record created time")
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -174,7 +122,6 @@ class AirtableGetRecordBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
record = await get_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -182,34 +129,9 @@ class AirtableGetRecordBlock(Block):
|
||||
input_data.record_id,
|
||||
)
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the single record (wrap in list and unwrap result)
|
||||
normalized_data = await normalize_records(
|
||||
[record],
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
normalized_record = normalized_data["records"][0]
|
||||
yield "id", normalized_record.get("id", None)
|
||||
yield "fields", normalized_record.get("fields", None)
|
||||
yield "created_time", normalized_record.get("createdTime", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
|
||||
|
||||
class AirtableCreateRecordsBlock(Block):
|
||||
@@ -226,10 +148,6 @@ class AirtableCreateRecordsBlock(Block):
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to create (each with 'fields' object)"
|
||||
)
|
||||
skip_normalization: bool = SchemaField(
|
||||
description="Skip output normalization to get raw Airtable response (faster but may have missing fields)",
|
||||
default=False,
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
@@ -255,7 +173,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
# The create_record API expects records in a specific format
|
||||
data = await create_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -264,22 +182,8 @@ class AirtableCreateRecordsBlock(Block):
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=input_data.return_fields_by_field_id,
|
||||
)
|
||||
result_records = cast(list[dict], data.get("records", []))
|
||||
|
||||
# Normalize output unless explicitly disabled
|
||||
if not input_data.skip_normalization and result_records:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
result_records, table_schema, include_field_metadata=False
|
||||
)
|
||||
result_records = normalized_data["records"]
|
||||
|
||||
yield "records", result_records
|
||||
yield "records", data.get("records", [])
|
||||
details = data.get("details", None)
|
||||
if details:
|
||||
yield "details", details
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .text_overlay import BannerbearTextOverlayBlock
|
||||
|
||||
__all__ = ["BannerbearTextOverlayBlock"]
|
||||
@@ -1,8 +0,0 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
bannerbear = (
|
||||
ProviderBuilder("bannerbear")
|
||||
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -1,239 +0,0 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import bannerbear
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="bannerbear",
|
||||
api_key=SecretStr("mock-bannerbear-api-key"),
|
||||
title="Mock Bannerbear API Key",
|
||||
)
|
||||
|
||||
|
||||
class TextModification(BlockSchema):
|
||||
name: str = SchemaField(
|
||||
description="The name of the layer to modify in the template"
|
||||
)
|
||||
text: str = SchemaField(description="The text content to add to this layer")
|
||||
color: str = SchemaField(
|
||||
description="Hex color code for the text (e.g., '#FF0000')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_family: str = SchemaField(
|
||||
description="Font family to use for the text",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_size: int = SchemaField(
|
||||
description="Font size in pixels",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
font_weight: str = SchemaField(
|
||||
description="Font weight (e.g., 'bold', 'normal')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_align: str = SchemaField(
|
||||
description="Text alignment (left, center, right)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class BannerbearTextOverlayBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = bannerbear.credentials_field(
|
||||
description="API credentials for Bannerbear"
|
||||
)
|
||||
template_id: str = SchemaField(
|
||||
description="The unique ID of your Bannerbear template"
|
||||
)
|
||||
project_id: str = SchemaField(
|
||||
description="Optional: Project ID (required when using Master API Key)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_modifications: List[TextModification] = SchemaField(
|
||||
description="List of text layers to modify in the template"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="Optional: URL of an image to use in the template",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
image_layer_name: str = SchemaField(
|
||||
description="Optional: Name of the image layer in the template",
|
||||
default="photo",
|
||||
advanced=True,
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="Optional: URL to receive webhook notification when image is ready",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: str = SchemaField(
|
||||
description="Optional: Custom metadata to attach to the image",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the image generation was successfully initiated"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="URL of the generated image (if synchronous) or placeholder"
|
||||
)
|
||||
uid: str = SchemaField(description="Unique identifier for the generated image")
|
||||
status: str = SchemaField(description="Status of the image generation")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c7d3a5c2-05fc-450e-8dce-3b0e04626009",
|
||||
description="Add text overlay to images using Bannerbear templates. Perfect for creating social media graphics, marketing materials, and dynamic image content.",
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"template_id": "jJWBKNELpQPvbX5R93Gk",
|
||||
"text_modifications": [
|
||||
{
|
||||
"name": "headline",
|
||||
"text": "Amazing Product Launch!",
|
||||
"color": "#FF0000",
|
||||
},
|
||||
{
|
||||
"name": "subtitle",
|
||||
"text": "50% OFF Today Only",
|
||||
},
|
||||
],
|
||||
"credentials": {
|
||||
"provider": "bannerbear",
|
||||
"id": str(uuid.uuid4()),
|
||||
"type": "api_key",
|
||||
},
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
||||
("uid", "test-uid-123"),
|
||||
("status", "completed"),
|
||||
],
|
||||
test_mock={
|
||||
"_make_api_request": lambda *args, **kwargs: {
|
||||
"uid": "test-uid-123",
|
||||
"status": "completed",
|
||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
||||
}
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def _make_api_request(self, payload: dict, api_key: str) -> dict:
|
||||
"""Make the actual API request to Bannerbear. This is separated for easy mocking in tests."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
"https://sync.api.bannerbear.com/v2/images",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status in [200, 201, 202]:
|
||||
return response.json()
|
||||
else:
|
||||
error_msg = f"API request failed with status {response.status}"
|
||||
if response.text:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = (
|
||||
f"{error_msg}: {error_data.get('message', response.text)}"
|
||||
)
|
||||
except Exception:
|
||||
error_msg = f"{error_msg}: {response.text}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Build the modifications array
|
||||
modifications = []
|
||||
|
||||
# Add text modifications
|
||||
for text_mod in input_data.text_modifications:
|
||||
mod_data: Dict[str, Any] = {
|
||||
"name": text_mod.name,
|
||||
"text": text_mod.text,
|
||||
}
|
||||
|
||||
# Add optional text styling parameters only if they have values
|
||||
if text_mod.color and text_mod.color.strip():
|
||||
mod_data["color"] = text_mod.color
|
||||
if text_mod.font_family and text_mod.font_family.strip():
|
||||
mod_data["font_family"] = text_mod.font_family
|
||||
if text_mod.font_size and text_mod.font_size > 0:
|
||||
mod_data["font_size"] = text_mod.font_size
|
||||
if text_mod.font_weight and text_mod.font_weight.strip():
|
||||
mod_data["font_weight"] = text_mod.font_weight
|
||||
if text_mod.text_align and text_mod.text_align.strip():
|
||||
mod_data["text_align"] = text_mod.text_align
|
||||
|
||||
modifications.append(mod_data)
|
||||
|
||||
# Add image modification if provided and not empty
|
||||
if input_data.image_url and input_data.image_url.strip():
|
||||
modifications.append(
|
||||
{
|
||||
"name": input_data.image_layer_name,
|
||||
"image_url": input_data.image_url,
|
||||
}
|
||||
)
|
||||
|
||||
# Build the request payload - only include non-empty optional fields
|
||||
payload = {
|
||||
"template": input_data.template_id,
|
||||
"modifications": modifications,
|
||||
}
|
||||
|
||||
# Add project_id if provided (required for Master API keys)
|
||||
if input_data.project_id and input_data.project_id.strip():
|
||||
payload["project_id"] = input_data.project_id
|
||||
|
||||
if input_data.webhook_url and input_data.webhook_url.strip():
|
||||
payload["webhook_url"] = input_data.webhook_url
|
||||
if input_data.metadata and input_data.metadata.strip():
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
# Make the API request using the private method
|
||||
data = await self._make_api_request(
|
||||
payload, credentials.api_key.get_secret_value()
|
||||
)
|
||||
|
||||
# Synchronous request - image should be ready
|
||||
yield "success", True
|
||||
yield "image_url", data.get("image_url", "")
|
||||
yield "uid", data.get("uid", "")
|
||||
yield "status", data.get("status", "completed")
|
||||
@@ -1094,117 +1094,6 @@ class GmailGetThreadBlock(GmailBase):
|
||||
return thread
|
||||
|
||||
|
||||
async def _build_reply_message(
|
||||
service, input_data, graph_exec_id: str, user_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Builds a reply MIME message for Gmail threads.
|
||||
|
||||
Returns:
|
||||
tuple: (base64-encoded raw message, threadId)
|
||||
"""
|
||||
# Get parent message for reply context
|
||||
parent = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.get(
|
||||
userId="me",
|
||||
id=input_data.parentMessageId,
|
||||
format="metadata",
|
||||
metadataHeaders=[
|
||||
"Subject",
|
||||
"References",
|
||||
"Message-ID",
|
||||
"From",
|
||||
"To",
|
||||
"Cc",
|
||||
"Reply-To",
|
||||
],
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Build headers dictionary, preserving all values for duplicate headers
|
||||
headers = {}
|
||||
for h in parent.get("payload", {}).get("headers", []):
|
||||
name = h["name"].lower()
|
||||
value = h["value"]
|
||||
if name in headers:
|
||||
# For duplicate headers, keep the first occurrence (most relevant for reply context)
|
||||
continue
|
||||
headers[name] = value
|
||||
|
||||
# Determine recipients if not specified
|
||||
if not (input_data.to or input_data.cc or input_data.bcc):
|
||||
if input_data.replyAll:
|
||||
recipients = [parseaddr(headers.get("from", ""))[1]]
|
||||
recipients += [addr for _, addr in getaddresses([headers.get("to", "")])]
|
||||
recipients += [addr for _, addr in getaddresses([headers.get("cc", "")])]
|
||||
# Use dict.fromkeys() for O(n) deduplication while preserving order
|
||||
input_data.to = list(dict.fromkeys(filter(None, recipients)))
|
||||
else:
|
||||
# Check Reply-To header first, fall back to From header
|
||||
reply_to = headers.get("reply-to", "")
|
||||
from_addr = headers.get("from", "")
|
||||
sender = parseaddr(reply_to if reply_to else from_addr)[1]
|
||||
input_data.to = [sender] if sender else []
|
||||
|
||||
# Set subject with Re: prefix if not already present
|
||||
if input_data.subject:
|
||||
subject = input_data.subject
|
||||
else:
|
||||
parent_subject = headers.get("subject", "").strip()
|
||||
# Only add "Re:" if not already present (case-insensitive check)
|
||||
if parent_subject.lower().startswith("re:"):
|
||||
subject = parent_subject
|
||||
else:
|
||||
subject = f"Re: {parent_subject}" if parent_subject else "Re:"
|
||||
|
||||
# Build references header for proper threading
|
||||
references = headers.get("references", "").split()
|
||||
if headers.get("message-id"):
|
||||
references.append(headers["message-id"])
|
||||
|
||||
# Create MIME message
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
if references:
|
||||
msg["References"] = " ".join(references)
|
||||
|
||||
# Use the helper function for consistent content type handling
|
||||
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
|
||||
|
||||
# Handle attachments
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
encoders.encode_base64(part)
|
||||
part.add_header(
|
||||
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
|
||||
)
|
||||
msg.attach(part)
|
||||
|
||||
# Encode message
|
||||
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
||||
return raw, input_data.threadId
|
||||
|
||||
|
||||
class GmailReplyBlock(GmailBase):
|
||||
"""
|
||||
Replies to Gmail threads with intelligent content type detection.
|
||||
@@ -1341,144 +1230,91 @@ class GmailReplyBlock(GmailBase):
|
||||
async def _reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Send the message
|
||||
return await asyncio.to_thread(
|
||||
parent = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"threadId": thread_id, "raw": raw})
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
class GmailDraftReplyBlock(GmailBase):
|
||||
"""
|
||||
Creates draft replies to Gmail threads with intelligent content type detection.
|
||||
|
||||
Features:
|
||||
- Automatic HTML detection: Draft replies containing HTML tags are formatted as text/html
|
||||
- No hard-wrap for plain text: Plain text draft replies preserve natural line flow
|
||||
- Manual content type override: Use content_type parameter to force specific format
|
||||
- Reply-all functionality: Option to reply to all original recipients
|
||||
- Thread preservation: Maintains proper email threading with headers
|
||||
- Full Unicode/emoji support with UTF-8 encoding
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
[
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
]
|
||||
)
|
||||
threadId: str = SchemaField(description="Thread ID to reply in")
|
||||
parentMessageId: str = SchemaField(
|
||||
description="ID of the message being replied to"
|
||||
)
|
||||
to: list[str] = SchemaField(description="To recipients", default_factory=list)
|
||||
cc: list[str] = SchemaField(description="CC recipients", default_factory=list)
|
||||
bcc: list[str] = SchemaField(description="BCC recipients", default_factory=list)
|
||||
replyAll: bool = SchemaField(
|
||||
description="Reply to all original recipients", default=False
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject", default="")
|
||||
body: str = SchemaField(description="Email body (plain text or HTML)")
|
||||
content_type: Optional[Literal["auto", "plain", "html"]] = SchemaField(
|
||||
description="Content type: 'auto' (default - detects HTML), 'plain', or 'html'",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
attachments: list[MediaFileType] = SchemaField(
|
||||
description="Files to attach", default_factory=list, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
draftId: str = SchemaField(description="Created draft ID")
|
||||
messageId: str = SchemaField(description="Draft message ID")
|
||||
threadId: str = SchemaField(description="Thread ID")
|
||||
status: str = SchemaField(description="Draft creation status")
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d7a9f3e2-8b4c-4d6f-9e1a-3c5b7f8d2a6e",
|
||||
description="Create draft replies to Gmail threads with automatic HTML detection and proper text formatting. Plain text draft replies maintain natural paragraph flow without 78-character line wrapping. HTML content is automatically detected and formatted correctly.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=GmailDraftReplyBlock.Input,
|
||||
output_schema=GmailDraftReplyBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"threadId": "t1",
|
||||
"parentMessageId": "m1",
|
||||
"body": "Thanks for your message. I'll review and get back to you.",
|
||||
"replyAll": False,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("draftId", "draft1"),
|
||||
("messageId", "m2"),
|
||||
("threadId", "t1"),
|
||||
("status", "draft_created"),
|
||||
],
|
||||
test_mock={
|
||||
"_create_draft_reply": lambda *args, **kwargs: {
|
||||
"id": "draft1",
|
||||
"message": {"id": "m2", "threadId": "t1"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
draft = await self._create_draft_reply(
|
||||
service,
|
||||
input_data,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
)
|
||||
yield "draftId", draft["id"]
|
||||
yield "messageId", draft["message"]["id"]
|
||||
yield "threadId", draft["message"].get("threadId", input_data.threadId)
|
||||
yield "status", "draft_created"
|
||||
|
||||
async def _create_draft_reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Create draft with proper thread association
|
||||
draft = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.drafts()
|
||||
.create(
|
||||
.get(
|
||||
userId="me",
|
||||
body={
|
||||
"message": {
|
||||
"threadId": thread_id,
|
||||
"raw": raw,
|
||||
}
|
||||
},
|
||||
id=input_data.parentMessageId,
|
||||
format="metadata",
|
||||
metadataHeaders=[
|
||||
"Subject",
|
||||
"References",
|
||||
"Message-ID",
|
||||
"From",
|
||||
"To",
|
||||
"Cc",
|
||||
"Reply-To",
|
||||
],
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return draft
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in parent.get("payload", {}).get("headers", [])
|
||||
}
|
||||
if not (input_data.to or input_data.cc or input_data.bcc):
|
||||
if input_data.replyAll:
|
||||
recipients = [parseaddr(headers.get("from", ""))[1]]
|
||||
recipients += [
|
||||
addr for _, addr in getaddresses([headers.get("to", "")])
|
||||
]
|
||||
recipients += [
|
||||
addr for _, addr in getaddresses([headers.get("cc", "")])
|
||||
]
|
||||
dedup: list[str] = []
|
||||
for r in recipients:
|
||||
if r and r not in dedup:
|
||||
dedup.append(r)
|
||||
input_data.to = dedup
|
||||
else:
|
||||
sender = parseaddr(headers.get("reply-to", headers.get("from", "")))[1]
|
||||
input_data.to = [sender] if sender else []
|
||||
subject = input_data.subject or (f"Re: {headers.get('subject', '')}".strip())
|
||||
references = headers.get("references", "").split()
|
||||
if headers.get("message-id"):
|
||||
references.append(headers["message-id"])
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
if references:
|
||||
msg["References"] = " ".join(references)
|
||||
# Use the new helper function for consistent content type handling
|
||||
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
|
||||
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
encoders.encode_base64(part)
|
||||
part.add_header(
|
||||
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
|
||||
)
|
||||
msg.attach(part)
|
||||
|
||||
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
||||
return await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"threadId": input_data.threadId, "raw": raw})
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
class GmailGetProfileBlock(GmailBase):
|
||||
|
||||
@@ -896,7 +896,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history]
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
"""Removes indentation up to and including `|` from a multi-line prompt."""
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
@@ -910,25 +909,24 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
if input_data.expected_format:
|
||||
expected_format = [
|
||||
f"{json.dumps(k)}: {json.dumps(v)}"
|
||||
for k, v in input_data.expected_format.items()
|
||||
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
|
||||
]
|
||||
if input_data.list_result:
|
||||
format_prompt = (
|
||||
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
|
||||
)
|
||||
else:
|
||||
format_prompt = ",\n| ".join(expected_format)
|
||||
format_prompt = "\n ".join(expected_format)
|
||||
|
||||
sys_prompt = trim_prompt(
|
||||
f"""
|
||||
|Reply with pure JSON strictly following this JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. DO NOT include any additional text (e.g. markdown code block fences) outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
|Reply strictly only in the following JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. Do not include any additional text outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
@@ -948,7 +946,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
return f"JSON decode error: {e}"
|
||||
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
error_feedback_message = ""
|
||||
retry_prompt = ""
|
||||
llm_model = input_data.model
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
@@ -972,25 +970,8 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
try:
|
||||
response_obj = json.loads(response_text)
|
||||
except JSONDecodeError as json_error:
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
|
||||
indented_json_error = str(json_error).replace("\n", "\n|")
|
||||
error_feedback_message = trim_prompt(
|
||||
f"""
|
||||
|Your previous response could not be parsed as valid JSON:
|
||||
|
|
||||
|{indented_json_error}
|
||||
|
|
||||
|Please provide a valid JSON response that matches the expected format.
|
||||
"""
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "user", "content": error_feedback_message}
|
||||
)
|
||||
continue
|
||||
response_obj = json.loads(response_text)
|
||||
|
||||
if input_data.list_result and isinstance(response_obj, dict):
|
||||
if "results" in response_obj:
|
||||
@@ -998,7 +979,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
elif len(response_obj) == 1:
|
||||
response_obj = list(response_obj.values())
|
||||
|
||||
validation_errors = "\n".join(
|
||||
response_error = "\n".join(
|
||||
[
|
||||
validation_error
|
||||
for response_item in (
|
||||
@@ -1010,7 +991,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
]
|
||||
)
|
||||
|
||||
if not validation_errors:
|
||||
if not response_error:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
@@ -1020,16 +1001,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "response", response_obj
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
error_feedback_message = trim_prompt(
|
||||
f"""
|
||||
|Your response did not match the expected format:
|
||||
|
|
||||
|{validation_errors}
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "user", "content": error_feedback_message})
|
||||
else:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
@@ -1040,6 +1011,21 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "response", {"response": response_text}
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
|
||||
retry_prompt = trim_prompt(
|
||||
f"""
|
||||
|This is your previous error response:
|
||||
|--
|
||||
|{response_text}
|
||||
|--
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{response_error}
|
||||
|--
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "user", "content": retry_prompt})
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
@@ -1052,12 +1038,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(
|
||||
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
|
||||
)
|
||||
# Don't add retry prompt for token limit errors,
|
||||
# just retry with lower maximum output tokens
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
|
||||
error_feedback_message = f"Error calling LLM: {e}"
|
||||
|
||||
raise RuntimeError(error_feedback_message)
|
||||
raise RuntimeError(retry_prompt)
|
||||
|
||||
|
||||
class AITextGeneratorBlock(AIBlockBase):
|
||||
|
||||
@@ -35,19 +35,20 @@ async def execute_graph(
|
||||
logger.info("Input data: %s", input_data)
|
||||
|
||||
# --- Test adding new executions --- #
|
||||
graph_exec = await agent_server.test_execute_graph(
|
||||
response = await agent_server.test_execute_graph(
|
||||
user_id=test_user.id,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
node_input=input_data,
|
||||
)
|
||||
logger.info("Created execution with ID: %s", graph_exec.id)
|
||||
graph_exec_id = response.graph_exec_id
|
||||
logger.info("Created execution with ID: %s", graph_exec_id)
|
||||
|
||||
# Execution queue should be empty
|
||||
logger.info("Waiting for execution to complete...")
|
||||
result = await wait_execution(test_user.id, graph_exec.id, 30)
|
||||
result = await wait_execution(test_user.id, graph_exec_id, 30)
|
||||
logger.info("Execution completed with %d results", len(result))
|
||||
return graph_exec.id
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .graph import NodeModel
|
||||
from .integrations import Webhook # noqa: F401
|
||||
|
||||
# Resolve Webhook forward references
|
||||
# Resolve Webhook <- NodeModel forward reference
|
||||
NodeModel.model_rebuild()
|
||||
LibraryAgentPreset.model_rebuild()
|
||||
|
||||
@@ -1,31 +1,57 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from autogpt_libs.api_key.key_manager import APIKeyManager
|
||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||
from prisma.errors import PrismaError
|
||||
from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import BaseModel, Field
|
||||
from prisma.types import (
|
||||
APIKeyCreateInput,
|
||||
APIKeyUpdateInput,
|
||||
APIKeyWhereInput,
|
||||
APIKeyWhereUniqueInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.data.db import BaseDbModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
keysmith = APIKeySmith()
|
||||
|
||||
|
||||
class APIKeyInfo(BaseModel):
|
||||
id: str
|
||||
# Some basic exceptions
|
||||
class APIKeyError(Exception):
|
||||
"""Base exception for API key operations"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyNotFoundError(APIKeyError):
|
||||
"""Raised when an API key is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyPermissionError(APIKeyError):
|
||||
"""Raised when there are permission issues with API key operations"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyValidationError(APIKeyError):
|
||||
"""Raised when API key validation fails"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKey(BaseDbModel):
|
||||
name: str
|
||||
head: str = Field(
|
||||
description=f"The first {APIKeySmith.HEAD_LENGTH} characters of the key"
|
||||
)
|
||||
tail: str = Field(
|
||||
description=f"The last {APIKeySmith.TAIL_LENGTH} characters of the key"
|
||||
)
|
||||
status: APIKeyStatus
|
||||
permissions: list[APIKeyPermission]
|
||||
prefix: str
|
||||
key: str
|
||||
status: APIKeyStatus = APIKeyStatus.ACTIVE
|
||||
permissions: List[APIKeyPermission]
|
||||
postfix: str
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
@@ -34,211 +60,266 @@ class APIKeyInfo(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
return APIKeyInfo(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
head=api_key.head,
|
||||
tail=api_key.tail,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
try:
|
||||
return APIKey(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
prefix=api_key.prefix,
|
||||
postfix=api_key.postfix,
|
||||
key=api_key.key,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating APIKey from db: {str(e)}")
|
||||
raise APIKeyError(f"Failed to create API key object: {str(e)}")
|
||||
|
||||
|
||||
class APIKeyInfoWithHash(APIKeyInfo):
|
||||
hash: str
|
||||
salt: str | None = None # None for legacy keys
|
||||
|
||||
def match(self, plaintext_key: str) -> bool:
|
||||
"""Returns whether the given key matches this API key object."""
|
||||
return keysmith.verify_key(plaintext_key, self.hash, self.salt)
|
||||
class APIKeyWithoutHash(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
status: APIKeyStatus
|
||||
permissions: List[APIKeyPermission]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime]
|
||||
revoked_at: Optional[datetime]
|
||||
description: Optional[str]
|
||||
user_id: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
return APIKeyInfoWithHash(
|
||||
**APIKeyInfo.from_db(api_key).model_dump(),
|
||||
hash=api_key.hash,
|
||||
salt=api_key.salt,
|
||||
)
|
||||
|
||||
def without_hash(self) -> APIKeyInfo:
|
||||
return APIKeyInfo(**self.model_dump(exclude={"hash", "salt"}))
|
||||
try:
|
||||
return APIKeyWithoutHash(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
prefix=api_key.prefix,
|
||||
postfix=api_key.postfix,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating APIKeyWithoutHash from db: {str(e)}")
|
||||
raise APIKeyError(f"Failed to create API key object: {str(e)}")
|
||||
|
||||
|
||||
async def create_api_key(
|
||||
async def generate_api_key(
|
||||
name: str,
|
||||
user_id: str,
|
||||
permissions: list[APIKeyPermission],
|
||||
permissions: List[APIKeyPermission],
|
||||
description: Optional[str] = None,
|
||||
) -> tuple[APIKeyInfo, str]:
|
||||
) -> tuple[APIKeyWithoutHash, str]:
|
||||
"""
|
||||
Generate a new API key and store it in the database.
|
||||
Returns the API key object (without hash) and the plain text key.
|
||||
"""
|
||||
generated_key = keysmith.generate_key()
|
||||
try:
|
||||
api_manager = APIKeyManager()
|
||||
key = api_manager.generate_api_key()
|
||||
|
||||
saved_key_obj = await PrismaAPIKey.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
}
|
||||
)
|
||||
api_key = await PrismaAPIKey.prisma().create(
|
||||
data=APIKeyCreateInput(
|
||||
id=str(uuid.uuid4()),
|
||||
name=name,
|
||||
prefix=key.prefix,
|
||||
postfix=key.postfix,
|
||||
key=key.hash,
|
||||
permissions=[p for p in permissions],
|
||||
description=description,
|
||||
userId=user_id,
|
||||
)
|
||||
)
|
||||
|
||||
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
||||
api_key_without_hash = APIKeyWithoutHash.from_db(api_key)
|
||||
return api_key_without_hash, key.raw
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while generating API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to generate API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while generating API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to generate API key: {str(e)}")
|
||||
|
||||
|
||||
async def get_active_api_keys_by_head(head: str) -> list[APIKeyInfoWithHash]:
|
||||
results = await PrismaAPIKey.prisma().find_many(
|
||||
where={"head": head, "status": APIKeyStatus.ACTIVE}
|
||||
)
|
||||
return [APIKeyInfoWithHash.from_db(key) for key in results]
|
||||
|
||||
|
||||
async def validate_api_key(plaintext_key: str) -> Optional[APIKeyInfo]:
|
||||
async def validate_api_key(plain_text_key: str) -> Optional[APIKey]:
|
||||
"""
|
||||
Validate an API key and return the API key object if valid and active.
|
||||
Validate an API key and return the API key object if valid.
|
||||
"""
|
||||
try:
|
||||
if not plaintext_key.startswith(APIKeySmith.PREFIX):
|
||||
if not plain_text_key.startswith(APIKeyManager.PREFIX):
|
||||
logger.warning("Invalid API key format")
|
||||
return None
|
||||
|
||||
head = plaintext_key[: APIKeySmith.HEAD_LENGTH]
|
||||
potential_matches = await get_active_api_keys_by_head(head)
|
||||
prefix = plain_text_key[: APIKeyManager.PREFIX_LENGTH]
|
||||
api_manager = APIKeyManager()
|
||||
|
||||
matched_api_key = next(
|
||||
(pm for pm in potential_matches if pm.match(plaintext_key)),
|
||||
None,
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where=APIKeyWhereInput(prefix=prefix, status=(APIKeyStatus.ACTIVE))
|
||||
)
|
||||
if not matched_api_key:
|
||||
# API key not found or invalid
|
||||
|
||||
if not api_key:
|
||||
logger.warning(f"No active API key found with prefix {prefix}")
|
||||
return None
|
||||
|
||||
# Migrate legacy keys to secure format on successful validation
|
||||
if matched_api_key.salt is None:
|
||||
matched_api_key = await _migrate_key_to_secure_hash(
|
||||
plaintext_key, matched_api_key
|
||||
is_valid = api_manager.verify_api_key(plain_text_key, api_key.key)
|
||||
if not is_valid:
|
||||
logger.warning("API key verification failed")
|
||||
return None
|
||||
|
||||
return APIKey.from_db(api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating API key: {str(e)}")
|
||||
raise APIKeyValidationError(f"Failed to validate API key: {str(e)}")
|
||||
|
||||
|
||||
async def revoke_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise APIKeyNotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to revoke this API key."
|
||||
)
|
||||
|
||||
return matched_api_key.without_hash()
|
||||
except Exception as e:
|
||||
logger.error(f"Error while validating API key: {e}")
|
||||
raise RuntimeError("Failed to validate API key") from e
|
||||
|
||||
|
||||
async def _migrate_key_to_secure_hash(
|
||||
plaintext_key: str, key_obj: APIKeyInfoWithHash
|
||||
) -> APIKeyInfoWithHash:
|
||||
"""Replace the SHA256 hash of a legacy API key with a salted Scrypt hash."""
|
||||
try:
|
||||
new_hash, new_salt = keysmith.hash_key(plaintext_key)
|
||||
await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_obj.id}, data={"hash": new_hash, "salt": new_salt}
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(
|
||||
status=APIKeyStatus.REVOKED, revokedAt=datetime.now(timezone.utc)
|
||||
),
|
||||
)
|
||||
logger.info(f"Migrated legacy API key #{key_obj.id} to secure format")
|
||||
# Update the API key object with new values for return
|
||||
key_obj.hash = new_hash
|
||||
key_obj.salt = new_salt
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate legacy API key #{key_obj.id}: {e}")
|
||||
|
||||
return key_obj
|
||||
|
||||
|
||||
async def revoke_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise NotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to revoke this API key.")
|
||||
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_id},
|
||||
data={
|
||||
"status": APIKeyStatus.REVOKED,
|
||||
"revokedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to revoke.")
|
||||
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> list[APIKeyInfo]:
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
return [APIKeyInfo.from_db(key) for key in api_keys]
|
||||
|
||||
|
||||
async def suspend_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
selector: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where=selector)
|
||||
|
||||
if not api_key:
|
||||
raise NotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to suspend this API key.")
|
||||
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=selector, data={"status": APIKeyStatus.SUSPENDED}
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to suspend.")
|
||||
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
def has_permission(api_key: APIKeyInfo, required_permission: APIKeyPermission) -> bool:
|
||||
return required_permission in api_key.permissions
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyInfo]:
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where={"id": key_id, "userId": user_id}
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while revoking API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to revoke API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while revoking API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to revoke API key: {str(e)}")
|
||||
|
||||
return APIKeyInfo.from_db(api_key)
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> List[APIKeyWithoutHash]:
|
||||
try:
|
||||
where_clause: APIKeyWhereInput = {"userId": user_id}
|
||||
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where=where_clause, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
return [APIKeyWithoutHash.from_db(key) for key in api_keys]
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while listing API keys: {str(e)}")
|
||||
raise APIKeyError(f"Failed to list API keys: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing API keys: {str(e)}")
|
||||
raise APIKeyError(f"Failed to list API keys: {str(e)}")
|
||||
|
||||
|
||||
async def suspend_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise APIKeyNotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to suspend this API key."
|
||||
)
|
||||
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(status=APIKeyStatus.SUSPENDED),
|
||||
)
|
||||
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while suspending API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to suspend API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while suspending API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to suspend API key: {str(e)}")
|
||||
|
||||
|
||||
def has_permission(api_key: APIKey, required_permission: APIKeyPermission) -> bool:
|
||||
try:
|
||||
return required_permission in api_key.permissions
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking API key permissions: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where=APIKeyWhereInput(id=key_id, userId=user_id)
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
return APIKeyWithoutHash.from_db(api_key)
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while getting API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to get API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to get API key: {str(e)}")
|
||||
|
||||
|
||||
async def update_api_key_permissions(
|
||||
key_id: str, user_id: str, permissions: list[APIKeyPermission]
|
||||
) -> APIKeyInfo:
|
||||
key_id: str, user_id: str, permissions: List[APIKeyPermission]
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
"""
|
||||
Update the permissions of an API key.
|
||||
"""
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if api_key is None:
|
||||
raise NotFoundError("No such API key found.")
|
||||
if api_key is None:
|
||||
raise APIKeyNotFoundError("No such API key found.")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to update this API key.")
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to update this API key."
|
||||
)
|
||||
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_id},
|
||||
data={"permissions": permissions},
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to update.")
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(permissions=permissions),
|
||||
)
|
||||
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while updating API key permissions: {str(e)}")
|
||||
raise APIKeyError(f"Failed to update API key permissions: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while updating API key permissions: {str(e)}")
|
||||
raise APIKeyError(f"Failed to update API key permissions: {str(e)}")
|
||||
|
||||
@@ -8,7 +8,6 @@ from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Generic,
|
||||
Optional,
|
||||
@@ -45,10 +44,9 @@ if TYPE_CHECKING:
|
||||
|
||||
app_config = Config()
|
||||
|
||||
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
|
||||
BlockInput = dict[str, Any] # Input: 1 input pin consumes 1 data.
|
||||
BlockOutputEntry = tuple[str, Any] # Output data should be a tuple of (name, value).
|
||||
BlockOutput = AsyncGen[BlockOutputEntry, None] # Output: 1 output pin produces n data.
|
||||
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
|
||||
BlockOutput = AsyncGen[BlockData, None] # Output: 1 output pin produces n data.
|
||||
CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a dict.
|
||||
|
||||
|
||||
@@ -91,45 +89,6 @@ class BlockCategory(Enum):
|
||||
return {"category": self.name, "description": self.value}
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
**data,
|
||||
)
|
||||
|
||||
|
||||
class BlockInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
inputSchema: dict[str, Any]
|
||||
outputSchema: dict[str, Any]
|
||||
costs: list[BlockCost]
|
||||
description: str
|
||||
categories: list[dict[str, str]]
|
||||
contributors: list[dict[str, Any]]
|
||||
staticOutput: bool
|
||||
uiType: str
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||
|
||||
@@ -347,7 +306,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
input_schema: Type[BlockSchemaInputType] = EmptySchema,
|
||||
output_schema: Type[BlockSchemaOutputType] = EmptySchema,
|
||||
test_input: BlockInput | list[BlockInput] | None = None,
|
||||
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
|
||||
test_output: BlockData | list[BlockData] | None = None,
|
||||
test_mock: dict[str, Any] | None = None,
|
||||
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
|
||||
disabled: bool = False,
|
||||
@@ -493,24 +452,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
"uiType": self.block_type.value,
|
||||
}
|
||||
|
||||
def get_info(self) -> BlockInfo:
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
return BlockInfo(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
inputSchema=self.input_schema.jsonschema(),
|
||||
outputSchema=self.output_schema.jsonschema(),
|
||||
costs=get_block_cost(self),
|
||||
description=self.description,
|
||||
categories=[category.dict() for category in self.categories],
|
||||
contributors=[
|
||||
contributor.model_dump() for contributor in self.contributors
|
||||
],
|
||||
staticOutput=self.static_output,
|
||||
uiType=self.block_type.value,
|
||||
)
|
||||
|
||||
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise ValueError(
|
||||
|
||||
@@ -29,7 +29,8 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.data.block import Block, BlockCost, BlockCostType
|
||||
from backend.data.block import Block
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
anthropic_credentials,
|
||||
|
||||
32
autogpt_platform/backend/backend/data/cost.py
Normal file
32
autogpt_platform/backend/backend/data/cost.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
**data,
|
||||
)
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import stripe
|
||||
from prisma import Json
|
||||
@@ -23,6 +23,7 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -40,9 +41,6 @@ from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block, BlockCost
|
||||
|
||||
settings = Settings()
|
||||
stripe.api_key = settings.secrets.stripe_api_key
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -999,14 +997,10 @@ def get_user_credit_model() -> UserCreditBase:
|
||||
return UserCredit()
|
||||
|
||||
|
||||
def get_block_costs() -> dict[str, list["BlockCost"]]:
|
||||
def get_block_costs() -> dict[str, list[BlockCost]]:
|
||||
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
|
||||
|
||||
|
||||
def get_block_cost(block: "Block") -> list["BlockCost"]:
|
||||
return BLOCK_COSTS.get(block.__class__, [])
|
||||
|
||||
|
||||
async def get_stripe_customer_id(user_id: str) -> str:
|
||||
user = await get_user_by_id(user_id)
|
||||
|
||||
|
||||
@@ -11,14 +11,11 @@ from typing import (
|
||||
Generator,
|
||||
Generic,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
@@ -27,6 +24,7 @@ from prisma.models import (
|
||||
AgentNodeExecutionKeyValueData,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
AgentGraphExecutionUpdateManyMutationInput,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
@@ -62,7 +60,7 @@ from .includes import (
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
graph_execution_include,
|
||||
)
|
||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -89,33 +87,6 @@ class BlockErrorStats(BaseModel):
|
||||
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
NodeInputMask = Mapping[str, JsonValue]
|
||||
NodesInputMasks = Mapping[str, NodeInputMask]
|
||||
|
||||
# dest: source
|
||||
VALID_STATUS_TRANSITIONS = {
|
||||
ExecutionStatus.QUEUED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
ExecutionStatus.RUNNING: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.TERMINATED, # For resuming halted execution
|
||||
],
|
||||
ExecutionStatus.COMPLETED: [
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
ExecutionStatus.FAILED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
ExecutionStatus.TERMINATED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class GraphExecutionMeta(BaseDbModel):
|
||||
@@ -123,15 +94,10 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
user_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
inputs: Optional[BlockInput] # no default -> required in the OpenAPI spec
|
||||
credential_inputs: Optional[dict[str, CredentialsMetaInput]]
|
||||
nodes_input_masks: Optional[dict[str, BlockInput]]
|
||||
preset_id: Optional[str]
|
||||
preset_id: Optional[str] = None
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
is_shared: bool = False
|
||||
share_token: Optional[str] = None
|
||||
|
||||
class Stats(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
@@ -213,18 +179,6 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
user_id=_graph_exec.userId,
|
||||
graph_id=_graph_exec.agentGraphId,
|
||||
graph_version=_graph_exec.agentGraphVersion,
|
||||
inputs=cast(BlockInput | None, _graph_exec.inputs),
|
||||
credential_inputs=(
|
||||
{
|
||||
name: CredentialsMetaInput.model_validate(cmi)
|
||||
for name, cmi in cast(dict, _graph_exec.credentialInputs).items()
|
||||
}
|
||||
if _graph_exec.credentialInputs
|
||||
else None
|
||||
),
|
||||
nodes_input_masks=cast(
|
||||
dict[str, BlockInput] | None, _graph_exec.nodesInputMasks
|
||||
),
|
||||
preset_id=_graph_exec.agentPresetId,
|
||||
status=ExecutionStatus(_graph_exec.executionStatus),
|
||||
started_at=start_time,
|
||||
@@ -248,13 +202,11 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
if stats
|
||||
else None
|
||||
),
|
||||
is_shared=_graph_exec.isShared,
|
||||
share_token=_graph_exec.shareToken,
|
||||
)
|
||||
|
||||
|
||||
class GraphExecution(GraphExecutionMeta):
|
||||
inputs: BlockInput # type: ignore - incompatible override is intentional
|
||||
inputs: BlockInput
|
||||
outputs: CompletedBlockOutput
|
||||
|
||||
@staticmethod
|
||||
@@ -274,18 +226,15 @@ class GraphExecution(GraphExecutionMeta):
|
||||
)
|
||||
|
||||
inputs = {
|
||||
**(
|
||||
graph_exec.inputs
|
||||
or {
|
||||
# fallback: extract inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
)
|
||||
}
|
||||
),
|
||||
**{
|
||||
# inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
)
|
||||
},
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
@@ -303,13 +252,14 @@ class GraphExecution(GraphExecutionMeta):
|
||||
if (
|
||||
block := get_block(exec.block_id)
|
||||
) and block.block_type == BlockType.OUTPUT:
|
||||
outputs[exec.input_data["name"]].append(exec.input_data.get("value"))
|
||||
outputs[exec.input_data["name"]].append(
|
||||
exec.input_data.get("value", None)
|
||||
)
|
||||
|
||||
return GraphExecution(
|
||||
**{
|
||||
field_name: getattr(graph_exec, field_name)
|
||||
for field_name in GraphExecutionMeta.model_fields
|
||||
if field_name != "inputs"
|
||||
},
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
@@ -342,17 +292,13 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
def to_graph_execution_entry(
|
||||
self,
|
||||
user_context: "UserContext",
|
||||
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
):
|
||||
def to_graph_execution_entry(self, user_context: "UserContext"):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
|
||||
user_context=user_context,
|
||||
)
|
||||
|
||||
@@ -369,9 +315,10 @@ class NodeExecutionResult(BaseModel):
|
||||
input_data: BlockInput
|
||||
output_data: CompletedBlockOutput
|
||||
add_time: datetime
|
||||
queue_time: datetime | None
|
||||
start_time: datetime | None
|
||||
end_time: datetime | None
|
||||
queue_time: datetime | None = None
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
stats: NodeExecutionStats | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
@@ -389,7 +336,7 @@ class NodeExecutionResult(BaseModel):
|
||||
else:
|
||||
input_data: BlockInput = defaultdict()
|
||||
for data in _node_exec.Input or []:
|
||||
input_data[data.name] = type_utils.convert(data.data, JsonValue)
|
||||
input_data[data.name] = type_utils.convert(data.data, type[Any])
|
||||
|
||||
output_data: CompletedBlockOutput = defaultdict(list)
|
||||
|
||||
@@ -398,7 +345,7 @@ class NodeExecutionResult(BaseModel):
|
||||
output_data[name].extend(messages)
|
||||
else:
|
||||
for data in _node_exec.Output or []:
|
||||
output_data[data.name].append(type_utils.convert(data.data, JsonValue))
|
||||
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
|
||||
|
||||
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
|
||||
if graph_execution:
|
||||
@@ -423,6 +370,7 @@ class NodeExecutionResult(BaseModel):
|
||||
queue_time=_node_exec.queuedTime,
|
||||
start_time=_node_exec.startedTime,
|
||||
end_time=_node_exec.endedTime,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
def to_node_execution_entry(
|
||||
@@ -593,12 +541,9 @@ async def get_graph_execution(
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
starting_nodes_input: list[tuple[str, BlockInput]], # list[(node_id, BlockInput)]
|
||||
inputs: Mapping[str, JsonValue],
|
||||
starting_nodes_input: list[tuple[str, BlockInput]],
|
||||
user_id: str,
|
||||
preset_id: Optional[str] = None,
|
||||
credential_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
@@ -606,18 +551,11 @@ async def create_graph_execution(
|
||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||
"""
|
||||
result = await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
),
|
||||
"nodesInputMasks": (
|
||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||
),
|
||||
"NodeExecutions": {
|
||||
data=AgentGraphExecutionCreateInput(
|
||||
agentGraphId=graph_id,
|
||||
agentGraphVersion=graph_version,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
NodeExecutions={
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
@@ -633,9 +571,9 @@ async def create_graph_execution(
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
},
|
||||
userId=user_id,
|
||||
agentPresetId=preset_id,
|
||||
),
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -646,7 +584,7 @@ async def upsert_execution_input(
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: JsonValue,
|
||||
input_data: Any,
|
||||
node_exec_id: str | None = None,
|
||||
) -> tuple[str, BlockInput]:
|
||||
"""
|
||||
@@ -695,7 +633,7 @@ async def upsert_execution_input(
|
||||
)
|
||||
return existing_execution.id, {
|
||||
**{
|
||||
input_data.name: type_utils.convert(input_data.data, JsonValue)
|
||||
input_data.name: type_utils.convert(input_data.data, type[Any])
|
||||
for input_data in existing_execution.Input or []
|
||||
},
|
||||
input_name: input_data,
|
||||
@@ -718,6 +656,42 @@ async def upsert_execution_input(
|
||||
)
|
||||
|
||||
|
||||
async def create_node_execution(
|
||||
node_exec_id: str,
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
) -> None:
|
||||
"""Create a new node execution with the first input."""
|
||||
json_input_data = SafeJson(input_data)
|
||||
await AgentNodeExecution.prisma().create(
|
||||
data=AgentNodeExecutionCreateInput(
|
||||
id=node_exec_id,
|
||||
agentNodeId=node_id,
|
||||
agentGraphExecutionId=graph_exec_id,
|
||||
executionStatus=ExecutionStatus.INCOMPLETE,
|
||||
Input={"create": {"name": input_name, "data": json_input_data}},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def add_input_to_node_execution(
|
||||
node_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
) -> None:
|
||||
"""Add an input to an existing node execution."""
|
||||
json_input_data = SafeJson(input_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data=AgentNodeExecutionInputOutputCreateInput(
|
||||
name=input_name,
|
||||
data=json_input_data,
|
||||
referencedByInputExecId=node_exec_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def upsert_execution_output(
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
@@ -756,11 +730,6 @@ async def update_graph_execution_stats(
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
if not status and not stats:
|
||||
raise ValueError(
|
||||
f"Must provide either status or stats to update for execution {graph_exec_id}"
|
||||
)
|
||||
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
|
||||
|
||||
if stats:
|
||||
@@ -772,25 +741,20 @@ async def update_graph_execution_stats(
|
||||
if status:
|
||||
update_data["executionStatus"] = status
|
||||
|
||||
where_clause: AgentGraphExecutionWhereInput = {"id": graph_exec_id}
|
||||
|
||||
if status:
|
||||
if allowed_from := VALID_STATUS_TRANSITIONS.get(status, []):
|
||||
# Add OR clause to check if current status is one of the allowed source statuses
|
||||
where_clause["AND"] = [
|
||||
{"id": graph_exec_id},
|
||||
{"OR": [{"executionStatus": s} for s in allowed_from]},
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Status {status} cannot be set via update for execution {graph_exec_id}. "
|
||||
f"This status can only be set at creation or is not a valid target status."
|
||||
)
|
||||
|
||||
await AgentGraphExecution.prisma().update_many(
|
||||
where=where_clause,
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
{"executionStatus": ExecutionStatus.RUNNING},
|
||||
{"executionStatus": ExecutionStatus.QUEUED},
|
||||
# Terminated graph can be resumed.
|
||||
{"executionStatus": ExecutionStatus.TERMINATED},
|
||||
],
|
||||
},
|
||||
data=update_data,
|
||||
)
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
|
||||
where={"id": graph_exec_id},
|
||||
@@ -798,7 +762,6 @@ async def update_graph_execution_stats(
|
||||
[*get_io_block_ids(), *get_webhook_block_ids()]
|
||||
),
|
||||
)
|
||||
|
||||
return GraphExecution.from_db(graph_exec)
|
||||
|
||||
|
||||
@@ -963,7 +926,7 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
|
||||
user_context: UserContext
|
||||
|
||||
|
||||
@@ -1025,18 +988,6 @@ class NodeExecutionEvent(NodeExecutionResult):
|
||||
)
|
||||
|
||||
|
||||
class SharedExecutionResponse(BaseModel):
|
||||
"""Public-safe response for shared executions"""
|
||||
|
||||
id: str
|
||||
graph_name: str
|
||||
graph_description: Optional[str]
|
||||
status: ExecutionStatus
|
||||
created_at: datetime
|
||||
outputs: CompletedBlockOutput # Only the final outputs, no intermediate data
|
||||
# Deliberately exclude: user_id, inputs, credentials, node details
|
||||
|
||||
|
||||
ExecutionEvent = Annotated[
|
||||
GraphExecutionEvent | NodeExecutionEvent, Field(discriminator="event_type")
|
||||
]
|
||||
@@ -1214,98 +1165,3 @@ async def get_block_error_stats(
|
||||
)
|
||||
for row in result
|
||||
]
|
||||
|
||||
|
||||
async def update_graph_execution_share_status(
|
||||
execution_id: str,
|
||||
user_id: str,
|
||||
is_shared: bool,
|
||||
share_token: str | None,
|
||||
shared_at: datetime | None,
|
||||
) -> None:
|
||||
"""Update the sharing status of a graph execution."""
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": execution_id},
|
||||
data={
|
||||
"isShared": is_shared,
|
||||
"shareToken": share_token,
|
||||
"sharedAt": shared_at,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_execution_by_share_token(
|
||||
share_token: str,
|
||||
) -> SharedExecutionResponse | None:
|
||||
"""Get a shared execution with limited public-safe data."""
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={
|
||||
"shareToken": share_token,
|
||||
"isShared": True,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={
|
||||
"AgentGraph": True,
|
||||
"NodeExecutions": {
|
||||
"include": {
|
||||
"Output": True,
|
||||
"Node": {
|
||||
"include": {
|
||||
"AgentBlock": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if not execution:
|
||||
return None
|
||||
|
||||
# Extract outputs from OUTPUT blocks only (consistent with GraphExecution.from_db)
|
||||
outputs: CompletedBlockOutput = defaultdict(list)
|
||||
if execution.NodeExecutions:
|
||||
for node_exec in execution.NodeExecutions:
|
||||
if node_exec.Node and node_exec.Node.agentBlockId:
|
||||
# Get the block definition to check its type
|
||||
block = get_block(node_exec.Node.agentBlockId)
|
||||
|
||||
if block and block.block_type == BlockType.OUTPUT:
|
||||
# For OUTPUT blocks, the data is stored in executionData or Input
|
||||
# The executionData contains the structured input with 'name' and 'value' fields
|
||||
if hasattr(node_exec, "executionData") and node_exec.executionData:
|
||||
exec_data = type_utils.convert(
|
||||
node_exec.executionData, dict[str, Any]
|
||||
)
|
||||
if "name" in exec_data:
|
||||
name = exec_data["name"]
|
||||
value = exec_data.get("value")
|
||||
outputs[name].append(value)
|
||||
elif node_exec.Input:
|
||||
# Build input_data from Input relation
|
||||
input_data = {}
|
||||
for data in node_exec.Input:
|
||||
if data.name and data.data is not None:
|
||||
input_data[data.name] = type_utils.convert(
|
||||
data.data, JsonValue
|
||||
)
|
||||
|
||||
if "name" in input_data:
|
||||
name = input_data["name"]
|
||||
value = input_data.get("value")
|
||||
outputs[name].append(value)
|
||||
|
||||
return SharedExecutionResponse(
|
||||
id=execution.id,
|
||||
graph_name=(
|
||||
execution.AgentGraph.name
|
||||
if (execution.AgentGraph and execution.AgentGraph.name)
|
||||
else "Untitled Agent"
|
||||
),
|
||||
graph_description=(
|
||||
execution.AgentGraph.description if execution.AgentGraph else None
|
||||
),
|
||||
status=ExecutionStatus(execution.executionStatus),
|
||||
created_at=execution.createdAt,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from prisma.enums import SubmissionStatus
|
||||
@@ -13,7 +12,7 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic import Field, JsonValue, create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -35,7 +34,6 @@ from .db import BaseDbModel, query_raw_with_schema, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .execution import NodesInputMasks
|
||||
from .integrations import Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -161,8 +159,6 @@ class BaseGraph(BaseDbModel):
|
||||
is_active: bool = True
|
||||
name: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
recommended_schedule_cron: str | None = None
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
forked_from_id: str | None = None
|
||||
@@ -209,35 +205,6 @@ class BaseGraph(BaseDbModel):
|
||||
None,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def trigger_setup_info(self) -> "GraphTriggerInfo | None":
|
||||
if not (
|
||||
self.webhook_input_node
|
||||
and (trigger_block := self.webhook_input_node.block).webhook_config
|
||||
):
|
||||
return None
|
||||
|
||||
return GraphTriggerInfo(
|
||||
provider=trigger_block.webhook_config.provider,
|
||||
config_schema={
|
||||
**(json_schema := trigger_block.input_schema.jsonschema()),
|
||||
"properties": {
|
||||
pn: sub_schema
|
||||
for pn, sub_schema in json_schema["properties"].items()
|
||||
if not is_credentials_field_name(pn)
|
||||
},
|
||||
"required": [
|
||||
pn
|
||||
for pn in json_schema.get("required", [])
|
||||
if not is_credentials_field_name(pn)
|
||||
],
|
||||
},
|
||||
credentials_input_name=next(
|
||||
iter(trigger_block.input_schema.get_credentials_fields()), None
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
@@ -271,14 +238,6 @@ class BaseGraph(BaseDbModel):
|
||||
}
|
||||
|
||||
|
||||
class GraphTriggerInfo(BaseModel):
|
||||
provider: ProviderName
|
||||
config_schema: dict[str, Any] = Field(
|
||||
description="Input schema for the trigger block"
|
||||
)
|
||||
credentials_input_name: Optional[str]
|
||||
|
||||
|
||||
class Graph(BaseGraph):
|
||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
||||
|
||||
@@ -383,8 +342,6 @@ class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
|
||||
created_at: datetime
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[NodeModel]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
@@ -397,10 +354,6 @@ class GraphModel(Graph):
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||
return cast(NodeModel, super().webhook_input_node)
|
||||
|
||||
def meta(self) -> "GraphMeta":
|
||||
"""
|
||||
Returns a GraphMeta object with metadata about the graph.
|
||||
@@ -461,7 +414,7 @@ class GraphModel(Graph):
|
||||
def validate_graph(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
"""
|
||||
Validate graph structure and raise `ValueError` on issues.
|
||||
@@ -475,7 +428,7 @@ class GraphModel(Graph):
|
||||
def _validate_graph(
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> None:
|
||||
errors = GraphModel._validate_graph_get_errors(
|
||||
graph, for_run, nodes_input_masks
|
||||
@@ -489,7 +442,7 @@ class GraphModel(Graph):
|
||||
def validate_graph_get_errors(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph and return structured errors per node.
|
||||
@@ -511,7 +464,7 @@ class GraphModel(Graph):
|
||||
def _validate_graph_get_errors(
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph and return structured errors per node.
|
||||
@@ -702,12 +655,9 @@ class GraphModel(Graph):
|
||||
version=graph.version,
|
||||
forked_from_id=graph.forkedFromId,
|
||||
forked_from_version=graph.forkedFromVersion,
|
||||
created_at=graph.createdAt,
|
||||
is_active=graph.isActive,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
instructions=graph.instructions,
|
||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
||||
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
|
||||
links=list(
|
||||
{
|
||||
@@ -1095,7 +1045,6 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
version=graph.version,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
recommendedScheduleCron=graph.recommended_schedule_cron,
|
||||
isActive=graph.is_active,
|
||||
userId=user_id,
|
||||
forkedFromId=graph.forked_from_id,
|
||||
@@ -1154,7 +1103,6 @@ def make_graph_model(creatable_graph: Graph, user_id: str) -> GraphModel:
|
||||
return GraphModel(
|
||||
**creatable_graph.model_dump(exclude={"nodes"}),
|
||||
user_id=user_id,
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
nodes=[
|
||||
NodeModel(
|
||||
**creatable_node.model_dump(),
|
||||
|
||||
@@ -59,15 +59,9 @@ def graph_execution_include(
|
||||
}
|
||||
|
||||
|
||||
AGENT_PRESET_INCLUDE: prisma.types.AgentPresetInclude = {
|
||||
"InputPresets": True,
|
||||
"Webhook": True,
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE},
|
||||
"AgentPresets": {"include": AGENT_PRESET_INCLUDE},
|
||||
"AgentPresets": {"include": {"InputPresets": True}},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ async def generate_activity_status_for_execution(
|
||||
# Check if we have OpenAI API key
|
||||
try:
|
||||
settings = Settings()
|
||||
if not settings.secrets.openai_internal_api_key:
|
||||
if not settings.secrets.openai_api_key:
|
||||
logger.debug(
|
||||
"OpenAI API key not configured, skipping activity status generation"
|
||||
)
|
||||
@@ -115,7 +115,7 @@ async def generate_activity_status_for_execution(
|
||||
|
||||
# Get all node executions for this graph execution
|
||||
node_executions = await db_client.get_node_executions(
|
||||
graph_exec_id, include_exec_data=True
|
||||
graph_exec_id=graph_exec_id, include_exec_data=True
|
||||
)
|
||||
|
||||
# Get graph metadata and full graph structure for name, description, and links
|
||||
@@ -187,7 +187,7 @@ async def generate_activity_status_for_execution(
|
||||
credentials = APIKeyCredentials(
|
||||
id="openai",
|
||||
provider="openai",
|
||||
api_key=SecretStr(settings.secrets.openai_internal_api_key),
|
||||
api_key=SecretStr(settings.secrets.openai_api_key),
|
||||
title="System OpenAI",
|
||||
)
|
||||
|
||||
|
||||
@@ -468,7 +468,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_llm.return_value = (
|
||||
"I analyzed your data and provided the requested insights."
|
||||
)
|
||||
@@ -520,7 +520,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = ""
|
||||
mock_settings.return_value.secrets.openai_api_key = ""
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -546,7 +546,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -581,7 +581,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_llm.return_value = "Agent completed execution."
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
@@ -633,7 +633,7 @@ class TestIntegration:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
|
||||
mock_response = LLMResponse(
|
||||
raw_response={},
|
||||
|
||||
@@ -4,12 +4,13 @@ from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
add_input_to_node_execution,
|
||||
create_graph_execution,
|
||||
create_node_execution,
|
||||
get_block_error_stats,
|
||||
get_execution_kv_data,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_latest_node_execution,
|
||||
get_node_execution,
|
||||
get_node_executions,
|
||||
set_execution_kv_data,
|
||||
@@ -17,7 +18,6 @@ from backend.data.execution import (
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_status,
|
||||
update_node_execution_status_batch,
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.generate_data import get_user_execution_summary_data
|
||||
@@ -105,13 +105,13 @@ class DatabaseManager(AppService):
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_execution = _(get_node_execution)
|
||||
get_node_executions = _(get_node_executions)
|
||||
get_latest_node_execution = _(get_latest_node_execution)
|
||||
update_node_execution_status = _(update_node_execution_status)
|
||||
update_node_execution_status_batch = _(update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
create_node_execution = _(create_node_execution)
|
||||
add_input_to_node_execution = _(add_input_to_node_execution)
|
||||
get_execution_kv_data = _(get_execution_kv_data)
|
||||
set_execution_kv_data = _(set_execution_kv_data)
|
||||
get_block_error_stats = _(get_block_error_stats)
|
||||
@@ -171,10 +171,12 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
get_graph_executions = _(d.get_graph_executions)
|
||||
get_graph_execution_meta = _(d.get_graph_execution_meta)
|
||||
get_node_executions = _(d.get_node_executions)
|
||||
create_node_execution = _(d.create_node_execution)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(d.update_graph_execution_stats)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
add_input_to_node_execution = _(d.add_input_to_node_execution)
|
||||
|
||||
# Graphs
|
||||
get_graph_metadata = _(d.get_graph_metadata)
|
||||
@@ -189,14 +191,6 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
# User Emails
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
|
||||
# Library
|
||||
list_library_agents = _(d.list_library_agents)
|
||||
add_store_agent_to_library = _(d.add_store_agent_to_library)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(d.get_store_agents)
|
||||
get_store_agent_details = _(d.get_store_agent_details)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -207,16 +201,12 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
create_graph_execution = d.create_graph_execution
|
||||
get_connected_output_nodes = d.get_connected_output_nodes
|
||||
get_latest_node_execution = d.get_latest_node_execution
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
update_graph_execution_stats = d.update_graph_execution_stats
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
|
||||
154
autogpt_platform/backend/backend/executor/execution_cache.py
Normal file
154
autogpt_platform/backend/backend/executor/execution_cache.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def with_lock(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self._lock:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ExecutionCache:
|
||||
def __init__(self, graph_exec_id: str, db_client: "DatabaseManagerClient"):
|
||||
self._lock = threading.RLock()
|
||||
self._graph_exec_id = graph_exec_id
|
||||
self._graph_stats: GraphExecutionStats = GraphExecutionStats()
|
||||
self._node_executions: OrderedDict[str, NodeExecutionResult] = OrderedDict()
|
||||
|
||||
for execution in db_client.get_node_executions(self._graph_exec_id):
|
||||
self._node_executions[execution.node_exec_id] = execution
|
||||
|
||||
@with_lock
|
||||
def get_node_execution(self, node_exec_id: str) -> Optional[NodeExecutionResult]:
|
||||
execution = self._node_executions.get(node_exec_id)
|
||||
return execution.model_copy(deep=True) if execution else None
|
||||
|
||||
@with_lock
|
||||
def get_latest_node_execution(self, node_id: str) -> Optional[NodeExecutionResult]:
|
||||
for execution in reversed(self._node_executions.values()):
|
||||
if (
|
||||
execution.node_id == node_id
|
||||
and execution.status != ExecutionStatus.INCOMPLETE
|
||||
):
|
||||
return execution.model_copy(deep=True)
|
||||
return None
|
||||
|
||||
@with_lock
|
||||
def get_node_executions(
|
||||
self,
|
||||
*,
|
||||
statuses: Optional[list] = None,
|
||||
block_ids: Optional[list] = None,
|
||||
node_id: Optional[str] = None,
|
||||
):
|
||||
results = []
|
||||
for execution in self._node_executions.values():
|
||||
if statuses and execution.status not in statuses:
|
||||
continue
|
||||
if block_ids and execution.block_id not in block_ids:
|
||||
continue
|
||||
if node_id and execution.node_id != node_id:
|
||||
continue
|
||||
results.append(execution.model_copy(deep=True))
|
||||
return results
|
||||
|
||||
@with_lock
|
||||
def update_node_execution_status(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: Optional[dict] = None,
|
||||
stats: Optional[dict] = None,
|
||||
):
|
||||
if exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {exec_id} not found in cache")
|
||||
|
||||
execution = self._node_executions[exec_id]
|
||||
execution.status = status
|
||||
|
||||
if execution_data:
|
||||
execution.input_data.update(execution_data)
|
||||
|
||||
if stats:
|
||||
execution.stats = execution.stats or NodeExecutionStats()
|
||||
current_stats = execution.stats.model_dump()
|
||||
current_stats.update(stats)
|
||||
execution.stats = NodeExecutionStats.model_validate(current_stats)
|
||||
|
||||
@with_lock
|
||||
def upsert_execution_output(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
) -> NodeExecutionResult:
|
||||
if node_exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {node_exec_id} not found in cache")
|
||||
|
||||
execution = self._node_executions[node_exec_id]
|
||||
if output_name not in execution.output_data:
|
||||
execution.output_data[output_name] = []
|
||||
execution.output_data[output_name].append(output_data)
|
||||
|
||||
return execution
|
||||
|
||||
@with_lock
|
||||
def update_graph_stats(
|
||||
self, status: Optional[ExecutionStatus] = None, stats: Optional[dict] = None
|
||||
):
|
||||
if status is not None:
|
||||
pass
|
||||
if stats is not None:
|
||||
current_stats = self._graph_stats.model_dump()
|
||||
current_stats.update(stats)
|
||||
self._graph_stats = GraphExecutionStats.model_validate(current_stats)
|
||||
|
||||
@with_lock
|
||||
def update_graph_start_time(self):
|
||||
"""Update graph start time (handled by database persistence)."""
|
||||
pass
|
||||
|
||||
@with_lock
|
||||
def find_incomplete_execution_for_input(
|
||||
self, node_id: str, input_name: str
|
||||
) -> tuple[str, NodeExecutionResult] | None:
|
||||
for exec_id, execution in self._node_executions.items():
|
||||
if (
|
||||
execution.node_id == node_id
|
||||
and execution.status == ExecutionStatus.INCOMPLETE
|
||||
and input_name not in execution.input_data
|
||||
):
|
||||
return exec_id, execution
|
||||
return None
|
||||
|
||||
@with_lock
|
||||
def add_node_execution(
|
||||
self, node_exec_id: str, execution: NodeExecutionResult
|
||||
) -> None:
|
||||
self._node_executions[node_exec_id] = execution
|
||||
|
||||
@with_lock
|
||||
def update_execution_input(
|
||||
self, exec_id: str, input_name: str, input_data: Any
|
||||
) -> dict:
|
||||
if exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {exec_id} not found in cache")
|
||||
execution = self._node_executions[exec_id]
|
||||
execution.input_data[input_name] = input_data
|
||||
return execution.input_data.copy()
|
||||
|
||||
def finalize(self) -> None:
|
||||
with self._lock:
|
||||
self._node_executions.clear()
|
||||
self._graph_stats = GraphExecutionStats()
|
||||
@@ -0,0 +1,355 @@
|
||||
"""Test execution creation with proper ID generation and persistence."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def execution_client_with_mock_db(event_loop):
|
||||
"""Create an ExecutionDataClient with proper database records."""
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, User
|
||||
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
|
||||
# Create test database records to satisfy foreign key constraints
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": "test_user_123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
)
|
||||
|
||||
await AgentGraph.prisma().create(
|
||||
data={
|
||||
"id": "test_graph_456",
|
||||
"version": 1,
|
||||
"userId": "test_user_123",
|
||||
"name": "Test Graph",
|
||||
"description": "Test graph for execution tests",
|
||||
}
|
||||
)
|
||||
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
|
||||
await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"id": "test_graph_exec_id",
|
||||
"userId": "test_user_123",
|
||||
"agentGraphId": "test_graph_456",
|
||||
"agentGraphVersion": 1,
|
||||
"executionStatus": AgentExecutionStatus.RUNNING,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Records might already exist, that's fine
|
||||
pass
|
||||
|
||||
# Mock the graph execution metadata - align with assertions below
|
||||
mock_graph_meta = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_123",
|
||||
graph_id="test_graph_456",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# Create client with ThreadPoolExecutor and graph metadata (constructed inside patch)
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# Storage for tracking created executions
|
||||
created_executions = []
|
||||
|
||||
async def mock_create_node_execution(
|
||||
node_exec_id, node_id, graph_exec_id, input_name, input_data
|
||||
):
|
||||
"""Mock execution creation that records what was created."""
|
||||
created_executions.append(
|
||||
{
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": node_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"input_name": input_name,
|
||||
"input_data": input_data,
|
||||
}
|
||||
)
|
||||
return node_exec_id
|
||||
|
||||
def sync_mock_create_node_execution(
|
||||
node_exec_id, node_id, graph_exec_id, input_name, input_data
|
||||
):
|
||||
"""Mock sync execution creation that records what was created."""
|
||||
created_executions.append(
|
||||
{
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": node_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"input_name": input_name,
|
||||
"input_data": input_data,
|
||||
}
|
||||
)
|
||||
return node_exec_id
|
||||
|
||||
# Prepare mock async and sync DB clients
|
||||
async_mock_client = AsyncMock()
|
||||
async_mock_client.create_node_execution = mock_create_node_execution
|
||||
|
||||
sync_mock_client = MagicMock()
|
||||
sync_mock_client.create_node_execution = sync_mock_create_node_execution
|
||||
# Mock graph execution for return values
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
|
||||
mock_graph_update = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_123",
|
||||
graph_id="test_graph_456",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# No-ops for other sync methods used by the client during tests
|
||||
sync_mock_client.add_input_to_node_execution.side_effect = lambda **kwargs: None
|
||||
sync_mock_client.update_node_execution_status.side_effect = (
|
||||
lambda *args, **kwargs: None
|
||||
)
|
||||
sync_mock_client.upsert_execution_output.side_effect = lambda **kwargs: None
|
||||
sync_mock_client.update_graph_execution_stats.side_effect = (
|
||||
lambda *args, **kwargs: mock_graph_update
|
||||
)
|
||||
sync_mock_client.update_graph_execution_start_time.side_effect = (
|
||||
lambda *args, **kwargs: mock_graph_update
|
||||
)
|
||||
|
||||
thread = threading.Thread(target=event_loop.run_forever, daemon=True)
|
||||
thread.start()
|
||||
with patch(
|
||||
"backend.executor.execution_data.get_database_manager_async_client",
|
||||
return_value=async_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_database_manager_client",
|
||||
return_value=sync_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_execution_event_bus"
|
||||
), patch(
|
||||
"backend.executor.execution_data.non_blocking_persist", lambda func: func
|
||||
):
|
||||
# Now construct the client under the patch so it captures the mocked clients
|
||||
client = ExecutionDataClient(executor, "test_graph_exec_id", mock_graph_meta)
|
||||
# Store the mocks for the test to access if needed
|
||||
setattr(client, "_test_async_client", async_mock_client)
|
||||
setattr(client, "_test_sync_client", sync_mock_client)
|
||||
setattr(client, "_created_executions", created_executions)
|
||||
yield client
|
||||
|
||||
# Cleanup test database records
|
||||
try:
|
||||
await AgentGraphExecution.prisma().delete_many(
|
||||
where={"id": "test_graph_exec_id"}
|
||||
)
|
||||
await AgentGraph.prisma().delete_many(where={"id": "test_graph_456"})
|
||||
await User.prisma().delete_many(where={"id": "test_user_123"})
|
||||
except Exception:
|
||||
# Cleanup may fail if records don't exist
|
||||
pass
|
||||
|
||||
# Cleanup
|
||||
event_loop.call_soon_threadsafe(event_loop.stop)
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestExecutionCreation:
|
||||
"""Test execution creation with proper ID generation and persistence."""
|
||||
|
||||
async def test_execution_creation_with_valid_ids(
|
||||
self, execution_client_with_mock_db
|
||||
):
|
||||
"""Test that execution creation generates and persists valid IDs."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
node_id = "test_node_789"
|
||||
input_name = "test_input"
|
||||
input_data = "test_value"
|
||||
block_id = "test_block_abc"
|
||||
|
||||
# This should trigger execution creation since cache is empty
|
||||
exec_id, input_dict = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Verify execution ID is valid UUID
|
||||
try:
|
||||
uuid.UUID(exec_id)
|
||||
except ValueError:
|
||||
pytest.fail(f"Generated execution ID '{exec_id}' is not a valid UUID")
|
||||
|
||||
# Verify execution was created in cache with complete data
|
||||
assert exec_id in client._cache._node_executions
|
||||
cached_execution = client._cache._node_executions[exec_id]
|
||||
|
||||
# Check all required fields have valid values
|
||||
assert cached_execution.user_id == "test_user_123"
|
||||
assert cached_execution.graph_id == "test_graph_456"
|
||||
assert cached_execution.graph_version == 1
|
||||
assert cached_execution.graph_exec_id == "test_graph_exec_id"
|
||||
assert cached_execution.node_exec_id == exec_id
|
||||
assert cached_execution.node_id == node_id
|
||||
assert cached_execution.block_id == block_id
|
||||
assert cached_execution.status == ExecutionStatus.INCOMPLETE
|
||||
assert cached_execution.input_data == {input_name: input_data}
|
||||
assert isinstance(cached_execution.add_time, datetime)
|
||||
|
||||
# Verify execution was persisted to database with our generated ID
|
||||
created_executions = getattr(client, "_created_executions", [])
|
||||
assert len(created_executions) == 1
|
||||
created = created_executions[0]
|
||||
assert created["node_exec_id"] == exec_id # Our generated ID was used
|
||||
assert created["node_id"] == node_id
|
||||
assert created["graph_exec_id"] == "test_graph_exec_id"
|
||||
assert created["input_name"] == input_name
|
||||
assert created["input_data"] == input_data
|
||||
|
||||
# Verify input dict returned correctly
|
||||
assert input_dict == {input_name: input_data}
|
||||
|
||||
async def test_execution_reuse_vs_creation(self, execution_client_with_mock_db):
|
||||
"""Test that execution reuse works and creation only happens when needed."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
node_id = "reuse_test_node"
|
||||
block_id = "reuse_test_block"
|
||||
|
||||
# Create first execution
|
||||
exec_id_1, input_dict_1 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_1",
|
||||
input_data="value_1",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# This should reuse the existing INCOMPLETE execution
|
||||
exec_id_2, input_dict_2 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_2",
|
||||
input_data="value_2",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Should reuse the same execution
|
||||
assert exec_id_1 == exec_id_2
|
||||
assert input_dict_2 == {"input_1": "value_1", "input_2": "value_2"}
|
||||
|
||||
# Only one execution should be created in database
|
||||
created_executions = getattr(client, "_created_executions", [])
|
||||
assert len(created_executions) == 1
|
||||
|
||||
# Verify cache has the merged inputs
|
||||
cached_execution = client._cache._node_executions[exec_id_1]
|
||||
assert cached_execution.input_data == {
|
||||
"input_1": "value_1",
|
||||
"input_2": "value_2",
|
||||
}
|
||||
|
||||
# Now complete the execution and try to add another input
|
||||
client.update_node_status_and_publish(
|
||||
exec_id=exec_id_1, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
# Verify the execution status was actually updated in the cache
|
||||
updated_execution = client._cache._node_executions[exec_id_1]
|
||||
assert (
|
||||
updated_execution.status == ExecutionStatus.COMPLETED
|
||||
), f"Expected COMPLETED but got {updated_execution.status}"
|
||||
|
||||
# This should create a NEW execution since the first is no longer INCOMPLETE
|
||||
exec_id_3, input_dict_3 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_3",
|
||||
input_data="value_3",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Should be a different execution
|
||||
assert exec_id_3 != exec_id_1
|
||||
assert input_dict_3 == {"input_3": "value_3"}
|
||||
|
||||
# Verify cache behavior: should have two different executions in cache now
|
||||
cached_executions = client._cache._node_executions
|
||||
assert len(cached_executions) == 2
|
||||
assert exec_id_1 in cached_executions
|
||||
assert exec_id_3 in cached_executions
|
||||
|
||||
# First execution should be COMPLETED
|
||||
assert cached_executions[exec_id_1].status == ExecutionStatus.COMPLETED
|
||||
# Third execution should be INCOMPLETE (newly created)
|
||||
assert cached_executions[exec_id_3].status == ExecutionStatus.INCOMPLETE
|
||||
|
||||
async def test_multiple_nodes_get_different_execution_ids(
|
||||
self, execution_client_with_mock_db
|
||||
):
|
||||
"""Test that different nodes get different execution IDs."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
# Create executions for different nodes
|
||||
exec_id_a, _ = client.upsert_execution_input(
|
||||
node_id="node_a",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="block_a",
|
||||
)
|
||||
|
||||
exec_id_b, _ = client.upsert_execution_input(
|
||||
node_id="node_b",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="block_b",
|
||||
)
|
||||
|
||||
# Should be different executions with different IDs
|
||||
assert exec_id_a != exec_id_b
|
||||
|
||||
# Both should be valid UUIDs
|
||||
uuid.UUID(exec_id_a)
|
||||
uuid.UUID(exec_id_b)
|
||||
|
||||
# Both should be in cache
|
||||
cached_executions = client._cache._node_executions
|
||||
assert len(cached_executions) == 2
|
||||
assert exec_id_a in cached_executions
|
||||
assert exec_id_b in cached_executions
|
||||
|
||||
# Both should have correct node IDs
|
||||
assert cached_executions[exec_id_a].node_id == "node_a"
|
||||
assert cached_executions[exec_id_b].node_id == "node_b"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
338
autogpt_platform/backend/backend/executor/execution_data.py
Normal file
338
autogpt_platform/backend/backend/executor/execution_data.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionMeta,
|
||||
NodeExecutionResult,
|
||||
)
|
||||
from backend.data.graph import Node
|
||||
from backend.data.model import GraphExecutionStats
|
||||
from backend.executor.execution_cache import ExecutionCache
|
||||
from backend.util.clients import (
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def non_blocking_persist(func: Callable[P, T]) -> Callable[P, None]:
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
# First argument is always self for methods - access through cast for typing
|
||||
self = cast("ExecutionDataClient", args[0])
|
||||
future = self._executor.submit(func, *args, **kwargs)
|
||||
self._pending_tasks.add(future)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ExecutionDataClient:
|
||||
def __init__(
|
||||
self, executor: Executor, graph_exec_id: str, graph_metadata: GraphExecutionMeta
|
||||
):
|
||||
self._executor = executor
|
||||
self._graph_exec_id = graph_exec_id
|
||||
self._cache = ExecutionCache(graph_exec_id, self.db_client_sync)
|
||||
self._pending_tasks = set()
|
||||
self._graph_metadata = graph_metadata
|
||||
self.graph_lock = threading.RLock()
|
||||
|
||||
def finalize_execution(self, timeout: float = 30.0):
|
||||
logger.info(f"Flushing db writes for execution {self._graph_exec_id}")
|
||||
exceptions = []
|
||||
|
||||
# Wait for all pending database operations to complete
|
||||
logger.debug(
|
||||
f"Waiting for {len(self._pending_tasks)} pending database operations"
|
||||
)
|
||||
for future in list(self._pending_tasks):
|
||||
try:
|
||||
future.result(timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"Background database operation failed: {e}")
|
||||
exceptions.append(e)
|
||||
finally:
|
||||
self._pending_tasks.discard(future)
|
||||
|
||||
self._cache.finalize()
|
||||
|
||||
if exceptions:
|
||||
logger.error(f"Background persistence failed with {len(exceptions)} errors")
|
||||
raise RuntimeError(
|
||||
f"Background persistence failed with {len(exceptions)} errors: {exceptions}"
|
||||
)
|
||||
|
||||
@property
|
||||
def db_client_async(self) -> "DatabaseManagerAsyncClient":
|
||||
return get_database_manager_async_client()
|
||||
|
||||
@property
|
||||
def db_client_sync(self) -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
@property
|
||||
def event_bus(self):
|
||||
return get_execution_event_bus()
|
||||
|
||||
async def get_node(self, node_id: str) -> Node:
|
||||
return await self.db_client_async.get_node(node_id)
|
||||
|
||||
def spend_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
cost: int,
|
||||
metadata: UsageTransactionMetadata,
|
||||
) -> int:
|
||||
return self.db_client_sync.spend_credits(
|
||||
user_id=user_id, cost=cost, metadata=metadata
|
||||
)
|
||||
|
||||
def get_graph_execution_meta(
|
||||
self, user_id: str, execution_id: str
|
||||
) -> GraphExecutionMeta | None:
|
||||
return self.db_client_sync.get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=execution_id
|
||||
)
|
||||
|
||||
def get_graph_metadata(
|
||||
self, graph_id: str, graph_version: int | None = None
|
||||
) -> Any:
|
||||
return self.db_client_sync.get_graph_metadata(graph_id, graph_version)
|
||||
|
||||
def get_credits(self, user_id: str) -> int:
|
||||
return self.db_client_sync.get_credits(user_id)
|
||||
|
||||
def get_user_email_by_id(self, user_id: str) -> str | None:
|
||||
return self.db_client_sync.get_user_email_by_id(user_id)
|
||||
|
||||
def get_latest_node_execution(self, node_id: str) -> NodeExecutionResult | None:
|
||||
return self._cache.get_latest_node_execution(node_id)
|
||||
|
||||
def get_node_execution(self, node_exec_id: str) -> NodeExecutionResult | None:
|
||||
return self._cache.get_node_execution(node_exec_id)
|
||||
|
||||
def get_node_executions(
|
||||
self,
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
block_ids: list[str] | None = None,
|
||||
) -> list[NodeExecutionResult]:
|
||||
return self._cache.get_node_executions(
|
||||
statuses=statuses, block_ids=block_ids, node_id=node_id
|
||||
)
|
||||
|
||||
def update_node_status_and_publish(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: dict | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
self._cache.update_node_execution_status(exec_id, status, execution_data, stats)
|
||||
self._persist_node_status_to_db(exec_id, status, execution_data, stats)
|
||||
|
||||
def upsert_execution_input(
|
||||
self, node_id: str, input_name: str, input_data: Any, block_id: str
|
||||
) -> tuple[str, dict]:
|
||||
# Validate input parameters to prevent foreign key constraint errors
|
||||
if not node_id or not isinstance(node_id, str):
|
||||
raise ValueError(f"Invalid node_id: {node_id}")
|
||||
if not self._graph_exec_id or not isinstance(self._graph_exec_id, str):
|
||||
raise ValueError(f"Invalid graph_exec_id: {self._graph_exec_id}")
|
||||
if not block_id or not isinstance(block_id, str):
|
||||
raise ValueError(f"Invalid block_id: {block_id}")
|
||||
|
||||
# UPDATE: Try to find an existing incomplete execution for this node and input
|
||||
if result := self._cache.find_incomplete_execution_for_input(
|
||||
node_id, input_name
|
||||
):
|
||||
exec_id, _ = result
|
||||
updated_input_data = self._cache.update_execution_input(
|
||||
exec_id, input_name, input_data
|
||||
)
|
||||
self._persist_add_input_to_db(exec_id, input_name, input_data)
|
||||
return exec_id, updated_input_data
|
||||
|
||||
# CREATE: No suitable execution found, create new one
|
||||
node_exec_id = str(uuid.uuid4())
|
||||
logger.debug(
|
||||
f"Creating new execution {node_exec_id} for node {node_id} "
|
||||
f"in graph execution {self._graph_exec_id}"
|
||||
)
|
||||
|
||||
new_execution = NodeExecutionResult(
|
||||
user_id=self._graph_metadata.user_id,
|
||||
graph_id=self._graph_metadata.graph_id,
|
||||
graph_version=self._graph_metadata.graph_version,
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
node_exec_id=node_exec_id,
|
||||
node_id=node_id,
|
||||
block_id=block_id,
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
input_data={input_name: input_data},
|
||||
output_data={},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
)
|
||||
self._cache.add_node_execution(node_exec_id, new_execution)
|
||||
self._persist_new_node_execution_to_db(
|
||||
node_exec_id, node_id, input_name, input_data
|
||||
)
|
||||
|
||||
return node_exec_id, {input_name: input_data}
|
||||
|
||||
def upsert_execution_output(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
):
|
||||
self._cache.upsert_execution_output(node_exec_id, output_name, output_data)
|
||||
self._persist_execution_output_to_db(node_exec_id, output_name, output_data)
|
||||
|
||||
def update_graph_stats_and_publish(
|
||||
self,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> None:
|
||||
stats_dict = stats.model_dump() if stats else None
|
||||
self._cache.update_graph_stats(status=status, stats=stats_dict)
|
||||
self._persist_graph_stats_to_db(status=status, stats=stats)
|
||||
|
||||
def update_graph_start_time_and_publish(self) -> None:
|
||||
self._cache.update_graph_start_time()
|
||||
self._persist_graph_start_time_to_db()
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_node_status_to_db(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: dict | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
exec_update = self.db_client_sync.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
self.event_bus.publish(exec_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_add_input_to_db(
|
||||
self, node_exec_id: str, input_name: str, input_data: Any
|
||||
):
|
||||
self.db_client_sync.add_input_to_node_execution(
|
||||
node_exec_id=node_exec_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_execution_output_to_db(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
):
|
||||
self.db_client_sync.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
if exec_update := self.get_node_execution(node_exec_id):
|
||||
self.event_bus.publish(exec_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_graph_stats_to_db(
|
||||
self,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
):
|
||||
graph_update = self.db_client_sync.update_graph_execution_stats(
|
||||
self._graph_exec_id, status, stats
|
||||
)
|
||||
if not graph_update:
|
||||
raise RuntimeError(
|
||||
f"Failed to update graph execution stats for {self._graph_exec_id}"
|
||||
)
|
||||
self.event_bus.publish(graph_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_graph_start_time_to_db(self):
|
||||
graph_update = self.db_client_sync.update_graph_execution_start_time(
|
||||
self._graph_exec_id
|
||||
)
|
||||
if not graph_update:
|
||||
raise RuntimeError(
|
||||
f"Failed to update graph execution start time for {self._graph_exec_id}"
|
||||
)
|
||||
self.event_bus.publish(graph_update)
|
||||
|
||||
async def generate_activity_status(
|
||||
self,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_stats: GraphExecutionStats,
|
||||
user_id: str,
|
||||
execution_status: ExecutionStatus,
|
||||
) -> str | None:
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
|
||||
return await generate_activity_status_for_execution(
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_stats=execution_stats,
|
||||
db_client=self.db_client_async,
|
||||
user_id=user_id,
|
||||
execution_status=execution_status,
|
||||
)
|
||||
|
||||
@non_blocking_persist
|
||||
def _send_execution_update(self, execution: NodeExecutionResult):
|
||||
"""Send execution update to event bus."""
|
||||
try:
|
||||
self.event_bus.publish(execution)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send execution update: {e}")
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_new_node_execution_to_db(
|
||||
self, node_exec_id: str, node_id: str, input_name: str, input_data: Any
|
||||
):
|
||||
try:
|
||||
self.db_client_sync.create_node_execution(
|
||||
node_exec_id=node_exec_id,
|
||||
node_id=node_id,
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create node execution {node_exec_id} for node {node_id} "
|
||||
f"in graph execution {self._graph_exec_id}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def increment_execution_count(self, user_id: str) -> int:
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}"
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
668
autogpt_platform/backend/backend/executor/execution_data_test.py
Normal file
668
autogpt_platform/backend/backend/executor/execution_data_test.py
Normal file
@@ -0,0 +1,668 @@
|
||||
"""Test suite for ExecutionDataClient."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def execution_client(event_loop):
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
|
||||
mock_graph_meta = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_id",
|
||||
graph_id="test_graph_id",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# Mock all database operations to prevent connection attempts
|
||||
async_mock_client = AsyncMock()
|
||||
sync_mock_client = MagicMock()
|
||||
|
||||
# Mock all database methods to return None or empty results
|
||||
sync_mock_client.get_node_executions.return_value = []
|
||||
sync_mock_client.create_node_execution.return_value = None
|
||||
sync_mock_client.add_input_to_node_execution.return_value = None
|
||||
sync_mock_client.update_node_execution_status.return_value = None
|
||||
sync_mock_client.upsert_execution_output.return_value = None
|
||||
sync_mock_client.update_graph_execution_stats.return_value = mock_graph_meta
|
||||
sync_mock_client.update_graph_execution_start_time.return_value = mock_graph_meta
|
||||
|
||||
# Mock event bus to prevent connection attempts
|
||||
mock_event_bus = MagicMock()
|
||||
mock_event_bus.publish.return_value = None
|
||||
|
||||
thread = threading.Thread(target=event_loop.run_forever, daemon=True)
|
||||
thread.start()
|
||||
|
||||
with patch(
|
||||
"backend.executor.execution_data.get_database_manager_async_client",
|
||||
return_value=async_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_database_manager_client",
|
||||
return_value=sync_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_execution_event_bus",
|
||||
return_value=mock_event_bus,
|
||||
), patch(
|
||||
"backend.executor.execution_data.non_blocking_persist", lambda func: func
|
||||
):
|
||||
|
||||
client = ExecutionDataClient(executor, "test_graph_exec_id", mock_graph_meta)
|
||||
yield client
|
||||
|
||||
event_loop.call_soon_threadsafe(event_loop.stop)
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestExecutionDataClient:
|
||||
|
||||
async def test_update_node_status_writes_to_cache_immediately(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that node status updates are immediately visible in cache."""
|
||||
# First create an execution to update
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
status = ExecutionStatus.RUNNING
|
||||
execution_data = {"step": "processing"}
|
||||
stats = {"duration": 5.2}
|
||||
|
||||
# Update status of existing execution
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=status,
|
||||
execution_data=execution_data,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
# Verify immediate visibility in cache
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert cached_exec.status == status
|
||||
# execution_data should be merged with existing input_data, not replace it
|
||||
expected_input_data = {"test_input": "test_value", "step": "processing"}
|
||||
assert cached_exec.input_data == expected_input_data
|
||||
|
||||
def test_update_node_status_execution_not_found_raises_error(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that updating non-existent execution raises error instead of creating it."""
|
||||
non_existent_id = "does-not-exist"
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Execution does-not-exist not found in cache"
|
||||
):
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=non_existent_id, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
async def test_upsert_execution_output_writes_to_cache_immediately(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that output updates are immediately visible in cache."""
|
||||
# First create an execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
output_name = "result"
|
||||
output_data = {"answer": 42, "confidence": 0.95}
|
||||
|
||||
# Update to RUNNING status first
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id, output_name=output_name, output_data=output_data
|
||||
)
|
||||
# Check output through the node execution
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert output_name in cached_exec.output_data
|
||||
assert cached_exec.output_data[output_name] == [output_data]
|
||||
|
||||
async def test_get_node_execution_reads_from_cache(self, execution_client):
|
||||
"""Test that get_node_execution returns cached data immediately."""
|
||||
# First create an execution to work with
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Then update its status
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
execution_data={"result": "success"},
|
||||
)
|
||||
|
||||
result = execution_client.get_node_execution(node_exec_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == ExecutionStatus.COMPLETED
|
||||
# execution_data gets merged with existing input_data
|
||||
expected_input_data = {"test_input": "test_value", "result": "success"}
|
||||
assert result.input_data == expected_input_data
|
||||
|
||||
async def test_get_latest_node_execution_reads_from_cache(self, execution_client):
|
||||
"""Test that get_latest_node_execution returns cached data."""
|
||||
node_id = "node-1"
|
||||
|
||||
# First create an execution for this node
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Then update its status to make it non-INCOMPLETE (so it's returned by get_latest)
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"from": "cache"},
|
||||
)
|
||||
|
||||
result = execution_client.get_latest_node_execution(node_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == ExecutionStatus.RUNNING
|
||||
# execution_data gets merged with existing input_data
|
||||
expected_input_data = {"test_input": "test_value", "from": "cache"}
|
||||
assert result.input_data == expected_input_data
|
||||
|
||||
async def test_get_node_executions_sync_filters_correctly(self, execution_client):
|
||||
# Create executions with different statuses
|
||||
executions = [
|
||||
(ExecutionStatus.RUNNING, "block-a"),
|
||||
(ExecutionStatus.COMPLETED, "block-a"),
|
||||
(ExecutionStatus.FAILED, "block-b"),
|
||||
(ExecutionStatus.RUNNING, "block-b"),
|
||||
]
|
||||
|
||||
exec_ids = []
|
||||
for i, (status, block_id) in enumerate(executions):
|
||||
# First create the execution
|
||||
exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=f"node-{i}",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id=block_id,
|
||||
)
|
||||
exec_ids.append(exec_id)
|
||||
|
||||
# Then update its status and metadata
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id, status=status, execution_data={"block": block_id}
|
||||
)
|
||||
# Update cached execution with graph_exec_id and block_id for filtering
|
||||
# Note: In real implementation, these would be set during creation
|
||||
# For test purposes, we'll skip this manual update since the filtering
|
||||
# logic should work with the data as created
|
||||
|
||||
# Test status filtering
|
||||
running_execs = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING]
|
||||
)
|
||||
assert len(running_execs) == 2
|
||||
assert all(e.status == ExecutionStatus.RUNNING for e in running_execs)
|
||||
|
||||
# Test block_id filtering
|
||||
block_a_execs = execution_client.get_node_executions(block_ids=["block-a"])
|
||||
assert len(block_a_execs) == 2
|
||||
assert all(e.block_id == "block-a" for e in block_a_execs)
|
||||
|
||||
# Test combined filtering
|
||||
running_block_b = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING], block_ids=["block-b"]
|
||||
)
|
||||
assert len(running_block_b) == 1
|
||||
assert running_block_b[0].status == ExecutionStatus.RUNNING
|
||||
assert running_block_b[0].block_id == "block-b"
|
||||
|
||||
async def test_write_then_read_consistency(self, execution_client):
|
||||
"""Test critical race condition scenario: immediate read after write."""
|
||||
# First create an execution to work with
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="consistency-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Write status
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"step": 1},
|
||||
)
|
||||
|
||||
# Write output
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="intermediate",
|
||||
output_data={"progress": 50},
|
||||
)
|
||||
|
||||
# Update status again
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
execution_data={"step": 2},
|
||||
)
|
||||
|
||||
# All changes should be immediately visible
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert cached_exec.status == ExecutionStatus.COMPLETED
|
||||
# execution_data gets merged with existing input_data - step 2 overwrites step 1
|
||||
expected_input_data = {"test_input": "test_value", "step": 2}
|
||||
assert cached_exec.input_data == expected_input_data
|
||||
|
||||
# Output should be visible in execution record
|
||||
assert cached_exec.output_data["intermediate"] == [{"progress": 50}]
|
||||
|
||||
async def test_concurrent_operations_are_thread_safe(self, execution_client):
|
||||
"""Test that concurrent operations don't corrupt cache."""
|
||||
num_threads = 3 # Reduced for simpler test
|
||||
operations_per_thread = 5 # Reduced for simpler test
|
||||
|
||||
# Create all executions upfront
|
||||
created_exec_ids = []
|
||||
for thread_id in range(num_threads):
|
||||
for i in range(operations_per_thread):
|
||||
exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=f"node-{thread_id}-{i}",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id=f"block-{thread_id}-{i}",
|
||||
)
|
||||
created_exec_ids.append((exec_id, thread_id, i))
|
||||
|
||||
def worker(thread_data):
|
||||
"""Perform multiple operations from a thread."""
|
||||
thread_id, ops = thread_data
|
||||
for i, (exec_id, _, _) in enumerate(ops):
|
||||
# Status updates
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"thread": thread_id, "op": i},
|
||||
)
|
||||
|
||||
# Output updates (use just one exec_id per thread for outputs)
|
||||
if i == 0: # Only add outputs to first execution of each thread
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=exec_id,
|
||||
output_name=f"output_{i}",
|
||||
output_data={"thread": thread_id, "value": i},
|
||||
)
|
||||
|
||||
# Organize executions by thread
|
||||
thread_data = []
|
||||
for tid in range(num_threads):
|
||||
thread_ops = [
|
||||
exec_data for exec_data in created_exec_ids if exec_data[1] == tid
|
||||
]
|
||||
thread_data.append((tid, thread_ops))
|
||||
|
||||
# Start multiple threads
|
||||
threads = []
|
||||
for data in thread_data:
|
||||
thread = threading.Thread(target=worker, args=(data,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for completion
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify data integrity
|
||||
expected_executions = num_threads * operations_per_thread
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == expected_executions
|
||||
|
||||
# Verify outputs - only first execution of each thread should have outputs
|
||||
output_count = 0
|
||||
for execution in all_executions:
|
||||
if execution.output_data:
|
||||
output_count += 1
|
||||
assert output_count == num_threads # One output per thread
|
||||
|
||||
async def test_sync_and_async_versions_consistent(self, execution_client):
|
||||
"""Test that sync and async versions of output operations behave the same."""
|
||||
# First create the execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="sync-async-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="sync_result",
|
||||
output_data={"method": "sync"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="async_result",
|
||||
output_data={"method": "async"},
|
||||
)
|
||||
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert "sync_result" in cached_exec.output_data
|
||||
assert "async_result" in cached_exec.output_data
|
||||
assert cached_exec.output_data["sync_result"] == [{"method": "sync"}]
|
||||
assert cached_exec.output_data["async_result"] == [{"method": "async"}]
|
||||
|
||||
async def test_finalize_execution_completes_and_clears_cache(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that finalize_execution waits for background tasks and clears cache."""
|
||||
# First create the execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="pending-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Trigger some background operations
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id, status=ExecutionStatus.RUNNING
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id, output_name="test", output_data={"value": 1}
|
||||
)
|
||||
|
||||
# Wait for background tasks - may fail in test environment due to DB issues
|
||||
try:
|
||||
execution_client.finalize_execution(timeout=5.0)
|
||||
except RuntimeError as e:
|
||||
# In test environment, background DB operations may fail, but cache should still be cleared
|
||||
assert "Background persistence failed" in str(e)
|
||||
|
||||
# Cache should be cleared regardless of background task failures
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == 0 # Cache should be cleared
|
||||
|
||||
async def test_manager_usage_pattern(self, execution_client):
|
||||
# Create executions first
|
||||
node_exec_id_1, _ = execution_client.upsert_execution_input(
|
||||
node_id="node-1",
|
||||
input_name="input1",
|
||||
input_data="data1",
|
||||
block_id="block-1",
|
||||
)
|
||||
|
||||
node_exec_id_2, _ = execution_client.upsert_execution_input(
|
||||
node_id="node-2",
|
||||
input_name="input_from_node1",
|
||||
input_data="value1",
|
||||
block_id="block-2",
|
||||
)
|
||||
|
||||
# Simulate manager.py workflow
|
||||
|
||||
# 1. Start execution
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "data1"},
|
||||
)
|
||||
|
||||
# 2. Node produces output
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id_1,
|
||||
output_name="result",
|
||||
output_data={"output": "value1"},
|
||||
)
|
||||
|
||||
# 3. Complete first node
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_1, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
# 4. Start second node (would read output from first)
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_2,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input_from_node1": "value1"},
|
||||
)
|
||||
|
||||
# 5. Manager queries for executions
|
||||
|
||||
all_executions = execution_client.get_node_executions()
|
||||
running_executions = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING]
|
||||
)
|
||||
completed_executions = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.COMPLETED]
|
||||
)
|
||||
|
||||
# Verify manager can see all data immediately
|
||||
assert len(all_executions) == 2
|
||||
assert len(running_executions) == 1
|
||||
assert len(completed_executions) == 1
|
||||
|
||||
# Verify output is accessible
|
||||
exec_1 = execution_client.get_node_execution(node_exec_id_1)
|
||||
assert exec_1 is not None
|
||||
assert exec_1.output_data["result"] == [{"output": "value1"}]
|
||||
|
||||
def test_stats_handling_in_update_node_status(self, execution_client):
|
||||
"""Test that stats parameter is properly handled in update_node_status_and_publish."""
|
||||
# Create a fake execution directly in cache to avoid database issues
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
|
||||
node_exec_id = "test-stats-exec-id"
|
||||
fake_execution = NodeExecutionResult(
|
||||
user_id="test-user",
|
||||
graph_id="test-graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test-graph-exec",
|
||||
node_exec_id=node_exec_id,
|
||||
node_id="stats-test-node",
|
||||
block_id="test-block",
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
input_data={"test_input": "test_value"},
|
||||
output_data={},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# Add directly to cache
|
||||
execution_client._cache.add_node_execution(node_exec_id, fake_execution)
|
||||
|
||||
stats = {"token_count": 150, "processing_time": 2.5}
|
||||
|
||||
# Update status with stats
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
# Verify execution was updated and stats are stored
|
||||
execution = execution_client.get_node_execution(node_exec_id)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.RUNNING
|
||||
|
||||
# Stats should be stored in proper stats field
|
||||
assert execution.stats is not None
|
||||
stats_dict = execution.stats.model_dump()
|
||||
# Only check the fields we set, ignore defaults
|
||||
assert stats_dict["token_count"] == 150
|
||||
assert stats_dict["processing_time"] == 2.5
|
||||
|
||||
# Update with additional stats
|
||||
additional_stats = {"error_count": 0}
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
stats=additional_stats,
|
||||
)
|
||||
|
||||
# Stats should be merged
|
||||
execution = execution_client.get_node_execution(node_exec_id)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.COMPLETED
|
||||
stats_dict = execution.stats.model_dump()
|
||||
# Check the merged stats
|
||||
assert stats_dict["token_count"] == 150
|
||||
assert stats_dict["processing_time"] == 2.5
|
||||
assert stats_dict["error_count"] == 0
|
||||
|
||||
async def test_upsert_execution_input_scenarios(self, execution_client):
|
||||
"""Test different scenarios of upsert_execution_input - create vs update."""
|
||||
node_id = "test-node"
|
||||
graph_exec_id = (
|
||||
"test_graph_exec_id" # Must match the ExecutionDataClient's scope
|
||||
)
|
||||
|
||||
# Scenario 1: Create new execution when none exists
|
||||
exec_id_1, input_data_1 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="first_input",
|
||||
input_data="value1",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should create new execution
|
||||
execution = execution_client.get_node_execution(exec_id_1)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.INCOMPLETE
|
||||
assert execution.node_id == node_id
|
||||
assert execution.graph_exec_id == graph_exec_id
|
||||
assert input_data_1 == {"first_input": "value1"}
|
||||
|
||||
# Scenario 2: Add input to existing INCOMPLETE execution
|
||||
exec_id_2, input_data_2 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="second_input",
|
||||
input_data="value2",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should use same execution
|
||||
assert exec_id_2 == exec_id_1
|
||||
assert input_data_2 == {"first_input": "value1", "second_input": "value2"}
|
||||
|
||||
# Verify execution has both inputs
|
||||
execution = execution_client.get_node_execution(exec_id_1)
|
||||
assert execution is not None
|
||||
assert execution.input_data == {
|
||||
"first_input": "value1",
|
||||
"second_input": "value2",
|
||||
}
|
||||
|
||||
# Scenario 3: Create new execution when existing is not INCOMPLETE
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id_1, status=ExecutionStatus.RUNNING
|
||||
)
|
||||
|
||||
exec_id_3, input_data_3 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="third_input",
|
||||
input_data="value3",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should create new execution
|
||||
assert exec_id_3 != exec_id_1
|
||||
execution_3 = execution_client.get_node_execution(exec_id_3)
|
||||
assert execution_3 is not None
|
||||
assert input_data_3 == {"third_input": "value3"}
|
||||
|
||||
# Verify we now have 2 executions
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == 2
|
||||
|
||||
def test_graph_stats_operations(self, execution_client):
|
||||
"""Test graph-level stats and start time operations."""
|
||||
|
||||
# Test update_graph_stats_and_publish
|
||||
from backend.data.model import GraphExecutionStats
|
||||
|
||||
stats = GraphExecutionStats(
|
||||
walltime=10.5, cputime=8.2, node_count=5, node_error_count=1
|
||||
)
|
||||
|
||||
execution_client.update_graph_stats_and_publish(
|
||||
status=ExecutionStatus.RUNNING, stats=stats
|
||||
)
|
||||
|
||||
# Verify stats are stored in cache
|
||||
cached_stats = execution_client._cache._graph_stats
|
||||
assert cached_stats.walltime == 10.5
|
||||
|
||||
execution_client.update_graph_start_time_and_publish()
|
||||
cached_stats = execution_client._cache._graph_stats
|
||||
assert cached_stats.walltime == 10.5
|
||||
|
||||
def test_public_methods_accessible(self, execution_client):
|
||||
"""Test that public methods are accessible."""
|
||||
assert hasattr(execution_client._cache, "update_node_execution_status")
|
||||
assert hasattr(execution_client._cache, "upsert_execution_output")
|
||||
assert hasattr(execution_client._cache, "add_node_execution")
|
||||
assert hasattr(execution_client._cache, "find_incomplete_execution_for_input")
|
||||
assert hasattr(execution_client._cache, "update_execution_input")
|
||||
assert hasattr(execution_client, "upsert_execution_input")
|
||||
assert hasattr(execution_client, "update_node_status_and_publish")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -5,14 +5,31 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
from typing import Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.block import (
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
@@ -22,45 +39,14 @@ from backend.data.notifications import (
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import LogMetadata
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
|
||||
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockOutputEntry,
|
||||
BlockSchema,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
@@ -69,21 +55,17 @@ from backend.executor.utils import (
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
from backend.util.decorator import (
|
||||
async_error_logged,
|
||||
async_time_measured,
|
||||
error_logged,
|
||||
time_measured,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
@@ -131,14 +113,13 @@ async def execute_node(
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
persist the execution result, and return the subsequent node to be executed.
|
||||
|
||||
Args:
|
||||
db_client: The client to send execution updates to the server.
|
||||
creds_manager: The manager to acquire and release credentials.
|
||||
data: The execution data for executing the current node.
|
||||
execution_stats: The execution statistics to be updated.
|
||||
@@ -235,21 +216,20 @@ async def execute_node(
|
||||
|
||||
|
||||
async def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
execution_data_client: ExecutionDataClient,
|
||||
node: Node,
|
||||
output: BlockOutputEntry,
|
||||
output: BlockData,
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
user_context: UserContext,
|
||||
) -> list[NodeExecutionEntry]:
|
||||
async def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
) -> NodeExecutionEntry:
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
execution_data_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.QUEUED,
|
||||
execution_data=data,
|
||||
@@ -282,21 +262,22 @@ async def _enqueue_next_nodes(
|
||||
next_data = parse_execution_output(output, next_output_name)
|
||||
if next_data is None and output_name != next_output_name:
|
||||
return enqueued_executions
|
||||
next_node = await db_client.get_node(next_node_id)
|
||||
next_node = await execution_data_client.get_node(next_node_id)
|
||||
|
||||
# Multiple node can register the same next node, we need this to be atomic
|
||||
# To avoid same execution to be enqueued multiple times,
|
||||
# Or the same input to be consumed multiple times.
|
||||
async with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
|
||||
with execution_data_client.graph_lock:
|
||||
# Add output data to the earliest incomplete execution, or create a new one.
|
||||
next_node_exec_id, next_node_input = await db_client.upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
next_node_exec_id, next_node_input = (
|
||||
execution_data_client.upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
block_id=next_node.block_id,
|
||||
)
|
||||
)
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
execution_data_client.update_node_status_and_publish(
|
||||
exec_id=next_node_exec_id,
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
)
|
||||
@@ -308,8 +289,8 @@ async def _enqueue_next_nodes(
|
||||
if link.is_static and link.sink_name not in next_node_input
|
||||
}
|
||||
if static_link_names and (
|
||||
latest_execution := await db_client.get_latest_node_execution(
|
||||
next_node_id, graph_exec_id
|
||||
latest_execution := execution_data_client.get_latest_node_execution(
|
||||
next_node_id
|
||||
)
|
||||
):
|
||||
for name in static_link_names:
|
||||
@@ -348,9 +329,8 @@ async def _enqueue_next_nodes(
|
||||
|
||||
# If link is static, there could be some incomplete executions waiting for it.
|
||||
# Load and complete the input missing input data, and try to re-enqueue them.
|
||||
for iexec in await db_client.get_node_executions(
|
||||
for iexec in execution_data_client.get_node_executions(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.INCOMPLETE],
|
||||
):
|
||||
idata = iexec.input_data
|
||||
@@ -414,12 +394,15 @@ class ExecutionProcessor:
|
||||
9. Node executor enqueues the next executed nodes to the node execution queue.
|
||||
"""
|
||||
|
||||
# Current execution data client (scoped to current graph execution)
|
||||
execution_data: ExecutionDataClient
|
||||
|
||||
@async_error_logged(swallow=True)
|
||||
async def on_node_execution(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
@@ -431,8 +414,7 @@ class ExecutionProcessor:
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
)
|
||||
db_client = get_db_async_client()
|
||||
node = await db_client.get_node(node_exec.node_id)
|
||||
node = await self.execution_data.get_node(node_exec.node_id)
|
||||
execution_stats = NodeExecutionStats()
|
||||
|
||||
timing_info, status = await self._on_node_execution(
|
||||
@@ -440,7 +422,6 @@ class ExecutionProcessor:
|
||||
node_exec=node_exec,
|
||||
node_exec_progress=node_exec_progress,
|
||||
stats=execution_stats,
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
@@ -464,15 +445,12 @@ class ExecutionProcessor:
|
||||
if node_error and not isinstance(node_error, str):
|
||||
node_stats["error"] = str(node_error) or node_stats.__class__.__name__
|
||||
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec.node_exec_id,
|
||||
status=status,
|
||||
stats=node_stats,
|
||||
)
|
||||
await async_update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
stats=graph_stats,
|
||||
)
|
||||
|
||||
@@ -485,22 +463,17 @@ class ExecutionProcessor:
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
stats: NodeExecutionStats,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
async def persist_output(output_name: str, output_data: Any) -> None:
|
||||
await db_client.upsert_execution_output(
|
||||
self.execution_data.upsert_execution_output(
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
if exec_update := await db_client.get_node_execution(
|
||||
node_exec.node_exec_id
|
||||
):
|
||||
await send_async_execution_update(exec_update)
|
||||
|
||||
node_exec_progress.add_output(
|
||||
ExecutionOutputEntry(
|
||||
@@ -512,8 +485,7 @@ class ExecutionProcessor:
|
||||
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec.node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
)
|
||||
@@ -574,6 +546,8 @@ class ExecutionProcessor:
|
||||
self.node_evaluation_thread = threading.Thread(
|
||||
target=self.node_evaluation_loop.run_forever, daemon=True
|
||||
)
|
||||
# single thread executor
|
||||
self.execution_data_executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.node_execution_thread.start()
|
||||
self.node_evaluation_thread.start()
|
||||
logger.info(f"[GraphExecutor] {self.tid} started")
|
||||
@@ -593,9 +567,13 @@ class ExecutionProcessor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
db_client = get_db_client()
|
||||
|
||||
exec_meta = db_client.get_graph_execution_meta(
|
||||
# Get graph execution metadata first via sync client
|
||||
from backend.util.clients import get_database_manager_client
|
||||
|
||||
db_client_sync = get_database_manager_client()
|
||||
|
||||
exec_meta = db_client_sync.get_graph_execution_meta(
|
||||
user_id=graph_exec.user_id,
|
||||
execution_id=graph_exec.graph_exec_id,
|
||||
)
|
||||
@@ -605,12 +583,15 @@ class ExecutionProcessor:
|
||||
)
|
||||
return
|
||||
|
||||
if exec_meta.status in [ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE]:
|
||||
# Create scoped ExecutionDataClient for this graph execution with metadata
|
||||
self.execution_data = ExecutionDataClient(
|
||||
self.execution_data_executor, graph_exec.graph_exec_id, exec_meta
|
||||
)
|
||||
|
||||
if exec_meta.status == ExecutionStatus.QUEUED:
|
||||
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
|
||||
exec_meta.status = ExecutionStatus.RUNNING
|
||||
send_execution_update(
|
||||
db_client.update_graph_execution_start_time(graph_exec.graph_exec_id)
|
||||
)
|
||||
self.execution_data.update_graph_start_time_and_publish()
|
||||
elif exec_meta.status == ExecutionStatus.RUNNING:
|
||||
log_metadata.info(
|
||||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
|
||||
@@ -620,9 +601,7 @@ class ExecutionProcessor:
|
||||
log_metadata.info(
|
||||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} was disturbed, continuing where it left off."
|
||||
)
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
status=ExecutionStatus.RUNNING,
|
||||
)
|
||||
else:
|
||||
@@ -653,12 +632,10 @@ class ExecutionProcessor:
|
||||
|
||||
# Activity status handling
|
||||
activity_status = asyncio.run_coroutine_threadsafe(
|
||||
generate_activity_status_for_execution(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.generate_activity_status(
|
||||
graph_id=graph_exec.graph_id,
|
||||
graph_version=graph_exec.graph_version,
|
||||
execution_stats=exec_stats,
|
||||
db_client=get_db_async_client(),
|
||||
user_id=graph_exec.user_id,
|
||||
execution_status=status,
|
||||
),
|
||||
@@ -673,15 +650,14 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
# Communication handling
|
||||
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
self._handle_agent_run_notif(graph_exec, exec_stats)
|
||||
|
||||
finally:
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
status=exec_meta.status,
|
||||
stats=exec_stats,
|
||||
)
|
||||
self.execution_data.finalize_execution()
|
||||
|
||||
def _charge_usage(
|
||||
self,
|
||||
@@ -690,7 +666,6 @@ class ExecutionProcessor:
|
||||
) -> tuple[int, int]:
|
||||
total_cost = 0
|
||||
remaining_balance = 0
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
@@ -700,7 +675,7 @@ class ExecutionProcessor:
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
remaining_balance = self.execution_data.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -718,7 +693,7 @@ class ExecutionProcessor:
|
||||
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
remaining_balance = self.execution_data.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -751,7 +726,6 @@ class ExecutionProcessor:
|
||||
"""
|
||||
execution_status: ExecutionStatus = ExecutionStatus.RUNNING
|
||||
error: Exception | None = None
|
||||
db_client = get_db_client()
|
||||
execution_stats_lock = threading.Lock()
|
||||
|
||||
# State holders ----------------------------------------------------
|
||||
@@ -762,7 +736,7 @@ class ExecutionProcessor:
|
||||
execution_queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
|
||||
try:
|
||||
if db_client.get_credits(graph_exec.user_id) <= 0:
|
||||
if self.execution_data.get_credits(graph_exec.user_id) <= 0:
|
||||
raise InsufficientBalanceError(
|
||||
user_id=graph_exec.user_id,
|
||||
message="You have no credits left to run an agent.",
|
||||
@@ -774,7 +748,7 @@ class ExecutionProcessor:
|
||||
try:
|
||||
if moderation_error := asyncio.run_coroutine_threadsafe(
|
||||
automod_manager.moderate_graph_execution_inputs(
|
||||
db_client=get_db_async_client(),
|
||||
db_client=self.execution_data.db_client_async,
|
||||
graph_exec=graph_exec,
|
||||
),
|
||||
self.node_evaluation_loop,
|
||||
@@ -789,16 +763,34 @@ class ExecutionProcessor:
|
||||
# ------------------------------------------------------------
|
||||
# Pre‑populate queue ---------------------------------------
|
||||
# ------------------------------------------------------------
|
||||
for node_exec in db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
|
||||
queued_executions = self.execution_data.get_node_executions(
|
||||
statuses=[
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
],
|
||||
):
|
||||
node_entry = node_exec.to_node_execution_entry(graph_exec.user_context)
|
||||
execution_queue.add(node_entry)
|
||||
)
|
||||
log_metadata.info(
|
||||
f"Pre-populating queue with {len(queued_executions)} executions from cache"
|
||||
)
|
||||
|
||||
for i, node_exec in enumerate(queued_executions):
|
||||
log_metadata.info(
|
||||
f" [{i}] {node_exec.node_exec_id}: status={node_exec.status}, node={node_exec.node_id}"
|
||||
)
|
||||
try:
|
||||
node_entry = node_exec.to_node_execution_entry(
|
||||
graph_exec.user_context
|
||||
)
|
||||
execution_queue.add(node_entry)
|
||||
log_metadata.info(" Added to execution queue successfully")
|
||||
except Exception as e:
|
||||
log_metadata.error(f" Failed to add to execution queue: {e}")
|
||||
|
||||
log_metadata.info(
|
||||
f"Execution queue populated with {len(queued_executions)} executions"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Main dispatch / polling loop -----------------------------
|
||||
@@ -818,13 +810,14 @@ class ExecutionProcessor:
|
||||
try:
|
||||
cost, remaining_balance = self._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
execution_count=self.execution_data.increment_execution_count(
|
||||
graph_exec.user_id
|
||||
),
|
||||
)
|
||||
with execution_stats_lock:
|
||||
execution_stats.cost += cost
|
||||
# Check if we crossed the low balance threshold
|
||||
self._handle_low_balance(
|
||||
db_client=db_client,
|
||||
user_id=graph_exec.user_id,
|
||||
current_balance=remaining_balance,
|
||||
transaction_cost=cost,
|
||||
@@ -832,19 +825,17 @@ class ExecutionProcessor:
|
||||
except InsufficientBalanceError as balance_error:
|
||||
error = balance_error # Set error to trigger FAILED status
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
db_client.upsert_execution_output(
|
||||
self.execution_data.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="error",
|
||||
output_data=str(error),
|
||||
)
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
self._handle_insufficient_funds_notif(
|
||||
db_client,
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
error,
|
||||
@@ -931,12 +922,13 @@ class ExecutionProcessor:
|
||||
time.sleep(0.1)
|
||||
|
||||
# loop done --------------------------------------------------
|
||||
# Background task finalization moved to finally block
|
||||
|
||||
# Output moderation
|
||||
try:
|
||||
if moderation_error := asyncio.run_coroutine_threadsafe(
|
||||
automod_manager.moderate_graph_execution_outputs(
|
||||
db_client=get_db_async_client(),
|
||||
db_client=self.execution_data.db_client_async,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
user_id=graph_exec.user_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
@@ -990,7 +982,6 @@ class ExecutionProcessor:
|
||||
error=error,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
log_metadata=log_metadata,
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
@error_logged(swallow=True)
|
||||
@@ -1003,7 +994,6 @@ class ExecutionProcessor:
|
||||
error: Exception | None,
|
||||
graph_exec_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
db_client: "DatabaseManagerClient",
|
||||
) -> None:
|
||||
"""
|
||||
Clean up running node executions and evaluations when graph execution ends.
|
||||
@@ -1037,8 +1027,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
while queued_execution := execution_queue.get_or_none():
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=queued_execution.node_exec_id,
|
||||
status=execution_status,
|
||||
stats={"error": str(error)} if error else None,
|
||||
@@ -1053,7 +1042,7 @@ class ExecutionProcessor:
|
||||
node_id: str,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
execution_queue: ExecutionQueue[NodeExecutionEntry],
|
||||
) -> None:
|
||||
"""Process a node's output, update its status, and enqueue next nodes.
|
||||
@@ -1066,12 +1055,10 @@ class ExecutionProcessor:
|
||||
nodes_input_masks: Optional map of node input overrides
|
||||
execution_queue: Queue to add next executions to
|
||||
"""
|
||||
db_client = get_db_async_client()
|
||||
|
||||
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
|
||||
|
||||
for next_execution in await _enqueue_next_nodes(
|
||||
db_client=db_client,
|
||||
execution_data_client=self.execution_data,
|
||||
node=output.node,
|
||||
output=output.data,
|
||||
user_id=graph_exec.user_id,
|
||||
@@ -1085,15 +1072,13 @@ class ExecutionProcessor:
|
||||
|
||||
def _handle_agent_run_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec: GraphExecutionEntry,
|
||||
exec_stats: GraphExecutionStats,
|
||||
):
|
||||
metadata = db_client.get_graph_metadata(
|
||||
metadata = self.execution_data.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
outputs = self.execution_data.get_node_executions(
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
|
||||
@@ -1122,13 +1107,12 @@ class ExecutionProcessor:
|
||||
|
||||
def _handle_insufficient_funds_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
metadata = self.execution_data.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
@@ -1147,7 +1131,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
user_email = self.execution_data.get_user_email_by_id(user_id)
|
||||
|
||||
alert_message = (
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
@@ -1169,7 +1153,6 @@ class ExecutionProcessor:
|
||||
|
||||
def _handle_low_balance(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
current_balance: int,
|
||||
transaction_cost: int,
|
||||
@@ -1198,7 +1181,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
user_email = self.execution_data.get_user_email_by_id(user_id)
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
@@ -1576,117 +1559,3 @@ class ExecutionManager(AppProcess):
|
||||
)
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
|
||||
def get_db_async_client() -> "DatabaseManagerAsyncClient":
|
||||
return get_database_manager_async_client()
|
||||
|
||||
|
||||
@func_retry
|
||||
async def send_async_execution_update(
|
||||
entry: GraphExecution | NodeExecutionResult | None,
|
||||
) -> None:
|
||||
if entry is None:
|
||||
return
|
||||
await get_async_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
@func_retry
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
|
||||
if entry is None:
|
||||
return
|
||||
return get_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
async def async_update_node_execution_status(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = await db_client.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
await send_async_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
|
||||
def update_node_execution_status(
|
||||
db_client: "DatabaseManagerClient",
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = db_client.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
send_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
|
||||
async def async_update_graph_execution_state(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||||
graph_update = await db_client.update_graph_execution_stats(
|
||||
graph_exec_id, status, stats
|
||||
)
|
||||
if graph_update:
|
||||
await send_async_execution_update(graph_update)
|
||||
else:
|
||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||
return graph_update
|
||||
|
||||
|
||||
def update_graph_execution_state(
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||||
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
|
||||
if graph_update:
|
||||
send_execution_update(graph_update)
|
||||
else:
|
||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||
return graph_update
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def synchronized(key: str, timeout: int = 60):
|
||||
r = await redis.get_redis_async()
|
||||
lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
await lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if await lock.locked() and await lock.owned():
|
||||
await lock.release()
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
"""
|
||||
Increment the execution count for a given user,
|
||||
this will be used to charge the user for the execution cost.
|
||||
"""
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}" # User Execution Count global key
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
|
||||
@@ -32,13 +32,17 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
@@ -62,6 +66,19 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
assert "$4.00" in discord_message
|
||||
assert "$6.00" in discord_message
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
@@ -90,12 +107,17 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
@@ -105,6 +127,19 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
@@ -133,12 +168,17 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
@@ -147,3 +187,16 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
# Verify no notification was sent (user was already below threshold)
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
@@ -35,20 +35,21 @@ async def execute_graph(
|
||||
logger.info(f"Input data: {input_data}")
|
||||
|
||||
# --- Test adding new executions --- #
|
||||
graph_exec = await agent_server.test_execute_graph(
|
||||
response = await agent_server.test_execute_graph(
|
||||
user_id=test_user.id,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
node_input=input_data,
|
||||
)
|
||||
logger.info(f"Created execution with ID: {graph_exec.id}")
|
||||
graph_exec_id = response.graph_exec_id
|
||||
logger.info(f"Created execution with ID: {graph_exec_id}")
|
||||
|
||||
# Execution queue should be empty
|
||||
logger.info("Waiting for execution to complete...")
|
||||
result = await wait_execution(test_user.id, graph_exec.id, 30)
|
||||
result = await wait_execution(test_user.id, graph_exec_id, 30)
|
||||
logger.info(f"Execution completed with {len(result)} results")
|
||||
assert len(result) == num_execs
|
||||
return graph_exec.id
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
async def assert_sample_graph_executions(
|
||||
@@ -378,7 +379,7 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
|
||||
# Verify execution
|
||||
assert result is not None
|
||||
graph_exec_id = result.id
|
||||
graph_exec_id = result["id"]
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, graph_exec_id)
|
||||
@@ -467,7 +468,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
|
||||
# Verify execution
|
||||
assert result is not None, "Result must not be None"
|
||||
graph_exec_id = result.id
|
||||
graph_exec_id = result["id"]
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, graph_exec_id)
|
||||
|
||||
@@ -191,22 +191,15 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
id: str
|
||||
name: str
|
||||
next_run_time: str
|
||||
timezone: str = Field(default="UTC", description="Timezone used for scheduling")
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
job_args: GraphExecutionJobArgs, job_obj: JobObj
|
||||
) -> "GraphExecutionJobInfo":
|
||||
# Extract timezone from the trigger if it's a CronTrigger
|
||||
timezone_str = "UTC"
|
||||
if hasattr(job_obj.trigger, "timezone"):
|
||||
timezone_str = str(job_obj.trigger.timezone)
|
||||
|
||||
return GraphExecutionJobInfo(
|
||||
id=job_obj.id,
|
||||
name=job_obj.name,
|
||||
next_run_time=job_obj.next_run_time.isoformat(),
|
||||
timezone=timezone_str,
|
||||
**job_args.model_dump(),
|
||||
)
|
||||
|
||||
@@ -402,7 +395,6 @@ class Scheduler(AppService):
|
||||
input_data: BlockInput,
|
||||
input_credentials: dict[str, CredentialsMetaInput],
|
||||
name: Optional[str] = None,
|
||||
user_timezone: str | None = None,
|
||||
) -> GraphExecutionJobInfo:
|
||||
# Validate the graph before scheduling to prevent runtime failures
|
||||
# We don't need the return value, just want the validation to run
|
||||
@@ -416,18 +408,7 @@ class Scheduler(AppService):
|
||||
)
|
||||
)
|
||||
|
||||
# Use provided timezone or default to UTC
|
||||
# Note: Timezone should be passed from the client to avoid database lookups
|
||||
if not user_timezone:
|
||||
user_timezone = "UTC"
|
||||
logger.warning(
|
||||
f"No timezone provided for user {user_id}, using UTC for scheduling. "
|
||||
f"Client should pass user's timezone for correct scheduling."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduling job for user {user_id} with timezone {user_timezone} (cron: {cron})"
|
||||
)
|
||||
logger.info(f"Scheduling job for user {user_id} in UTC (cron: {cron})")
|
||||
|
||||
job_args = GraphExecutionJobArgs(
|
||||
user_id=user_id,
|
||||
@@ -441,12 +422,12 @@ class Scheduler(AppService):
|
||||
execute_graph,
|
||||
kwargs=job_args.model_dump(),
|
||||
name=name,
|
||||
trigger=CronTrigger.from_crontab(cron, timezone=user_timezone),
|
||||
trigger=CronTrigger.from_crontab(cron, timezone="UTC"),
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Added job {job.id} with cron schedule '{cron}' in timezone {user_timezone}, input data: {input_data}"
|
||||
f"Added job {job.id} with cron schedule '{cron}' in UTC, input data: {input_data}"
|
||||
)
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from backend.executor.utils import merge_execution_input, parse_execution_output
|
||||
from backend.util.mock import MockObject
|
||||
@@ -277,147 +276,3 @@ def test_merge_execution_input():
|
||||
result = merge_execution_input(data)
|
||||
assert "mixed" in result
|
||||
assert result["mixed"].attr[0]["key"] == "value3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
"""
|
||||
Verify that calling the function with its own output creates the same execution again.
|
||||
"""
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Mock data
|
||||
graph_id = "test-graph-id"
|
||||
user_id = "test-user-id"
|
||||
inputs = {"test_input": "test_value"}
|
||||
preset_id = "test-preset-id"
|
||||
graph_version = 1
|
||||
graph_credentials_inputs = {
|
||||
"cred_key": CredentialsMetaInput(
|
||||
id="cred-id", provider=ProviderName("test_provider"), type="oauth2"
|
||||
)
|
||||
}
|
||||
nodes_input_masks = {"node1": {"input1": "masked_value"}}
|
||||
|
||||
# Mock the graph object returned by validate_and_construct_node_execution_input
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.version = graph_version
|
||||
|
||||
# Mock the starting nodes input and compiled nodes input masks
|
||||
starting_nodes_input = [
|
||||
("node1", {"input1": "value1"}),
|
||||
("node2", {"input1": "value2"}),
|
||||
]
|
||||
compiled_nodes_input_masks = {"node1": {"input1": "compiled_mask"}}
|
||||
|
||||
# Mock the graph execution object
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Mock user context
|
||||
mock_user_context = {"user_id": user_id, "context": "test_context"}
|
||||
|
||||
# Mock the queue and event bus
|
||||
mock_queue = mocker.AsyncMock()
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils.validate_and_construct_node_execution_input"
|
||||
)
|
||||
mock_edb = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_get_user_context = mocker.patch("backend.executor.utils.get_user_context")
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
|
||||
# Setup mock returns
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
|
||||
mock_get_user_context.return_value = mock_user_context
|
||||
mock_get_queue.return_value = mock_queue
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
# Call the function - first execution
|
||||
result1 = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
# Store the parameters used in the first call to create_graph_execution
|
||||
first_call_kwargs = mock_edb.create_graph_execution.call_args[1]
|
||||
|
||||
# Verify the create_graph_execution was called with correct parameters
|
||||
mock_edb.create_graph_execution.assert_called_once_with(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=mock_graph.version,
|
||||
inputs=inputs,
|
||||
credential_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
# Set up the graph execution mock to have properties we can extract
|
||||
mock_graph_exec.graph_id = graph_id
|
||||
mock_graph_exec.user_id = user_id
|
||||
mock_graph_exec.graph_version = graph_version
|
||||
mock_graph_exec.inputs = inputs
|
||||
mock_graph_exec.credential_inputs = graph_credentials_inputs
|
||||
mock_graph_exec.nodes_input_masks = nodes_input_masks
|
||||
mock_graph_exec.preset_id = preset_id
|
||||
|
||||
# Create a second mock execution for the sanity check
|
||||
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec_2.id = "execution-id-456"
|
||||
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Reset mocks and set up for second call
|
||||
mock_edb.create_graph_execution.reset_mock()
|
||||
mock_edb.create_graph_execution.return_value = mock_graph_exec_2
|
||||
mock_validate.reset_mock()
|
||||
|
||||
# Sanity check: call add_graph_execution with properties from first result
|
||||
# This should create the same execution parameters
|
||||
result2 = await add_graph_execution(
|
||||
graph_id=mock_graph_exec.graph_id,
|
||||
user_id=mock_graph_exec.user_id,
|
||||
inputs=mock_graph_exec.inputs,
|
||||
preset_id=mock_graph_exec.preset_id,
|
||||
graph_version=mock_graph_exec.graph_version,
|
||||
graph_credentials_inputs=mock_graph_exec.credential_inputs,
|
||||
nodes_input_masks=mock_graph_exec.nodes_input_masks,
|
||||
)
|
||||
|
||||
# Verify that create_graph_execution was called with identical parameters
|
||||
second_call_kwargs = mock_edb.create_graph_execution.call_args[1]
|
||||
|
||||
# The sanity check: both calls should use identical parameters
|
||||
assert first_call_kwargs == second_call_kwargs
|
||||
|
||||
# Both executions should succeed (though they create different objects)
|
||||
assert result1 == mock_graph_exec
|
||||
assert result2 == mock_graph_exec_2
|
||||
@@ -4,27 +4,20 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, Mapping, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCostType,
|
||||
BlockInput,
|
||||
BlockOutputEntry,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.db import prisma
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node
|
||||
@@ -246,7 +239,7 @@ def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockOutputEntry, name: str) -> JsonValue | None:
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
|
||||
@@ -270,7 +263,7 @@ def parse_execution_output(output: BlockOutputEntry, name: str) -> JsonValue | N
|
||||
if tokens is None:
|
||||
return None
|
||||
|
||||
cur: JsonValue = data
|
||||
cur: Any = data
|
||||
for delim, ident in tokens:
|
||||
if delim == LIST_SPLIT:
|
||||
# list[index]
|
||||
@@ -435,7 +428,7 @@ def validate_exec(
|
||||
async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Checks all credentials for all nodes of the graph and returns structured errors.
|
||||
@@ -515,8 +508,8 @@ async def _validate_node_input_credentials(
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
graph: GraphModel,
|
||||
graph_credentials_input: Mapping[str, CredentialsMetaInput],
|
||||
) -> NodesInputMasks:
|
||||
graph_credentials_input: dict[str, CredentialsMetaInput],
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
"""
|
||||
Maps credentials for an execution to the correct nodes.
|
||||
|
||||
@@ -551,8 +544,8 @@ def make_node_credentials_input_map(
|
||||
async def validate_graph_with_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> Mapping[str, Mapping[str, str]]:
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph including credentials and return structured errors per node.
|
||||
|
||||
@@ -582,7 +575,7 @@ async def _construct_starting_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
@@ -623,7 +616,7 @@ async def _construct_starting_node_execution_input(
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = cast(str | None, node.input_default.get("name"))
|
||||
input_name = node.input_default.get("name")
|
||||
if input_name and input_name in graph_inputs:
|
||||
input_data = {"value": graph_inputs[input_name]}
|
||||
|
||||
@@ -650,9 +643,9 @@ async def validate_and_construct_node_execution_input(
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], dict[str, dict[str, JsonValue]]]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
This centralizes the logic used by both scheduler validation and actual execution.
|
||||
@@ -666,9 +659,7 @@ async def validate_and_construct_node_execution_input(
|
||||
nodes_input_masks: Node inputs to use.
|
||||
|
||||
Returns:
|
||||
GraphModel: Full graph object for the given `graph_id`.
|
||||
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
||||
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
||||
tuple[GraphModel, list[tuple[str, BlockInput]]]: Graph model and list of tuples for node execution input.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the graph is not found.
|
||||
@@ -709,11 +700,11 @@ async def validate_and_construct_node_execution_input(
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
overrides_map_1: NodesInputMasks,
|
||||
overrides_map_2: NodesInputMasks,
|
||||
) -> NodesInputMasks:
|
||||
overrides_map_1: dict[str, dict[str, JsonValue]],
|
||||
overrides_map_2: dict[str, dict[str, JsonValue]],
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
"""Perform a per-node merge of input overrides"""
|
||||
result = dict(overrides_map_1).copy()
|
||||
result = overrides_map_1.copy()
|
||||
for node_id, overrides2 in overrides_map_2.items():
|
||||
if node_id in result:
|
||||
result[node_id] = {**result[node_id], **overrides2}
|
||||
@@ -863,8 +854,8 @@ async def add_graph_execution(
|
||||
inputs: Optional[BlockInput] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
@@ -888,7 +879,7 @@ async def add_graph_execution(
|
||||
else:
|
||||
edb = get_database_manager_async_client()
|
||||
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks = (
|
||||
graph, starting_nodes_input, nodes_input_masks = (
|
||||
await validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
@@ -901,43 +892,37 @@ async def add_graph_execution(
|
||||
graph_exec = None
|
||||
|
||||
try:
|
||||
# Sanity check: running add_graph_execution with the properties of
|
||||
# the graph_exec created here should create the same execution again.
|
||||
graph_exec = await edb.create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
inputs=inputs or {},
|
||||
credential_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
||||
user_context=await get_user_context(user_id),
|
||||
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
||||
)
|
||||
# Fetch user context for the graph execution
|
||||
user_context = await get_user_context(user_id)
|
||||
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(user_context)
|
||||
if nodes_input_masks:
|
||||
graph_exec_entry.nodes_input_masks = nodes_input_masks
|
||||
|
||||
logger.info(
|
||||
f"Created graph execution #{graph_exec.id} for graph "
|
||||
f"#{graph_id} with {len(starting_nodes_input)} starting nodes. "
|
||||
f"Now publishing to execution queue."
|
||||
)
|
||||
|
||||
exec_queue = await get_async_execution_queue()
|
||||
await exec_queue.publish_message(
|
||||
await queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||
|
||||
graph_exec.status = ExecutionStatus.QUEUED
|
||||
await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=graph_exec.status,
|
||||
)
|
||||
await get_async_execution_event_bus().publish(graph_exec)
|
||||
bus = get_async_execution_event_bus()
|
||||
await bus.publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except BaseException as e:
|
||||
@@ -967,7 +952,7 @@ async def add_graph_execution(
|
||||
class ExecutionOutputEntry(BaseModel):
|
||||
node: Node
|
||||
node_exec_id: str
|
||||
data: BlockOutputEntry
|
||||
data: BlockData
|
||||
|
||||
|
||||
class NodeExecutionProgress:
|
||||
|
||||
@@ -7,9 +7,10 @@ from backend.data.graph import set_node_webhook
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
from .utils import setup_webhook_for_block
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import BaseGraph, GraphModel, NodeModel
|
||||
from backend.data.graph import BaseGraph, GraphModel, Node, NodeModel
|
||||
from backend.data.model import Credentials
|
||||
|
||||
from ._base import BaseWebhooksManager
|
||||
@@ -42,19 +43,32 @@ async def _on_graph_activate(graph: "BaseGraph", user_id: str) -> "BaseGraph": .
|
||||
|
||||
async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
|
||||
get_credentials = credentials_manager.cached_getter(user_id)
|
||||
updated_nodes = []
|
||||
for new_node in graph.nodes:
|
||||
block_input_schema = cast(BlockSchema, new_node.block.input_schema)
|
||||
|
||||
for creds_field_name in block_input_schema.get_credentials_fields().keys():
|
||||
# Prevent saving graph with non-existent credentials
|
||||
if (
|
||||
creds_meta := new_node.input_default.get(creds_field_name)
|
||||
) and not await get_credentials(creds_meta["id"]):
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
node_credentials = None
|
||||
if (
|
||||
# Webhook-triggered blocks are only allowed to have 1 credentials input
|
||||
(
|
||||
creds_field_name := next(
|
||||
iter(block_input_schema.get_credentials_fields()), None
|
||||
)
|
||||
)
|
||||
and (creds_meta := new_node.input_default.get(creds_field_name))
|
||||
and not (node_credentials := await get_credentials(creds_meta["id"]))
|
||||
):
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_activate(
|
||||
user_id, graph.id, new_node, credentials=node_credentials
|
||||
)
|
||||
updated_nodes.append(updated_node)
|
||||
|
||||
graph.nodes = updated_nodes
|
||||
return graph
|
||||
|
||||
|
||||
@@ -71,14 +85,20 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
block_input_schema = cast(BlockSchema, node.block.input_schema)
|
||||
|
||||
node_credentials = None
|
||||
for creds_field_name in block_input_schema.get_credentials_fields().keys():
|
||||
if (creds_meta := node.input_default.get(creds_field_name)) and not (
|
||||
node_credentials := await get_credentials(creds_meta["id"])
|
||||
):
|
||||
logger.warning(
|
||||
f"Node #{node.id} input '{creds_field_name}' referenced "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
if (
|
||||
# Webhook-triggered blocks are only allowed to have 1 credentials input
|
||||
(
|
||||
creds_field_name := next(
|
||||
iter(block_input_schema.get_credentials_fields()), None
|
||||
)
|
||||
)
|
||||
and (creds_meta := node.input_default.get(creds_field_name))
|
||||
and not (node_credentials := await get_credentials(creds_meta["id"]))
|
||||
):
|
||||
logger.error(
|
||||
f"Node #{node.id} input '{creds_field_name}' referenced non-existent "
|
||||
f"credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_deactivate(
|
||||
user_id, node, credentials=node_credentials
|
||||
@@ -89,6 +109,32 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
return graph
|
||||
|
||||
|
||||
async def on_node_activate(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
node: "Node",
|
||||
*,
|
||||
credentials: Optional["Credentials"] = None,
|
||||
) -> "Node":
|
||||
"""Hook to be called when the node is activated/created"""
|
||||
|
||||
if node.block.webhook_config:
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=node.block,
|
||||
trigger_config=node.input_default,
|
||||
for_graph_id=graph_id,
|
||||
)
|
||||
if new_webhook:
|
||||
node = await set_node_webhook(node.id, new_webhook.id)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Node #{node.id} does not have everything for a webhook: {feedback}"
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
async def on_node_deactivate(
|
||||
user_id: str,
|
||||
node: "NodeModel",
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional, cast
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Config
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
@@ -12,7 +13,6 @@ if TYPE_CHECKING:
|
||||
from backend.data.block import Block, BlockSchema
|
||||
from backend.data.integrations import Webhook
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app_config = Config()
|
||||
@@ -20,7 +20,7 @@ credentials_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
# TODO: add test to assert this matches the actual API route
|
||||
def webhook_ingress_url(provider_name: "ProviderName", webhook_id: str) -> str:
|
||||
def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
|
||||
return (
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
|
||||
f"/webhooks/{webhook_id}/ingress"
|
||||
@@ -144,62 +144,3 @@ async def setup_webhook_for_block(
|
||||
)
|
||||
logger.debug(f"Acquired webhook: {webhook}")
|
||||
return webhook, None
|
||||
|
||||
|
||||
async def migrate_legacy_triggered_graphs():
|
||||
from prisma.models import AgentGraph
|
||||
|
||||
from backend.data.graph import AGENT_GRAPH_INCLUDE, GraphModel, set_node_webhook
|
||||
from backend.data.model import is_credentials_field_name
|
||||
from backend.server.v2.library.db import create_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPresetCreatable
|
||||
|
||||
triggered_graphs = [
|
||||
GraphModel.from_db(_graph)
|
||||
for _graph in await AgentGraph.prisma().find_many(
|
||||
where={
|
||||
"isActive": True,
|
||||
"Nodes": {"some": {"NOT": [{"webhookId": None}]}},
|
||||
},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
]
|
||||
|
||||
n_migrated_webhooks = 0
|
||||
|
||||
for graph in triggered_graphs:
|
||||
if not ((trigger_node := graph.webhook_input_node) and trigger_node.webhook_id):
|
||||
continue
|
||||
|
||||
# Use trigger node's inputs for the preset
|
||||
preset_credentials = {
|
||||
field_name: creds_meta
|
||||
for field_name, creds_meta in trigger_node.input_default.items()
|
||||
if is_credentials_field_name(field_name)
|
||||
}
|
||||
preset_inputs = {
|
||||
field_name: value
|
||||
for field_name, value in trigger_node.input_default.items()
|
||||
if not is_credentials_field_name(field_name)
|
||||
}
|
||||
|
||||
# Create a triggered preset for the graph
|
||||
await create_preset(
|
||||
graph.user_id,
|
||||
LibraryAgentPresetCreatable(
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
inputs=preset_inputs,
|
||||
credentials=preset_credentials,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
webhook_id=trigger_node.webhook_id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
# Detach webhook from the graph node
|
||||
await set_node_webhook(trigger_node.id, None)
|
||||
|
||||
n_migrated_webhooks += 1
|
||||
|
||||
logger.info(f"Migrated {n_migrated_webhooks} node triggers to triggered presets")
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
"""
|
||||
Prometheus instrumentation for FastAPI services.
|
||||
|
||||
This module provides centralized metrics collection and instrumentation
|
||||
for all FastAPI services in the AutoGPT platform.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from prometheus_client import Counter, Gauge, Histogram, Info
|
||||
from prometheus_fastapi_instrumentator import Instrumentator, metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Custom business metrics with controlled cardinality
|
||||
GRAPH_EXECUTIONS = Counter(
|
||||
"autogpt_graph_executions_total",
|
||||
"Total number of graph executions",
|
||||
labelnames=[
|
||||
"status"
|
||||
], # Removed graph_id and user_id to prevent cardinality explosion
|
||||
)
|
||||
|
||||
GRAPH_EXECUTIONS_BY_USER = Counter(
|
||||
"autogpt_graph_executions_by_user_total",
|
||||
"Total number of graph executions by user (sampled)",
|
||||
labelnames=["status"], # Only status, user_id tracked separately when needed
|
||||
)
|
||||
|
||||
BLOCK_EXECUTIONS = Counter(
|
||||
"autogpt_block_executions_total",
|
||||
"Total number of block executions",
|
||||
labelnames=["block_type", "status"], # block_type is bounded
|
||||
)
|
||||
|
||||
BLOCK_DURATION = Histogram(
|
||||
"autogpt_block_duration_seconds",
|
||||
"Duration of block executions in seconds",
|
||||
labelnames=["block_type"],
|
||||
buckets=[0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60],
|
||||
)
|
||||
|
||||
WEBSOCKET_CONNECTIONS = Gauge(
|
||||
"autogpt_websocket_connections_total",
|
||||
"Total number of active WebSocket connections",
|
||||
# Removed user_id label - track total only to prevent cardinality explosion
|
||||
)
|
||||
|
||||
SCHEDULER_JOBS = Gauge(
|
||||
"autogpt_scheduler_jobs",
|
||||
"Current number of scheduled jobs",
|
||||
labelnames=["job_type", "status"],
|
||||
)
|
||||
|
||||
DATABASE_QUERIES = Histogram(
|
||||
"autogpt_database_query_duration_seconds",
|
||||
"Duration of database queries in seconds",
|
||||
labelnames=["operation", "table"],
|
||||
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5],
|
||||
)
|
||||
|
||||
RABBITMQ_MESSAGES = Counter(
|
||||
"autogpt_rabbitmq_messages_total",
|
||||
"Total number of RabbitMQ messages",
|
||||
labelnames=["queue", "status"],
|
||||
)
|
||||
|
||||
AUTHENTICATION_ATTEMPTS = Counter(
|
||||
"autogpt_auth_attempts_total",
|
||||
"Total number of authentication attempts",
|
||||
labelnames=["method", "status"],
|
||||
)
|
||||
|
||||
API_KEY_USAGE = Counter(
|
||||
"autogpt_api_key_usage_total",
|
||||
"API key usage by provider",
|
||||
labelnames=["provider", "block_type", "status"],
|
||||
)
|
||||
|
||||
# Function/operation level metrics with controlled cardinality
|
||||
GRAPH_OPERATIONS = Counter(
|
||||
"autogpt_graph_operations_total",
|
||||
"Graph operations by type",
|
||||
labelnames=["operation", "status"], # create, update, delete, execute, etc.
|
||||
)
|
||||
|
||||
USER_OPERATIONS = Counter(
|
||||
"autogpt_user_operations_total",
|
||||
"User operations by type",
|
||||
labelnames=["operation", "status"], # login, register, update_profile, etc.
|
||||
)
|
||||
|
||||
RATE_LIMIT_HITS = Counter(
|
||||
"autogpt_rate_limit_hits_total",
|
||||
"Number of rate limit hits",
|
||||
labelnames=["endpoint"], # Removed user_id to prevent cardinality explosion
|
||||
)
|
||||
|
||||
SERVICE_INFO = Info(
|
||||
"autogpt_service",
|
||||
"Service information",
|
||||
)
|
||||
|
||||
|
||||
def instrument_fastapi(
|
||||
app: FastAPI,
|
||||
service_name: str,
|
||||
expose_endpoint: bool = True,
|
||||
endpoint: str = "/metrics",
|
||||
include_in_schema: bool = False,
|
||||
excluded_handlers: Optional[list] = None,
|
||||
) -> Instrumentator:
|
||||
"""
|
||||
Instrument a FastAPI application with Prometheus metrics.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
service_name: Name of the service for metrics labeling
|
||||
expose_endpoint: Whether to expose /metrics endpoint
|
||||
endpoint: Path for metrics endpoint
|
||||
include_in_schema: Whether to include metrics endpoint in OpenAPI schema
|
||||
excluded_handlers: List of paths to exclude from metrics
|
||||
|
||||
Returns:
|
||||
Configured Instrumentator instance
|
||||
"""
|
||||
|
||||
# Set service info
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
service_version = version("autogpt-platform-backend")
|
||||
except Exception:
|
||||
service_version = "unknown"
|
||||
|
||||
SERVICE_INFO.info(
|
||||
{
|
||||
"service": service_name,
|
||||
"version": service_version,
|
||||
}
|
||||
)
|
||||
|
||||
# Create instrumentator with default metrics
|
||||
instrumentator = Instrumentator(
|
||||
should_group_status_codes=True,
|
||||
should_ignore_untemplated=True,
|
||||
should_respect_env_var=True,
|
||||
should_instrument_requests_inprogress=True,
|
||||
excluded_handlers=excluded_handlers or ["/health", "/readiness"],
|
||||
env_var_name="ENABLE_METRICS",
|
||||
inprogress_name="autogpt_http_requests_inprogress",
|
||||
inprogress_labels=True,
|
||||
)
|
||||
|
||||
# Add default HTTP metrics
|
||||
instrumentator.add(
|
||||
metrics.default(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Add request size metrics
|
||||
instrumentator.add(
|
||||
metrics.request_size(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Add response size metrics
|
||||
instrumentator.add(
|
||||
metrics.response_size(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Add latency metrics with custom buckets for better granularity
|
||||
instrumentator.add(
|
||||
metrics.latency(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60],
|
||||
)
|
||||
)
|
||||
|
||||
# Add combined metrics (requests by method and status)
|
||||
instrumentator.add(
|
||||
metrics.combined_size(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Instrument the app
|
||||
instrumentator.instrument(app)
|
||||
|
||||
# Expose metrics endpoint if requested
|
||||
if expose_endpoint:
|
||||
instrumentator.expose(
|
||||
app,
|
||||
endpoint=endpoint,
|
||||
include_in_schema=include_in_schema,
|
||||
tags=["monitoring"] if include_in_schema else None,
|
||||
)
|
||||
logger.info(f"Metrics endpoint exposed at {endpoint} for {service_name}")
|
||||
|
||||
return instrumentator
|
||||
|
||||
|
||||
def record_graph_execution(graph_id: str, status: str, user_id: str):
|
||||
"""Record a graph execution event.
|
||||
|
||||
Args:
|
||||
graph_id: Graph identifier (kept for future sampling/debugging)
|
||||
status: Execution status (success/error/validation_error)
|
||||
user_id: User identifier (kept for future sampling/debugging)
|
||||
"""
|
||||
# Track overall executions without high-cardinality labels
|
||||
GRAPH_EXECUTIONS.labels(status=status).inc()
|
||||
|
||||
# Optionally track per-user executions (implement sampling if needed)
|
||||
# For now, just track status to avoid cardinality explosion
|
||||
GRAPH_EXECUTIONS_BY_USER.labels(status=status).inc()
|
||||
|
||||
|
||||
def record_block_execution(block_type: str, status: str, duration: float):
|
||||
"""Record a block execution event with duration."""
|
||||
BLOCK_EXECUTIONS.labels(block_type=block_type, status=status).inc()
|
||||
BLOCK_DURATION.labels(block_type=block_type).observe(duration)
|
||||
|
||||
|
||||
def update_websocket_connections(user_id: str, delta: int):
|
||||
"""Update the number of active WebSocket connections.
|
||||
|
||||
Args:
|
||||
user_id: User identifier (kept for future sampling/debugging)
|
||||
delta: Change in connection count (+1 for connect, -1 for disconnect)
|
||||
"""
|
||||
# Track total connections without user_id to prevent cardinality explosion
|
||||
if delta > 0:
|
||||
WEBSOCKET_CONNECTIONS.inc(delta)
|
||||
else:
|
||||
WEBSOCKET_CONNECTIONS.dec(abs(delta))
|
||||
|
||||
|
||||
def record_database_query(operation: str, table: str, duration: float):
|
||||
"""Record a database query with duration."""
|
||||
DATABASE_QUERIES.labels(operation=operation, table=table).observe(duration)
|
||||
|
||||
|
||||
def record_rabbitmq_message(queue: str, status: str):
|
||||
"""Record a RabbitMQ message event."""
|
||||
RABBITMQ_MESSAGES.labels(queue=queue, status=status).inc()
|
||||
|
||||
|
||||
def record_authentication_attempt(method: str, status: str):
|
||||
"""Record an authentication attempt."""
|
||||
AUTHENTICATION_ATTEMPTS.labels(method=method, status=status).inc()
|
||||
|
||||
|
||||
def record_api_key_usage(provider: str, block_type: str, status: str):
|
||||
"""Record API key usage by provider and block."""
|
||||
API_KEY_USAGE.labels(provider=provider, block_type=block_type, status=status).inc()
|
||||
|
||||
|
||||
def record_rate_limit_hit(endpoint: str, user_id: str):
|
||||
"""Record a rate limit hit.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint that was rate limited
|
||||
user_id: User identifier (kept for future sampling/debugging)
|
||||
"""
|
||||
RATE_LIMIT_HITS.labels(endpoint=endpoint).inc()
|
||||
|
||||
|
||||
def record_graph_operation(operation: str, status: str):
|
||||
"""Record a graph operation (create, update, delete, execute, etc.)."""
|
||||
GRAPH_OPERATIONS.labels(operation=operation, status=status).inc()
|
||||
|
||||
|
||||
def record_user_operation(operation: str, status: str):
|
||||
"""Record a user operation (login, register, etc.)."""
|
||||
USER_OPERATIONS.labels(operation=operation, status=status).inc()
|
||||
@@ -63,7 +63,7 @@ except ImportError:
|
||||
|
||||
# Cost System
|
||||
try:
|
||||
from backend.data.block import BlockCost, BlockCostType
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
except ImportError:
|
||||
from backend.data.block_cost_config import BlockCost, BlockCostType
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Callable, List, Optional, Type
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import BlockCost, BlockCostType
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
|
||||
@@ -8,8 +8,9 @@ BLOCK_COSTS configuration used by the execution system.
|
||||
import logging
|
||||
from typing import List, Type
|
||||
|
||||
from backend.data.block import Block, BlockCost
|
||||
from backend.data.block import Block
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Callable, List, Optional, Set, Type
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.block import BlockCost
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
|
||||
@@ -6,10 +6,10 @@ import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.blocks.basic import Block
|
||||
from backend.data.model import Credentials
|
||||
from backend.data.model import APIKeyCredentials, Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
@@ -17,8 +17,6 @@ from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
if TYPE_CHECKING:
|
||||
from backend.sdk.provider import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SDKOAuthCredentials(BaseModel):
|
||||
"""OAuth credentials configuration for SDK providers."""
|
||||
@@ -104,8 +102,21 @@ class AutoRegistry:
|
||||
"""Register an environment variable as an API key for a provider."""
|
||||
with cls._lock:
|
||||
cls._api_key_mappings[provider] = env_var_name
|
||||
# Note: The credential itself is created by ProviderBuilder.with_api_key()
|
||||
# We only store the mapping here to avoid duplication
|
||||
|
||||
# Dynamically check if the env var exists and create credential
|
||||
import os
|
||||
|
||||
api_key = os.getenv(env_var_name)
|
||||
if api_key:
|
||||
credential = APIKeyCredentials(
|
||||
id=f"{provider}-default",
|
||||
provider=provider,
|
||||
api_key=SecretStr(api_key),
|
||||
title=f"Default {provider} credentials",
|
||||
)
|
||||
# Check if credential already exists to avoid duplicates
|
||||
if not any(c.id == credential.id for c in cls._default_credentials):
|
||||
cls._default_credentials.append(credential)
|
||||
|
||||
@classmethod
|
||||
def get_all_credentials(cls) -> List[Credentials]:
|
||||
@@ -199,43 +210,3 @@ class AutoRegistry:
|
||||
webhooks.load_webhook_managers = patched_load
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch webhook managers: {e}")
|
||||
|
||||
# Patch credentials store to include SDK-registered credentials
|
||||
try:
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
# Get the module from sys.modules to respect mocking
|
||||
if "backend.integrations.credentials_store" in sys.modules:
|
||||
creds_store: Any = sys.modules["backend.integrations.credentials_store"]
|
||||
else:
|
||||
import backend.integrations.credentials_store
|
||||
|
||||
creds_store: Any = backend.integrations.credentials_store
|
||||
|
||||
if hasattr(creds_store, "IntegrationCredentialsStore"):
|
||||
store_class = creds_store.IntegrationCredentialsStore
|
||||
if hasattr(store_class, "get_all_creds"):
|
||||
original_get_all_creds = store_class.get_all_creds
|
||||
|
||||
async def patched_get_all_creds(self, user_id: str):
|
||||
# Get original credentials
|
||||
original_creds = await original_get_all_creds(self, user_id)
|
||||
|
||||
# Add SDK-registered credentials
|
||||
sdk_creds = cls.get_all_credentials()
|
||||
|
||||
# Combine credentials, avoiding duplicates by ID
|
||||
existing_ids = {c.id for c in original_creds}
|
||||
for cred in sdk_creds:
|
||||
if cred.id not in existing_ids:
|
||||
original_creds.append(cred)
|
||||
|
||||
return original_creds
|
||||
|
||||
store_class.get_all_creds = patched_get_all_creds
|
||||
logger.info(
|
||||
"Successfully patched IntegrationCredentialsStore.get_all_creds"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch credentials store: {e}")
|
||||
|
||||
@@ -88,7 +88,6 @@ async def test_send_graph_execution_result(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
preset_id=None,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
ended_at=datetime.now(tz=timezone.utc),
|
||||
@@ -102,8 +101,6 @@ async def test_send_graph_execution_result(
|
||||
"input_1": "some input value :)",
|
||||
"input_2": "some *other* input value",
|
||||
},
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
outputs={
|
||||
"the_output": ["some output value"],
|
||||
"other_output": ["sike there was another output"],
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
from .routes.v1 import v1_router
|
||||
@@ -14,12 +13,3 @@ external_app = FastAPI(
|
||||
|
||||
external_app.add_middleware(SecurityHeadersMiddleware)
|
||||
external_app.include_router(v1_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_app,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from fastapi.security import APIKeyHeader
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
|
||||
from backend.data.api_key import has_permission, validate_api_key
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
api_key_header = APIKeyHeader(name="X-API-Key")
|
||||
|
||||
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
async def require_api_key(request: Request):
|
||||
"""Base middleware for API key authentication"""
|
||||
api_key = await api_key_header(request)
|
||||
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
@@ -17,19 +19,18 @@ async def require_api_key(api_key: str | None = Security(api_key_header)) -> API
|
||||
if not api_key_obj:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
request.state.api_key = api_key_obj
|
||||
return api_key_obj
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""Dependency function for checking specific permissions"""
|
||||
|
||||
async def check_permission(
|
||||
api_key: APIKeyInfo = Security(require_api_key),
|
||||
) -> APIKeyInfo:
|
||||
async def check_permission(api_key=Depends(require_api_key)):
|
||||
if not has_permission(api_key, permission):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"API key lacks the required permission '{permission}'",
|
||||
detail=f"API key missing required permission: {permission}",
|
||||
)
|
||||
return api_key
|
||||
|
||||
|
||||
@@ -2,14 +2,14 @@ import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.data.block
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.external.middleware import require_permission
|
||||
@@ -47,7 +47,7 @@ class GraphExecutionResult(TypedDict):
|
||||
@v1_router.get(
|
||||
path="/blocks",
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||
dependencies=[Depends(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||
)
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
@@ -57,12 +57,12 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
@v1_router.post(
|
||||
path="/blocks/{block_id}/execute",
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.EXECUTE_BLOCK))],
|
||||
dependencies=[Depends(require_permission(APIKeyPermission.EXECUTE_BLOCK))],
|
||||
)
|
||||
async def execute_graph_block(
|
||||
block_id: str,
|
||||
data: BlockInput,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||
) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
if not obj:
|
||||
@@ -82,7 +82,7 @@ async def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = await add_graph_execution(
|
||||
@@ -104,7 +104,7 @@ async def execute_graph(
|
||||
async def get_graph_execution_results(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||
api_key: APIKey = Depends(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||
) -> GraphExecutionResult:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=api_key.user_id)
|
||||
if not graph:
|
||||
|
||||
@@ -81,10 +81,6 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
# Add noindex header for shared execution pages
|
||||
if "/public/shared" in request.url.path:
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow"
|
||||
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(request.url.path):
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Optional
|
||||
|
||||
import pydantic
|
||||
|
||||
from backend.data.api_key import APIKeyInfo, APIKeyPermission
|
||||
from backend.data.api_key import APIKeyPermission, APIKeyWithoutHash
|
||||
from backend.data.graph import Graph
|
||||
from backend.util.timezone_name import TimeZoneName
|
||||
|
||||
@@ -34,6 +34,10 @@ class WSSubscribeGraphExecutionsRequest(pydantic.BaseModel):
|
||||
graph_id: str
|
||||
|
||||
|
||||
class ExecuteGraphResponse(pydantic.BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
class CreateGraph(pydantic.BaseModel):
|
||||
graph: Graph
|
||||
|
||||
@@ -45,7 +49,7 @@ class CreateAPIKeyRequest(pydantic.BaseModel):
|
||||
|
||||
|
||||
class CreateAPIKeyResponse(pydantic.BaseModel):
|
||||
api_key: APIKeyInfo
|
||||
api_key: APIKeyWithoutHash
|
||||
plain_text_key: str
|
||||
|
||||
|
||||
|
||||
@@ -12,13 +12,11 @@ from autogpt_libs.auth import add_auth_responses_to_openapi
|
||||
from autogpt_libs.auth import verify_settings as verify_auth_settings
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
@@ -37,12 +35,10 @@ import backend.util.settings
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.util import json
|
||||
from backend.util.cloud_storage import shutdown_cloud_storage_handler
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
|
||||
@@ -80,8 +76,6 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
@@ -143,16 +137,6 @@ app.add_middleware(SecurityHeadersMiddleware)
|
||||
# Add 401 responses to authenticated endpoints in OpenAPI spec
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
app,
|
||||
service_name="rest-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=settings.config.app_env
|
||||
== backend.util.settings.AppEnvironment.LOCAL,
|
||||
)
|
||||
|
||||
|
||||
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: fastapi.Request, exc: Exception):
|
||||
@@ -211,14 +195,10 @@ async def validation_error_handler(
|
||||
)
|
||||
|
||||
|
||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||
|
||||
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
app.include_router(
|
||||
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
|
||||
@@ -383,7 +363,6 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
preset_id=preset_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs or {},
|
||||
credential_inputs={},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Sequence
|
||||
|
||||
import pydantic
|
||||
@@ -29,9 +27,20 @@ from typing_extensions import Optional, TypedDict
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.v2.library.db as library_db
|
||||
from backend.data import api_key as api_key_db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import (
|
||||
APIKeyError,
|
||||
APIKeyNotFoundError,
|
||||
APIKeyPermissionError,
|
||||
APIKeyWithoutHash,
|
||||
generate_api_key,
|
||||
get_api_key_by_id,
|
||||
list_user_api_keys,
|
||||
revoke_api_key,
|
||||
suspend_api_key,
|
||||
update_api_key_permissions,
|
||||
)
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
@@ -65,15 +74,11 @@ from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
on_graph_activate,
|
||||
on_graph_deactivate,
|
||||
)
|
||||
from backend.monitoring.instrumentation import (
|
||||
record_block_execution,
|
||||
record_graph_execution,
|
||||
record_graph_operation,
|
||||
)
|
||||
from backend.server.model import (
|
||||
CreateAPIKeyRequest,
|
||||
CreateAPIKeyResponse,
|
||||
CreateGraph,
|
||||
ExecuteGraphResponse,
|
||||
RequestTopUp,
|
||||
SetGraphActiveVersion,
|
||||
TimezoneResponse,
|
||||
@@ -86,6 +91,7 @@ from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.timezone_utils import (
|
||||
convert_cron_to_utc,
|
||||
convert_utc_time_to_user_timezone,
|
||||
get_user_timezone_or_utc,
|
||||
)
|
||||
@@ -103,7 +109,6 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
# Define the API routes
|
||||
@@ -172,6 +177,7 @@ async def get_user_timezone_route(
|
||||
summary="Update user timezone",
|
||||
tags=["auth"],
|
||||
dependencies=[Security(requires_user)],
|
||||
response_model=TimezoneResponse,
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
@@ -287,26 +293,10 @@ async def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlock
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
output = defaultdict(list)
|
||||
async for name, data in obj.execute(data):
|
||||
output[name].append(data)
|
||||
|
||||
# Record successful block execution with duration
|
||||
duration = time.time() - start_time
|
||||
block_type = obj.__class__.__name__
|
||||
record_block_execution(
|
||||
block_type=block_type, status="success", duration=duration
|
||||
)
|
||||
|
||||
return output
|
||||
except Exception:
|
||||
# Record failed block execution
|
||||
duration = time.time() - start_time
|
||||
block_type = obj.__class__.__name__
|
||||
record_block_execution(block_type=block_type, status="error", duration=duration)
|
||||
raise
|
||||
output = defaultdict(list)
|
||||
async for name, data in obj.execute(data):
|
||||
output[name].append(data)
|
||||
return output
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -793,7 +783,7 @@ async def execute_graph(
|
||||
],
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> execution_db.GraphExecutionMeta:
|
||||
) -> ExecuteGraphResponse:
|
||||
current_balance = await _user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
@@ -802,7 +792,7 @@ async def execute_graph(
|
||||
)
|
||||
|
||||
try:
|
||||
result = await execution_utils.add_graph_execution(
|
||||
graph_exec = await execution_utils.add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
@@ -810,16 +800,8 @@ async def execute_graph(
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
)
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
record_graph_operation(operation="execute", status="success")
|
||||
return result
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
|
||||
except GraphValidationError as e:
|
||||
# Record failed graph execution
|
||||
record_graph_execution(
|
||||
graph_id=graph_id, status="validation_error", user_id=user_id
|
||||
)
|
||||
record_graph_operation(operation="execute", status="validation_error")
|
||||
# Return structured validation errors that the frontend can parse
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -830,11 +812,6 @@ async def execute_graph(
|
||||
"node_errors": e.node_errors,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
# Record any other failures
|
||||
record_graph_execution(graph_id=graph_id, status="error", user_id=user_id)
|
||||
record_graph_operation(operation="execute", status="error")
|
||||
raise
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -959,99 +936,6 @@ async def delete_graph_execution(
|
||||
)
|
||||
|
||||
|
||||
class ShareRequest(pydantic.BaseModel):
|
||||
"""Optional request body for share endpoint."""
|
||||
|
||||
pass # Empty body is fine
|
||||
|
||||
|
||||
class ShareResponse(pydantic.BaseModel):
|
||||
"""Response from share endpoints."""
|
||||
|
||||
share_url: str
|
||||
share_token: str
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/graphs/{graph_id}/executions/{graph_exec_id}/share",
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def enable_execution_sharing(
|
||||
graph_id: Annotated[str, Path],
|
||||
graph_exec_id: Annotated[str, Path],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
_body: ShareRequest = Body(default=ShareRequest()),
|
||||
) -> ShareResponse:
|
||||
"""Enable sharing for a graph execution."""
|
||||
# Verify the execution belongs to the user
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
# Generate a unique share token
|
||||
share_token = str(uuid.uuid4())
|
||||
|
||||
# Update the execution with share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
is_shared=True,
|
||||
share_token=share_token,
|
||||
shared_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Return the share URL
|
||||
frontend_url = Settings().config.frontend_base_url or "http://localhost:3000"
|
||||
share_url = f"{frontend_url}/share/{share_token}"
|
||||
|
||||
return ShareResponse(share_url=share_url, share_token=share_token)
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
"/graphs/{graph_id}/executions/{graph_exec_id}/share",
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def disable_execution_sharing(
|
||||
graph_id: Annotated[str, Path],
|
||||
graph_exec_id: Annotated[str, Path],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> None:
|
||||
"""Disable sharing for a graph execution."""
|
||||
# Verify the execution belongs to the user
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
# Remove share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
is_shared=False,
|
||||
share_token=None,
|
||||
shared_at=None,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get("/public/shared/{share_token}")
|
||||
async def get_shared_execution(
|
||||
share_token: Annotated[
|
||||
str,
|
||||
Path(regex=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||
],
|
||||
) -> execution_db.SharedExecutionResponse:
|
||||
"""Get a shared graph execution by share token (no auth required)."""
|
||||
execution = await execution_db.get_graph_execution_by_share_token(share_token)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Shared execution not found")
|
||||
|
||||
return execution
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Schedules ########################
|
||||
########################################################
|
||||
@@ -1063,10 +947,6 @@ class ScheduleCreationRequest(pydantic.BaseModel):
|
||||
cron: str
|
||||
inputs: dict[str, Any]
|
||||
credentials: dict[str, CredentialsMetaInput] = pydantic.Field(default_factory=dict)
|
||||
timezone: Optional[str] = pydantic.Field(
|
||||
default=None,
|
||||
description="User's timezone for scheduling (e.g., 'America/New_York'). If not provided, will use user's saved timezone or UTC.",
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -1091,22 +971,26 @@ async def create_graph_execution_schedule(
|
||||
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
|
||||
)
|
||||
|
||||
# Use timezone from request if provided, otherwise fetch from user profile
|
||||
if schedule_params.timezone:
|
||||
user_timezone = schedule_params.timezone
|
||||
else:
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
# Convert cron expression from user timezone to UTC
|
||||
try:
|
||||
utc_cron = convert_cron_to_utc(schedule_params.cron, user_timezone)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid cron expression for timezone {user_timezone}: {e}",
|
||||
)
|
||||
|
||||
result = await get_scheduler_client().add_execution_schedule(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
name=schedule_params.name,
|
||||
cron=schedule_params.cron,
|
||||
cron=utc_cron, # Send UTC cron to scheduler
|
||||
input_data=schedule_params.inputs,
|
||||
input_credentials=schedule_params.credentials,
|
||||
user_timezone=user_timezone,
|
||||
)
|
||||
|
||||
# Convert the next_run_time back to user timezone for display
|
||||
@@ -1128,11 +1012,24 @@ async def list_graph_execution_schedules(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
graph_id: str = Path(),
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await get_scheduler_client().get_execution_schedules(
|
||||
schedules = await get_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
|
||||
# Get user timezone for conversion
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
# Convert next_run_time to user timezone for display
|
||||
for schedule in schedules:
|
||||
if schedule.next_run_time:
|
||||
schedule.next_run_time = convert_utc_time_to_user_timezone(
|
||||
schedule.next_run_time, user_timezone
|
||||
)
|
||||
|
||||
return schedules
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/schedules",
|
||||
@@ -1143,7 +1040,20 @@ async def list_graph_execution_schedules(
|
||||
async def list_all_graphs_execution_schedules(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await get_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
schedules = await get_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
|
||||
# Get user timezone for conversion
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
# Convert UTC next_run_time to user timezone for display
|
||||
for schedule in schedules:
|
||||
if schedule.next_run_time:
|
||||
schedule.next_run_time = convert_utc_time_to_user_timezone(
|
||||
schedule.next_run_time, user_timezone
|
||||
)
|
||||
|
||||
return schedules
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
@@ -1174,6 +1084,7 @@ async def delete_graph_execution_schedule(
|
||||
@v1_router.post(
|
||||
"/api-keys",
|
||||
summary="Create new API key",
|
||||
response_model=CreateAPIKeyResponse,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
@@ -1181,73 +1092,128 @@ async def create_api_key(
|
||||
request: CreateAPIKeyRequest, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> CreateAPIKeyResponse:
|
||||
"""Create a new API key"""
|
||||
api_key_info, plain_text_key = await api_key_db.create_api_key(
|
||||
name=request.name,
|
||||
user_id=user_id,
|
||||
permissions=request.permissions,
|
||||
description=request.description,
|
||||
)
|
||||
return CreateAPIKeyResponse(api_key=api_key_info, plain_text_key=plain_text_key)
|
||||
try:
|
||||
api_key, plain_text = await generate_api_key(
|
||||
name=request.name,
|
||||
user_id=user_id,
|
||||
permissions=request.permissions,
|
||||
description=request.description,
|
||||
)
|
||||
return CreateAPIKeyResponse(api_key=api_key, plain_text_key=plain_text)
|
||||
except APIKeyError as e:
|
||||
logger.error(
|
||||
"Could not create API key for user %s: %s. Review input and permissions.",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Verify request payload and try again."},
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/api-keys",
|
||||
summary="List user API keys",
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_api_keys(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[api_key_db.APIKeyInfo]:
|
||||
) -> list[APIKeyWithoutHash]:
|
||||
"""List all API keys for the user"""
|
||||
return await api_key_db.list_user_api_keys(user_id)
|
||||
try:
|
||||
return await list_user_api_keys(user_id)
|
||||
except APIKeyError as e:
|
||||
logger.error("Failed to list API keys for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Check API key service availability."},
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/api-keys/{key_id}",
|
||||
summary="Get specific API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_api_key(
|
||||
key_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> api_key_db.APIKeyInfo:
|
||||
) -> APIKeyWithoutHash:
|
||||
"""Get a specific API key"""
|
||||
api_key = await api_key_db.get_api_key_by_id(key_id, user_id)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
return api_key
|
||||
try:
|
||||
api_key = await get_api_key_by_id(key_id, user_id)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
return api_key
|
||||
except APIKeyError as e:
|
||||
logger.error("Error retrieving API key %s for user %s: %s", key_id, user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Ensure the key ID is correct."},
|
||||
)
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
"/api-keys/{key_id}",
|
||||
summary="Revoke API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def delete_api_key(
|
||||
key_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> api_key_db.APIKeyInfo:
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
"""Revoke an API key"""
|
||||
return await api_key_db.revoke_api_key(key_id, user_id)
|
||||
try:
|
||||
return await revoke_api_key(key_id, user_id)
|
||||
except APIKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
except APIKeyPermissionError:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
except APIKeyError as e:
|
||||
logger.error("Failed to revoke API key %s for user %s: %s", key_id, user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Verify permissions or try again later.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/api-keys/{key_id}/suspend",
|
||||
summary="Suspend API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def suspend_key(
|
||||
key_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> api_key_db.APIKeyInfo:
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
"""Suspend an API key"""
|
||||
return await api_key_db.suspend_api_key(key_id, user_id)
|
||||
try:
|
||||
return await suspend_api_key(key_id, user_id)
|
||||
except APIKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
except APIKeyPermissionError:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
except APIKeyError as e:
|
||||
logger.error("Failed to suspend API key %s for user %s: %s", key_id, user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Check user permissions and retry."},
|
||||
)
|
||||
|
||||
|
||||
@v1_router.put(
|
||||
"/api-keys/{key_id}/permissions",
|
||||
summary="Update key permissions",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
@@ -1255,8 +1221,22 @@ async def update_permissions(
|
||||
key_id: str,
|
||||
request: UpdatePermissionsRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> api_key_db.APIKeyInfo:
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
"""Update API key permissions"""
|
||||
return await api_key_db.update_api_key_permissions(
|
||||
key_id, user_id, request.permissions
|
||||
)
|
||||
try:
|
||||
return await update_api_key_permissions(key_id, user_id, request.permissions)
|
||||
except APIKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
except APIKeyPermissionError:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
except APIKeyError as e:
|
||||
logger.error(
|
||||
"Failed to update permissions for API key %s of user %s: %s",
|
||||
key_id,
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Ensure permissions list is valid."},
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
@@ -266,7 +265,6 @@ def test_get_graphs(
|
||||
name="Test Graph",
|
||||
description="A test graph",
|
||||
user_id=test_user_id,
|
||||
created_at=datetime(2025, 9, 4, 13, 37),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
@@ -301,7 +299,6 @@ def test_get_graph(
|
||||
name="Test Graph",
|
||||
description="A test graph",
|
||||
user_id=test_user_id,
|
||||
created_at=datetime(2025, 9, 4, 13, 37),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
@@ -351,7 +348,6 @@ def test_delete_graph(
|
||||
name="Test Graph",
|
||||
description="A test graph",
|
||||
user_id=test_user_id,
|
||||
created_at=datetime(2025, 9, 4, 13, 37),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
|
||||
@@ -3,9 +3,8 @@ API Key authentication utilities for FastAPI applications.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.security import APIKeyHeader
|
||||
@@ -13,8 +12,6 @@ from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIKeyAuthenticator(APIKeyHeader):
|
||||
"""
|
||||
@@ -54,8 +51,7 @@ class APIKeyAuthenticator(APIKeyHeader):
|
||||
header_name (str): The name of the header containing the API key
|
||||
expected_token (Optional[str]): The expected API key value for simple token matching
|
||||
validator (Optional[Callable]): Custom validation function that takes an API key
|
||||
string and returns a truthy value if and only if the passed string is a
|
||||
valid API key. Can be async.
|
||||
string and returns a boolean or object. Can be async.
|
||||
status_if_missing (int): HTTP status code to use for validation errors
|
||||
message_if_invalid (str): Error message to return when validation fails
|
||||
"""
|
||||
@@ -64,9 +60,7 @@ class APIKeyAuthenticator(APIKeyHeader):
|
||||
self,
|
||||
header_name: str,
|
||||
expected_token: Optional[str] = None,
|
||||
validator: Optional[
|
||||
Callable[[str], Any] | Callable[[str], Awaitable[Any]]
|
||||
] = None,
|
||||
validator: Optional[Callable[[str], bool]] = None,
|
||||
status_if_missing: int = HTTP_401_UNAUTHORIZED,
|
||||
message_if_invalid: str = "Invalid API key",
|
||||
):
|
||||
@@ -81,7 +75,7 @@ class APIKeyAuthenticator(APIKeyHeader):
|
||||
self.message_if_invalid = message_if_invalid
|
||||
|
||||
async def __call__(self, request: Request) -> Any:
|
||||
api_key = await super().__call__(request)
|
||||
api_key = await super()(request)
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=self.status_if_missing, detail="No API key in request"
|
||||
@@ -112,9 +106,4 @@ class APIKeyAuthenticator(APIKeyHeader):
|
||||
f"{self.__class__.__name__}.expected_token is not set; "
|
||||
"either specify it or provide a custom validator"
|
||||
)
|
||||
try:
|
||||
return secrets.compare_digest(api_key, self.expected_token)
|
||||
except TypeError as e:
|
||||
# If value is not an ASCII string, compare_digest raises a TypeError
|
||||
logger.warning(f"{self.model.name} API key check failed: {e}")
|
||||
return False
|
||||
return secrets.compare_digest(api_key, self.expected_token)
|
||||
|
||||
@@ -1,537 +0,0 @@
|
||||
"""
|
||||
Unit tests for APIKeyAuthenticator class.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
|
||||
from backend.server.utils.api_key_auth import APIKeyAuthenticator
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request object."""
|
||||
request = Mock(spec=Request)
|
||||
request.state = Mock()
|
||||
request.headers = {}
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_auth():
|
||||
"""Create a basic APIKeyAuthenticator instance."""
|
||||
return APIKeyAuthenticator(
|
||||
header_name="X-API-Key", expected_token="test-secret-token"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_auth_custom_validator():
|
||||
"""Create APIKeyAuthenticator with custom validator."""
|
||||
|
||||
def custom_validator(api_key: str) -> bool:
|
||||
return api_key == "custom-valid-key"
|
||||
|
||||
return APIKeyAuthenticator(header_name="X-API-Key", validator=custom_validator)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_auth_async_validator():
|
||||
"""Create APIKeyAuthenticator with async custom validator."""
|
||||
|
||||
async def async_validator(api_key: str) -> bool:
|
||||
return api_key == "async-valid-key"
|
||||
|
||||
return APIKeyAuthenticator(header_name="X-API-Key", validator=async_validator)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_auth_object_validator():
|
||||
"""Create APIKeyAuthenticator that returns objects from validator."""
|
||||
|
||||
async def object_validator(api_key: str):
|
||||
if api_key == "user-key":
|
||||
return {"user_id": "123", "permissions": ["read", "write"]}
|
||||
return None
|
||||
|
||||
return APIKeyAuthenticator(header_name="X-API-Key", validator=object_validator)
|
||||
|
||||
|
||||
# ========== Basic Initialization Tests ========== #
|
||||
|
||||
|
||||
def test_init_with_expected_token():
|
||||
"""Test initialization with expected token."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="test-token")
|
||||
|
||||
assert auth.model.name == "X-API-Key"
|
||||
assert auth.expected_token == "test-token"
|
||||
assert auth.custom_validator is None
|
||||
assert auth.status_if_missing == HTTP_401_UNAUTHORIZED
|
||||
assert auth.message_if_invalid == "Invalid API key"
|
||||
|
||||
|
||||
def test_init_with_custom_validator():
|
||||
"""Test initialization with custom validator."""
|
||||
|
||||
def validator(key: str) -> bool:
|
||||
return True
|
||||
|
||||
auth = APIKeyAuthenticator(header_name="Authorization", validator=validator)
|
||||
|
||||
assert auth.model.name == "Authorization"
|
||||
assert auth.expected_token is None
|
||||
assert auth.custom_validator == validator
|
||||
assert auth.status_if_missing == HTTP_401_UNAUTHORIZED
|
||||
assert auth.message_if_invalid == "Invalid API key"
|
||||
|
||||
|
||||
def test_init_with_custom_parameters():
|
||||
"""Test initialization with custom status and message."""
|
||||
auth = APIKeyAuthenticator(
|
||||
header_name="X-Custom-Key",
|
||||
expected_token="token",
|
||||
status_if_missing=HTTP_403_FORBIDDEN,
|
||||
message_if_invalid="Access denied",
|
||||
)
|
||||
|
||||
assert auth.model.name == "X-Custom-Key"
|
||||
assert auth.status_if_missing == HTTP_403_FORBIDDEN
|
||||
assert auth.message_if_invalid == "Access denied"
|
||||
|
||||
|
||||
def test_scheme_name_generation():
|
||||
"""Test that scheme_name is generated correctly."""
|
||||
auth = APIKeyAuthenticator(header_name="X-Custom-Header", expected_token="token")
|
||||
|
||||
assert auth.scheme_name == "APIKeyAuthenticator-X-Custom-Header"
|
||||
|
||||
|
||||
# ========== Authentication Flow Tests ========== #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_missing(api_key_auth, mock_request):
|
||||
"""Test behavior when API key is missing from request."""
|
||||
# Mock the parent class method to return None (no API key)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "No API key in request"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_valid(api_key_auth, mock_request):
|
||||
"""Test behavior with valid API key."""
|
||||
# Mock the parent class to return the API key
|
||||
with patch.object(
|
||||
api_key_auth.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="test-secret-token",
|
||||
):
|
||||
result = await api_key_auth(mock_request)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_invalid(api_key_auth, mock_request):
|
||||
"""Test behavior with invalid API key."""
|
||||
# Mock the parent class to return an invalid API key
|
||||
with patch.object(
|
||||
api_key_auth.__class__.__bases__[0], "__call__", return_value="invalid-token"
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
# ========== Custom Validator Tests ========== #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_status_and_message(mock_request):
|
||||
"""Test custom status code and message."""
|
||||
auth = APIKeyAuthenticator(
|
||||
header_name="X-API-Key",
|
||||
expected_token="valid-token",
|
||||
status_if_missing=HTTP_403_FORBIDDEN,
|
||||
message_if_invalid="Access forbidden",
|
||||
)
|
||||
|
||||
# Test missing key
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_403_FORBIDDEN
|
||||
assert exc_info.value.detail == "No API key in request"
|
||||
|
||||
# Test invalid key
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value="invalid-token"
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_403_FORBIDDEN
|
||||
assert exc_info.value.detail == "Access forbidden"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_sync_validator(api_key_auth_custom_validator, mock_request):
|
||||
"""Test with custom synchronous validator."""
|
||||
# Mock the parent class to return the API key
|
||||
with patch.object(
|
||||
api_key_auth_custom_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="custom-valid-key",
|
||||
):
|
||||
result = await api_key_auth_custom_validator(mock_request)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_sync_validator_invalid(
|
||||
api_key_auth_custom_validator, mock_request
|
||||
):
|
||||
"""Test custom synchronous validator with invalid key."""
|
||||
with patch.object(
|
||||
api_key_auth_custom_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="invalid-key",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth_custom_validator(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_async_validator(api_key_auth_async_validator, mock_request):
|
||||
"""Test with custom async validator."""
|
||||
with patch.object(
|
||||
api_key_auth_async_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="async-valid-key",
|
||||
):
|
||||
result = await api_key_auth_async_validator(mock_request)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_async_validator_invalid(
|
||||
api_key_auth_async_validator, mock_request
|
||||
):
|
||||
"""Test custom async validator with invalid key."""
|
||||
with patch.object(
|
||||
api_key_auth_async_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="invalid-key",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth_async_validator(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validator_returns_object(api_key_auth_object_validator, mock_request):
|
||||
"""Test validator that returns an object instead of boolean."""
|
||||
with patch.object(
|
||||
api_key_auth_object_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="user-key",
|
||||
):
|
||||
result = await api_key_auth_object_validator(mock_request)
|
||||
|
||||
expected_result = {"user_id": "123", "permissions": ["read", "write"]}
|
||||
assert result == expected_result
|
||||
# Verify the object is stored in request state
|
||||
assert mock_request.state.api_key == expected_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validator_returns_none(api_key_auth_object_validator, mock_request):
|
||||
"""Test validator that returns None (falsy)."""
|
||||
with patch.object(
|
||||
api_key_auth_object_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="invalid-key",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth_object_validator(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validator_database_lookup_simulation(mock_request):
|
||||
"""Test simulation of database lookup validator."""
|
||||
# Simulate database records
|
||||
valid_api_keys = {
|
||||
"key123": {"user_id": "user1", "active": True},
|
||||
"key456": {"user_id": "user2", "active": False},
|
||||
}
|
||||
|
||||
async def db_validator(api_key: str):
|
||||
record = valid_api_keys.get(api_key)
|
||||
return record if record and record["active"] else None
|
||||
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", validator=db_validator)
|
||||
|
||||
# Test valid active key
|
||||
with patch.object(auth.__class__.__bases__[0], "__call__", return_value="key123"):
|
||||
result = await auth(mock_request)
|
||||
assert result == {"user_id": "user1", "active": True}
|
||||
assert mock_request.state.api_key == {"user_id": "user1", "active": True}
|
||||
|
||||
# Test inactive key
|
||||
mock_request.state = Mock() # Reset state
|
||||
with patch.object(auth.__class__.__bases__[0], "__call__", return_value="key456"):
|
||||
with pytest.raises(HTTPException):
|
||||
await auth(mock_request)
|
||||
|
||||
# Test non-existent key
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value="nonexistent"
|
||||
):
|
||||
with pytest.raises(HTTPException):
|
||||
await auth(mock_request)
|
||||
|
||||
|
||||
# ========== Default Validator Tests ========== #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_validator_key_valid(api_key_auth):
|
||||
"""Test default validator with valid token."""
|
||||
result = await api_key_auth.default_validator("test-secret-token")
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_validator_key_invalid(api_key_auth):
|
||||
"""Test default validator with invalid token."""
|
||||
result = await api_key_auth.default_validator("wrong-token")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_validator_missing_expected_token():
|
||||
"""Test default validator when expected_token is not set."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key")
|
||||
|
||||
with pytest.raises(MissingConfigError) as exc_info:
|
||||
await auth.default_validator("any-token")
|
||||
|
||||
assert "expected_token is not set" in str(exc_info.value)
|
||||
assert "either specify it or provide a custom validator" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_validator_uses_constant_time_comparison(api_key_auth):
|
||||
"""
|
||||
Test that default validator uses secrets.compare_digest for timing attack protection
|
||||
"""
|
||||
with patch("secrets.compare_digest") as mock_compare:
|
||||
mock_compare.return_value = True
|
||||
|
||||
await api_key_auth.default_validator("test-token")
|
||||
|
||||
mock_compare.assert_called_once_with("test-token", "test-secret-token")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_empty(mock_request):
|
||||
"""Test behavior with empty string API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
with patch.object(auth.__class__.__bases__[0], "__call__", return_value=""):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_whitespace_only(mock_request):
|
||||
"""Test behavior with whitespace-only API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=" \t\n "
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_very_long(mock_request):
|
||||
"""Test behavior with extremely long API key (potential DoS protection)."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# Create a very long API key (10MB)
|
||||
long_api_key = "a" * (10 * 1024 * 1024)
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=long_api_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_null_bytes(mock_request):
|
||||
"""Test behavior with API key containing null bytes."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
api_key_with_null = "valid\x00token"
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=api_key_with_null
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_control_characters(mock_request):
|
||||
"""Test behavior with API key containing control characters."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# API key with various control characters
|
||||
api_key_with_control = "valid\r\n\t\x1b[31mtoken"
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=api_key_with_control
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_unicode_characters(mock_request):
|
||||
"""Test behavior with Unicode characters in API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# API key with Unicode characters
|
||||
unicode_api_key = "validтокен🔑"
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=unicode_api_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_unicode_characters_normalization_attack(mock_request):
|
||||
"""Test that Unicode normalization doesn't bypass validation."""
|
||||
# Create auth with composed Unicode character
|
||||
auth = APIKeyAuthenticator(
|
||||
header_name="X-API-Key", expected_token="café" # é is composed
|
||||
)
|
||||
|
||||
# Try with decomposed version (c + a + f + e + ´)
|
||||
decomposed_key = "cafe\u0301" # é as combining character
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=decomposed_key
|
||||
):
|
||||
# Should fail because secrets.compare_digest doesn't normalize
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_binary_data(mock_request):
|
||||
"""Test behavior with binary data in API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# Binary data that might cause encoding issues
|
||||
binary_api_key = bytes([0xFF, 0xFE, 0xFD, 0xFC, 0x80, 0x81]).decode(
|
||||
"latin1", errors="ignore"
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=binary_api_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_regex_dos_attack_pattern(mock_request):
|
||||
"""Test behavior with API key of repeated characters (pattern attack)."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# Pattern that might cause regex DoS in poorly implemented validators
|
||||
repeated_key = "a" * 1000 + "b" * 1000 + "c" * 1000
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=repeated_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_keys_with_newline_variations(mock_request):
|
||||
"""Test different newline characters in API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
newline_variations = [
|
||||
"valid\ntoken", # Unix newline
|
||||
"valid\r\ntoken", # Windows newline
|
||||
"valid\rtoken", # Mac newline
|
||||
"valid\x85token", # NEL (Next Line)
|
||||
"valid\x0Btoken", # Vertical Tab
|
||||
"valid\x0Ctoken", # Form Feed
|
||||
]
|
||||
|
||||
for api_key in newline_variations:
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=api_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
@@ -147,8 +147,10 @@ class AutoModManager:
|
||||
return None
|
||||
|
||||
# Get completed executions and collect outputs
|
||||
completed_executions = await db_client.get_node_executions(
|
||||
graph_exec_id, statuses=[ExecutionStatus.COMPLETED], include_exec_data=True
|
||||
completed_executions = await db_client.get_node_executions( # type: ignore
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
include_exec_data=True,
|
||||
)
|
||||
|
||||
if not completed_executions:
|
||||
@@ -218,7 +220,7 @@ class AutoModManager:
|
||||
):
|
||||
"""Update node execution statuses for frontend display when moderation fails"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.executor.manager import send_async_execution_update
|
||||
from backend.util.clients import get_async_execution_event_bus
|
||||
|
||||
if moderation_type == "input":
|
||||
# For input moderation, mark queued/running/incomplete nodes as failed
|
||||
@@ -232,8 +234,10 @@ class AutoModManager:
|
||||
target_statuses = [ExecutionStatus.COMPLETED]
|
||||
|
||||
# Get the executions that need to be updated
|
||||
executions_to_update = await db_client.get_node_executions(
|
||||
graph_exec_id, statuses=target_statuses, include_exec_data=True
|
||||
executions_to_update = await db_client.get_node_executions( # type: ignore
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=target_statuses,
|
||||
include_exec_data=True,
|
||||
)
|
||||
|
||||
if not executions_to_update:
|
||||
@@ -276,10 +280,12 @@ class AutoModManager:
|
||||
updated_execs = await asyncio.gather(*exec_updates)
|
||||
|
||||
# Send all websocket updates in parallel
|
||||
event_bus = get_async_execution_event_bus()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
send_async_execution_update(updated_exec)
|
||||
event_bus.publish(updated_exec)
|
||||
for updated_exec in updated_execs
|
||||
if updated_exec is not None
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -7,10 +7,12 @@ import prisma
|
||||
import backend.data.block
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.block import Block, BlockCategory, BlockInfo, BlockSchema
|
||||
from backend.data.block import Block, BlockCategory, BlockSchema
|
||||
from backend.data.credit import get_block_costs
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.v2.builder.model import (
|
||||
BlockCategoryResponse,
|
||||
BlockData,
|
||||
BlockResponse,
|
||||
BlockType,
|
||||
CountResponse,
|
||||
@@ -23,7 +25,7 @@ from backend.util.models import Pagination
|
||||
logger = logging.getLogger(__name__)
|
||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||
_static_counts_cache: dict | None = None
|
||||
_suggested_blocks: list[BlockInfo] | None = None
|
||||
_suggested_blocks: list[BlockData] | None = None
|
||||
|
||||
|
||||
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
|
||||
@@ -51,7 +53,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
||||
|
||||
# Append if the category has less than the specified number of blocks
|
||||
if len(categories[category].blocks) < category_blocks:
|
||||
categories[category].blocks.append(block.get_info())
|
||||
categories[category].blocks.append(block.to_dict())
|
||||
|
||||
# Sort categories by name
|
||||
return sorted(categories.values(), key=lambda x: x.name)
|
||||
@@ -107,8 +109,10 @@ def get_blocks(
|
||||
take -= 1
|
||||
blocks.append(block)
|
||||
|
||||
costs = get_block_costs()
|
||||
|
||||
return BlockResponse(
|
||||
blocks=[b.get_info() for b in blocks],
|
||||
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
@@ -170,9 +174,11 @@ def search_blocks(
|
||||
take -= 1
|
||||
blocks.append(block)
|
||||
|
||||
costs = get_block_costs()
|
||||
|
||||
return SearchBlocksResponse(
|
||||
blocks=BlockResponse(
|
||||
blocks=[b.get_info() for b in blocks],
|
||||
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
@@ -317,7 +323,7 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
return providers
|
||||
|
||||
|
||||
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
async def get_suggested_blocks(count: int = 5) -> list[BlockData]:
|
||||
global _suggested_blocks
|
||||
|
||||
if _suggested_blocks is not None and len(_suggested_blocks) >= count:
|
||||
@@ -345,7 +351,7 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
|
||||
# Get the top blocks based on execution count
|
||||
# But ignore Input and Output blocks
|
||||
blocks: list[tuple[BlockInfo, int]] = []
|
||||
blocks: list[tuple[BlockData, int]] = []
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
@@ -360,7 +366,7 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
(row["execution_count"] for row in results if row["block_id"] == block.id),
|
||||
0,
|
||||
)
|
||||
blocks.append((block.get_info(), execution_count))
|
||||
blocks.append((block.to_dict(), execution_count))
|
||||
# Sort blocks by execution count
|
||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.data.block import BlockInfo
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
|
||||
@@ -17,27 +16,29 @@ FilterType = Literal[
|
||||
|
||||
BlockType = Literal["all", "input", "action", "output"]
|
||||
|
||||
BlockData = dict[str, Any]
|
||||
|
||||
|
||||
# Suggestions
|
||||
class SuggestionsResponse(BaseModel):
|
||||
otto_suggestions: list[str]
|
||||
recent_searches: list[str]
|
||||
providers: list[ProviderName]
|
||||
top_blocks: list[BlockInfo]
|
||||
top_blocks: list[BlockData]
|
||||
|
||||
|
||||
# All blocks
|
||||
class BlockCategoryResponse(BaseModel):
|
||||
name: str
|
||||
total_blocks: int
|
||||
blocks: list[BlockInfo]
|
||||
blocks: list[BlockData]
|
||||
|
||||
model_config = {"use_enum_values": False} # <== use enum names like "AI"
|
||||
|
||||
|
||||
# Input/Action/Output and see all for block categories
|
||||
class BlockResponse(BaseModel):
|
||||
blocks: list[BlockInfo]
|
||||
blocks: list[BlockData]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@@ -70,7 +71,7 @@ class SearchBlocksResponse(BaseModel):
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
items: list[BlockInfo | library_model.LibraryAgent | store_model.StoreAgent]
|
||||
items: list[BlockData | library_model.LibraryAgent | store_model.StoreAgent]
|
||||
total_items: dict[FilterType, int]
|
||||
page: int
|
||||
more_pages: bool
|
||||
|
||||
@@ -16,7 +16,7 @@ import backend.server.v2.store.media as store_media
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.db import transaction
|
||||
from backend.data.execution import get_graph_execution
|
||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
@@ -144,92 +144,6 @@ async def list_library_agents(
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agents") from e
|
||||
|
||||
|
||||
async def list_favorite_library_agents(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Retrieves a paginated list of favorite LibraryAgent records for a given user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user whose favorite LibraryAgents we want to retrieve.
|
||||
page: Current page (1-indexed).
|
||||
page_size: Number of items per page.
|
||||
|
||||
Returns:
|
||||
A LibraryAgentResponse containing the list of favorite agents and pagination details.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there is an issue fetching from Prisma.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Fetching favorite library agents for user_id={user_id}, "
|
||||
f"page={page}, page_size={page_size}"
|
||||
)
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||
raise store_exceptions.DatabaseError("Invalid pagination input")
|
||||
|
||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"isFavorite": True, # Only fetch favorites
|
||||
}
|
||||
|
||||
# Sort favorites by updated date descending
|
||||
order_by: prisma.types.LibraryAgentOrderByInput = {"updatedAt": "desc"}
|
||||
|
||||
try:
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(user_id),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
agent_count = await prisma.models.LibraryAgent.prisma().count(
|
||||
where=where_clause
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
|
||||
)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
logger.error(
|
||||
f"Error parsing LibraryAgent #{agent.id} from DB item: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Return the response with only valid agents
|
||||
return library_model.LibraryAgentResponse(
|
||||
agents=valid_library_agents,
|
||||
pagination=Pagination(
|
||||
total_items=agent_count,
|
||||
total_pages=(agent_count + page_size - 1) // page_size,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching favorite library agents: {e}")
|
||||
raise store_exceptions.DatabaseError(
|
||||
"Failed to fetch favorite library agents"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Get a specific agent from the user's library.
|
||||
@@ -703,7 +617,7 @@ async def list_presets(
|
||||
where=query_filter,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
total_items = await prisma.models.AgentPreset.prisma().count(where=query_filter)
|
||||
total_pages = (total_items + page_size - 1) // page_size
|
||||
@@ -748,7 +662,7 @@ async def get_preset(
|
||||
try:
|
||||
preset = await prisma.models.AgentPreset.prisma().find_unique(
|
||||
where={"id": preset_id},
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not preset or preset.userId != user_id or preset.isDeleted:
|
||||
return None
|
||||
@@ -795,12 +709,15 @@ async def create_preset(
|
||||
)
|
||||
for name, data in {
|
||||
**preset.inputs,
|
||||
**preset.credentials,
|
||||
**{
|
||||
key: creds_meta.model_dump(exclude_none=True)
|
||||
for key, creds_meta in preset.credentials.items()
|
||||
},
|
||||
}.items()
|
||||
]
|
||||
},
|
||||
),
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
return library_model.LibraryAgentPreset.from_db(new_preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
@@ -830,25 +747,6 @@ async def create_preset_from_graph_execution(
|
||||
if not graph_execution:
|
||||
raise NotFoundError(f"Graph execution #{graph_exec_id} not found")
|
||||
|
||||
# Sanity check: credential inputs must be available if required for this preset
|
||||
if graph_execution.credential_inputs is None:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=graph_execution.graph_id,
|
||||
version=graph_execution.graph_version,
|
||||
user_id=graph_execution.user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph_execution.graph_id} not found or accessible"
|
||||
)
|
||||
elif len(graph.aggregate_credentials_inputs()) > 0:
|
||||
raise ValueError(
|
||||
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
||||
"because it was run before this feature existed "
|
||||
"and so the input credentials were not saved."
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Creating preset for user #{user_id} from graph execution #{graph_exec_id}",
|
||||
)
|
||||
@@ -856,7 +754,7 @@ async def create_preset_from_graph_execution(
|
||||
user_id=user_id,
|
||||
preset=library_model.LibraryAgentPresetCreatable(
|
||||
inputs=graph_execution.inputs,
|
||||
credentials=graph_execution.credential_inputs or {},
|
||||
credentials={}, # FIXME
|
||||
graph_id=graph_execution.graph_id,
|
||||
graph_version=graph_execution.graph_version,
|
||||
name=create_request.name,
|
||||
@@ -936,7 +834,7 @@ async def update_preset(
|
||||
updated = await prisma.models.AgentPreset.prisma(tx).update(
|
||||
where={"id": preset_id},
|
||||
data=update_data,
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
|
||||
@@ -951,7 +849,7 @@ async def set_preset_webhook(
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
current = await prisma.models.AgentPreset.prisma().find_unique(
|
||||
where={"id": preset_id},
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not current or current.userId != user_id:
|
||||
raise NotFoundError(f"Preset #{preset_id} not found")
|
||||
@@ -963,7 +861,7 @@ async def set_preset_webhook(
|
||||
if webhook_id
|
||||
else {"Webhook": {"disconnect": True}}
|
||||
),
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
@@ -9,11 +9,9 @@ import pydantic
|
||||
import backend.data.block as block_model
|
||||
import backend.data.graph as graph_model
|
||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.integrations import Webhook
|
||||
|
||||
|
||||
class LibraryAgentStatus(str, Enum):
|
||||
COMPLETED = "COMPLETED" # All runs completed
|
||||
@@ -22,6 +20,14 @@ class LibraryAgentStatus(str, Enum):
|
||||
ERROR = "ERROR" # Agent is in an error state
|
||||
|
||||
|
||||
class LibraryAgentTriggerInfo(pydantic.BaseModel):
|
||||
provider: ProviderName
|
||||
config_schema: dict[str, Any] = pydantic.Field(
|
||||
description="Input schema for the trigger block"
|
||||
)
|
||||
credentials_input_name: Optional[str]
|
||||
|
||||
|
||||
class LibraryAgent(pydantic.BaseModel):
|
||||
"""
|
||||
Represents an agent in the library, including metadata for display and
|
||||
@@ -43,7 +49,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
|
||||
name: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
|
||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
output_schema: dict[str, Any]
|
||||
@@ -54,7 +59,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
has_external_trigger: bool = pydantic.Field(
|
||||
description="Whether the agent has an external trigger (e.g. webhook) node"
|
||||
)
|
||||
trigger_setup_info: Optional[graph_model.GraphTriggerInfo] = None
|
||||
trigger_setup_info: Optional[LibraryAgentTriggerInfo] = None
|
||||
|
||||
# Indicates whether there's a new output (based on recent runs)
|
||||
new_output: bool
|
||||
@@ -65,12 +70,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
# Indicates if this agent is the latest version
|
||||
is_latest_version: bool
|
||||
|
||||
# Whether the agent is marked as favorite by the user
|
||||
is_favorite: bool
|
||||
|
||||
# Recommended schedule cron (from marketplace agents)
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
agent: prisma.models.LibraryAgent,
|
||||
@@ -127,19 +126,39 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
updated_at=updated_at,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
instructions=graph.instructions,
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
credentials_input_schema=(
|
||||
graph.credentials_input_schema if sub_graphs is not None else None
|
||||
),
|
||||
has_external_trigger=graph.has_external_trigger,
|
||||
trigger_setup_info=graph.trigger_setup_info,
|
||||
trigger_setup_info=(
|
||||
LibraryAgentTriggerInfo(
|
||||
provider=trigger_block.webhook_config.provider,
|
||||
config_schema={
|
||||
**(json_schema := trigger_block.input_schema.jsonschema()),
|
||||
"properties": {
|
||||
pn: sub_schema
|
||||
for pn, sub_schema in json_schema["properties"].items()
|
||||
if not is_credentials_field_name(pn)
|
||||
},
|
||||
"required": [
|
||||
pn
|
||||
for pn in json_schema.get("required", [])
|
||||
if not is_credentials_field_name(pn)
|
||||
],
|
||||
},
|
||||
credentials_input_name=next(
|
||||
iter(trigger_block.input_schema.get_credentials_fields()), None
|
||||
),
|
||||
)
|
||||
if graph.webhook_input_node
|
||||
and (trigger_block := graph.webhook_input_node.block).webhook_config
|
||||
else None
|
||||
),
|
||||
new_output=new_output,
|
||||
can_access_graph=can_access_graph,
|
||||
is_latest_version=is_latest_version,
|
||||
is_favorite=agent.isFavorite,
|
||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||
)
|
||||
|
||||
|
||||
@@ -263,21 +282,12 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
webhook: "Webhook | None"
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, preset: prisma.models.AgentPreset) -> "LibraryAgentPreset":
|
||||
from backend.data.integrations import Webhook
|
||||
|
||||
if preset.InputPresets is None:
|
||||
raise ValueError("InputPresets must be included in AgentPreset query")
|
||||
if preset.webhookId and preset.Webhook is None:
|
||||
raise ValueError(
|
||||
"Webhook must be included in AgentPreset query when webhookId is set"
|
||||
)
|
||||
|
||||
input_data: block_model.BlockInput = {}
|
||||
input_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
@@ -293,7 +303,6 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
return cls(
|
||||
id=preset.id,
|
||||
user_id=preset.userId,
|
||||
created_at=preset.createdAt,
|
||||
updated_at=preset.updatedAt,
|
||||
graph_id=preset.agentGraphId,
|
||||
graph_version=preset.agentGraphVersion,
|
||||
@@ -303,7 +312,6 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
inputs=input_data,
|
||||
credentials=input_credentials,
|
||||
webhook_id=preset.webhookId,
|
||||
webhook=Webhook.from_db(preset.Webhook) if preset.Webhook else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -79,54 +79,6 @@ async def list_library_agents(
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/favorites",
|
||||
summary="List Favorite Library Agents",
|
||||
responses={
|
||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||
},
|
||||
)
|
||||
async def list_favorite_library_agents(
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
page: int = Query(
|
||||
1,
|
||||
ge=1,
|
||||
description="Page number to retrieve (must be >= 1)",
|
||||
),
|
||||
page_size: int = Query(
|
||||
15,
|
||||
ge=1,
|
||||
description="Number of agents per page (must be >= 1)",
|
||||
),
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Get all favorite agents in the user's library.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user.
|
||||
page: Page number to retrieve.
|
||||
page_size: Number of agents per page.
|
||||
|
||||
Returns:
|
||||
A LibraryAgentResponse containing favorite agents and pagination metadata.
|
||||
|
||||
Raises:
|
||||
HTTPException: If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.list_favorite_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not list favorite library agents for user #{user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
||||
async def get_library_agent(
|
||||
library_agent_id: str,
|
||||
|
||||
@@ -6,10 +6,8 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.integrations import get_webhook
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -371,41 +369,48 @@ async def execute_preset(
|
||||
preset_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
inputs: dict[str, Any] = Body(..., embed=True, default_factory=dict),
|
||||
credential_inputs: dict[str, CredentialsMetaInput] = Body(
|
||||
..., embed=True, default_factory=dict
|
||||
),
|
||||
) -> GraphExecutionMeta:
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
"""
|
||||
Execute a preset given graph parameters, returning the execution ID on success.
|
||||
|
||||
Args:
|
||||
preset_id: ID of the preset to execute.
|
||||
user_id: ID of the authenticated user.
|
||||
inputs: Optionally, inputs to override the preset for execution.
|
||||
credential_inputs: Optionally, credentials to override the preset for execution.
|
||||
preset_id (str): ID of the preset to execute.
|
||||
user_id (str): ID of the authenticated user.
|
||||
inputs (dict[str, Any]): Optionally, additional input data for the graph execution.
|
||||
|
||||
Returns:
|
||||
GraphExecutionMeta: Object representing the created execution.
|
||||
{id: graph_exec_id}: A response containing the execution ID.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the preset is not found or an error occurs while executing the preset.
|
||||
"""
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset #{preset_id} not found",
|
||||
try:
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset #{preset_id} not found",
|
||||
)
|
||||
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | inputs
|
||||
|
||||
execution = await add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=preset.graph_id,
|
||||
graph_version=preset.graph_version,
|
||||
preset_id=preset_id,
|
||||
inputs=merged_node_input,
|
||||
)
|
||||
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | inputs
|
||||
merged_credential_inputs = preset.credentials | credential_inputs
|
||||
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
|
||||
|
||||
return await add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=preset.graph_id,
|
||||
graph_version=preset.graph_version,
|
||||
preset_id=preset_id,
|
||||
inputs=merged_node_input,
|
||||
graph_credentials_inputs=merged_credential_inputs,
|
||||
)
|
||||
return {"id": execution.id}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Preset execution failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
@@ -50,11 +50,9 @@ async def test_get_library_agents_success(
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
recommended_schedule_cron=None,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
is_favorite=False,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
library_model.LibraryAgent(
|
||||
@@ -71,11 +69,9 @@ async def test_get_library_agents_success(
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
recommended_schedule_cron=None,
|
||||
new_output=False,
|
||||
can_access_graph=False,
|
||||
is_latest_version=True,
|
||||
is_favorite=False,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
],
|
||||
@@ -123,76 +119,6 @@ def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_favorite_library_agents_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mocked_value = library_model.LibraryAgentResponse(
|
||||
agents=[
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
recommended_schedule_cron=None,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
is_favorite=True,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
],
|
||||
pagination=Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=15
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/agents/favorites")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = library_model.LibraryAgentResponse.model_validate(response.json())
|
||||
assert len(data.agents) == 1
|
||||
assert data.agents[0].is_favorite is True
|
||||
assert data.agents[0].name == "Favorite Agent 1"
|
||||
|
||||
mock_db_call.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
page=1,
|
||||
page_size=15,
|
||||
)
|
||||
|
||||
|
||||
def test_get_favorite_library_agents_error(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents/favorites")
|
||||
assert response.status_code == 500
|
||||
mock_db_call.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
page=1,
|
||||
page_size=15,
|
||||
)
|
||||
|
||||
|
||||
def test_add_agent_to_library_success(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
@@ -213,7 +139,6 @@ def test_add_agent_to_library_success(
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
is_favorite=False,
|
||||
updated_at=FIXED_NOW,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -10,7 +9,6 @@ import prisma.types
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
from backend.data.db import transaction
|
||||
from backend.data.graph import (
|
||||
GraphMeta,
|
||||
GraphModel,
|
||||
@@ -72,7 +70,7 @@ async def get_store_agents(
|
||||
)
|
||||
sanitized_query = sanitize_query(search_query)
|
||||
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
where_clause = {}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
@@ -96,13 +94,15 @@ async def get_store_agents(
|
||||
|
||||
try:
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total = await prisma.models.StoreAgent.prisma().count(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause)
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
@@ -183,36 +183,6 @@ async def get_store_agent_details(
|
||||
store_listing.hasApprovedVersion if store_listing else False
|
||||
)
|
||||
|
||||
if active_version_id:
|
||||
agent_by_active = await prisma.models.StoreAgent.prisma().find_first(
|
||||
where={"storeListingVersionId": active_version_id}
|
||||
)
|
||||
if agent_by_active:
|
||||
agent = agent_by_active
|
||||
elif store_listing:
|
||||
latest_approved = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"storeListingId": store_listing.id,
|
||||
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
|
||||
},
|
||||
order=[{"version": "desc"}],
|
||||
)
|
||||
)
|
||||
if latest_approved:
|
||||
agent_latest = await prisma.models.StoreAgent.prisma().find_first(
|
||||
where={"storeListingVersionId": latest_approved.id}
|
||||
)
|
||||
if agent_latest:
|
||||
agent = agent_latest
|
||||
|
||||
if store_listing and store_listing.ActiveVersion:
|
||||
recommended_schedule_cron = (
|
||||
store_listing.ActiveVersion.recommendedScheduleCron
|
||||
)
|
||||
else:
|
||||
recommended_schedule_cron = None
|
||||
|
||||
logger.debug(f"Found agent details for {username}/{agent_name}")
|
||||
return backend.server.v2.store.model.StoreAgentDetails(
|
||||
store_listing_version_id=agent.storeListingVersionId,
|
||||
@@ -220,8 +190,8 @@ async def get_store_agent_details(
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username or "",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
creator=agent.creator_username,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
@@ -231,7 +201,6 @@ async def get_store_agent_details(
|
||||
last_updated=agent.updated_at,
|
||||
active_version_id=active_version_id,
|
||||
has_approved_version=has_approved_version,
|
||||
recommended_schedule_cron=recommended_schedule_cron,
|
||||
)
|
||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||
raise
|
||||
@@ -294,8 +263,8 @@ async def get_store_agent_by_version_id(
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username or "",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
creator=agent.creator_username,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
@@ -499,7 +468,6 @@ async def get_store_submissions(
|
||||
sub_heading=sub.sub_heading,
|
||||
slug=sub.slug,
|
||||
description=sub.description,
|
||||
instructions=getattr(sub, "instructions", None),
|
||||
image_urls=sub.image_urls or [],
|
||||
date_submitted=sub.date_submitted or datetime.now(tz=timezone.utc),
|
||||
status=sub.status,
|
||||
@@ -591,11 +559,9 @@ async def create_store_submission(
|
||||
video_url: str | None = None,
|
||||
image_urls: list[str] = [],
|
||||
description: str = "",
|
||||
instructions: str | None = None,
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
changes_summary: str | None = "Initial Submission",
|
||||
recommended_schedule_cron: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create the first (and only) store listing and thus submission as a normal user
|
||||
@@ -663,7 +629,6 @@ async def create_store_submission(
|
||||
video_url=video_url,
|
||||
image_urls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
sub_heading=sub_heading,
|
||||
categories=categories,
|
||||
changes_summary=changes_summary,
|
||||
@@ -685,13 +650,11 @@ async def create_store_submission(
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(tz=timezone.utc),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
)
|
||||
]
|
||||
},
|
||||
@@ -716,7 +679,6 @@ async def create_store_submission(
|
||||
slug=slug,
|
||||
sub_heading=sub_heading,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=listing.createdAt,
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
@@ -748,8 +710,6 @@ async def edit_store_submission(
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
changes_summary: str | None = "Update submission",
|
||||
recommended_schedule_cron: str | None = None,
|
||||
instructions: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Edit an existing store listing submission.
|
||||
@@ -829,8 +789,6 @@ async def edit_store_submission(
|
||||
sub_heading=sub_heading,
|
||||
categories=categories,
|
||||
changes_summary=changes_summary,
|
||||
recommended_schedule_cron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
# For PENDING submissions, we can update the existing version
|
||||
@@ -846,8 +804,6 @@ async def edit_store_submission(
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -866,7 +822,6 @@ async def edit_store_submission(
|
||||
sub_heading=sub_heading,
|
||||
slug=current_version.StoreListing.slug,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
||||
status=updated_version.submissionStatus,
|
||||
@@ -908,11 +863,9 @@ async def create_store_version(
|
||||
video_url: str | None = None,
|
||||
image_urls: list[str] = [],
|
||||
description: str = "",
|
||||
instructions: str | None = None,
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
changes_summary: str | None = "Initial submission",
|
||||
recommended_schedule_cron: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create a new version for an existing store listing
|
||||
@@ -977,13 +930,11 @@ async def create_store_version(
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
storeListingId=store_listing_id,
|
||||
)
|
||||
)
|
||||
@@ -999,7 +950,6 @@ async def create_store_version(
|
||||
slug=listing.slug,
|
||||
sub_heading=sub_heading,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=datetime.now(),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
@@ -1176,20 +1126,7 @@ async def get_my_agents(
|
||||
try:
|
||||
search_filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"AgentGraph": {
|
||||
"is": {
|
||||
"StoreListings": {
|
||||
"none": {
|
||||
"isDeleted": False,
|
||||
"Versions": {
|
||||
"some": {
|
||||
"isAvailable": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}},
|
||||
"isArchived": False,
|
||||
"isDeleted": False,
|
||||
}
|
||||
@@ -1213,7 +1150,6 @@ async def get_my_agents(
|
||||
last_edited=graph.updatedAt or graph.createdAt,
|
||||
description=graph.description or "",
|
||||
agent_image=library_agent.imageUrl,
|
||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
||||
)
|
||||
for library_agent in library_agents
|
||||
if (graph := library_agent.AgentGraph)
|
||||
@@ -1264,103 +1200,40 @@ async def get_agent(store_listing_version_id: str) -> GraphModel:
|
||||
#####################################################
|
||||
|
||||
|
||||
async def _approve_sub_agent(
|
||||
tx,
|
||||
sub_graph: prisma.models.AgentGraph,
|
||||
main_agent_name: str,
|
||||
main_agent_version: int,
|
||||
main_agent_user_id: str,
|
||||
) -> None:
|
||||
"""Approve a single sub-agent by creating/updating store listings as needed"""
|
||||
heading = f"Sub-agent of {main_agent_name} v{main_agent_version}"
|
||||
async def _get_missing_sub_store_listing(
|
||||
graph: prisma.models.AgentGraph,
|
||||
) -> list[prisma.models.AgentGraph]:
|
||||
"""
|
||||
Agent graph can have sub-graphs, and those sub-graphs also need to be store listed.
|
||||
This method fetches the sub-graphs, and returns the ones not listed in the store.
|
||||
"""
|
||||
sub_graphs = await get_sub_graphs(graph)
|
||||
if not sub_graphs:
|
||||
return []
|
||||
|
||||
# Find existing listing for this sub-agent
|
||||
listing = await prisma.models.StoreListing.prisma(tx).find_first(
|
||||
where={"agentGraphId": sub_graph.id, "isDeleted": False},
|
||||
include={"Versions": True},
|
||||
)
|
||||
|
||||
# Early return: Create new listing if none exists
|
||||
if not listing:
|
||||
await prisma.models.StoreListing.prisma(tx).create(
|
||||
data=prisma.types.StoreListingCreateInput(
|
||||
slug=f"sub-agent-{sub_graph.id[:8]}",
|
||||
agentGraphId=sub_graph.id,
|
||||
agentGraphVersion=sub_graph.version,
|
||||
owningUserId=main_agent_user_id,
|
||||
hasApprovedVersion=True,
|
||||
Versions={
|
||||
"create": [
|
||||
_create_sub_agent_version_data(
|
||||
sub_graph, heading, main_agent_name
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Find version matching this sub-graph
|
||||
matching_version = next(
|
||||
(
|
||||
v
|
||||
for v in listing.Versions or []
|
||||
if v.agentGraphId == sub_graph.id
|
||||
and v.agentGraphVersion == sub_graph.version
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# Early return: Approve existing version if found and not already approved
|
||||
if matching_version:
|
||||
if matching_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
||||
return # Already approved, nothing to do
|
||||
|
||||
await prisma.models.StoreListingVersion.prisma(tx).update(
|
||||
where={"id": matching_version.id},
|
||||
data={
|
||||
# Fetch all the sub-graphs that are listed, and return the ones missing.
|
||||
store_listed_sub_graphs = {
|
||||
(listing.agentGraphId, listing.agentGraphVersion)
|
||||
for listing in await prisma.models.StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{
|
||||
"agentGraphId": sub_graph.id,
|
||||
"agentGraphVersion": sub_graph.version,
|
||||
}
|
||||
for sub_graph in sub_graphs
|
||||
],
|
||||
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
|
||||
"reviewedAt": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": listing.id}, data={"hasApprovedVersion": True}
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
# Create new version if no matching version found
|
||||
next_version = max((v.version for v in listing.Versions or []), default=0) + 1
|
||||
await prisma.models.StoreListingVersion.prisma(tx).create(
|
||||
data={
|
||||
**_create_sub_agent_version_data(sub_graph, heading, main_agent_name),
|
||||
"version": next_version,
|
||||
"storeListingId": listing.id,
|
||||
}
|
||||
)
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": listing.id}, data={"hasApprovedVersion": True}
|
||||
)
|
||||
|
||||
|
||||
def _create_sub_agent_version_data(
|
||||
sub_graph: prisma.models.AgentGraph, heading: str, main_agent_name: str
|
||||
) -> prisma.types.StoreListingVersionCreateInput:
|
||||
"""Create store listing version data for a sub-agent"""
|
||||
return prisma.types.StoreListingVersionCreateInput(
|
||||
agentGraphId=sub_graph.id,
|
||||
agentGraphVersion=sub_graph.version,
|
||||
name=sub_graph.name or heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
|
||||
subHeading=heading,
|
||||
description=(
|
||||
f"{heading}: {sub_graph.description}" if sub_graph.description else heading
|
||||
),
|
||||
changesSummary=f"Auto-approved as sub-agent of {main_agent_name}",
|
||||
isAvailable=False,
|
||||
submittedAt=datetime.now(tz=timezone.utc),
|
||||
imageUrls=[], # Sub-agents don't need images
|
||||
categories=[], # Sub-agents don't need categories
|
||||
)
|
||||
return [
|
||||
sub_graph
|
||||
for sub_graph in sub_graphs
|
||||
if (sub_graph.id, sub_graph.version) not in store_listed_sub_graphs
|
||||
]
|
||||
|
||||
|
||||
async def review_store_submission(
|
||||
@@ -1398,46 +1271,33 @@ async def review_store_submission(
|
||||
|
||||
# If approving, update the listing to indicate it has an approved version
|
||||
if is_approved and store_listing_version.AgentGraph:
|
||||
async with transaction() as tx:
|
||||
# Handle sub-agent approvals in transaction
|
||||
await asyncio.gather(
|
||||
*[
|
||||
_approve_sub_agent(
|
||||
tx,
|
||||
sub_graph,
|
||||
store_listing_version.name,
|
||||
store_listing_version.agentGraphVersion,
|
||||
store_listing_version.StoreListing.owningUserId,
|
||||
)
|
||||
for sub_graph in await get_sub_graphs(
|
||||
store_listing_version.AgentGraph
|
||||
)
|
||||
]
|
||||
)
|
||||
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentGraphVersion}"
|
||||
|
||||
# Update the AgentGraph with store listing data
|
||||
await prisma.models.AgentGraph.prisma().update(
|
||||
where={
|
||||
"graphVersionId": {
|
||||
"id": store_listing_version.agentGraphId,
|
||||
"version": store_listing_version.agentGraphVersion,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"name": store_listing_version.name,
|
||||
"description": store_listing_version.description,
|
||||
"recommendedScheduleCron": store_listing_version.recommendedScheduleCron,
|
||||
"instructions": store_listing_version.instructions,
|
||||
},
|
||||
sub_store_listing_versions = [
|
||||
prisma.types.StoreListingVersionCreateWithoutRelationsInput(
|
||||
agentGraphId=sub_graph.id,
|
||||
agentGraphVersion=sub_graph.version,
|
||||
name=sub_graph.name or heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
|
||||
subHeading=heading,
|
||||
description=f"{heading}: {sub_graph.description}",
|
||||
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentGraphId}.",
|
||||
isAvailable=False, # Hide sub-graphs from the store by default.
|
||||
submittedAt=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
for sub_graph in await _get_missing_sub_store_listing(
|
||||
store_listing_version.AgentGraph
|
||||
)
|
||||
]
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
data={
|
||||
"hasApprovedVersion": True,
|
||||
"ActiveVersion": {"connect": {"id": store_listing_version_id}},
|
||||
},
|
||||
)
|
||||
await prisma.models.StoreListing.prisma().update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
data={
|
||||
"hasApprovedVersion": True,
|
||||
"ActiveVersion": {"connect": {"id": store_listing_version_id}},
|
||||
"Versions": {"create": sub_store_listing_versions},
|
||||
},
|
||||
)
|
||||
|
||||
# If rejecting an approved agent, update the StoreListing accordingly
|
||||
if is_rejecting_approved:
|
||||
@@ -1593,7 +1453,6 @@ async def review_store_submission(
|
||||
else ""
|
||||
),
|
||||
description=submission.description,
|
||||
instructions=submission.instructions,
|
||||
image_urls=submission.imageUrls or [],
|
||||
date_submitted=submission.submittedAt or submission.createdAt,
|
||||
status=submission.submissionStatus,
|
||||
@@ -1729,7 +1588,6 @@ async def get_admin_listings_with_versions(
|
||||
sub_heading=version.subHeading,
|
||||
slug=listing.slug,
|
||||
description=version.description,
|
||||
instructions=version.instructions,
|
||||
image_urls=version.imageUrls or [],
|
||||
date_submitted=version.submittedAt or version.createdAt,
|
||||
status=version.submissionStatus,
|
||||
|
||||
@@ -41,7 +41,6 @@ async def test_get_store_agents(mocker):
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -83,53 +82,16 @@ async def test_get_store_agent_details(mocker):
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
)
|
||||
|
||||
# Mock active version agent (what we want to return for active version)
|
||||
mock_active_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="active-version-id",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent Active",
|
||||
agent_video="active_video.mp4",
|
||||
agent_image=["active_image.jpg"],
|
||||
featured=False,
|
||||
creator_username="creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test heading active",
|
||||
description="Test description active",
|
||||
categories=["test"],
|
||||
runs=15,
|
||||
rating=4.8,
|
||||
versions=["1.0", "2.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
)
|
||||
|
||||
# Create a mock StoreListing result
|
||||
mock_store_listing = mocker.MagicMock()
|
||||
mock_store_listing.activeVersionId = "active-version-id"
|
||||
mock_store_listing.hasApprovedVersion = True
|
||||
mock_store_listing.ActiveVersion = mocker.MagicMock()
|
||||
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
|
||||
|
||||
# Mock StoreAgent prisma call - need to handle multiple calls
|
||||
# Mock StoreAgent prisma call
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
|
||||
# Set up side_effect to return different results for different calls
|
||||
def mock_find_first_side_effect(*args, **kwargs):
|
||||
where_clause = kwargs.get("where", {})
|
||||
if "storeListingVersionId" in where_clause:
|
||||
# Second call for active version
|
||||
return mock_active_agent
|
||||
else:
|
||||
# First call for initial lookup
|
||||
return mock_agent
|
||||
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(
|
||||
side_effect=mock_find_first_side_effect
|
||||
)
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Mock Profile prisma call
|
||||
mock_profile = mocker.MagicMock()
|
||||
@@ -139,7 +101,7 @@ async def test_get_store_agent_details(mocker):
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Mock StoreListing prisma call
|
||||
# Mock StoreListing prisma call - this is what was missing
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_store_listing
|
||||
@@ -148,25 +110,16 @@ async def test_get_store_agent_details(mocker):
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
|
||||
# Verify results - should use active version data
|
||||
# Verify results
|
||||
assert result.slug == "test-agent"
|
||||
assert result.agent_name == "Test Agent Active" # From active version
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.active_version_id == "active-version-id"
|
||||
assert result.has_approved_version is True
|
||||
assert (
|
||||
result.store_listing_version_id == "active-version-id"
|
||||
) # Should be active version ID
|
||||
|
||||
# Verify mocks called correctly - now expecting 2 calls
|
||||
assert mock_store_agent.return_value.find_first.call_count == 2
|
||||
|
||||
# Check the specific calls
|
||||
calls = mock_store_agent.return_value.find_first.call_args_list
|
||||
assert calls[0] == mocker.call(
|
||||
# Verify mocks called correctly
|
||||
mock_store_agent.return_value.find_first.assert_called_once_with(
|
||||
where={"creator_username": "creator", "slug": "test-agent"}
|
||||
)
|
||||
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
|
||||
|
||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ class MyAgent(pydantic.BaseModel):
|
||||
agent_image: str | None = None
|
||||
description: str
|
||||
last_edited: datetime.datetime
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class MyAgentsResponse(pydantic.BaseModel):
|
||||
@@ -49,13 +48,11 @@ class StoreAgentDetails(pydantic.BaseModel):
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
categories: list[str]
|
||||
runs: int
|
||||
rating: float
|
||||
versions: list[str]
|
||||
last_updated: datetime.datetime
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
@@ -104,7 +101,6 @@ class StoreSubmission(pydantic.BaseModel):
|
||||
sub_heading: str
|
||||
slug: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
image_urls: list[str]
|
||||
date_submitted: datetime.datetime
|
||||
status: prisma.enums.SubmissionStatus
|
||||
@@ -159,10 +155,8 @@ class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
video_url: str | None = None
|
||||
image_urls: list[str] = []
|
||||
description: str = ""
|
||||
instructions: str | None = None
|
||||
categories: list[str] = []
|
||||
changes_summary: str | None = None
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class StoreSubmissionEditRequest(pydantic.BaseModel):
|
||||
@@ -171,10 +165,8 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
|
||||
video_url: str | None = None
|
||||
image_urls: list[str] = []
|
||||
description: str = ""
|
||||
instructions: str | None = None
|
||||
categories: list[str] = []
|
||||
changes_summary: str | None = None
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class ProfileDetails(pydantic.BaseModel):
|
||||
|
||||
@@ -532,11 +532,9 @@ async def create_submission(
|
||||
video_url=submission_request.video_url,
|
||||
image_urls=submission_request.image_urls,
|
||||
description=submission_request.description,
|
||||
instructions=submission_request.instructions,
|
||||
sub_heading=submission_request.sub_heading,
|
||||
categories=submission_request.categories,
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store submission")
|
||||
@@ -579,11 +577,9 @@ async def edit_submission(
|
||||
video_url=submission_request.video_url,
|
||||
image_urls=submission_request.image_urls,
|
||||
description=submission_request.description,
|
||||
instructions=submission_request.instructions,
|
||||
sub_heading=submission_request.sub_heading,
|
||||
categories=submission_request.categories,
|
||||
changes_summary=submission_request.changes_summary,
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -11,10 +11,6 @@ from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.monitoring.instrumentation import (
|
||||
instrument_fastapi,
|
||||
update_websocket_connections,
|
||||
)
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import (
|
||||
WSMessage,
|
||||
@@ -42,15 +38,6 @@ docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||
app = FastAPI(lifespan=lifespan, docs_url=docs_url)
|
||||
_connection_manager = None
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
app,
|
||||
service_name="websocket-server",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=settings.config.app_env == AppEnvironment.LOCAL,
|
||||
)
|
||||
|
||||
|
||||
def get_connection_manager():
|
||||
global _connection_manager
|
||||
@@ -229,10 +216,6 @@ async def websocket_router(
|
||||
if not user_id:
|
||||
return
|
||||
await manager.connect_socket(websocket)
|
||||
|
||||
# Track WebSocket connection
|
||||
update_websocket_connections(user_id, 1)
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
@@ -303,8 +286,6 @@ async def websocket_router(
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect_socket(websocket)
|
||||
logger.debug("WebSocket client disconnected")
|
||||
finally:
|
||||
update_websocket_connections(user_id, -1)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
||||
@@ -251,14 +251,14 @@ async def block_autogen_agent():
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"input": "Write me a block that writes a string into a file."}
|
||||
graph_exec = await server.agent_server.test_execute_graph(
|
||||
response = await server.agent_server.test_execute_graph(
|
||||
graph_id=test_graph.id,
|
||||
user_id=test_user.id,
|
||||
node_input=input_data,
|
||||
)
|
||||
print(graph_exec)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
graph_exec_id=graph_exec.id,
|
||||
graph_exec_id=response.graph_exec_id,
|
||||
timeout=1200,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
@@ -155,13 +155,13 @@ async def reddit_marketing_agent():
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
graph_exec = await server.agent_server.test_execute_graph(
|
||||
response = await server.agent_server.test_execute_graph(
|
||||
graph_id=test_graph.id,
|
||||
user_id=test_user.id,
|
||||
node_input=input_data,
|
||||
)
|
||||
print(graph_exec)
|
||||
result = await wait_execution(test_user.id, graph_exec.id, 120)
|
||||
print(response)
|
||||
result = await wait_execution(test_user.id, response.graph_exec_id, 120)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
@@ -88,12 +88,12 @@ async def sample_agent():
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), test_user.id)
|
||||
input_data = {"input_1": "Hello", "input_2": "World"}
|
||||
graph_exec = await server.agent_server.test_execute_graph(
|
||||
response = await server.agent_server.test_execute_graph(
|
||||
graph_id=test_graph.id,
|
||||
user_id=test_user.id,
|
||||
node_input=input_data,
|
||||
)
|
||||
await wait_execution(test_user.id, graph_exec.id, 10)
|
||||
await wait_execution(test_user.id, response.graph_exec_id, 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
class MissingConfigError(Exception):
|
||||
"""The attempted operation requires configuration which is not available"""
|
||||
|
||||
@@ -72,7 +69,7 @@ class GraphValidationError(ValueError):
|
||||
"""Structured validation error for graph validation failures"""
|
||||
|
||||
def __init__(
|
||||
self, message: str, node_errors: Mapping[str, Mapping[str, str]] | None = None
|
||||
self, message: str, node_errors: dict[str, dict[str, str]] | None = None
|
||||
):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
@@ -33,7 +33,7 @@ def sentry_init():
|
||||
)
|
||||
|
||||
|
||||
def sentry_capture_error(error: BaseException):
|
||||
def sentry_capture_error(error: Exception):
|
||||
sentry_sdk.capture_exception(error)
|
||||
sentry_sdk.flush()
|
||||
|
||||
|
||||
@@ -76,14 +76,6 @@ class AppProcess(ABC):
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Termination request: {type(e).__name__}; {e} executing cleanup."
|
||||
)
|
||||
# Send error to Sentry before cleanup
|
||||
if not isinstance(e, (KeyboardInterrupt, SystemExit)):
|
||||
try:
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
|
||||
sentry_capture_error(e)
|
||||
except Exception:
|
||||
pass # Silently ignore if Sentry isn't available
|
||||
finally:
|
||||
self.cleanup()
|
||||
logger.info(f"[{self.service_name}] Terminated.")
|
||||
|
||||
@@ -479,9 +479,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
)
|
||||
|
||||
openai_api_key: str = Field(default="", description="OpenAI API key")
|
||||
openai_internal_api_key: str = Field(
|
||||
default="", description="OpenAI Internal API key"
|
||||
)
|
||||
aiml_api_key: str = Field(default="", description="'AI/ML API' key")
|
||||
anthropic_api_key: str = Field(default="", description="Anthropic API key")
|
||||
groq_api_key: str = Field(default="", description="Groq API key")
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
-- These changes are part of improvements to our API key system.
|
||||
-- See https://github.com/Significant-Gravitas/AutoGPT/pull/10796 for context.
|
||||
|
||||
-- Add 'salt' column for Scrypt hashing
|
||||
ALTER TABLE "APIKey" ADD COLUMN "salt" TEXT;
|
||||
|
||||
-- Rename columns for clarity
|
||||
ALTER TABLE "APIKey" RENAME COLUMN "key" TO "hash";
|
||||
ALTER TABLE "APIKey" RENAME COLUMN "prefix" TO "head";
|
||||
ALTER TABLE "APIKey" RENAME COLUMN "postfix" TO "tail";
|
||||
@@ -1,5 +0,0 @@
|
||||
-- Add 'credentialInputs', 'inputs', and 'nodesInputMasks' columns to the AgentGraphExecution table
|
||||
ALTER TABLE "AgentGraphExecution"
|
||||
ADD COLUMN "credentialInputs" JSONB,
|
||||
ADD COLUMN "inputs" JSONB,
|
||||
ADD COLUMN "nodesInputMasks" JSONB;
|
||||
@@ -1,53 +0,0 @@
|
||||
-- Update StoreAgent view to include is_available field and fix creator field nullability
|
||||
|
||||
BEGIN;
|
||||
|
||||
-- Drop and recreate the StoreAgent view with isAvailable field
|
||||
DROP VIEW IF EXISTS "StoreAgent";
|
||||
|
||||
CREATE OR REPLACE VIEW "StoreAgent" AS
|
||||
WITH agent_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
)
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
slv.id AS "storeListingVersionId",
|
||||
slv."createdAt" AS updated_at,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS agent_name,
|
||||
slv."videoUrl" AS agent_video,
|
||||
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
|
||||
slv."isFeatured" AS featured,
|
||||
p.username AS creator_username, -- Allow NULL for malformed sub-agents
|
||||
p."avatarUrl" AS creator_avatar, -- Allow NULL for malformed sub-agents
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.categories,
|
||||
COALESCE(ar.run_count, 0::bigint) AS runs,
|
||||
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
|
||||
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions,
|
||||
slv."isAvailable" AS is_available -- Add isAvailable field to filter sub-agents
|
||||
FROM "StoreListing" sl
|
||||
INNER JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
JOIN "AgentGraph" a
|
||||
ON slv."agentGraphId" = a.id
|
||||
AND slv."agentGraphVersion" = a.version
|
||||
LEFT JOIN "Profile" p
|
||||
ON sl."owningUserId" = p."userId"
|
||||
LEFT JOIN "mv_review_stats" rs
|
||||
ON sl.id = rs."storeListingId"
|
||||
LEFT JOIN "mv_agent_run_counts" ar
|
||||
ON a.id = ar."agentGraphId"
|
||||
LEFT JOIN agent_versions av
|
||||
ON sl.id = av."storeListingId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true;
|
||||
|
||||
COMMIT;
|
||||
@@ -1,3 +0,0 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "StoreListingVersion" ADD COLUMN "recommendedScheduleCron" TEXT;
|
||||
ALTER TABLE "AgentGraph" ADD COLUMN "recommendedScheduleCron" TEXT;
|
||||
@@ -1,66 +0,0 @@
|
||||
-- Fixes the refresh function+job introduced in 20250604130249_optimise_store_agent_and_creator_views
|
||||
-- by improving the function to accept a schema parameter and updating the cron job to use it.
|
||||
-- This resolves the issue where pg_cron jobs fail because they run in 'public' schema
|
||||
-- but the materialized views exist in 'platform' schema.
|
||||
|
||||
|
||||
-- Create parameterized refresh function that accepts schema name
|
||||
CREATE OR REPLACE FUNCTION refresh_store_materialized_views()
|
||||
RETURNS void
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
DECLARE
|
||||
target_schema text := current_schema(); -- Use the current schema where the function is called
|
||||
BEGIN
|
||||
-- Use CONCURRENTLY for better performance during refresh
|
||||
REFRESH MATERIALIZED VIEW CONCURRENTLY "mv_agent_run_counts";
|
||||
REFRESH MATERIALIZED VIEW CONCURRENTLY "mv_review_stats";
|
||||
RAISE NOTICE 'Materialized views refreshed in schema % at %', target_schema, NOW();
|
||||
EXCEPTION
|
||||
WHEN OTHERS THEN
|
||||
-- Fallback to non-concurrent refresh if concurrent fails
|
||||
REFRESH MATERIALIZED VIEW "mv_agent_run_counts";
|
||||
REFRESH MATERIALIZED VIEW "mv_review_stats";
|
||||
RAISE NOTICE 'Materialized views refreshed (non-concurrent) in schema % at %. Concurrent refresh failed due to: %', target_schema, NOW(), SQLERRM;
|
||||
END;
|
||||
$$;
|
||||
|
||||
-- Initial refresh + test of the function to ensure it works
|
||||
SELECT refresh_store_materialized_views();
|
||||
|
||||
-- Re-create the cron job to use the improved function
|
||||
DO $$
|
||||
DECLARE
|
||||
has_pg_cron BOOLEAN;
|
||||
current_schema_name text := current_schema();
|
||||
old_job_name text;
|
||||
job_name text;
|
||||
BEGIN
|
||||
-- Check if pg_cron extension exists
|
||||
SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_cron') INTO has_pg_cron;
|
||||
|
||||
IF has_pg_cron THEN
|
||||
old_job_name := format('refresh-store-views-%s', current_schema_name);
|
||||
job_name := format('refresh-store-views_%s', current_schema_name);
|
||||
|
||||
-- Try to unschedule existing job (ignore errors if it doesn't exist)
|
||||
BEGIN
|
||||
PERFORM cron.unschedule(old_job_name);
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
NULL;
|
||||
END;
|
||||
|
||||
-- Schedule the new job with explicit schema parameter
|
||||
PERFORM cron.schedule(
|
||||
job_name,
|
||||
'*/15 * * * *',
|
||||
format('SET search_path TO %I; SELECT refresh_store_materialized_views();', current_schema_name)
|
||||
);
|
||||
RAISE NOTICE 'Scheduled job %; runs every 15 minutes for schema %', job_name, current_schema_name;
|
||||
ELSE
|
||||
RAISE WARNING '⚠️ Automatic refresh NOT configured - pg_cron is not available';
|
||||
RAISE WARNING '⚠️ You must manually refresh views with: SELECT refresh_store_materialized_views();';
|
||||
RAISE WARNING '⚠️ Or install pg_cron for automatic refresh in production';
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,3 +0,0 @@
|
||||
-- Re-create foreign key CreditTransaction <- User with ON DELETE NO ACTION
|
||||
ALTER TABLE "CreditTransaction" DROP CONSTRAINT "CreditTransaction_userId_fkey";
|
||||
ALTER TABLE "CreditTransaction" ADD CONSTRAINT "CreditTransaction_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE NO ACTION ON UPDATE CASCADE;
|
||||
@@ -1,22 +0,0 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- A unique constraint covering the columns `[shareToken]` on the table `AgentGraphExecution` will be added. If there are existing duplicate values, this will fail.
|
||||
|
||||
*/
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraphExecution" ADD COLUMN "isShared" BOOLEAN NOT NULL DEFAULT false,
|
||||
ADD COLUMN "shareToken" TEXT,
|
||||
ADD COLUMN "sharedAt" TIMESTAMP(3);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "AgentGraphExecution_shareToken_key" ON "AgentGraphExecution"("shareToken");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentGraphExecution_shareToken_idx" ON "AgentGraphExecution"("shareToken");
|
||||
|
||||
-- RenameIndex
|
||||
ALTER INDEX "APIKey_key_key" RENAME TO "APIKey_hash_key";
|
||||
|
||||
-- RenameIndex
|
||||
ALTER INDEX "APIKey_prefix_name_idx" RENAME TO "APIKey_head_name_idx";
|
||||
@@ -1,53 +0,0 @@
|
||||
-- Add instructions field to AgentGraph and StoreListingVersion tables and update StoreSubmission view
|
||||
|
||||
BEGIN;
|
||||
|
||||
-- AddColumn
|
||||
ALTER TABLE "AgentGraph" ADD COLUMN "instructions" TEXT;
|
||||
|
||||
-- AddColumn
|
||||
ALTER TABLE "StoreListingVersion" ADD COLUMN "instructions" TEXT;
|
||||
|
||||
-- Drop the existing view
|
||||
DROP VIEW IF EXISTS "StoreSubmission";
|
||||
|
||||
-- Recreate the view with the new instructions field
|
||||
CREATE VIEW "StoreSubmission" AS
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
sl."owningUserId" AS user_id,
|
||||
slv."agentGraphId" AS agent_id,
|
||||
slv.version AS agent_version,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS name,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.instructions,
|
||||
slv."imageUrls" AS image_urls,
|
||||
slv."submittedAt" AS date_submitted,
|
||||
slv."submissionStatus" AS status,
|
||||
COALESCE(ar.run_count, 0::bigint) AS runs,
|
||||
COALESCE(avg(sr.score::numeric), 0.0)::double precision AS rating,
|
||||
slv.id AS store_listing_version_id,
|
||||
slv."reviewerId" AS reviewer_id,
|
||||
slv."reviewComments" AS review_comments,
|
||||
slv."internalComments" AS internal_comments,
|
||||
slv."reviewedAt" AS reviewed_at,
|
||||
slv."changesSummary" AS changes_summary,
|
||||
slv."videoUrl" AS video_url,
|
||||
slv.categories
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
|
||||
LEFT JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
|
||||
LEFT JOIN (
|
||||
SELECT "AgentGraphExecution"."agentGraphId", count(*) AS run_count
|
||||
FROM "AgentGraphExecution"
|
||||
GROUP BY "AgentGraphExecution"."agentGraphId"
|
||||
) ar ON ar."agentGraphId" = slv."agentGraphId"
|
||||
WHERE sl."isDeleted" = false
|
||||
GROUP BY sl.id, sl."owningUserId", slv.id, slv."agentGraphId", slv.version, sl.slug, slv.name,
|
||||
slv."subHeading", slv.description, slv.instructions, slv."imageUrls", slv."submittedAt",
|
||||
slv."submissionStatus", slv."reviewerId", slv."reviewComments", slv."internalComments",
|
||||
slv."reviewedAt", slv."changesSummary", slv."videoUrl", slv.categories, ar.run_count;
|
||||
|
||||
COMMIT;
|
||||
101
autogpt_platform/backend/poetry.lock
generated
101
autogpt_platform/backend/poetry.lock
generated
@@ -403,7 +403,6 @@ develop = true
|
||||
|
||||
[package.dependencies]
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
@@ -899,62 +898,52 @@ pytz = ">2021.1"
|
||||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
version = "45.0.7"
|
||||
version = "43.0.3"
|
||||
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
||||
optional = false
|
||||
python-versions = "!=3.9.0,!=3.9.1,>=3.7"
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "cryptography-45.0.7-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:3be4f21c6245930688bd9e162829480de027f8bf962ede33d4f8ba7d67a00cee"},
|
||||
{file = "cryptography-45.0.7-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:67285f8a611b0ebc0857ced2081e30302909f571a46bfa7a3cc0ad303fe015c6"},
|
||||
{file = "cryptography-45.0.7-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:577470e39e60a6cd7780793202e63536026d9b8641de011ed9d8174da9ca5339"},
|
||||
{file = "cryptography-45.0.7-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:4bd3e5c4b9682bc112d634f2c6ccc6736ed3635fc3319ac2bb11d768cc5a00d8"},
|
||||
{file = "cryptography-45.0.7-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:465ccac9d70115cd4de7186e60cfe989de73f7bb23e8a7aa45af18f7412e75bf"},
|
||||
{file = "cryptography-45.0.7-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:16ede8a4f7929b4b7ff3642eba2bf79aa1d71f24ab6ee443935c0d269b6bc513"},
|
||||
{file = "cryptography-45.0.7-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8978132287a9d3ad6b54fcd1e08548033cc09dc6aacacb6c004c73c3eb5d3ac3"},
|
||||
{file = "cryptography-45.0.7-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:b6a0e535baec27b528cb07a119f321ac024592388c5681a5ced167ae98e9fff3"},
|
||||
{file = "cryptography-45.0.7-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:a24ee598d10befaec178efdff6054bc4d7e883f615bfbcd08126a0f4931c83a6"},
|
||||
{file = "cryptography-45.0.7-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:fa26fa54c0a9384c27fcdc905a2fb7d60ac6e47d14bc2692145f2b3b1e2cfdbd"},
|
||||
{file = "cryptography-45.0.7-cp311-abi3-win32.whl", hash = "sha256:bef32a5e327bd8e5af915d3416ffefdbe65ed975b646b3805be81b23580b57b8"},
|
||||
{file = "cryptography-45.0.7-cp311-abi3-win_amd64.whl", hash = "sha256:3808e6b2e5f0b46d981c24d79648e5c25c35e59902ea4391a0dcb3e667bf7443"},
|
||||
{file = "cryptography-45.0.7-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bfb4c801f65dd61cedfc61a83732327fafbac55a47282e6f26f073ca7a41c3b2"},
|
||||
{file = "cryptography-45.0.7-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:81823935e2f8d476707e85a78a405953a03ef7b7b4f55f93f7c2d9680e5e0691"},
|
||||
{file = "cryptography-45.0.7-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3994c809c17fc570c2af12c9b840d7cea85a9fd3e5c0e0491f4fa3c029216d59"},
|
||||
{file = "cryptography-45.0.7-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:dad43797959a74103cb59c5dac71409f9c27d34c8a05921341fb64ea8ccb1dd4"},
|
||||
{file = "cryptography-45.0.7-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:ce7a453385e4c4693985b4a4a3533e041558851eae061a58a5405363b098fcd3"},
|
||||
{file = "cryptography-45.0.7-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:b04f85ac3a90c227b6e5890acb0edbaf3140938dbecf07bff618bf3638578cf1"},
|
||||
{file = "cryptography-45.0.7-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:48c41a44ef8b8c2e80ca4527ee81daa4c527df3ecbc9423c41a420a9559d0e27"},
|
||||
{file = "cryptography-45.0.7-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:f3df7b3d0f91b88b2106031fd995802a2e9ae13e02c36c1fc075b43f420f3a17"},
|
||||
{file = "cryptography-45.0.7-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:dd342f085542f6eb894ca00ef70236ea46070c8a13824c6bde0dfdcd36065b9b"},
|
||||
{file = "cryptography-45.0.7-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:1993a1bb7e4eccfb922b6cd414f072e08ff5816702a0bdb8941c247a6b1b287c"},
|
||||
{file = "cryptography-45.0.7-cp37-abi3-win32.whl", hash = "sha256:18fcf70f243fe07252dcb1b268a687f2358025ce32f9f88028ca5c364b123ef5"},
|
||||
{file = "cryptography-45.0.7-cp37-abi3-win_amd64.whl", hash = "sha256:7285a89df4900ed3bfaad5679b1e668cb4b38a8de1ccbfc84b05f34512da0a90"},
|
||||
{file = "cryptography-45.0.7-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:de58755d723e86175756f463f2f0bddd45cc36fbd62601228a3f8761c9f58252"},
|
||||
{file = "cryptography-45.0.7-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a20e442e917889d1a6b3c570c9e3fa2fdc398c20868abcea268ea33c024c4083"},
|
||||
{file = "cryptography-45.0.7-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:258e0dff86d1d891169b5af222d362468a9570e2532923088658aa866eb11130"},
|
||||
{file = "cryptography-45.0.7-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d97cf502abe2ab9eff8bd5e4aca274da8d06dd3ef08b759a8d6143f4ad65d4b4"},
|
||||
{file = "cryptography-45.0.7-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:c987dad82e8c65ebc985f5dae5e74a3beda9d0a2a4daf8a1115f3772b59e5141"},
|
||||
{file = "cryptography-45.0.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:c13b1e3afd29a5b3b2656257f14669ca8fa8d7956d509926f0b130b600b50ab7"},
|
||||
{file = "cryptography-45.0.7-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4a862753b36620af6fc54209264f92c716367f2f0ff4624952276a6bbd18cbde"},
|
||||
{file = "cryptography-45.0.7-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:06ce84dc14df0bf6ea84666f958e6080cdb6fe1231be2a51f3fc1267d9f3fb34"},
|
||||
{file = "cryptography-45.0.7-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d0c5c6bac22b177bf8da7435d9d27a6834ee130309749d162b26c3105c0795a9"},
|
||||
{file = "cryptography-45.0.7-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:2f641b64acc00811da98df63df7d59fd4706c0df449da71cb7ac39a0732b40ae"},
|
||||
{file = "cryptography-45.0.7-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:f5414a788ecc6ee6bc58560e85ca624258a55ca434884445440a810796ea0e0b"},
|
||||
{file = "cryptography-45.0.7-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:1f3d56f73595376f4244646dd5c5870c14c196949807be39e79e7bd9bac3da63"},
|
||||
{file = "cryptography-45.0.7.tar.gz", hash = "sha256:4b1654dfc64ea479c242508eb8c724044f1e964a47d1d1cacc5132292d851971"},
|
||||
{file = "cryptography-43.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bf7a1932ac4176486eab36a19ed4c0492da5d97123f1406cf15e41b05e787d2e"},
|
||||
{file = "cryptography-43.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63efa177ff54aec6e1c0aefaa1a241232dcd37413835a9b674b6e3f0ae2bfd3e"},
|
||||
{file = "cryptography-43.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e1ce50266f4f70bf41a2c6dc4358afadae90e2a1e5342d3c08883df1675374f"},
|
||||
{file = "cryptography-43.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:443c4a81bb10daed9a8f334365fe52542771f25aedaf889fd323a853ce7377d6"},
|
||||
{file = "cryptography-43.0.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:74f57f24754fe349223792466a709f8e0c093205ff0dca557af51072ff47ab18"},
|
||||
{file = "cryptography-43.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9762ea51a8fc2a88b70cf2995e5675b38d93bf36bd67d91721c309df184f49bd"},
|
||||
{file = "cryptography-43.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:81ef806b1fef6b06dcebad789f988d3b37ccaee225695cf3e07648eee0fc6b73"},
|
||||
{file = "cryptography-43.0.3-cp37-abi3-win32.whl", hash = "sha256:cbeb489927bd7af4aa98d4b261af9a5bc025bd87f0e3547e11584be9e9427be2"},
|
||||
{file = "cryptography-43.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:f46304d6f0c6ab8e52770addfa2fc41e6629495548862279641972b6215451cd"},
|
||||
{file = "cryptography-43.0.3-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:8ac43ae87929a5982f5948ceda07001ee5e83227fd69cf55b109144938d96984"},
|
||||
{file = "cryptography-43.0.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:846da004a5804145a5f441b8530b4bf35afbf7da70f82409f151695b127213d5"},
|
||||
{file = "cryptography-43.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f996e7268af62598f2fc1204afa98a3b5712313a55c4c9d434aef49cadc91d4"},
|
||||
{file = "cryptography-43.0.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:f7b178f11ed3664fd0e995a47ed2b5ff0a12d893e41dd0494f406d1cf555cab7"},
|
||||
{file = "cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:c2e6fc39c4ab499049df3bdf567f768a723a5e8464816e8f009f121a5a9f4405"},
|
||||
{file = "cryptography-43.0.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e1be4655c7ef6e1bbe6b5d0403526601323420bcf414598955968c9ef3eb7d16"},
|
||||
{file = "cryptography-43.0.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:df6b6c6d742395dd77a23ea3728ab62f98379eff8fb61be2744d4679ab678f73"},
|
||||
{file = "cryptography-43.0.3-cp39-abi3-win32.whl", hash = "sha256:d56e96520b1020449bbace2b78b603442e7e378a9b3bd68de65c782db1507995"},
|
||||
{file = "cryptography-43.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:0c580952eef9bf68c4747774cde7ec1d85a6e61de97281f2dba83c7d2c806362"},
|
||||
{file = "cryptography-43.0.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d03b5621a135bffecad2c73e9f4deb1a0f977b9a8ffe6f8e002bf6c9d07b918c"},
|
||||
{file = "cryptography-43.0.3-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a2a431ee15799d6db9fe80c82b055bae5a752bef645bba795e8e52687c69efe3"},
|
||||
{file = "cryptography-43.0.3-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:281c945d0e28c92ca5e5930664c1cefd85efe80e5c0d2bc58dd63383fda29f83"},
|
||||
{file = "cryptography-43.0.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:f18c716be16bc1fea8e95def49edf46b82fccaa88587a45f8dc0ff6ab5d8e0a7"},
|
||||
{file = "cryptography-43.0.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4a02ded6cd4f0a5562a8887df8b3bd14e822a90f97ac5e544c162899bc467664"},
|
||||
{file = "cryptography-43.0.3-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:53a583b6637ab4c4e3591a15bc9db855b8d9dee9a669b550f311480acab6eb08"},
|
||||
{file = "cryptography-43.0.3-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1ec0bcf7e17c0c5669d881b1cd38c4972fade441b27bda1051665faaa89bdcaa"},
|
||||
{file = "cryptography-43.0.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2ce6fae5bdad59577b44e4dfed356944fbf1d925269114c28be377692643b4ff"},
|
||||
{file = "cryptography-43.0.3.tar.gz", hash = "sha256:315b9001266a492a6ff443b61238f956b214dbec9910a081ba5b6646a055a805"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cffi = {version = ">=1.14", markers = "platform_python_implementation != \"PyPy\""}
|
||||
cffi = {version = ">=1.12", markers = "platform_python_implementation != \"PyPy\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3.0)", "sphinx-inline-tabs ; python_full_version >= \"3.8.0\"", "sphinx-rtd-theme (>=3.0.0) ; python_full_version >= \"3.8.0\""]
|
||||
docstest = ["pyenchant (>=3)", "readme-renderer (>=30.0)", "sphinxcontrib-spelling (>=7.3.1)"]
|
||||
nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2) ; python_full_version >= \"3.8.0\""]
|
||||
pep8test = ["check-sdist ; python_full_version >= \"3.8.0\"", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"]
|
||||
sdist = ["build (>=1.0.0)"]
|
||||
docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=1.1.1)"]
|
||||
docstest = ["pyenchant (>=1.6.11)", "readme-renderer", "sphinxcontrib-spelling (>=4.0.1)"]
|
||||
nox = ["nox"]
|
||||
pep8test = ["check-sdist", "click", "mypy", "ruff"]
|
||||
sdist = ["build"]
|
||||
ssh = ["bcrypt (>=3.1.5)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==45.0.7)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test = ["certifi", "cryptography-vectors (==43.0.3)", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"]
|
||||
test-randomorder = ["pytest-randomly"]
|
||||
|
||||
[[package]]
|
||||
@@ -4145,22 +4134,6 @@ files = [
|
||||
[package.extras]
|
||||
twisted = ["twisted"]
|
||||
|
||||
[[package]]
|
||||
name = "prometheus-fastapi-instrumentator"
|
||||
version = "7.1.0"
|
||||
description = "Instrument your FastAPI app with Prometheus metrics"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "prometheus_fastapi_instrumentator-7.1.0-py3-none-any.whl", hash = "sha256:978130f3c0bb7b8ebcc90d35516a6fe13e02d2eb358c8f83887cdef7020c31e9"},
|
||||
{file = "prometheus_fastapi_instrumentator-7.1.0.tar.gz", hash = "sha256:be7cd61eeea4e5912aeccb4261c6631b3f227d8924542d79eaf5af3f439cbe5e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
prometheus-client = ">=0.8.0,<1.0.0"
|
||||
starlette = ">=0.30.0,<1.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "propcache"
|
||||
version = "0.3.2"
|
||||
@@ -7159,4 +7132,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "2c7e9370f500039b99868376021627c5a120e0ee31c5c5e6de39db2c3d82f414"
|
||||
content-hash = "892daa57d7126d9a9d5308005b07328a39b8c4cd7fe198f9b5ab10f957787c48"
|
||||
|
||||
@@ -17,7 +17,7 @@ apscheduler = "^3.11.0"
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||
click = "^8.2.0"
|
||||
cryptography = "^45.0"
|
||||
cryptography = "^43.0"
|
||||
discord-py = "^2.5.2"
|
||||
e2b-code-interpreter = "^1.5.2"
|
||||
fastapi = "^0.116.1"
|
||||
@@ -45,7 +45,6 @@ postmarker = "^1.0"
|
||||
praw = "~7.8.1"
|
||||
prisma = "^0.15.0"
|
||||
prometheus-client = "^0.22.1"
|
||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||
psutil = "^7.0.0"
|
||||
psycopg2-binary = "^2.9.10"
|
||||
pydantic = { extras = ["email"], version = "^2.11.7" }
|
||||
|
||||
@@ -36,7 +36,7 @@ model User {
|
||||
notifyOnAgentApproved Boolean @default(true)
|
||||
notifyOnAgentRejected Boolean @default(true)
|
||||
|
||||
timezone String @default("not-set")
|
||||
timezone String @default("not-set")
|
||||
|
||||
// Relations
|
||||
|
||||
@@ -110,8 +110,6 @@ model AgentGraph {
|
||||
|
||||
name String?
|
||||
description String?
|
||||
instructions String?
|
||||
recommendedScheduleCron String?
|
||||
|
||||
isActive Boolean @default(true)
|
||||
|
||||
@@ -356,31 +354,20 @@ model AgentGraphExecution {
|
||||
agentGraphVersion Int @default(1)
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Cascade)
|
||||
|
||||
agentPresetId String?
|
||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||
|
||||
inputs Json?
|
||||
credentialInputs Json?
|
||||
nodesInputMasks Json?
|
||||
|
||||
NodeExecutions AgentNodeExecution[]
|
||||
|
||||
// Link to User model -- Executed by this user
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
stats Json?
|
||||
|
||||
// Sharing fields
|
||||
isShared Boolean @default(false)
|
||||
shareToken String? @unique
|
||||
sharedAt DateTime?
|
||||
stats Json?
|
||||
agentPresetId String?
|
||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||
|
||||
@@index([agentGraphId, agentGraphVersion])
|
||||
@@index([userId])
|
||||
@@index([createdAt])
|
||||
@@index([agentPresetId])
|
||||
@@index([shareToken])
|
||||
}
|
||||
|
||||
// This model describes the execution of an AgentNode.
|
||||
@@ -535,7 +522,7 @@ model CreditTransaction {
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
userId String
|
||||
User User? @relation(fields: [userId], references: [id], onDelete: NoAction)
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
amount Int
|
||||
type CreditTransactionType
|
||||
@@ -638,15 +625,14 @@ view StoreAgent {
|
||||
agent_image String[]
|
||||
|
||||
featured Boolean @default(false)
|
||||
creator_username String?
|
||||
creator_avatar String?
|
||||
creator_username String
|
||||
creator_avatar String
|
||||
sub_heading String
|
||||
description String
|
||||
categories String[]
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
is_available Boolean @default(true)
|
||||
|
||||
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
||||
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
||||
@@ -764,7 +750,6 @@ model StoreListingVersion {
|
||||
videoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
|
||||
isFeatured Boolean @default(false)
|
||||
@@ -794,8 +779,6 @@ model StoreListingVersion {
|
||||
reviewComments String? // Comments visible to creator
|
||||
reviewedAt DateTime?
|
||||
|
||||
recommendedScheduleCron String? // cron expression like "0 9 * * *"
|
||||
|
||||
// Reviews for this specific version
|
||||
Reviews StoreListingReview[]
|
||||
|
||||
@@ -839,13 +822,11 @@ enum APIKeyPermission {
|
||||
}
|
||||
|
||||
model APIKey {
|
||||
id String @id @default(uuid())
|
||||
name String
|
||||
head String // First few chars for identification
|
||||
tail String
|
||||
hash String @unique
|
||||
salt String? // null for legacy unsalted keys
|
||||
|
||||
id String @id @default(uuid())
|
||||
name String
|
||||
prefix String // First 8 chars for identification
|
||||
postfix String
|
||||
key String @unique // Hashed key
|
||||
status APIKeyStatus @default(ACTIVE)
|
||||
permissions APIKeyPermission[]
|
||||
|
||||
@@ -859,7 +840,7 @@ model APIKey {
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@index([head, name])
|
||||
@@index([prefix, name])
|
||||
@@index([userId, status])
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
"creator_avatar": "avatar1.jpg",
|
||||
"sub_heading": "Test agent subheading",
|
||||
"description": "Test agent description",
|
||||
"instructions": null,
|
||||
"categories": [
|
||||
"category1",
|
||||
"category2"
|
||||
@@ -23,7 +22,6 @@
|
||||
"1.1.0"
|
||||
],
|
||||
"last_updated": "2023-01-01T00:00:00",
|
||||
"recommended_schedule_cron": null,
|
||||
"active_version_id": null,
|
||||
"has_approved_version": false
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
{
|
||||
"created_at": "2025-09-04T13:37:00",
|
||||
"credentials_input_schema": {
|
||||
"properties": {},
|
||||
"title": "TestGraphCredentialsInputSchema",
|
||||
@@ -15,7 +14,6 @@
|
||||
"required": [],
|
||||
"type": "object"
|
||||
},
|
||||
"instructions": null,
|
||||
"is_active": true,
|
||||
"links": [],
|
||||
"name": "Test Graph",
|
||||
@@ -25,9 +23,7 @@
|
||||
"required": [],
|
||||
"type": "object"
|
||||
},
|
||||
"recommended_schedule_cron": null,
|
||||
"sub_graphs": [],
|
||||
"trigger_setup_info": null,
|
||||
"user_id": "test-user-id",
|
||||
"version": 1
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user