mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 00:28:31 -05:00
Compare commits
5 Commits
fix/databa
...
feat/execu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f652cb978 | ||
|
|
279552a2a3 | ||
|
|
fb6ac1d6ca | ||
|
|
9db15bff02 | ||
|
|
db4b94e0dc |
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
@@ -1,97 +0,0 @@
|
||||
name: Auto Fix CI Failures
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["CI"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
actions: read
|
||||
issues: write
|
||||
id-token: write # Required for OIDC token exchange
|
||||
|
||||
jobs:
|
||||
auto-fix:
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'failure' &&
|
||||
github.event.workflow_run.pull_requests[0] &&
|
||||
!startsWith(github.event.workflow_run.head_branch, 'claude-auto-fix-ci-')
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.workflow_run.head_branch }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Setup git identity
|
||||
run: |
|
||||
git config --global user.email "claude[bot]@users.noreply.github.com"
|
||||
git config --global user.name "claude[bot]"
|
||||
|
||||
- name: Create fix branch
|
||||
id: branch
|
||||
run: |
|
||||
BRANCH_NAME="claude-auto-fix-ci-${{ github.event.workflow_run.head_branch }}-${{ github.run_id }}"
|
||||
git checkout -b "$BRANCH_NAME"
|
||||
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const run = await github.rest.actions.getWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const jobs = await github.rest.actions.listJobsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const failedJobs = jobs.data.jobs.filter(job => job.conclusion === 'failure');
|
||||
|
||||
let errorLogs = [];
|
||||
for (const job of failedJobs) {
|
||||
const logs = await github.rest.actions.downloadJobLogsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
job_id: job.id
|
||||
});
|
||||
errorLogs.push({
|
||||
jobName: job.name,
|
||||
logs: logs.data
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
runUrl: run.data.html_url,
|
||||
failedJobs: failedJobs.map(j => j.name),
|
||||
errorLogs: errorLogs
|
||||
};
|
||||
|
||||
- name: Fix CI failures with Claude
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
prompt: |
|
||||
/fix-ci
|
||||
Failed CI Run: ${{ fromJSON(steps.failure_details.outputs.result).runUrl }}
|
||||
Failed Jobs: ${{ join(fromJSON(steps.failure_details.outputs.result).failedJobs, ', ') }}
|
||||
PR Number: ${{ github.event.workflow_run.pull_requests[0].number }}
|
||||
Branch Name: ${{ steps.branch.outputs.branch_name }}
|
||||
Base Branch: ${{ github.event.workflow_run.head_branch }}
|
||||
Repository: ${{ github.repository }}
|
||||
|
||||
Error logs:
|
||||
${{ toJSON(fromJSON(steps.failure_details.outputs.result).errorLogs) }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: "--allowedTools 'Edit,MultiEdit,Write,Read,Glob,Grep,LS,Bash(git:*),Bash(bun:*),Bash(npm:*),Bash(npx:*),Bash(gh:*)'"
|
||||
379
.github/workflows/claude-dependabot.yml
vendored
379
.github/workflows/claude-dependabot.yml
vendored
@@ -1,379 +0,0 @@
|
||||
# Claude Dependabot PR Review Workflow
|
||||
#
|
||||
# This workflow automatically runs Claude analysis on Dependabot PRs to:
|
||||
# - Identify dependency changes and their versions
|
||||
# - Look up changelogs for updated packages
|
||||
# - Assess breaking changes and security impacts
|
||||
# - Provide actionable recommendations for the development team
|
||||
#
|
||||
# Triggered on: Dependabot PRs (opened, synchronize)
|
||||
# Requirements: ANTHROPIC_API_KEY secret must be configured
|
||||
|
||||
name: Claude Dependabot PR Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
|
||||
jobs:
|
||||
dependabot-review:
|
||||
# Only run on Dependabot PRs
|
||||
if: github.actor == 'dependabot[bot]'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
|
||||
- name: Run Claude Dependabot Analysis
|
||||
id: claude_review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
You are Claude, an AI assistant specialized in reviewing Dependabot dependency update PRs.
|
||||
|
||||
Your primary tasks are:
|
||||
1. **Analyze the dependency changes** in this Dependabot PR
|
||||
2. **Look up changelogs** for all updated dependencies to understand what changed
|
||||
3. **Identify breaking changes** and assess potential impact on the AutoGPT codebase
|
||||
4. **Provide actionable recommendations** for the development team
|
||||
|
||||
## Analysis Process:
|
||||
|
||||
1. **Identify Changed Dependencies**:
|
||||
- Use git diff to see what dependencies were updated
|
||||
- Parse package.json, poetry.lock, requirements files, etc.
|
||||
- List all package versions: old → new
|
||||
|
||||
2. **Changelog Research**:
|
||||
- For each updated dependency, look up its changelog/release notes
|
||||
- Use WebFetch to access GitHub releases, NPM package pages, PyPI project pages. The pr should also have some details
|
||||
- Focus on versions between the old and new versions
|
||||
- Identify: breaking changes, deprecations, security fixes, new features
|
||||
|
||||
3. **Breaking Change Assessment**:
|
||||
- Categorize changes: BREAKING, MAJOR, MINOR, PATCH, SECURITY
|
||||
- Assess impact on AutoGPT's usage patterns
|
||||
- Check if AutoGPT uses affected APIs/features
|
||||
- Look for migration guides or upgrade instructions
|
||||
|
||||
4. **Codebase Impact Analysis**:
|
||||
- Search the AutoGPT codebase for usage of changed APIs
|
||||
- Identify files that might be affected by breaking changes
|
||||
- Check test files for deprecated usage patterns
|
||||
- Look for configuration changes needed
|
||||
|
||||
## Output Format:
|
||||
|
||||
Provide a comprehensive review comment with:
|
||||
|
||||
### 🔍 Dependency Analysis Summary
|
||||
- List of updated packages with version changes
|
||||
- Overall risk assessment (LOW/MEDIUM/HIGH)
|
||||
|
||||
### 📋 Detailed Changelog Review
|
||||
For each updated dependency:
|
||||
- **Package**: name (old_version → new_version)
|
||||
- **Changes**: Summary of key changes
|
||||
- **Breaking Changes**: List any breaking changes
|
||||
- **Security Fixes**: Note security improvements
|
||||
- **Migration Notes**: Any upgrade steps needed
|
||||
|
||||
### ⚠️ Impact Assessment
|
||||
- **Breaking Changes Found**: Yes/No with details
|
||||
- **Affected Files**: List AutoGPT files that may need updates
|
||||
- **Test Impact**: Any tests that may need updating
|
||||
- **Configuration Changes**: Required config updates
|
||||
|
||||
### 🛠️ Recommendations
|
||||
- **Action Required**: What the team should do
|
||||
- **Testing Focus**: Areas to test thoroughly
|
||||
- **Follow-up Tasks**: Any additional work needed
|
||||
- **Merge Recommendation**: APPROVE/REVIEW_NEEDED/HOLD
|
||||
|
||||
### 📚 Useful Links
|
||||
- Links to relevant changelogs, migration guides, documentation
|
||||
|
||||
Be thorough but concise. Focus on actionable insights that help the development team make informed decisions about the dependency updates.
|
||||
284
.github/workflows/claude.yml
vendored
284
.github/workflows/claude.yml
vendored
@@ -30,296 +30,18 @@ jobs:
|
||||
github.event.issue.author_association == 'COLLABORATOR'
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
contents: read
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
uses: anthropics/claude-code-action@beta
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*), Bash(gh pr edit:*)"
|
||||
--model opus
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
||||
@@ -3,7 +3,6 @@ name: AutoGPT Platform - Deploy Prod Environment
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -18,8 +17,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name || 'master' }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -39,7 +36,7 @@ jobs:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
runs-on: ubuntu-latest
|
||||
@@ -50,5 +47,4 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_prod
|
||||
client-payload: |
|
||||
{"ref": "${{ github.ref_name || 'master' }}", "repository": "${{ github.repository }}"}
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
@@ -5,13 +5,6 @@ on:
|
||||
branches: [ dev ]
|
||||
paths:
|
||||
- 'autogpt_platform/**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_ref:
|
||||
description: 'Git ref (branch/tag) of AutoGPT to deploy'
|
||||
required: true
|
||||
default: 'master'
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -26,8 +19,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -57,4 +48,4 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_dev
|
||||
client-payload: '{"ref": "${{ github.event.inputs.git_ref || github.ref }}", "repository": "${{ github.repository }}"}'
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
|
||||
5
.github/workflows/platform-backend-ci.yml
vendored
5
.github/workflows/platform-backend-ci.yml
vendored
@@ -37,7 +37,9 @@ jobs:
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:latest
|
||||
image: bitnami/redis:6.2
|
||||
env:
|
||||
REDIS_PASSWORD: testpassword
|
||||
ports:
|
||||
- 6379:6379
|
||||
rabbitmq:
|
||||
@@ -202,6 +204,7 @@ jobs:
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
env:
|
||||
|
||||
2
.github/workflows/platform-frontend-ci.yml
vendored
2
.github/workflows/platform-frontend-ci.yml
vendored
@@ -160,7 +160,7 @@ jobs:
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||
docker compose -f ../docker-compose.yml up -d
|
||||
env:
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||
|
||||
@@ -61,27 +61,24 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && pnpm i
|
||||
cd frontend && npm install
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
npm run dev
|
||||
|
||||
# Run E2E tests
|
||||
pnpm test
|
||||
npm run test
|
||||
|
||||
# Run Storybook for component development
|
||||
pnpm storybook
|
||||
npm run storybook
|
||||
|
||||
# Build production
|
||||
pnpm build
|
||||
npm run build
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
npm run types
|
||||
```
|
||||
|
||||
We have a components library in autogpt_platform/frontend/src/components/atoms that should be used when adding new pages and components.
|
||||
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
raw: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
hash: str
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
PREFIX: str = "agpt_"
|
||||
PREFIX_LENGTH: int = 8
|
||||
POSTFIX_LENGTH: int = 8
|
||||
|
||||
def generate_api_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with all its parts."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
return APIKeyContainer(
|
||||
raw=raw_key,
|
||||
prefix=raw_key[: self.PREFIX_LENGTH],
|
||||
postfix=raw_key[-self.POSTFIX_LENGTH :],
|
||||
hash=hashlib.sha256(raw_key.encode()).hexdigest(),
|
||||
)
|
||||
|
||||
def verify_api_key(self, provided_key: str, stored_hash: str) -> bool:
|
||||
"""Verify if a provided API key matches the stored hash."""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(provided_hash, stored_hash)
|
||||
@@ -1,78 +0,0 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
key: str
|
||||
head: str
|
||||
tail: str
|
||||
hash: str
|
||||
salt: str
|
||||
|
||||
|
||||
class APIKeySmith:
|
||||
PREFIX: str = "agpt_"
|
||||
HEAD_LENGTH: int = 8
|
||||
TAIL_LENGTH: int = 8
|
||||
|
||||
def generate_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with secure hashing."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
hash, salt = self.hash_key(raw_key)
|
||||
|
||||
return APIKeyContainer(
|
||||
key=raw_key,
|
||||
head=raw_key[: self.HEAD_LENGTH],
|
||||
tail=raw_key[-self.TAIL_LENGTH :],
|
||||
hash=hash,
|
||||
salt=salt,
|
||||
)
|
||||
|
||||
def verify_key(
|
||||
self, provided_key: str, known_hash: str, known_salt: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Verify an API key against a known hash (+ salt).
|
||||
Supports verifying both legacy SHA256 and secure Scrypt hashes.
|
||||
"""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
|
||||
# Handle legacy SHA256 hashes (migration support)
|
||||
if known_salt is None:
|
||||
legacy_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(legacy_hash, known_hash)
|
||||
|
||||
try:
|
||||
salt_bytes = bytes.fromhex(known_salt)
|
||||
provided_hash = self._hash_key_with_salt(provided_key, salt_bytes)
|
||||
return secrets.compare_digest(provided_hash, known_hash)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
def _generate_salt(self) -> bytes:
|
||||
"""Generate a random salt for hashing."""
|
||||
return secrets.token_bytes(32)
|
||||
|
||||
def _hash_key_with_salt(self, raw_key: str, salt: bytes) -> str:
|
||||
"""Hash API key using Scrypt with salt."""
|
||||
kdf = Scrypt(
|
||||
length=32,
|
||||
salt=salt,
|
||||
n=2**14, # CPU/memory cost parameter
|
||||
r=8, # Block size parameter
|
||||
p=1, # Parallelization parameter
|
||||
)
|
||||
key_hash = kdf.derive(raw_key.encode())
|
||||
return key_hash.hex()
|
||||
@@ -1,79 +0,0 @@
|
||||
import hashlib
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
|
||||
|
||||
def test_generate_api_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
assert key.key.startswith(keysmith.PREFIX)
|
||||
assert key.head == key.key[: keysmith.HEAD_LENGTH]
|
||||
assert key.tail == key.key[-keysmith.TAIL_LENGTH :]
|
||||
assert len(key.hash) == 64 # 32 bytes hex encoded
|
||||
assert len(key.salt) == 64 # 32 bytes hex encoded
|
||||
|
||||
|
||||
def test_verify_new_secure_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test correct key validates
|
||||
assert keysmith.verify_key(key.key, key.hash, key.salt) is True
|
||||
|
||||
# Test wrong key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey123"
|
||||
assert keysmith.verify_key(wrong_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_verify_legacy_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}legacykey123"
|
||||
legacy_hash = hashlib.sha256(legacy_key.encode()).hexdigest()
|
||||
|
||||
# Test legacy key validates without salt
|
||||
assert keysmith.verify_key(legacy_key, legacy_hash) is True
|
||||
|
||||
# Test wrong legacy key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wronglegacy"
|
||||
assert keysmith.verify_key(wrong_key, legacy_hash) is False
|
||||
|
||||
|
||||
def test_rehash_existing_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}migratekey123"
|
||||
|
||||
# Migrate the legacy key
|
||||
new_hash, new_salt = keysmith.hash_key(legacy_key)
|
||||
|
||||
# Verify migrated key works
|
||||
assert keysmith.verify_key(legacy_key, new_hash, new_salt) is True
|
||||
|
||||
# Verify different key fails with migrated hash
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey"
|
||||
assert keysmith.verify_key(wrong_key, new_hash, new_salt) is False
|
||||
|
||||
|
||||
def test_invalid_key_prefix():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test key without proper prefix fails
|
||||
invalid_key = "invalid_prefix_key"
|
||||
assert keysmith.verify_key(invalid_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_secure_hash_requires_salt():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Secure hash without salt should fail
|
||||
assert keysmith.verify_key(key.key, key.hash) is False
|
||||
|
||||
|
||||
def test_invalid_salt_format():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Invalid salt format should fail gracefully
|
||||
assert keysmith.verify_key(key.key, key.hash, "invalid_hex") is False
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
@@ -15,8 +13,8 @@ class RateLimitSettings(BaseSettings):
|
||||
default="6379", description="Redis port", validation_alias="REDIS_PORT"
|
||||
)
|
||||
|
||||
redis_password: Optional[str] = Field(
|
||||
default=None,
|
||||
redis_password: str = Field(
|
||||
default="password",
|
||||
description="Redis password",
|
||||
validation_alias="REDIS_PASSWORD",
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ class RateLimiter:
|
||||
self,
|
||||
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
|
||||
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
|
||||
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
|
||||
redis_password: str = RATE_LIMIT_SETTINGS.redis_password,
|
||||
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
|
||||
):
|
||||
self.redis = Redis(
|
||||
|
||||
@@ -1,68 +1,90 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _make_hashable_key(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
Convert args and kwargs into a hashable cache key.
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
pass
|
||||
|
||||
Handles unhashable types like dict, list, set by converting them to
|
||||
their sorted string representations.
|
||||
"""
|
||||
|
||||
def make_hashable(obj: Any) -> Any:
|
||||
"""Recursively convert an object to a hashable representation."""
|
||||
if isinstance(obj, dict):
|
||||
# Sort dict items to ensure consistent ordering
|
||||
return (
|
||||
"__dict__",
|
||||
tuple(sorted((k, make_hashable(v)) for k, v in obj.items())),
|
||||
)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return ("__list__", tuple(make_hashable(item) for item in obj))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", tuple(sorted(make_hashable(item) for item in obj)))
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# Handle objects with __dict__ attribute
|
||||
return ("__obj__", obj.__class__.__name__, make_hashable(obj.__dict__))
|
||||
else:
|
||||
# For basic hashable types (str, int, bool, None, etc.)
|
||||
try:
|
||||
hash(obj)
|
||||
return obj
|
||||
except TypeError:
|
||||
# Fallback: convert to string representation
|
||||
return ("__str__", str(obj))
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
pass
|
||||
|
||||
hashable_args = tuple(make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(sorted((k, make_hashable(v)) for k, v in kwargs.items()))
|
||||
return (hashable_args, hashable_kwargs)
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
|
||||
FuncT = TypeVar("FuncT")
|
||||
|
||||
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for cached functions with cache management methods."""
|
||||
class AsyncCachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for async functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
@@ -72,169 +94,101 @@ class CachedFunction(Protocol[P, R_co]):
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
def cache_delete(self, *args: P.args, **kwargs: P.kwargs) -> bool:
|
||||
"""Delete a specific cache entry by its arguments. Returns True if entry existed."""
|
||||
return False
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def cached(
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> Callable[[Callable], CachedFunction]:
|
||||
def async_ttl_cache(
|
||||
maxsize: int = 128, ttl_seconds: int | None = None
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
"""
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
TTL (Time To Live) cache decorator for async functions.
|
||||
|
||||
Uses double-checked locking to prevent multiple threads/coroutines from
|
||||
executing the expensive operation simultaneously during cache misses.
|
||||
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||
|
||||
Args:
|
||||
func: The function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
@cache() # Default: maxsize=128, no TTL
|
||||
def expensive_sync_operation(param: str) -> dict:
|
||||
# With TTL
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||
async def api_call(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cache() # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cache(maxsize=1000, ttl_seconds=300) # Custom maxsize and TTL
|
||||
def another_operation(param: str) -> dict:
|
||||
# Without TTL (permanent cache like lru_cache)
|
||||
@async_ttl_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(target_func):
|
||||
# Cache storage and locks
|
||||
cache_storage = {}
|
||||
def decorator(
|
||||
async_func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
# Cache storage - use union type to handle both cases
|
||||
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
|
||||
|
||||
if inspect.iscoroutinefunction(target_func):
|
||||
# Async function with asyncio.Lock
|
||||
cache_lock = asyncio.Lock()
|
||||
@wraps(async_func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# Create cache key from arguments
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
current_time = time.time()
|
||||
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with cache_lock:
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
# Check if we have a valid cached entry
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
# No TTL - return cached result directly
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, cache_storage[key])
|
||||
else:
|
||||
# With TTL - check expiration
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, result)
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
# Expired entry
|
||||
del cache_storage[key]
|
||||
logger.debug(
|
||||
f"Cache entry expired for {async_func.__name__}"
|
||||
)
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
# Cache miss or expired - fetch fresh data
|
||||
logger.debug(
|
||||
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
result = await async_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
# Store in cache
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
# Simple cleanup when cache gets too large
|
||||
if len(cache_storage) > maxsize:
|
||||
# Remove oldest entries (simple FIFO cleanup)
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
logger.debug(
|
||||
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
wrapper = async_wrapper
|
||||
|
||||
else:
|
||||
# Sync function with threading.Lock
|
||||
cache_lock = threading.Lock()
|
||||
|
||||
@wraps(target_func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
with cache_lock:
|
||||
# Double-check: another thread might have populated cache
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
# Add cache management methods
|
||||
# Add cache management methods (similar to functools.lru_cache)
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
@@ -245,84 +199,68 @@ def cached(
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
def cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
# Attach methods to wrapper
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
setattr(wrapper, "cache_delete", cache_delete)
|
||||
|
||||
return cast(CachedFunction, wrapper)
|
||||
return cast(AsyncCachedFunction[P, R], wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def thread_cached(func):
|
||||
"""
|
||||
Thread-local cache decorator for both sync and async functions.
|
||||
@overload
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
pass
|
||||
|
||||
Each thread gets its own cache, which is useful for request-scoped caching
|
||||
in web applications where you want to cache within a single request but
|
||||
not across requests.
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
pass
|
||||
|
||||
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]] | None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> (
|
||||
AsyncCachedFunction[P, R]
|
||||
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
|
||||
):
|
||||
"""
|
||||
Process-level cache decorator for async functions (no TTL).
|
||||
|
||||
Similar to functools.lru_cache but works with async functions.
|
||||
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||
|
||||
Args:
|
||||
func: The function to cache
|
||||
func: The async function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
|
||||
Returns:
|
||||
Decorated function with thread-local caching
|
||||
Decorated function or decorator
|
||||
|
||||
Example:
|
||||
@thread_cached
|
||||
def expensive_operation(param: str) -> dict:
|
||||
# Without parentheses (uses default maxsize=128)
|
||||
@async_cache
|
||||
async def get_data(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@thread_cached # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
# With parentheses and custom maxsize
|
||||
@async_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
# Expensive computation here
|
||||
return {"result": param}
|
||||
"""
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = await func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
if func is None:
|
||||
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
|
||||
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
"""Clear thread-local cache for a function."""
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
# Called without parentheses @async_cache
|
||||
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
return decorator(func)
|
||||
|
||||
@@ -16,7 +16,12 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import cached, clear_thread_cache, thread_cached
|
||||
from autogpt_libs.utils.cache import (
|
||||
async_cache,
|
||||
async_ttl_cache,
|
||||
clear_thread_cache,
|
||||
thread_cached,
|
||||
)
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
@@ -325,202 +330,102 @@ class TestThreadCached:
|
||||
assert mock.call_count == 2
|
||||
|
||||
|
||||
class TestCache:
|
||||
"""Tests for the unified @cache decorator (works for both sync and async)."""
|
||||
|
||||
def test_basic_sync_caching(self):
|
||||
"""Test basic sync caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
def expensive_sync_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = expensive_sync_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = expensive_sync_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = expensive_sync_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
class TestAsyncTTLCache:
|
||||
"""Tests for the @async_ttl_cache decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_async_caching(self):
|
||||
"""Test basic async caching functionality."""
|
||||
async def test_basic_caching(self):
|
||||
"""Test basic caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
async def expensive_async_function(x: int, y: int = 0) -> int:
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await expensive_async_function(1, 2)
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await expensive_async_function(1, 2)
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await expensive_async_function(2, 3)
|
||||
result3 = await cached_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
def test_sync_thundering_herd_protection(self):
|
||||
"""Test that concurrent sync calls don't cause thundering herd."""
|
||||
call_count = 0
|
||||
results = []
|
||||
|
||||
@cached()
|
||||
def slow_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
time.sleep(0.1) # Simulate expensive operation
|
||||
return x * x
|
||||
|
||||
def worker():
|
||||
result = slow_function(5)
|
||||
results.append(result)
|
||||
|
||||
# Launch multiple concurrent threads
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(worker) for _ in range(5)]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 25 for result in results)
|
||||
# Only one thread should have executed the expensive operation
|
||||
assert call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_thundering_herd_protection(self):
|
||||
"""Test that concurrent async calls don't cause thundering herd."""
|
||||
async def test_ttl_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
async def slow_async_function(x: int) -> int:
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def short_lived_cache(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.1) # Simulate expensive operation
|
||||
return x * x
|
||||
|
||||
# Launch concurrent coroutines
|
||||
tasks = [slow_async_function(7) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 49 for result in results)
|
||||
# Only one coroutine should have executed the expensive operation
|
||||
assert call_count == 1
|
||||
|
||||
def test_ttl_functionality(self):
|
||||
"""Test TTL functionality with sync function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
def ttl_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 3
|
||||
return x * 2
|
||||
|
||||
# First call
|
||||
result1 = ttl_function(3)
|
||||
assert result1 == 9
|
||||
result1 = await short_lived_cache(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = ttl_function(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
time.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = ttl_function(3)
|
||||
assert result3 == 9
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_ttl_functionality(self):
|
||||
"""Test TTL functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def async_ttl_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await async_ttl_function(3)
|
||||
assert result1 == 12
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await async_ttl_function(3)
|
||||
assert result2 == 12
|
||||
result2 = await short_lived_cache(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = await async_ttl_function(3)
|
||||
assert result3 == 12
|
||||
result3 = await short_lived_cache(5)
|
||||
assert result3 == 10
|
||||
assert call_count == 2
|
||||
|
||||
def test_cache_info(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info(self):
|
||||
"""Test cache info functionality."""
|
||||
|
||||
@cached(maxsize=10, ttl_seconds=60)
|
||||
def info_test_function(x: int) -> int:
|
||||
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 10
|
||||
assert info["ttl_seconds"] == 60
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] == 300
|
||||
|
||||
# Add an entry
|
||||
info_test_function(1)
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
def test_cache_clear(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
def clearable_function(x: int) -> int:
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = clearable_function(2)
|
||||
result1 = await clearable_function(2)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = clearable_function(2)
|
||||
result2 = await clearable_function(2)
|
||||
assert result2 == 8
|
||||
assert call_count == 1
|
||||
|
||||
@@ -528,149 +433,273 @@ class TestCache:
|
||||
clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = clearable_function(2)
|
||||
result3 = await clearable_function(2)
|
||||
assert result3 == 8
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_clear(self):
|
||||
"""Test cache clearing functionality with async function."""
|
||||
async def test_maxsize_cleanup(self):
|
||||
"""Test that cache cleans up when maxsize is exceeded."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
async def async_clearable_function(x: int) -> int:
|
||||
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||
async def size_limited_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 5
|
||||
return x**2
|
||||
|
||||
# First call
|
||||
result1 = await async_clearable_function(2)
|
||||
# Fill cache to maxsize
|
||||
await size_limited_function(1) # call_count: 1
|
||||
await size_limited_function(2) # call_count: 2
|
||||
await size_limited_function(3) # call_count: 3
|
||||
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Add one more entry - should trigger cleanup
|
||||
await size_limited_function(4) # call_count: 4
|
||||
|
||||
# Cache size should be reduced (cleanup removes oldest entries)
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_argument_variations(self):
|
||||
"""Test caching with different argument patterns."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# Different ways to call with same logical arguments
|
||||
result1 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
|
||||
# Same arguments, same order - should use cache
|
||||
result2 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Different arguments - should call function
|
||||
result3 = await arg_test_function(2, "test", c=200)
|
||||
assert call_count == 2
|
||||
assert result1 != result3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self):
|
||||
"""Test that exceptions are not cached."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def exception_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value not allowed")
|
||||
return x * 2
|
||||
|
||||
# Successful call - should be cached
|
||||
result1 = await exception_function(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await async_clearable_function(2)
|
||||
# Same successful call - should use cache
|
||||
result2 = await exception_function(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Clear cache
|
||||
async_clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await async_clearable_function(2)
|
||||
assert result3 == 10
|
||||
# Exception call - should not be cached
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
# Same exception call - should call again (not cached)
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_returns_results_not_coroutines(self):
|
||||
"""Test that cached async functions return actual results, not coroutines."""
|
||||
async def test_concurrent_calls(self):
|
||||
"""Test caching behavior with concurrent calls."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
async def async_result_function(x: int) -> str:
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def concurrent_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return f"result_{x}"
|
||||
await asyncio.sleep(0.05) # Simulate work
|
||||
return x * x
|
||||
|
||||
# Launch concurrent calls with same arguments
|
||||
tasks = [concurrent_function(3) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 9 for result in results)
|
||||
|
||||
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||
# This tests that the cache doesn't break under concurrent access
|
||||
assert 1 <= call_count <= 5
|
||||
|
||||
|
||||
class TestAsyncCache:
|
||||
"""Tests for the @async_cache decorator (no TTL)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching_no_ttl(self):
|
||||
"""Test basic caching functionality without TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_cache(maxsize=10)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await async_result_function(1)
|
||||
assert result1 == "result_1"
|
||||
assert isinstance(result1, str) # Should be string, not coroutine
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should return cached result (string), not coroutine
|
||||
result2 = await async_result_function(1)
|
||||
assert result2 == "result_1"
|
||||
assert isinstance(result2, str) # Should be string, not coroutine
|
||||
assert call_count == 1 # Function should not be called again
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Verify results are identical
|
||||
assert result1 is result2 # Should be same cached object
|
||||
# Third call after some time - should still use cache (no TTL)
|
||||
await asyncio.sleep(0.05)
|
||||
result3 = await cached_function(1, 2)
|
||||
assert result3 == 3
|
||||
assert call_count == 1 # Still no additional call
|
||||
|
||||
def test_cache_delete(self):
|
||||
"""Test selective cache deletion functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
def deletable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 6
|
||||
|
||||
# First call for x=1
|
||||
result1 = deletable_function(1)
|
||||
assert result1 == 6
|
||||
assert call_count == 1
|
||||
|
||||
# First call for x=2
|
||||
result2 = deletable_function(2)
|
||||
assert result2 == 12
|
||||
# Different args - should call function again
|
||||
result4 = await cached_function(2, 3)
|
||||
assert result4 == 5
|
||||
assert call_count == 2
|
||||
|
||||
# Second calls - should use cache
|
||||
assert deletable_function(1) == 6
|
||||
assert deletable_function(2) == 12
|
||||
assert call_count == 2
|
||||
|
||||
# Delete specific entry for x=1
|
||||
was_deleted = deletable_function.cache_delete(1)
|
||||
assert was_deleted is True
|
||||
|
||||
# Call with x=1 should execute function again
|
||||
result3 = deletable_function(1)
|
||||
assert result3 == 6
|
||||
assert call_count == 3
|
||||
|
||||
# Call with x=2 should still use cache
|
||||
assert deletable_function(2) == 12
|
||||
assert call_count == 3
|
||||
|
||||
# Try to delete non-existent entry
|
||||
was_deleted = deletable_function.cache_delete(99)
|
||||
assert was_deleted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_delete(self):
|
||||
"""Test selective cache deletion functionality with async function."""
|
||||
async def test_no_ttl_vs_ttl_behavior(self):
|
||||
"""Test the difference between TTL and no-TTL caching."""
|
||||
ttl_call_count = 0
|
||||
no_ttl_call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_call_count
|
||||
ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
@async_cache(maxsize=10) # No TTL
|
||||
async def no_ttl_function(x: int) -> int:
|
||||
nonlocal no_ttl_call_count
|
||||
no_ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First calls
|
||||
await ttl_function(5)
|
||||
await no_ttl_function(5)
|
||||
assert ttl_call_count == 1
|
||||
assert no_ttl_call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Second calls after TTL expiry
|
||||
await ttl_function(5) # Should call function again (TTL expired)
|
||||
await no_ttl_function(5) # Should use cache (no TTL)
|
||||
assert ttl_call_count == 2 # TTL function called again
|
||||
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_info(self):
|
||||
"""Test cache info for no-TTL cache."""
|
||||
|
||||
@async_cache(maxsize=5)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] is None # No TTL
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
|
||||
class TestTTLOptional:
|
||||
"""Tests for optional TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_none_behavior(self):
|
||||
"""Test that ttl_seconds=None works like no TTL."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
async def async_deletable_function(x: int) -> int:
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||
async def no_ttl_via_none(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 7
|
||||
return x**2
|
||||
|
||||
# First call for x=1
|
||||
result1 = await async_deletable_function(1)
|
||||
assert result1 == 7
|
||||
# First call
|
||||
result1 = await no_ttl_via_none(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# First call for x=2
|
||||
result2 = await async_deletable_function(2)
|
||||
assert result2 == 14
|
||||
assert call_count == 2
|
||||
# Wait (would expire if there was TTL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Second calls - should use cache
|
||||
assert await async_deletable_function(1) == 7
|
||||
assert await async_deletable_function(2) == 14
|
||||
assert call_count == 2
|
||||
# Second call - should still use cache
|
||||
result2 = await no_ttl_via_none(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Delete specific entry for x=1
|
||||
was_deleted = async_deletable_function.cache_delete(1)
|
||||
assert was_deleted is True
|
||||
# Check cache info
|
||||
info = no_ttl_via_none.cache_info()
|
||||
assert info["ttl_seconds"] is None
|
||||
|
||||
# Call with x=1 should execute function again
|
||||
result3 = await async_deletable_function(1)
|
||||
assert result3 == 7
|
||||
assert call_count == 3
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_options_comparison(self):
|
||||
"""Test different cache options work as expected."""
|
||||
ttl_calls = 0
|
||||
no_ttl_calls = 0
|
||||
|
||||
# Call with x=2 should still use cache
|
||||
assert await async_deletable_function(2) == 14
|
||||
assert call_count == 3
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_calls
|
||||
ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Try to delete non-existent entry
|
||||
was_deleted = async_deletable_function.cache_delete(99)
|
||||
assert was_deleted is False
|
||||
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||
async def process_function(x: int) -> int:
|
||||
nonlocal no_ttl_calls
|
||||
no_ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Both should cache initially
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Immediate second calls - both should use cache
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# After TTL expiry
|
||||
await ttl_function(3) # Should call function again
|
||||
await process_function(3) # Should still use cache
|
||||
assert ttl_calls == 2 # TTL cache expired, called again
|
||||
assert no_ttl_calls == 1 # Process cache never expires
|
||||
|
||||
76
autogpt_platform/autogpt_libs/poetry.lock
generated
76
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1002,18 +1002,6 @@ dynamodb = ["boto3 (>=1.9.71)"]
|
||||
redis = ["redis (>=2.10.5)"]
|
||||
test-filesource = ["pyyaml (>=5.3.1)", "watchdog (>=3.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.9.1"
|
||||
description = "Node.js virtual environment builder"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"},
|
||||
{file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.35.0"
|
||||
@@ -1359,27 +1347,6 @@ files = [
|
||||
{file = "pyrfc3339-2.0.1.tar.gz", hash = "sha256:e47843379ea35c1296c3b6c67a948a1a490ae0584edfcbdea0eaffb5dd29960b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyright"
|
||||
version = "1.1.404"
|
||||
description = "Command line wrapper for pyright"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pyright-1.1.404-py3-none-any.whl", hash = "sha256:c7b7ff1fdb7219c643079e4c3e7d4125f0dafcc19d253b47e898d130ea426419"},
|
||||
{file = "pyright-1.1.404.tar.gz", hash = "sha256:455e881a558ca6be9ecca0b30ce08aa78343ecc031d37a198ffa9a7a1abeb63e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nodeenv = ">=1.6.0"
|
||||
typing-extensions = ">=4.1"
|
||||
|
||||
[package.extras]
|
||||
all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"]
|
||||
dev = ["twine (>=3.4.1)"]
|
||||
nodejs = ["nodejs-wheel-binaries"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.4.1"
|
||||
@@ -1567,31 +1534,31 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.12.11"
|
||||
version = "0.12.9"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.12.11-py3-none-linux_armv6l.whl", hash = "sha256:93fce71e1cac3a8bf9200e63a38ac5c078f3b6baebffb74ba5274fb2ab276065"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8e33ac7b28c772440afa80cebb972ffd823621ded90404f29e5ab6d1e2d4b93"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d69fb9d4937aa19adb2e9f058bc4fbfe986c2040acb1a4a9747734834eaa0bfd"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:411954eca8464595077a93e580e2918d0a01a19317af0a72132283e28ae21bee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6a2c0a2e1a450f387bf2c6237c727dd22191ae8c00e448e0672d624b2bbd7fb0"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ca4c3a7f937725fd2413c0e884b5248a19369ab9bdd850b5781348ba283f644"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4d1df0098124006f6a66ecf3581a7f7e754c4df7644b2e6704cd7ca80ff95211"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a8dd5f230efc99a24ace3b77e3555d3fbc0343aeed3fc84c8d89e75ab2ff793"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dc75533039d0ed04cd33fb8ca9ac9620b99672fe7ff1533b6402206901c34ee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fc58f9266d62c6eccc75261a665f26b4ef64840887fc6cbc552ce5b29f96cc8"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:5a0113bd6eafd545146440225fe60b4e9489f59eb5f5f107acd715ba5f0b3d2f"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0d737b4059d66295c3ea5720e6efc152623bb83fde5444209b69cd33a53e2000"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:916fc5defee32dbc1fc1650b576a8fed68f5e8256e2180d4d9855aea43d6aab2"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c984f07d7adb42d3ded5be894fb4007f30f82c87559438b4879fe7aa08c62b39"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e07fbb89f2e9249f219d88331c833860489b49cdf4b032b8e4432e9b13e8a4b9"},
|
||||
{file = "ruff-0.12.11-py3-none-win32.whl", hash = "sha256:c792e8f597c9c756e9bcd4d87cf407a00b60af77078c96f7b6366ea2ce9ba9d3"},
|
||||
{file = "ruff-0.12.11-py3-none-win_amd64.whl", hash = "sha256:a3283325960307915b6deb3576b96919ee89432ebd9c48771ca12ee8afe4a0fd"},
|
||||
{file = "ruff-0.12.11-py3-none-win_arm64.whl", hash = "sha256:bae4d6e6a2676f8fb0f98b74594a048bae1b944aab17e9f5d504062303c6dbea"},
|
||||
{file = "ruff-0.12.11.tar.gz", hash = "sha256:c6b09ae8426a65bbee5425b9d0b82796dbb07cb1af045743c79bfb163001165d"},
|
||||
{file = "ruff-0.12.9-py3-none-linux_armv6l.whl", hash = "sha256:fcebc6c79fcae3f220d05585229463621f5dbf24d79fdc4936d9302e177cfa3e"},
|
||||
{file = "ruff-0.12.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aed9d15f8c5755c0e74467731a007fcad41f19bcce41cd75f768bbd687f8535f"},
|
||||
{file = "ruff-0.12.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5b15ea354c6ff0d7423814ba6d44be2807644d0c05e9ed60caca87e963e93f70"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d596c2d0393c2502eaabfef723bd74ca35348a8dac4267d18a94910087807c53"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1b15599931a1a7a03c388b9c5df1bfa62be7ede6eb7ef753b272381f39c3d0ff"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d02faa2977fb6f3f32ddb7828e212b7dd499c59eb896ae6c03ea5c303575756"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:17d5b6b0b3a25259b69ebcba87908496e6830e03acfb929ef9fd4c58675fa2ea"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72db7521860e246adbb43f6ef464dd2a532ef2ef1f5dd0d470455b8d9f1773e0"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a03242c1522b4e0885af63320ad754d53983c9599157ee33e77d748363c561ce"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fc83e4e9751e6c13b5046d7162f205d0a7bac5840183c5beebf824b08a27340"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:881465ed56ba4dd26a691954650de6ad389a2d1fdb130fe51ff18a25639fe4bb"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:43f07a3ccfc62cdb4d3a3348bf0588358a66da756aa113e071b8ca8c3b9826af"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:07adb221c54b6bba24387911e5734357f042e5669fa5718920ee728aba3cbadc"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f5cd34fabfdea3933ab85d72359f118035882a01bff15bd1d2b15261d85d5f66"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:f6be1d2ca0686c54564da8e7ee9e25f93bdd6868263805f8c0b8fc6a449db6d7"},
|
||||
{file = "ruff-0.12.9-py3-none-win32.whl", hash = "sha256:cc7a37bd2509974379d0115cc5608a1a4a6c4bff1b452ea69db83c8855d53f93"},
|
||||
{file = "ruff-0.12.9-py3-none-win_amd64.whl", hash = "sha256:6fb15b1977309741d7d098c8a3cb7a30bc112760a00fb6efb7abc85f00ba5908"},
|
||||
{file = "ruff-0.12.9-py3-none-win_arm64.whl", hash = "sha256:63c8c819739d86b96d500cce885956a1a48ab056bbcbc61b747ad494b2485089"},
|
||||
{file = "ruff-0.12.9.tar.gz", hash = "sha256:fbd94b2e3c623f659962934e52c2bea6fc6da11f667a427a368adaf3af2c866a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1773,6 +1740,7 @@ files = [
|
||||
{file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"},
|
||||
{file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"},
|
||||
]
|
||||
markers = {dev = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "typing-inspection"
|
||||
@@ -1929,4 +1897,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
|
||||
content-hash = "ef7818fba061cea2841c6d7ca4852acde83e4f73b32fca1315e58660002bb0d0"
|
||||
|
||||
@@ -9,7 +9,6 @@ packages = [{ include = "autogpt_libs" }]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
@@ -22,12 +21,11 @@ supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pyright = "^1.1.404"
|
||||
ruff = "^0.12.9"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
ruff = "^0.12.11"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -21,7 +21,7 @@ PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
# Redis Configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
# REDIS_PASSWORD=
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
# RabbitMQ Credentials
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
@@ -66,11 +66,6 @@ NVIDIA_API_KEY=
|
||||
GITHUB_CLIENT_ID=
|
||||
GITHUB_CLIENT_SECRET=
|
||||
|
||||
# Notion OAuth App server credentials - https://developers.notion.com/docs/authorization
|
||||
# Configure a public integration
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
|
||||
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
|
||||
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
|
||||
# You'll need to add/enable the following scopes (minimum):
|
||||
|
||||
10
autogpt_platform/backend/.gitignore
vendored
10
autogpt_platform/backend/.gitignore
vendored
@@ -9,12 +9,4 @@ secrets/*
|
||||
!secrets/.gitkeep
|
||||
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
|
||||
# Load test results and reports
|
||||
load-tests/*_RESULTS.md
|
||||
load-tests/*_REPORT.md
|
||||
load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
*.ign.*
|
||||
@@ -9,15 +9,8 @@ WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
# Install Node.js repository key and setup
|
||||
# Update package list and install Python and build dependencies
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y curl ca-certificates gnupg \
|
||||
&& mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list
|
||||
|
||||
# Update package list and install Python, Node.js, and build dependencies
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
python3.13 \
|
||||
python3.13-dev \
|
||||
@@ -27,9 +20,7 @@ RUN apt-get update \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
libssl-dev \
|
||||
postgresql-client \
|
||||
nodejs \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
postgresql-client
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
@@ -63,18 +54,13 @@ ENV PATH=/opt/poetry/bin:$PATH
|
||||
# Install Python without upgrading system-managed packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
python3-pip
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
# Copy Prisma binaries
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
@@ -5,8 +6,6 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -16,7 +15,7 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@cached()
|
||||
@functools.cache
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -10,7 +12,7 @@ from backend.data.block import (
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus, NodesInputMasks
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
from backend.util.retry import func_retry
|
||||
@@ -31,7 +33,7 @@ class AgentExecutorBlock(Block):
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
nodes_input_masks: Optional[NodesInputMasks] = SchemaField(
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
|
||||
default=None, hidden=True
|
||||
)
|
||||
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import MediaFileType
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
JPG = "jpg"
|
||||
PNG = "png"
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="replicate",
|
||||
api_key=SecretStr("mock-replicate-api-key"),
|
||||
title="Mock Replicate API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class AIImageCustomizerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Replicate API key with permissions for Google Gemini image models",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="A text description of the image you want to generate",
|
||||
title="Prompt",
|
||||
)
|
||||
model: GeminiImageModel = SchemaField(
|
||||
description="The AI model to use for image generation and editing",
|
||||
default=GeminiImageModel.NANO_BANANA,
|
||||
title="Model",
|
||||
)
|
||||
images: list[MediaFileType] = SchemaField(
|
||||
description="Optional list of input images to reference or modify",
|
||||
default=[],
|
||||
title="Input Images",
|
||||
)
|
||||
output_format: OutputFormat = SchemaField(
|
||||
description="Format of the output image",
|
||||
default=OutputFormat.PNG,
|
||||
title="Output Format",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
image_url: MediaFileType = SchemaField(description="URL of the generated image")
|
||||
error: str = SchemaField(description="Error message if generation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d76bbe4c-930e-4894-8469-b66775511f71",
|
||||
description=(
|
||||
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
|
||||
"Provide a prompt and optional reference images to create or modify images."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
input_schema=AIImageCustomizerBlock.Input,
|
||||
output_schema=AIImageCustomizerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Make the scene more vibrant and colorful",
|
||||
"model": GeminiImageModel.NANO_BANANA,
|
||||
"images": [],
|
||||
"output_format": OutputFormat.JPG,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||
"https://replicate.delivery/generated-image.jpg"
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.value,
|
||||
prompt=input_data.prompt,
|
||||
images=input_data.images,
|
||||
output_format=input_data.output_format.value,
|
||||
)
|
||||
yield "image_url", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
images: list[MediaFileType],
|
||||
output_format: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
|
||||
input_params: dict = {
|
||||
"prompt": prompt,
|
||||
"output_format": output_format,
|
||||
}
|
||||
|
||||
# Add images to input if provided (API expects "image_input" parameter)
|
||||
if images:
|
||||
input_params["image_input"] = [str(img) for img in images]
|
||||
|
||||
output: FileOutput | str = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
if isinstance(output, FileOutput):
|
||||
return MediaFileType(output.url)
|
||||
if isinstance(output, str):
|
||||
return MediaFileType(output)
|
||||
|
||||
raise ValueError("No output received from the model")
|
||||
@@ -661,167 +661,6 @@ async def update_field(
|
||||
#################################################################
|
||||
|
||||
|
||||
async def get_table_schema(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
table_id_or_name: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the schema for a specific table, including all field definitions.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The base ID
|
||||
table_id_or_name: The table ID or name
|
||||
|
||||
Returns:
|
||||
Dict containing table schema with fields information
|
||||
"""
|
||||
# First get all tables to find the right one
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
tables = data.get("tables", [])
|
||||
|
||||
# Find the matching table
|
||||
for table in tables:
|
||||
if table.get("id") == table_id_or_name or table.get("name") == table_id_or_name:
|
||||
return table
|
||||
|
||||
raise ValueError(f"Table '{table_id_or_name}' not found in base '{base_id}'")
|
||||
|
||||
|
||||
def get_empty_value_for_field(field_type: str) -> Any:
|
||||
"""
|
||||
Return the appropriate empty value for a given Airtable field type.
|
||||
|
||||
Args:
|
||||
field_type: The Airtable field type
|
||||
|
||||
Returns:
|
||||
The appropriate empty value for that field type
|
||||
"""
|
||||
# Fields that should be false when empty
|
||||
if field_type == "checkbox":
|
||||
return False
|
||||
|
||||
# Fields that should be empty arrays
|
||||
if field_type in [
|
||||
"multipleSelects",
|
||||
"multipleRecordLinks",
|
||||
"multipleAttachments",
|
||||
"multipleLookupValues",
|
||||
"multipleCollaborators",
|
||||
]:
|
||||
return []
|
||||
|
||||
# Fields that should be 0 when empty (numeric types)
|
||||
if field_type in [
|
||||
"number",
|
||||
"percent",
|
||||
"currency",
|
||||
"rating",
|
||||
"duration",
|
||||
"count",
|
||||
"autoNumber",
|
||||
]:
|
||||
return 0
|
||||
|
||||
# Fields that should be empty strings
|
||||
if field_type in [
|
||||
"singleLineText",
|
||||
"multilineText",
|
||||
"email",
|
||||
"url",
|
||||
"phoneNumber",
|
||||
"richText",
|
||||
"barcode",
|
||||
]:
|
||||
return ""
|
||||
|
||||
# Everything else gets null (dates, single selects, formulas, etc.)
|
||||
return None
|
||||
|
||||
|
||||
async def normalize_records(
|
||||
records: list[dict],
|
||||
table_schema: dict,
|
||||
include_field_metadata: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Normalize Airtable records to include all fields with proper empty values.
|
||||
|
||||
Args:
|
||||
records: List of record objects from Airtable API
|
||||
table_schema: Table schema containing field definitions
|
||||
include_field_metadata: Whether to include field metadata in response
|
||||
|
||||
Returns:
|
||||
Dict with normalized records and optionally field metadata
|
||||
"""
|
||||
fields = table_schema.get("fields", [])
|
||||
|
||||
# Normalize each record
|
||||
normalized_records = []
|
||||
for record in records:
|
||||
normalized = {
|
||||
"id": record.get("id"),
|
||||
"createdTime": record.get("createdTime"),
|
||||
"fields": {},
|
||||
}
|
||||
|
||||
# Add existing fields
|
||||
existing_fields = record.get("fields", {})
|
||||
|
||||
# Add all fields from schema, using empty values for missing ones
|
||||
for field in fields:
|
||||
field_name = field["name"]
|
||||
field_type = field["type"]
|
||||
|
||||
if field_name in existing_fields:
|
||||
# Field exists, use its value
|
||||
normalized["fields"][field_name] = existing_fields[field_name]
|
||||
else:
|
||||
# Field is missing, add appropriate empty value
|
||||
normalized["fields"][field_name] = get_empty_value_for_field(field_type)
|
||||
|
||||
normalized_records.append(normalized)
|
||||
|
||||
# Build result dictionary
|
||||
if include_field_metadata:
|
||||
field_metadata = {}
|
||||
for field in fields:
|
||||
metadata = {"type": field["type"], "id": field["id"]}
|
||||
|
||||
# Add type-specific metadata
|
||||
options = field.get("options", {})
|
||||
if field["type"] == "currency" and "symbol" in options:
|
||||
metadata["symbol"] = options["symbol"]
|
||||
metadata["precision"] = options.get("precision", 2)
|
||||
elif field["type"] == "duration" and "durationFormat" in options:
|
||||
metadata["format"] = options["durationFormat"]
|
||||
elif field["type"] == "percent" and "precision" in options:
|
||||
metadata["precision"] = options["precision"]
|
||||
elif (
|
||||
field["type"] in ["singleSelect", "multipleSelects"]
|
||||
and "choices" in options
|
||||
):
|
||||
metadata["choices"] = [choice["name"] for choice in options["choices"]]
|
||||
elif field["type"] == "rating" and "max" in options:
|
||||
metadata["max"] = options["max"]
|
||||
metadata["icon"] = options.get("icon", "star")
|
||||
metadata["color"] = options.get("color", "yellowBright")
|
||||
|
||||
field_metadata[field["name"]] = metadata
|
||||
|
||||
return {"records": normalized_records, "field_metadata": field_metadata}
|
||||
else:
|
||||
return {"records": normalized_records}
|
||||
|
||||
|
||||
async def list_records(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
@@ -1410,26 +1249,3 @@ async def list_bases(
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
async def get_base_tables(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get all tables for a specific base.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The ID of the base
|
||||
|
||||
Returns:
|
||||
list[dict]: List of table objects with their schemas
|
||||
"""
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
return data.get("tables", [])
|
||||
|
||||
@@ -14,13 +14,13 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import create_base, get_base_tables, list_bases
|
||||
from ._api import create_base, list_bases
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableCreateBaseBlock(Block):
|
||||
"""
|
||||
Creates a new base in an Airtable workspace, or returns existing base if one with the same name exists.
|
||||
Creates a new base in an Airtable workspace.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
@@ -31,10 +31,6 @@ class AirtableCreateBaseBlock(Block):
|
||||
description="The workspace ID where the base will be created"
|
||||
)
|
||||
name: str = SchemaField(description="The name of the new base")
|
||||
find_existing: bool = SchemaField(
|
||||
description="If true, return existing base with same name instead of creating duplicate",
|
||||
default=True,
|
||||
)
|
||||
tables: list[dict] = SchemaField(
|
||||
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
|
||||
default=[
|
||||
@@ -54,18 +50,14 @@ class AirtableCreateBaseBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
base_id: str = SchemaField(description="The ID of the created or found base")
|
||||
base_id: str = SchemaField(description="The ID of the created base")
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
table: dict = SchemaField(description="A single table object")
|
||||
was_created: bool = SchemaField(
|
||||
description="True if a new base was created, False if existing was found",
|
||||
default=True,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
|
||||
description="Create or find a base in Airtable",
|
||||
description="Create a new base in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
@@ -74,31 +66,6 @@ class AirtableCreateBaseBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# If find_existing is true, check if a base with this name already exists
|
||||
if input_data.find_existing:
|
||||
# List all bases to check for existing one with same name
|
||||
# Note: Airtable API doesn't have a direct search, so we need to list and filter
|
||||
existing_bases = await list_bases(credentials)
|
||||
|
||||
for base in existing_bases.get("bases", []):
|
||||
if base.get("name") == input_data.name:
|
||||
# Base already exists, return it
|
||||
base_id = base.get("id")
|
||||
yield "base_id", base_id
|
||||
yield "was_created", False
|
||||
|
||||
# Get the tables for this base
|
||||
try:
|
||||
tables = await get_base_tables(credentials, base_id)
|
||||
yield "tables", tables
|
||||
for table in tables:
|
||||
yield "table", table
|
||||
except Exception:
|
||||
# If we can't get tables, return empty list
|
||||
yield "tables", []
|
||||
return
|
||||
|
||||
# No existing base found or find_existing is false, create new one
|
||||
data = await create_base(
|
||||
credentials,
|
||||
input_data.workspace_id,
|
||||
@@ -107,7 +74,6 @@ class AirtableCreateBaseBlock(Block):
|
||||
)
|
||||
|
||||
yield "base_id", data.get("id", None)
|
||||
yield "was_created", True
|
||||
yield "tables", data.get("tables", [])
|
||||
for table in data.get("tables", []):
|
||||
yield "table", table
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Airtable record operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
@@ -18,9 +18,7 @@ from ._api import (
|
||||
create_record,
|
||||
delete_multiple_records,
|
||||
get_record,
|
||||
get_table_schema,
|
||||
list_records,
|
||||
normalize_records,
|
||||
update_multiple_records,
|
||||
)
|
||||
from ._config import airtable
|
||||
@@ -56,24 +54,12 @@ class AirtableListRecordsBlock(Block):
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return (comma-separated)", default=[]
|
||||
)
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of record objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more records)", default=None
|
||||
)
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -87,7 +73,6 @@ class AirtableListRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
data = await list_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -103,33 +88,8 @@ class AirtableListRecordsBlock(Block):
|
||||
fields=input_data.return_fields if input_data.return_fields else None,
|
||||
)
|
||||
|
||||
records = data.get("records", [])
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
records,
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
yield "records", normalized_data["records"]
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "records", records
|
||||
yield "offset", data.get("offset", None)
|
||||
yield "records", data.get("records", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
|
||||
class AirtableGetRecordBlock(Block):
|
||||
@@ -144,23 +104,11 @@ class AirtableGetRecordBlock(Block):
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
record_id: str = SchemaField(description="The record ID to retrieve")
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="The record ID")
|
||||
fields: dict = SchemaField(description="The record fields")
|
||||
created_time: str = SchemaField(description="The record created time")
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -174,7 +122,6 @@ class AirtableGetRecordBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
record = await get_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -182,34 +129,9 @@ class AirtableGetRecordBlock(Block):
|
||||
input_data.record_id,
|
||||
)
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the single record (wrap in list and unwrap result)
|
||||
normalized_data = await normalize_records(
|
||||
[record],
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
normalized_record = normalized_data["records"][0]
|
||||
yield "id", normalized_record.get("id", None)
|
||||
yield "fields", normalized_record.get("fields", None)
|
||||
yield "created_time", normalized_record.get("createdTime", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
|
||||
|
||||
class AirtableCreateRecordsBlock(Block):
|
||||
@@ -226,10 +148,6 @@ class AirtableCreateRecordsBlock(Block):
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to create (each with 'fields' object)"
|
||||
)
|
||||
skip_normalization: bool = SchemaField(
|
||||
description="Skip output normalization to get raw Airtable response (faster but may have missing fields)",
|
||||
default=False,
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
@@ -255,7 +173,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
# The create_record API expects records in a specific format
|
||||
data = await create_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -264,22 +182,8 @@ class AirtableCreateRecordsBlock(Block):
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=input_data.return_fields_by_field_id,
|
||||
)
|
||||
result_records = cast(list[dict], data.get("records", []))
|
||||
|
||||
# Normalize output unless explicitly disabled
|
||||
if not input_data.skip_normalization and result_records:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
result_records, table_schema, include_field_metadata=False
|
||||
)
|
||||
result_records = normalized_data["records"]
|
||||
|
||||
yield "records", result_records
|
||||
yield "records", data.get("records", [])
|
||||
details = data.get("details", None)
|
||||
if details:
|
||||
yield "details", details
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .text_overlay import BannerbearTextOverlayBlock
|
||||
|
||||
__all__ = ["BannerbearTextOverlayBlock"]
|
||||
@@ -1,8 +0,0 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
bannerbear = (
|
||||
ProviderBuilder("bannerbear")
|
||||
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -1,239 +0,0 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import bannerbear
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="bannerbear",
|
||||
api_key=SecretStr("mock-bannerbear-api-key"),
|
||||
title="Mock Bannerbear API Key",
|
||||
)
|
||||
|
||||
|
||||
class TextModification(BlockSchema):
|
||||
name: str = SchemaField(
|
||||
description="The name of the layer to modify in the template"
|
||||
)
|
||||
text: str = SchemaField(description="The text content to add to this layer")
|
||||
color: str = SchemaField(
|
||||
description="Hex color code for the text (e.g., '#FF0000')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_family: str = SchemaField(
|
||||
description="Font family to use for the text",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_size: int = SchemaField(
|
||||
description="Font size in pixels",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
font_weight: str = SchemaField(
|
||||
description="Font weight (e.g., 'bold', 'normal')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_align: str = SchemaField(
|
||||
description="Text alignment (left, center, right)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class BannerbearTextOverlayBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = bannerbear.credentials_field(
|
||||
description="API credentials for Bannerbear"
|
||||
)
|
||||
template_id: str = SchemaField(
|
||||
description="The unique ID of your Bannerbear template"
|
||||
)
|
||||
project_id: str = SchemaField(
|
||||
description="Optional: Project ID (required when using Master API Key)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_modifications: List[TextModification] = SchemaField(
|
||||
description="List of text layers to modify in the template"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="Optional: URL of an image to use in the template",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
image_layer_name: str = SchemaField(
|
||||
description="Optional: Name of the image layer in the template",
|
||||
default="photo",
|
||||
advanced=True,
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="Optional: URL to receive webhook notification when image is ready",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: str = SchemaField(
|
||||
description="Optional: Custom metadata to attach to the image",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the image generation was successfully initiated"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="URL of the generated image (if synchronous) or placeholder"
|
||||
)
|
||||
uid: str = SchemaField(description="Unique identifier for the generated image")
|
||||
status: str = SchemaField(description="Status of the image generation")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c7d3a5c2-05fc-450e-8dce-3b0e04626009",
|
||||
description="Add text overlay to images using Bannerbear templates. Perfect for creating social media graphics, marketing materials, and dynamic image content.",
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"template_id": "jJWBKNELpQPvbX5R93Gk",
|
||||
"text_modifications": [
|
||||
{
|
||||
"name": "headline",
|
||||
"text": "Amazing Product Launch!",
|
||||
"color": "#FF0000",
|
||||
},
|
||||
{
|
||||
"name": "subtitle",
|
||||
"text": "50% OFF Today Only",
|
||||
},
|
||||
],
|
||||
"credentials": {
|
||||
"provider": "bannerbear",
|
||||
"id": str(uuid.uuid4()),
|
||||
"type": "api_key",
|
||||
},
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
||||
("uid", "test-uid-123"),
|
||||
("status", "completed"),
|
||||
],
|
||||
test_mock={
|
||||
"_make_api_request": lambda *args, **kwargs: {
|
||||
"uid": "test-uid-123",
|
||||
"status": "completed",
|
||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
||||
}
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def _make_api_request(self, payload: dict, api_key: str) -> dict:
|
||||
"""Make the actual API request to Bannerbear. This is separated for easy mocking in tests."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
"https://sync.api.bannerbear.com/v2/images",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status in [200, 201, 202]:
|
||||
return response.json()
|
||||
else:
|
||||
error_msg = f"API request failed with status {response.status}"
|
||||
if response.text:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = (
|
||||
f"{error_msg}: {error_data.get('message', response.text)}"
|
||||
)
|
||||
except Exception:
|
||||
error_msg = f"{error_msg}: {response.text}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Build the modifications array
|
||||
modifications = []
|
||||
|
||||
# Add text modifications
|
||||
for text_mod in input_data.text_modifications:
|
||||
mod_data: Dict[str, Any] = {
|
||||
"name": text_mod.name,
|
||||
"text": text_mod.text,
|
||||
}
|
||||
|
||||
# Add optional text styling parameters only if they have values
|
||||
if text_mod.color and text_mod.color.strip():
|
||||
mod_data["color"] = text_mod.color
|
||||
if text_mod.font_family and text_mod.font_family.strip():
|
||||
mod_data["font_family"] = text_mod.font_family
|
||||
if text_mod.font_size and text_mod.font_size > 0:
|
||||
mod_data["font_size"] = text_mod.font_size
|
||||
if text_mod.font_weight and text_mod.font_weight.strip():
|
||||
mod_data["font_weight"] = text_mod.font_weight
|
||||
if text_mod.text_align and text_mod.text_align.strip():
|
||||
mod_data["text_align"] = text_mod.text_align
|
||||
|
||||
modifications.append(mod_data)
|
||||
|
||||
# Add image modification if provided and not empty
|
||||
if input_data.image_url and input_data.image_url.strip():
|
||||
modifications.append(
|
||||
{
|
||||
"name": input_data.image_layer_name,
|
||||
"image_url": input_data.image_url,
|
||||
}
|
||||
)
|
||||
|
||||
# Build the request payload - only include non-empty optional fields
|
||||
payload = {
|
||||
"template": input_data.template_id,
|
||||
"modifications": modifications,
|
||||
}
|
||||
|
||||
# Add project_id if provided (required for Master API keys)
|
||||
if input_data.project_id and input_data.project_id.strip():
|
||||
payload["project_id"] = input_data.project_id
|
||||
|
||||
if input_data.webhook_url and input_data.webhook_url.strip():
|
||||
payload["webhook_url"] = input_data.webhook_url
|
||||
if input_data.metadata and input_data.metadata.strip():
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
# Make the API request using the private method
|
||||
data = await self._make_api_request(
|
||||
payload, credentials.api_key.get_secret_value()
|
||||
)
|
||||
|
||||
# Synchronous request - image should be ready
|
||||
yield "success", True
|
||||
yield "image_url", data.get("image_url", "")
|
||||
yield "uid", data.get("uid", "")
|
||||
yield "status", data.get("status", "completed")
|
||||
@@ -113,7 +113,6 @@ class DataForSeoClient:
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
depth: Optional[int] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get related keywords from DataForSEO Labs.
|
||||
@@ -126,7 +125,6 @@ class DataForSeoClient:
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
depth: Keyword search depth (0-4), controls number of returned keywords
|
||||
|
||||
Returns:
|
||||
API response with related keywords
|
||||
@@ -150,8 +148,6 @@ class DataForSeoClient:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
if depth is not None:
|
||||
task_data["depth"] = depth
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
|
||||
@@ -78,12 +78,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
depth: int = SchemaField(
|
||||
description="Keyword search depth (0-4). Controls the number of returned keywords: 0=1 keyword, 1=~8 keywords, 2=~72 keywords, 3=~584 keywords, 4=~4680 keywords",
|
||||
default=1,
|
||||
ge=0,
|
||||
le=4,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
related_keywords: List[RelatedKeyword] = SchemaField(
|
||||
@@ -160,7 +154,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
depth=input_data.depth,
|
||||
)
|
||||
|
||||
async def run(
|
||||
|
||||
@@ -1094,117 +1094,6 @@ class GmailGetThreadBlock(GmailBase):
|
||||
return thread
|
||||
|
||||
|
||||
async def _build_reply_message(
|
||||
service, input_data, graph_exec_id: str, user_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Builds a reply MIME message for Gmail threads.
|
||||
|
||||
Returns:
|
||||
tuple: (base64-encoded raw message, threadId)
|
||||
"""
|
||||
# Get parent message for reply context
|
||||
parent = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.get(
|
||||
userId="me",
|
||||
id=input_data.parentMessageId,
|
||||
format="metadata",
|
||||
metadataHeaders=[
|
||||
"Subject",
|
||||
"References",
|
||||
"Message-ID",
|
||||
"From",
|
||||
"To",
|
||||
"Cc",
|
||||
"Reply-To",
|
||||
],
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Build headers dictionary, preserving all values for duplicate headers
|
||||
headers = {}
|
||||
for h in parent.get("payload", {}).get("headers", []):
|
||||
name = h["name"].lower()
|
||||
value = h["value"]
|
||||
if name in headers:
|
||||
# For duplicate headers, keep the first occurrence (most relevant for reply context)
|
||||
continue
|
||||
headers[name] = value
|
||||
|
||||
# Determine recipients if not specified
|
||||
if not (input_data.to or input_data.cc or input_data.bcc):
|
||||
if input_data.replyAll:
|
||||
recipients = [parseaddr(headers.get("from", ""))[1]]
|
||||
recipients += [addr for _, addr in getaddresses([headers.get("to", "")])]
|
||||
recipients += [addr for _, addr in getaddresses([headers.get("cc", "")])]
|
||||
# Use dict.fromkeys() for O(n) deduplication while preserving order
|
||||
input_data.to = list(dict.fromkeys(filter(None, recipients)))
|
||||
else:
|
||||
# Check Reply-To header first, fall back to From header
|
||||
reply_to = headers.get("reply-to", "")
|
||||
from_addr = headers.get("from", "")
|
||||
sender = parseaddr(reply_to if reply_to else from_addr)[1]
|
||||
input_data.to = [sender] if sender else []
|
||||
|
||||
# Set subject with Re: prefix if not already present
|
||||
if input_data.subject:
|
||||
subject = input_data.subject
|
||||
else:
|
||||
parent_subject = headers.get("subject", "").strip()
|
||||
# Only add "Re:" if not already present (case-insensitive check)
|
||||
if parent_subject.lower().startswith("re:"):
|
||||
subject = parent_subject
|
||||
else:
|
||||
subject = f"Re: {parent_subject}" if parent_subject else "Re:"
|
||||
|
||||
# Build references header for proper threading
|
||||
references = headers.get("references", "").split()
|
||||
if headers.get("message-id"):
|
||||
references.append(headers["message-id"])
|
||||
|
||||
# Create MIME message
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
if references:
|
||||
msg["References"] = " ".join(references)
|
||||
|
||||
# Use the helper function for consistent content type handling
|
||||
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
|
||||
|
||||
# Handle attachments
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
encoders.encode_base64(part)
|
||||
part.add_header(
|
||||
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
|
||||
)
|
||||
msg.attach(part)
|
||||
|
||||
# Encode message
|
||||
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
||||
return raw, input_data.threadId
|
||||
|
||||
|
||||
class GmailReplyBlock(GmailBase):
|
||||
"""
|
||||
Replies to Gmail threads with intelligent content type detection.
|
||||
@@ -1341,144 +1230,91 @@ class GmailReplyBlock(GmailBase):
|
||||
async def _reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Send the message
|
||||
return await asyncio.to_thread(
|
||||
parent = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"threadId": thread_id, "raw": raw})
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
class GmailDraftReplyBlock(GmailBase):
|
||||
"""
|
||||
Creates draft replies to Gmail threads with intelligent content type detection.
|
||||
|
||||
Features:
|
||||
- Automatic HTML detection: Draft replies containing HTML tags are formatted as text/html
|
||||
- No hard-wrap for plain text: Plain text draft replies preserve natural line flow
|
||||
- Manual content type override: Use content_type parameter to force specific format
|
||||
- Reply-all functionality: Option to reply to all original recipients
|
||||
- Thread preservation: Maintains proper email threading with headers
|
||||
- Full Unicode/emoji support with UTF-8 encoding
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
[
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
]
|
||||
)
|
||||
threadId: str = SchemaField(description="Thread ID to reply in")
|
||||
parentMessageId: str = SchemaField(
|
||||
description="ID of the message being replied to"
|
||||
)
|
||||
to: list[str] = SchemaField(description="To recipients", default_factory=list)
|
||||
cc: list[str] = SchemaField(description="CC recipients", default_factory=list)
|
||||
bcc: list[str] = SchemaField(description="BCC recipients", default_factory=list)
|
||||
replyAll: bool = SchemaField(
|
||||
description="Reply to all original recipients", default=False
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject", default="")
|
||||
body: str = SchemaField(description="Email body (plain text or HTML)")
|
||||
content_type: Optional[Literal["auto", "plain", "html"]] = SchemaField(
|
||||
description="Content type: 'auto' (default - detects HTML), 'plain', or 'html'",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
attachments: list[MediaFileType] = SchemaField(
|
||||
description="Files to attach", default_factory=list, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
draftId: str = SchemaField(description="Created draft ID")
|
||||
messageId: str = SchemaField(description="Draft message ID")
|
||||
threadId: str = SchemaField(description="Thread ID")
|
||||
status: str = SchemaField(description="Draft creation status")
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d7a9f3e2-8b4c-4d6f-9e1a-3c5b7f8d2a6e",
|
||||
description="Create draft replies to Gmail threads with automatic HTML detection and proper text formatting. Plain text draft replies maintain natural paragraph flow without 78-character line wrapping. HTML content is automatically detected and formatted correctly.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=GmailDraftReplyBlock.Input,
|
||||
output_schema=GmailDraftReplyBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"threadId": "t1",
|
||||
"parentMessageId": "m1",
|
||||
"body": "Thanks for your message. I'll review and get back to you.",
|
||||
"replyAll": False,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("draftId", "draft1"),
|
||||
("messageId", "m2"),
|
||||
("threadId", "t1"),
|
||||
("status", "draft_created"),
|
||||
],
|
||||
test_mock={
|
||||
"_create_draft_reply": lambda *args, **kwargs: {
|
||||
"id": "draft1",
|
||||
"message": {"id": "m2", "threadId": "t1"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
draft = await self._create_draft_reply(
|
||||
service,
|
||||
input_data,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
)
|
||||
yield "draftId", draft["id"]
|
||||
yield "messageId", draft["message"]["id"]
|
||||
yield "threadId", draft["message"].get("threadId", input_data.threadId)
|
||||
yield "status", "draft_created"
|
||||
|
||||
async def _create_draft_reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Create draft with proper thread association
|
||||
draft = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.drafts()
|
||||
.create(
|
||||
.get(
|
||||
userId="me",
|
||||
body={
|
||||
"message": {
|
||||
"threadId": thread_id,
|
||||
"raw": raw,
|
||||
}
|
||||
},
|
||||
id=input_data.parentMessageId,
|
||||
format="metadata",
|
||||
metadataHeaders=[
|
||||
"Subject",
|
||||
"References",
|
||||
"Message-ID",
|
||||
"From",
|
||||
"To",
|
||||
"Cc",
|
||||
"Reply-To",
|
||||
],
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return draft
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in parent.get("payload", {}).get("headers", [])
|
||||
}
|
||||
if not (input_data.to or input_data.cc or input_data.bcc):
|
||||
if input_data.replyAll:
|
||||
recipients = [parseaddr(headers.get("from", ""))[1]]
|
||||
recipients += [
|
||||
addr for _, addr in getaddresses([headers.get("to", "")])
|
||||
]
|
||||
recipients += [
|
||||
addr for _, addr in getaddresses([headers.get("cc", "")])
|
||||
]
|
||||
dedup: list[str] = []
|
||||
for r in recipients:
|
||||
if r and r not in dedup:
|
||||
dedup.append(r)
|
||||
input_data.to = dedup
|
||||
else:
|
||||
sender = parseaddr(headers.get("reply-to", headers.get("from", "")))[1]
|
||||
input_data.to = [sender] if sender else []
|
||||
subject = input_data.subject or (f"Re: {headers.get('subject', '')}".strip())
|
||||
references = headers.get("references", "").split()
|
||||
if headers.get("message-id"):
|
||||
references.append(headers["message-id"])
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
if references:
|
||||
msg["References"] = " ".join(references)
|
||||
# Use the new helper function for consistent content type handling
|
||||
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
|
||||
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
encoders.encode_base64(part)
|
||||
part.add_header(
|
||||
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
|
||||
)
|
||||
msg.attach(part)
|
||||
|
||||
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
||||
return await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"threadId": input_data.threadId, "raw": raw})
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
class GmailGetProfileBlock(GmailBase):
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.util.settings import Config
|
||||
from backend.util.text import TextFormatter
|
||||
from backend.util.type import LongTextType, MediaFileType, ShortTextType
|
||||
|
||||
formatter = TextFormatter()
|
||||
config = Config()
|
||||
|
||||
|
||||
@@ -131,11 +132,6 @@ class AgentOutputBlock(Block):
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
escape_html: bool = SchemaField(
|
||||
default=False,
|
||||
advanced=True,
|
||||
description="Whether to escape special characters in the inserted values to be HTML-safe. Enable for HTML output, disable for plain text.",
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to treat the output as advanced.",
|
||||
default=False,
|
||||
@@ -197,7 +193,6 @@ class AgentOutputBlock(Block):
|
||||
"""
|
||||
if input_data.format:
|
||||
try:
|
||||
formatter = TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
# This file contains a lot of prompt block strings that would trigger "line too long"
|
||||
# flake8: noqa: E501
|
||||
import ast
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
from abc import ABC
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
@@ -31,7 +27,7 @@ from backend.util.prompt import compress_prompt, estimate_token_count
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
fmt = TextFormatter()
|
||||
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.AIML_API,
|
||||
@@ -208,13 +204,13 @@ MODEL_METADATA = {
|
||||
"anthropic", 200000, 32000
|
||||
), # claude-opus-4-1-20250805
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-4-opus-20250514
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
@@ -386,9 +382,7 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_parallel_tool_calls_param(
|
||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||
):
|
||||
def get_parallel_tool_calls_param(llm_model: LlmModel, parallel_tool_calls):
|
||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
return openai.NOT_GIVEN
|
||||
@@ -399,8 +393,8 @@ async def llm_call(
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
max_tokens: int | None,
|
||||
force_json_output: bool = False,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
parallel_tool_calls=None,
|
||||
@@ -413,7 +407,7 @@ async def llm_call(
|
||||
credentials: The API key credentials to use.
|
||||
llm_model: The LLM model to use.
|
||||
prompt: The prompt to send to the LLM.
|
||||
force_json_output: Whether the response should be in JSON format.
|
||||
json_format: Whether the response should be in JSON format.
|
||||
max_tokens: The maximum number of tokens to generate in the chat completion.
|
||||
tools: The tools to use in the chat completion.
|
||||
ollama_host: The host for ollama to use.
|
||||
@@ -452,7 +446,7 @@ async def llm_call(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
if force_json_output:
|
||||
if json_format:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
@@ -565,7 +559,7 @@ async def llm_call(
|
||||
raise ValueError("Groq does not support tools.")
|
||||
|
||||
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if force_json_output else None
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
response = await client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
@@ -723,7 +717,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
response_format = None
|
||||
if force_json_output:
|
||||
if json_format:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
@@ -786,17 +780,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
force_json_output: bool = SchemaField(
|
||||
title="Restrict LLM to pure JSON output",
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to force the LLM to produce a JSON-only response. "
|
||||
"This can increase the block's reliability, "
|
||||
"but may also reduce the quality of the response "
|
||||
"because it prohibits the LLM from reasoning "
|
||||
"before providing its JSON response."
|
||||
),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
@@ -865,18 +848,17 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[""],
|
||||
response=(
|
||||
'<json_output id="test123456">{\n'
|
||||
' "key1": "key1Value",\n'
|
||||
' "key2": "key2Value"\n'
|
||||
"}</json_output>"
|
||||
response=json.dumps(
|
||||
{
|
||||
"key1": "key1Value",
|
||||
"key2": "key2Value",
|
||||
}
|
||||
),
|
||||
tool_calls=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
reasoning=None,
|
||||
),
|
||||
"get_collision_proof_output_tag_id": lambda *args: "test123456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -885,9 +867,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
compress_prompt_to_fit: bool,
|
||||
max_tokens: int | None,
|
||||
force_json_output: bool = False,
|
||||
compress_prompt_to_fit: bool = True,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
) -> LLMResponse:
|
||||
@@ -900,8 +882,8 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
json_format=json_format,
|
||||
max_tokens=max_tokens,
|
||||
force_json_output=force_json_output,
|
||||
tools=tools,
|
||||
ollama_host=ollama_host,
|
||||
compress_prompt_to_fit=compress_prompt_to_fit,
|
||||
@@ -913,6 +895,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(f"Calling LLM with input data: {input_data}")
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history]
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = fmt.format_string(input_data.prompt, values)
|
||||
@@ -921,15 +907,27 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
|
||||
# Use a one-time unique tag to prevent collisions with user/LLM content
|
||||
output_tag_id = self.get_collision_proof_output_tag_id()
|
||||
output_tag_start = f'<json_output id="{output_tag_id}">'
|
||||
if input_data.expected_format:
|
||||
sys_prompt = self.response_format_instructions(
|
||||
input_data.expected_format,
|
||||
list_mode=input_data.list_result,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
expected_format = [
|
||||
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
|
||||
]
|
||||
if input_data.list_result:
|
||||
format_prompt = (
|
||||
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
|
||||
)
|
||||
else:
|
||||
format_prompt = "\n ".join(expected_format)
|
||||
|
||||
sys_prompt = trim_prompt(
|
||||
f"""
|
||||
|Reply strictly only in the following JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. Do not include any additional text outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
|
||||
@@ -947,21 +945,18 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
except JSONDecodeError as e:
|
||||
return f"JSON decode error: {e}"
|
||||
|
||||
error_feedback_message = ""
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
llm_model = input_data.model
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
try:
|
||||
llm_response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
compress_prompt_to_fit=input_data.compress_prompt_to_fit,
|
||||
force_json_output=(
|
||||
input_data.force_json_output
|
||||
and bool(input_data.expected_format)
|
||||
),
|
||||
json_format=bool(input_data.expected_format),
|
||||
ollama_host=input_data.ollama_host,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
@@ -975,55 +970,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
try:
|
||||
response_obj = self.get_json_from_response(
|
||||
response_text,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
except (ValueError, JSONDecodeError) as parse_error:
|
||||
censored_response = re.sub(r"[A-Za-z0-9]", "*", response_text)
|
||||
response_snippet = (
|
||||
f"{censored_response[:50]}...{censored_response[-30:]}"
|
||||
)
|
||||
logger.warning(
|
||||
f"Error getting JSON from LLM response: {parse_error}\n\n"
|
||||
f"Response start+end: `{response_snippet}`"
|
||||
)
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
|
||||
error_feedback_message = self.invalid_response_feedback(
|
||||
parse_error,
|
||||
was_parseable=False,
|
||||
list_mode=input_data.list_result,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "user", "content": error_feedback_message}
|
||||
)
|
||||
continue
|
||||
response_obj = json.loads(response_text)
|
||||
|
||||
# Handle object response for `force_json_output`+`list_result`
|
||||
if input_data.list_result and isinstance(response_obj, dict):
|
||||
if "results" in response_obj and isinstance(
|
||||
response_obj["results"], list
|
||||
):
|
||||
response_obj = response_obj["results"]
|
||||
else:
|
||||
error_feedback_message = (
|
||||
"Expected an array of objects in the 'results' key, "
|
||||
f"but got: {response_obj}"
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": response_text}
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "user", "content": error_feedback_message}
|
||||
)
|
||||
continue
|
||||
if "results" in response_obj:
|
||||
response_obj = response_obj.get("results", [])
|
||||
elif len(response_obj) == 1:
|
||||
response_obj = list(response_obj.values())
|
||||
|
||||
validation_errors = "\n".join(
|
||||
response_error = "\n".join(
|
||||
[
|
||||
validation_error
|
||||
for response_item in (
|
||||
@@ -1035,7 +991,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
]
|
||||
)
|
||||
|
||||
if not validation_errors:
|
||||
if not response_error:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
@@ -1045,16 +1001,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "response", response_obj
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
error_feedback_message = self.invalid_response_feedback(
|
||||
validation_errors,
|
||||
was_parseable=True,
|
||||
list_mode=input_data.list_result,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
prompt.append({"role": "user", "content": error_feedback_message})
|
||||
else:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
@@ -1065,6 +1011,21 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "response", {"response": response_text}
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
|
||||
retry_prompt = trim_prompt(
|
||||
f"""
|
||||
|This is your previous error response:
|
||||
|--
|
||||
|{response_text}
|
||||
|--
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{response_error}
|
||||
|--
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "user", "content": retry_prompt})
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
@@ -1077,133 +1038,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(
|
||||
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
|
||||
)
|
||||
# Don't add retry prompt for token limit errors,
|
||||
# just retry with lower maximum output tokens
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
|
||||
error_feedback_message = f"Error calling LLM: {e}"
|
||||
|
||||
raise RuntimeError(error_feedback_message)
|
||||
|
||||
def response_format_instructions(
|
||||
self,
|
||||
expected_object_format: dict[str, str],
|
||||
*,
|
||||
list_mode: bool,
|
||||
pure_json_mode: bool,
|
||||
output_tag_start: str,
|
||||
) -> str:
|
||||
expected_output_format = json.dumps(expected_object_format, indent=2)
|
||||
output_type = "object" if not list_mode else "array"
|
||||
outer_output_type = "object" if pure_json_mode else output_type
|
||||
|
||||
if output_type == "array":
|
||||
indented_obj_format = expected_output_format.replace("\n", "\n ")
|
||||
expected_output_format = f"[\n {indented_obj_format},\n ...\n]"
|
||||
if pure_json_mode:
|
||||
indented_list_format = expected_output_format.replace("\n", "\n ")
|
||||
expected_output_format = (
|
||||
"{\n"
|
||||
' "reasoning": "... (optional)",\n' # for better performance
|
||||
f' "results": {indented_list_format}\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
# Preserve indentation in prompt
|
||||
expected_output_format = expected_output_format.replace("\n", "\n|")
|
||||
|
||||
# Prepare prompt
|
||||
if not pure_json_mode:
|
||||
expected_output_format = (
|
||||
f"{output_tag_start}\n{expected_output_format}\n</json_output>"
|
||||
)
|
||||
|
||||
instructions = f"""
|
||||
|In your response you MUST include a valid JSON {outer_output_type} strictly following this format:
|
||||
|{expected_output_format}
|
||||
|
|
||||
|If you cannot provide all the keys, you MUST provide an empty string for the values you cannot answer.
|
||||
""".strip()
|
||||
|
||||
if not pure_json_mode:
|
||||
instructions += f"""
|
||||
|
|
||||
|You MUST enclose your final JSON answer in {output_tag_start}...</json_output> tags, even if the user specifies a different tag.
|
||||
|There MUST be exactly ONE {output_tag_start}...</json_output> block in your response, which MUST ONLY contain the JSON {outer_output_type} and nothing else. Other text outside this block is allowed.
|
||||
""".strip()
|
||||
|
||||
return trim_prompt(instructions)
|
||||
|
||||
def invalid_response_feedback(
|
||||
self,
|
||||
error,
|
||||
*,
|
||||
was_parseable: bool,
|
||||
list_mode: bool,
|
||||
pure_json_mode: bool,
|
||||
output_tag_start: str,
|
||||
) -> str:
|
||||
outer_output_type = "object" if not list_mode or pure_json_mode else "array"
|
||||
|
||||
if was_parseable:
|
||||
complaint = f"Your previous response did not match the expected {outer_output_type} format."
|
||||
else:
|
||||
complaint = f"Your previous response did not contain a parseable JSON {outer_output_type}."
|
||||
|
||||
indented_parse_error = str(error).replace("\n", "\n|")
|
||||
|
||||
instruction = (
|
||||
f"Please provide a {output_tag_start}...</json_output> block containing a"
|
||||
if not pure_json_mode
|
||||
else "Please provide a"
|
||||
) + f" valid JSON {outer_output_type} that matches the expected format."
|
||||
|
||||
return trim_prompt(
|
||||
f"""
|
||||
|{complaint}
|
||||
|
|
||||
|{indented_parse_error}
|
||||
|
|
||||
|{instruction}
|
||||
"""
|
||||
)
|
||||
|
||||
def get_json_from_response(
|
||||
self, response_text: str, *, pure_json_mode: bool, output_tag_start: str
|
||||
) -> dict[str, Any] | list[dict[str, Any]]:
|
||||
if pure_json_mode:
|
||||
# Handle pure JSON responses
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except JSONDecodeError as first_parse_error:
|
||||
# If that didn't work, try finding the { and } to deal with possible ```json fences etc.
|
||||
json_start = response_text.find("{")
|
||||
json_end = response_text.rfind("}")
|
||||
try:
|
||||
return json.loads(response_text[json_start : json_end + 1])
|
||||
except JSONDecodeError:
|
||||
# Raise the original error, as it's more likely to be relevant
|
||||
raise first_parse_error from None
|
||||
|
||||
if output_tag_start not in response_text:
|
||||
raise ValueError(
|
||||
"Response does not contain the expected "
|
||||
f"{output_tag_start}...</json_output> block."
|
||||
)
|
||||
json_output = (
|
||||
response_text.split(output_tag_start, 1)[1]
|
||||
.rsplit("</json_output>", 1)[0]
|
||||
.strip()
|
||||
)
|
||||
return json.loads(json_output)
|
||||
|
||||
def get_collision_proof_output_tag_id(self) -> str:
|
||||
return secrets.token_hex(8)
|
||||
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
"""Removes indentation up to and including `|` from a multi-line prompt."""
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
raise RuntimeError(retry_prompt)
|
||||
|
||||
|
||||
class AITextGeneratorBlock(AIBlockBase):
|
||||
|
||||
@@ -1,536 +0,0 @@
|
||||
"""
|
||||
Notion API helper functions and client for making authenticated requests.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
NOTION_VERSION = "2022-06-28"
|
||||
|
||||
|
||||
class NotionAPIException(Exception):
|
||||
"""Exception raised for Notion API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class NotionClient:
|
||||
"""Client for interacting with the Notion API."""
|
||||
|
||||
def __init__(self, credentials: OAuth2Credentials):
|
||||
self.credentials = credentials
|
||||
self.headers = {
|
||||
"Authorization": credentials.auth_header(),
|
||||
"Notion-Version": NOTION_VERSION,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.requests = Requests()
|
||||
|
||||
async def get_page(self, page_id: str) -> dict:
|
||||
"""
|
||||
Fetch a page by ID.
|
||||
|
||||
Args:
|
||||
page_id: The ID of the page to fetch.
|
||||
|
||||
Returns:
|
||||
The page object from Notion API.
|
||||
"""
|
||||
url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
response = await self.requests.get(url, headers=self.headers)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to fetch page: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def get_blocks(self, block_id: str, recursive: bool = True) -> List[dict]:
|
||||
"""
|
||||
Fetch all blocks from a page or block.
|
||||
|
||||
Args:
|
||||
block_id: The ID of the page or block to fetch children from.
|
||||
recursive: Whether to fetch nested blocks recursively.
|
||||
|
||||
Returns:
|
||||
List of block objects.
|
||||
"""
|
||||
blocks = []
|
||||
cursor = None
|
||||
|
||||
while True:
|
||||
url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
params = {"page_size": 100}
|
||||
if cursor:
|
||||
params["start_cursor"] = cursor
|
||||
|
||||
response = await self.requests.get(url, headers=self.headers, params=params)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to fetch blocks: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
current_blocks = data.get("results", [])
|
||||
|
||||
# If recursive, fetch children for blocks that have them
|
||||
if recursive:
|
||||
for block in current_blocks:
|
||||
if block.get("has_children"):
|
||||
block["children"] = await self.get_blocks(
|
||||
block["id"], recursive=True
|
||||
)
|
||||
|
||||
blocks.extend(current_blocks)
|
||||
|
||||
if not data.get("has_more"):
|
||||
break
|
||||
cursor = data.get("next_cursor")
|
||||
|
||||
return blocks
|
||||
|
||||
async def query_database(
|
||||
self,
|
||||
database_id: str,
|
||||
filter_obj: Optional[dict] = None,
|
||||
sorts: Optional[List[dict]] = None,
|
||||
page_size: int = 100,
|
||||
) -> dict:
|
||||
"""
|
||||
Query a database with optional filters and sorts.
|
||||
|
||||
Args:
|
||||
database_id: The ID of the database to query.
|
||||
filter_obj: Optional filter object for the query.
|
||||
sorts: Optional list of sort objects.
|
||||
page_size: Number of results per page.
|
||||
|
||||
Returns:
|
||||
Query results including pages and pagination info.
|
||||
"""
|
||||
url = f"https://api.notion.com/v1/databases/{database_id}/query"
|
||||
|
||||
payload: Dict[str, Any] = {"page_size": page_size}
|
||||
if filter_obj:
|
||||
payload["filter"] = filter_obj
|
||||
if sorts:
|
||||
payload["sorts"] = sorts
|
||||
|
||||
response = await self.requests.post(url, headers=self.headers, json=payload)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to query database: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def create_page(
|
||||
self,
|
||||
parent: dict,
|
||||
properties: dict,
|
||||
children: Optional[List[dict]] = None,
|
||||
icon: Optional[dict] = None,
|
||||
cover: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a new page.
|
||||
|
||||
Args:
|
||||
parent: Parent object (page_id or database_id).
|
||||
properties: Page properties.
|
||||
children: Optional list of block children.
|
||||
icon: Optional icon object.
|
||||
cover: Optional cover object.
|
||||
|
||||
Returns:
|
||||
The created page object.
|
||||
"""
|
||||
url = "https://api.notion.com/v1/pages"
|
||||
|
||||
payload: Dict[str, Any] = {"parent": parent, "properties": properties}
|
||||
|
||||
if children:
|
||||
payload["children"] = children
|
||||
if icon:
|
||||
payload["icon"] = icon
|
||||
if cover:
|
||||
payload["cover"] = cover
|
||||
|
||||
response = await self.requests.post(url, headers=self.headers, json=payload)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to create page: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def update_page(self, page_id: str, properties: dict) -> dict:
|
||||
"""
|
||||
Update a page's properties.
|
||||
|
||||
Args:
|
||||
page_id: The ID of the page to update.
|
||||
properties: Properties to update.
|
||||
|
||||
Returns:
|
||||
The updated page object.
|
||||
"""
|
||||
url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
|
||||
response = await self.requests.patch(
|
||||
url, headers=self.headers, json={"properties": properties}
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to update page: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def append_blocks(self, block_id: str, children: List[dict]) -> dict:
|
||||
"""
|
||||
Append blocks to a page or block.
|
||||
|
||||
Args:
|
||||
block_id: The ID of the page or block to append to.
|
||||
children: List of block objects to append.
|
||||
|
||||
Returns:
|
||||
Response with the created blocks.
|
||||
"""
|
||||
url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
|
||||
response = await self.requests.patch(
|
||||
url, headers=self.headers, json={"children": children}
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to append blocks: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str = "",
|
||||
filter_obj: Optional[dict] = None,
|
||||
sort: Optional[dict] = None,
|
||||
page_size: int = 100,
|
||||
) -> dict:
|
||||
"""
|
||||
Search for pages and databases.
|
||||
|
||||
Args:
|
||||
query: Search query text.
|
||||
filter_obj: Optional filter object.
|
||||
sort: Optional sort object.
|
||||
page_size: Number of results per page.
|
||||
|
||||
Returns:
|
||||
Search results.
|
||||
"""
|
||||
url = "https://api.notion.com/v1/search"
|
||||
|
||||
payload: Dict[str, Any] = {"page_size": page_size}
|
||||
if query:
|
||||
payload["query"] = query
|
||||
if filter_obj:
|
||||
payload["filter"] = filter_obj
|
||||
if sort:
|
||||
payload["sort"] = sort
|
||||
|
||||
response = await self.requests.post(url, headers=self.headers, json=payload)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Search failed: {response.status} - {response.text()}", response.status
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
# Conversion helper functions
|
||||
|
||||
|
||||
def parse_rich_text(rich_text_array: List[dict]) -> str:
|
||||
"""
|
||||
Extract plain text from a Notion rich text array.
|
||||
|
||||
Args:
|
||||
rich_text_array: Array of rich text objects from Notion.
|
||||
|
||||
Returns:
|
||||
Plain text string.
|
||||
"""
|
||||
if not rich_text_array:
|
||||
return ""
|
||||
|
||||
text_parts = []
|
||||
for text_obj in rich_text_array:
|
||||
if "plain_text" in text_obj:
|
||||
text_parts.append(text_obj["plain_text"])
|
||||
|
||||
return "".join(text_parts)
|
||||
|
||||
|
||||
def rich_text_to_markdown(rich_text_array: List[dict]) -> str:
|
||||
"""
|
||||
Convert Notion rich text array to markdown with formatting.
|
||||
|
||||
Args:
|
||||
rich_text_array: Array of rich text objects from Notion.
|
||||
|
||||
Returns:
|
||||
Markdown formatted string.
|
||||
"""
|
||||
if not rich_text_array:
|
||||
return ""
|
||||
|
||||
markdown_parts = []
|
||||
|
||||
for text_obj in rich_text_array:
|
||||
text = text_obj.get("plain_text", "")
|
||||
annotations = text_obj.get("annotations", {})
|
||||
|
||||
# Apply formatting based on annotations
|
||||
if annotations.get("code"):
|
||||
text = f"`{text}`"
|
||||
else:
|
||||
if annotations.get("bold"):
|
||||
text = f"**{text}**"
|
||||
if annotations.get("italic"):
|
||||
text = f"*{text}*"
|
||||
if annotations.get("strikethrough"):
|
||||
text = f"~~{text}~~"
|
||||
if annotations.get("underline"):
|
||||
text = f"<u>{text}</u>"
|
||||
|
||||
# Handle links
|
||||
if text_obj.get("href"):
|
||||
text = f"[{text}]({text_obj['href']})"
|
||||
|
||||
markdown_parts.append(text)
|
||||
|
||||
return "".join(markdown_parts)
|
||||
|
||||
|
||||
def block_to_markdown(block: dict, indent_level: int = 0) -> str:
|
||||
"""
|
||||
Convert a single Notion block to markdown.
|
||||
|
||||
Args:
|
||||
block: Block object from Notion API.
|
||||
indent_level: Current indentation level for nested blocks.
|
||||
|
||||
Returns:
|
||||
Markdown string representation of the block.
|
||||
"""
|
||||
block_type = block.get("type")
|
||||
indent = " " * indent_level
|
||||
markdown_lines = []
|
||||
|
||||
# Handle different block types
|
||||
if block_type == "paragraph":
|
||||
text = rich_text_to_markdown(block["paragraph"].get("rich_text", []))
|
||||
if text:
|
||||
markdown_lines.append(f"{indent}{text}")
|
||||
|
||||
elif block_type == "heading_1":
|
||||
text = parse_rich_text(block["heading_1"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}# {text}")
|
||||
|
||||
elif block_type == "heading_2":
|
||||
text = parse_rich_text(block["heading_2"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}## {text}")
|
||||
|
||||
elif block_type == "heading_3":
|
||||
text = parse_rich_text(block["heading_3"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}### {text}")
|
||||
|
||||
elif block_type == "bulleted_list_item":
|
||||
text = rich_text_to_markdown(block["bulleted_list_item"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}- {text}")
|
||||
|
||||
elif block_type == "numbered_list_item":
|
||||
text = rich_text_to_markdown(block["numbered_list_item"].get("rich_text", []))
|
||||
# Note: This is simplified - proper numbering would need context
|
||||
markdown_lines.append(f"{indent}1. {text}")
|
||||
|
||||
elif block_type == "to_do":
|
||||
text = rich_text_to_markdown(block["to_do"].get("rich_text", []))
|
||||
checked = "x" if block["to_do"].get("checked") else " "
|
||||
markdown_lines.append(f"{indent}- [{checked}] {text}")
|
||||
|
||||
elif block_type == "toggle":
|
||||
text = rich_text_to_markdown(block["toggle"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}<details>")
|
||||
markdown_lines.append(f"{indent}<summary>{text}</summary>")
|
||||
markdown_lines.append(f"{indent}")
|
||||
# Process children if they exist
|
||||
if block.get("children"):
|
||||
for child in block["children"]:
|
||||
child_markdown = block_to_markdown(child, indent_level + 1)
|
||||
if child_markdown:
|
||||
markdown_lines.append(child_markdown)
|
||||
markdown_lines.append(f"{indent}</details>")
|
||||
|
||||
elif block_type == "code":
|
||||
code = parse_rich_text(block["code"].get("rich_text", []))
|
||||
language = block["code"].get("language", "")
|
||||
markdown_lines.append(f"{indent}```{language}")
|
||||
markdown_lines.append(f"{indent}{code}")
|
||||
markdown_lines.append(f"{indent}```")
|
||||
|
||||
elif block_type == "quote":
|
||||
text = rich_text_to_markdown(block["quote"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}> {text}")
|
||||
|
||||
elif block_type == "divider":
|
||||
markdown_lines.append(f"{indent}---")
|
||||
|
||||
elif block_type == "image":
|
||||
image = block["image"]
|
||||
url = image.get("external", {}).get("url") or image.get("file", {}).get(
|
||||
"url", ""
|
||||
)
|
||||
caption = parse_rich_text(image.get("caption", []))
|
||||
alt_text = caption if caption else "Image"
|
||||
markdown_lines.append(f"{indent}")
|
||||
if caption:
|
||||
markdown_lines.append(f"{indent}*{caption}*")
|
||||
|
||||
elif block_type == "video":
|
||||
video = block["video"]
|
||||
url = video.get("external", {}).get("url") or video.get("file", {}).get(
|
||||
"url", ""
|
||||
)
|
||||
caption = parse_rich_text(video.get("caption", []))
|
||||
markdown_lines.append(f"{indent}[Video]({url})")
|
||||
if caption:
|
||||
markdown_lines.append(f"{indent}*{caption}*")
|
||||
|
||||
elif block_type == "file":
|
||||
file = block["file"]
|
||||
url = file.get("external", {}).get("url") or file.get("file", {}).get("url", "")
|
||||
caption = parse_rich_text(file.get("caption", []))
|
||||
name = caption if caption else "File"
|
||||
markdown_lines.append(f"{indent}[{name}]({url})")
|
||||
|
||||
elif block_type == "bookmark":
|
||||
url = block["bookmark"].get("url", "")
|
||||
caption = parse_rich_text(block["bookmark"].get("caption", []))
|
||||
markdown_lines.append(f"{indent}[{caption if caption else url}]({url})")
|
||||
|
||||
elif block_type == "equation":
|
||||
expression = block["equation"].get("expression", "")
|
||||
markdown_lines.append(f"{indent}$${expression}$$")
|
||||
|
||||
elif block_type == "callout":
|
||||
text = rich_text_to_markdown(block["callout"].get("rich_text", []))
|
||||
icon = block["callout"].get("icon", {})
|
||||
if icon.get("emoji"):
|
||||
markdown_lines.append(f"{indent}> {icon['emoji']} {text}")
|
||||
else:
|
||||
markdown_lines.append(f"{indent}> ℹ️ {text}")
|
||||
|
||||
elif block_type == "child_page":
|
||||
title = block["child_page"].get("title", "Untitled")
|
||||
markdown_lines.append(f"{indent}📄 [{title}](notion://page/{block['id']})")
|
||||
|
||||
elif block_type == "child_database":
|
||||
title = block["child_database"].get("title", "Untitled Database")
|
||||
markdown_lines.append(f"{indent}🗂️ [{title}](notion://database/{block['id']})")
|
||||
|
||||
elif block_type == "table":
|
||||
# Tables are complex - for now just indicate there's a table
|
||||
markdown_lines.append(
|
||||
f"{indent}[Table with {block['table'].get('table_width', 0)} columns]"
|
||||
)
|
||||
|
||||
elif block_type == "column_list":
|
||||
# Process columns
|
||||
if block.get("children"):
|
||||
markdown_lines.append(f"{indent}<div style='display: flex'>")
|
||||
for column in block["children"]:
|
||||
markdown_lines.append(f"{indent}<div style='flex: 1'>")
|
||||
if column.get("children"):
|
||||
for child in column["children"]:
|
||||
child_markdown = block_to_markdown(child, indent_level + 1)
|
||||
if child_markdown:
|
||||
markdown_lines.append(child_markdown)
|
||||
markdown_lines.append(f"{indent}</div>")
|
||||
markdown_lines.append(f"{indent}</div>")
|
||||
|
||||
# Handle children for blocks that haven't been processed yet
|
||||
elif block.get("children") and block_type not in ["toggle", "column_list"]:
|
||||
for child in block["children"]:
|
||||
child_markdown = block_to_markdown(child, indent_level)
|
||||
if child_markdown:
|
||||
markdown_lines.append(child_markdown)
|
||||
|
||||
return "\n".join(markdown_lines) if markdown_lines else ""
|
||||
|
||||
|
||||
def blocks_to_markdown(blocks: List[dict]) -> str:
|
||||
"""
|
||||
Convert a list of Notion blocks to a markdown document.
|
||||
|
||||
Args:
|
||||
blocks: List of block objects from Notion API.
|
||||
|
||||
Returns:
|
||||
Complete markdown document as a string.
|
||||
"""
|
||||
markdown_parts = []
|
||||
|
||||
for i, block in enumerate(blocks):
|
||||
markdown = block_to_markdown(block)
|
||||
if markdown:
|
||||
markdown_parts.append(markdown)
|
||||
# Add spacing between top-level blocks (except lists)
|
||||
if i < len(blocks) - 1:
|
||||
next_type = blocks[i + 1].get("type", "")
|
||||
current_type = block.get("type", "")
|
||||
# Don't add extra spacing between list items
|
||||
list_types = {"bulleted_list_item", "numbered_list_item", "to_do"}
|
||||
if not (current_type in list_types and next_type in list_types):
|
||||
markdown_parts.append("")
|
||||
|
||||
return "\n".join(markdown_parts)
|
||||
|
||||
|
||||
def extract_page_title(page: dict) -> str:
|
||||
"""
|
||||
Extract the title from a Notion page object.
|
||||
|
||||
Args:
|
||||
page: Page object from Notion API.
|
||||
|
||||
Returns:
|
||||
Page title as a string.
|
||||
"""
|
||||
properties = page.get("properties", {})
|
||||
|
||||
# Find the title property (it has type "title")
|
||||
for prop_name, prop_value in properties.items():
|
||||
if prop_value.get("type") == "title":
|
||||
return parse_rich_text(prop_value.get("title", []))
|
||||
|
||||
return "Untitled"
|
||||
@@ -1,42 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
NOTION_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.notion_client_id and secrets.notion_client_secret
|
||||
)
|
||||
|
||||
NotionCredentials = OAuth2Credentials
|
||||
NotionCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.NOTION], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
def NotionCredentialsField() -> NotionCredentialsInput:
|
||||
"""Creates a Notion OAuth2 credentials field."""
|
||||
return CredentialsField(
|
||||
description="Connect your Notion account. Ensure the pages/databases are shared with the integration."
|
||||
)
|
||||
|
||||
|
||||
# Test credentials for Notion OAuth2
|
||||
TEST_CREDENTIALS = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="notion",
|
||||
access_token=SecretStr("test_access_token"),
|
||||
title="Mock Notion OAuth",
|
||||
scopes=["read_content", "insert_content", "update_content"],
|
||||
username="testuser",
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
@@ -1,360 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionCreatePageBlock(Block):
|
||||
"""Create a new page in Notion with content."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
parent_page_id: Optional[str] = SchemaField(
|
||||
description="Parent page ID to create the page under. Either this OR parent_database_id is required.",
|
||||
default=None,
|
||||
)
|
||||
parent_database_id: Optional[str] = SchemaField(
|
||||
description="Parent database ID to create the page in. Either this OR parent_page_id is required.",
|
||||
default=None,
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Title of the new page",
|
||||
)
|
||||
content: Optional[str] = SchemaField(
|
||||
description="Content for the page. Can be plain text or markdown - will be converted to Notion blocks.",
|
||||
default=None,
|
||||
)
|
||||
properties: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Additional properties for database pages (e.g., {'Status': 'In Progress', 'Priority': 'High'})",
|
||||
default=None,
|
||||
)
|
||||
icon_emoji: Optional[str] = SchemaField(
|
||||
description="Emoji to use as the page icon (e.g., '📄', '🚀')", default=None
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parent(self):
|
||||
"""Ensure either parent_page_id or parent_database_id is provided."""
|
||||
if not self.parent_page_id and not self.parent_database_id:
|
||||
raise ValueError(
|
||||
"Either parent_page_id or parent_database_id must be provided"
|
||||
)
|
||||
if self.parent_page_id and self.parent_database_id:
|
||||
raise ValueError(
|
||||
"Only one of parent_page_id or parent_database_id should be provided, not both"
|
||||
)
|
||||
return self
|
||||
|
||||
class Output(BlockSchema):
|
||||
page_id: str = SchemaField(description="ID of the created page.")
|
||||
page_url: str = SchemaField(description="URL of the created page.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c15febe0-66ce-4c6f-aebd-5ab351653804",
|
||||
description="Create a new page in Notion. Requires EITHER a parent_page_id OR parent_database_id. Supports markdown content.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=NotionCreatePageBlock.Input,
|
||||
output_schema=NotionCreatePageBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"parent_page_id": "00000000-0000-0000-0000-000000000000",
|
||||
"title": "Test Page",
|
||||
"content": "This is test content.",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("page_id", "12345678-1234-1234-1234-123456789012"),
|
||||
(
|
||||
"page_url",
|
||||
"https://notion.so/Test-Page-12345678123412341234123456789012",
|
||||
),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"create_page": lambda *args, **kwargs: (
|
||||
"12345678-1234-1234-1234-123456789012",
|
||||
"https://notion.so/Test-Page-12345678123412341234123456789012",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _markdown_to_blocks(content: str) -> List[dict]:
|
||||
"""Convert markdown content to Notion block objects."""
|
||||
if not content:
|
||||
return []
|
||||
|
||||
blocks = []
|
||||
lines = content.split("\n")
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
|
||||
# Skip empty lines
|
||||
if not line.strip():
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Headings
|
||||
if line.startswith("### "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "heading_3",
|
||||
"heading_3": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[4:].strip()}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
elif line.startswith("## "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "heading_2",
|
||||
"heading_2": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[3:].strip()}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
elif line.startswith("# "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "heading_1",
|
||||
"heading_1": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[2:].strip()}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Bullet points
|
||||
elif line.strip().startswith("- "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "bulleted_list_item",
|
||||
"bulleted_list_item": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": line.strip()[2:].strip()},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Numbered list
|
||||
elif line.strip() and line.strip()[0].isdigit() and ". " in line:
|
||||
content_start = line.find(". ") + 2
|
||||
blocks.append(
|
||||
{
|
||||
"type": "numbered_list_item",
|
||||
"numbered_list_item": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": line[content_start:].strip()},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Code block
|
||||
elif line.strip().startswith("```"):
|
||||
code_lines = []
|
||||
language = line[3:].strip() or "plain text"
|
||||
i += 1
|
||||
while i < len(lines) and not lines[i].strip().startswith("```"):
|
||||
code_lines.append(lines[i])
|
||||
i += 1
|
||||
blocks.append(
|
||||
{
|
||||
"type": "code",
|
||||
"code": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": "\n".join(code_lines)},
|
||||
}
|
||||
],
|
||||
"language": language,
|
||||
},
|
||||
}
|
||||
)
|
||||
# Quote
|
||||
elif line.strip().startswith("> "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "quote",
|
||||
"quote": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": line.strip()[2:].strip()},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Horizontal rule
|
||||
elif line.strip() in ["---", "***", "___"]:
|
||||
blocks.append({"type": "divider", "divider": {}})
|
||||
# Regular paragraph
|
||||
else:
|
||||
# Parse for basic markdown formatting
|
||||
text_content = line.strip()
|
||||
rich_text = []
|
||||
|
||||
# Simple bold/italic parsing (this is simplified)
|
||||
if "**" in text_content or "*" in text_content:
|
||||
# For now, just pass as plain text
|
||||
# A full implementation would parse and create proper annotations
|
||||
rich_text = [{"type": "text", "text": {"content": text_content}}]
|
||||
else:
|
||||
rich_text = [{"type": "text", "text": {"content": text_content}}]
|
||||
|
||||
blocks.append(
|
||||
{"type": "paragraph", "paragraph": {"rich_text": rich_text}}
|
||||
)
|
||||
|
||||
i += 1
|
||||
|
||||
return blocks
|
||||
|
||||
@staticmethod
|
||||
def _build_properties(
|
||||
title: str, additional_properties: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Build properties object for page creation."""
|
||||
properties: Dict[str, Any] = {
|
||||
"title": {"title": [{"type": "text", "text": {"content": title}}]}
|
||||
}
|
||||
|
||||
if additional_properties:
|
||||
for key, value in additional_properties.items():
|
||||
if key.lower() == "title":
|
||||
continue # Skip title as we already have it
|
||||
|
||||
# Try to intelligently map property types
|
||||
if isinstance(value, bool):
|
||||
properties[key] = {"checkbox": value}
|
||||
elif isinstance(value, (int, float)):
|
||||
properties[key] = {"number": value}
|
||||
elif isinstance(value, list):
|
||||
# Assume multi-select
|
||||
properties[key] = {
|
||||
"multi_select": [{"name": str(item)} for item in value]
|
||||
}
|
||||
elif isinstance(value, str):
|
||||
# Could be select, rich_text, or other types
|
||||
# For simplicity, try common patterns
|
||||
if key.lower() in ["status", "priority", "type", "category"]:
|
||||
properties[key] = {"select": {"name": value}}
|
||||
elif key.lower() in ["url", "link"]:
|
||||
properties[key] = {"url": value}
|
||||
elif key.lower() in ["email"]:
|
||||
properties[key] = {"email": value}
|
||||
else:
|
||||
properties[key] = {
|
||||
"rich_text": [{"type": "text", "text": {"content": value}}]
|
||||
}
|
||||
|
||||
return properties
|
||||
|
||||
@staticmethod
|
||||
async def create_page(
|
||||
credentials: OAuth2Credentials,
|
||||
title: str,
|
||||
parent_page_id: Optional[str] = None,
|
||||
parent_database_id: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
icon_emoji: Optional[str] = None,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Create a new Notion page.
|
||||
|
||||
Returns:
|
||||
Tuple of (page_id, page_url)
|
||||
"""
|
||||
if not parent_page_id and not parent_database_id:
|
||||
raise ValueError(
|
||||
"Either parent_page_id or parent_database_id must be provided"
|
||||
)
|
||||
if parent_page_id and parent_database_id:
|
||||
raise ValueError(
|
||||
"Only one of parent_page_id or parent_database_id should be provided, not both"
|
||||
)
|
||||
|
||||
client = NotionClient(credentials)
|
||||
|
||||
# Build parent object
|
||||
if parent_page_id:
|
||||
parent = {"type": "page_id", "page_id": parent_page_id}
|
||||
else:
|
||||
parent = {"type": "database_id", "database_id": parent_database_id}
|
||||
|
||||
# Build properties
|
||||
page_properties = NotionCreatePageBlock._build_properties(title, properties)
|
||||
|
||||
# Convert content to blocks if provided
|
||||
children = None
|
||||
if content:
|
||||
children = NotionCreatePageBlock._markdown_to_blocks(content)
|
||||
|
||||
# Build icon if provided
|
||||
icon = None
|
||||
if icon_emoji:
|
||||
icon = {"type": "emoji", "emoji": icon_emoji}
|
||||
|
||||
# Create the page
|
||||
result = await client.create_page(
|
||||
parent=parent, properties=page_properties, children=children, icon=icon
|
||||
)
|
||||
|
||||
page_id = result.get("id", "")
|
||||
page_url = result.get("url", "")
|
||||
|
||||
if not page_id or not page_url:
|
||||
raise ValueError("Failed to get page ID or URL from Notion response")
|
||||
|
||||
return page_id, page_url
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
page_id, page_url = await self.create_page(
|
||||
credentials,
|
||||
input_data.title,
|
||||
input_data.parent_page_id,
|
||||
input_data.parent_database_id,
|
||||
input_data.content,
|
||||
input_data.properties,
|
||||
input_data.icon_emoji,
|
||||
)
|
||||
yield "page_id", page_id
|
||||
yield "page_url", page_url
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
@@ -1,285 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient, parse_rich_text
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionReadDatabaseBlock(Block):
|
||||
"""Query a Notion database and retrieve entries with their properties."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
database_id: str = SchemaField(
|
||||
description="Notion database ID. Must be accessible by the connected integration.",
|
||||
)
|
||||
filter_property: Optional[str] = SchemaField(
|
||||
description="Property name to filter by (e.g., 'Status', 'Priority')",
|
||||
default=None,
|
||||
)
|
||||
filter_value: Optional[str] = SchemaField(
|
||||
description="Value to filter for in the specified property", default=None
|
||||
)
|
||||
sort_property: Optional[str] = SchemaField(
|
||||
description="Property name to sort by", default=None
|
||||
)
|
||||
sort_direction: Optional[str] = SchemaField(
|
||||
description="Sort direction: 'ascending' or 'descending'",
|
||||
default="ascending",
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of entries to retrieve",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
entries: List[Dict[str, Any]] = SchemaField(
|
||||
description="List of database entries with their properties."
|
||||
)
|
||||
entry: Dict[str, Any] = SchemaField(
|
||||
description="Individual database entry (yields one per entry found)."
|
||||
)
|
||||
entry_ids: List[str] = SchemaField(
|
||||
description="List of entry IDs for batch operations."
|
||||
)
|
||||
entry_id: str = SchemaField(
|
||||
description="Individual entry ID (yields one per entry found)."
|
||||
)
|
||||
count: int = SchemaField(description="Number of entries retrieved.")
|
||||
database_title: str = SchemaField(description="Title of the database.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fcd53135-88c9-4ba3-be50-cc6936286e6c",
|
||||
description="Query a Notion database with optional filtering and sorting, returning structured entries.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=NotionReadDatabaseBlock.Input,
|
||||
output_schema=NotionReadDatabaseBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"database_id": "00000000-0000-0000-0000-000000000000",
|
||||
"limit": 10,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"entries",
|
||||
[{"Name": "Test Entry", "Status": "Active", "_id": "test-123"}],
|
||||
),
|
||||
("entry_ids", ["test-123"]),
|
||||
(
|
||||
"entry",
|
||||
{"Name": "Test Entry", "Status": "Active", "_id": "test-123"},
|
||||
),
|
||||
("entry_id", "test-123"),
|
||||
("count", 1),
|
||||
("database_title", "Test Database"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"query_database": lambda *args, **kwargs: (
|
||||
[{"Name": "Test Entry", "Status": "Active", "_id": "test-123"}],
|
||||
1,
|
||||
"Test Database",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_property_value(prop: dict) -> Any:
|
||||
"""Parse a Notion property value into a simple Python type."""
|
||||
prop_type = prop.get("type")
|
||||
|
||||
if prop_type == "title":
|
||||
return parse_rich_text(prop.get("title", []))
|
||||
elif prop_type == "rich_text":
|
||||
return parse_rich_text(prop.get("rich_text", []))
|
||||
elif prop_type == "number":
|
||||
return prop.get("number")
|
||||
elif prop_type == "select":
|
||||
select = prop.get("select")
|
||||
return select.get("name") if select else None
|
||||
elif prop_type == "multi_select":
|
||||
return [item.get("name") for item in prop.get("multi_select", [])]
|
||||
elif prop_type == "date":
|
||||
date = prop.get("date")
|
||||
if date:
|
||||
return date.get("start")
|
||||
return None
|
||||
elif prop_type == "checkbox":
|
||||
return prop.get("checkbox", False)
|
||||
elif prop_type == "url":
|
||||
return prop.get("url")
|
||||
elif prop_type == "email":
|
||||
return prop.get("email")
|
||||
elif prop_type == "phone_number":
|
||||
return prop.get("phone_number")
|
||||
elif prop_type == "people":
|
||||
return [
|
||||
person.get("name", person.get("id"))
|
||||
for person in prop.get("people", [])
|
||||
]
|
||||
elif prop_type == "files":
|
||||
files = prop.get("files", [])
|
||||
return [
|
||||
f.get(
|
||||
"name",
|
||||
f.get("external", {}).get("url", f.get("file", {}).get("url")),
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
elif prop_type == "relation":
|
||||
return [rel.get("id") for rel in prop.get("relation", [])]
|
||||
elif prop_type == "formula":
|
||||
formula = prop.get("formula", {})
|
||||
return formula.get(formula.get("type"))
|
||||
elif prop_type == "rollup":
|
||||
rollup = prop.get("rollup", {})
|
||||
return rollup.get(rollup.get("type"))
|
||||
elif prop_type == "created_time":
|
||||
return prop.get("created_time")
|
||||
elif prop_type == "created_by":
|
||||
return prop.get("created_by", {}).get(
|
||||
"name", prop.get("created_by", {}).get("id")
|
||||
)
|
||||
elif prop_type == "last_edited_time":
|
||||
return prop.get("last_edited_time")
|
||||
elif prop_type == "last_edited_by":
|
||||
return prop.get("last_edited_by", {}).get(
|
||||
"name", prop.get("last_edited_by", {}).get("id")
|
||||
)
|
||||
else:
|
||||
# Return the raw value for unknown types
|
||||
return prop
|
||||
|
||||
@staticmethod
|
||||
def _build_filter(property_name: str, value: str) -> dict:
|
||||
"""Build a simple filter object for a property."""
|
||||
# This is a simplified filter - in reality, you'd need to know the property type
|
||||
# For now, we'll try common filter types
|
||||
return {
|
||||
"or": [
|
||||
{"property": property_name, "rich_text": {"contains": value}},
|
||||
{"property": property_name, "title": {"contains": value}},
|
||||
{"property": property_name, "select": {"equals": value}},
|
||||
{"property": property_name, "multi_select": {"contains": value}},
|
||||
]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def query_database(
|
||||
credentials: OAuth2Credentials,
|
||||
database_id: str,
|
||||
filter_property: Optional[str] = None,
|
||||
filter_value: Optional[str] = None,
|
||||
sort_property: Optional[str] = None,
|
||||
sort_direction: str = "ascending",
|
||||
limit: int = 100,
|
||||
) -> tuple[List[Dict[str, Any]], int, str]:
|
||||
"""
|
||||
Query a Notion database and parse the results.
|
||||
|
||||
Returns:
|
||||
Tuple of (entries_list, count, database_title)
|
||||
"""
|
||||
client = NotionClient(credentials)
|
||||
|
||||
# Build filter if specified
|
||||
filter_obj = None
|
||||
if filter_property and filter_value:
|
||||
filter_obj = NotionReadDatabaseBlock._build_filter(
|
||||
filter_property, filter_value
|
||||
)
|
||||
|
||||
# Build sorts if specified
|
||||
sorts = None
|
||||
if sort_property:
|
||||
sorts = [{"property": sort_property, "direction": sort_direction}]
|
||||
|
||||
# Query the database
|
||||
result = await client.query_database(
|
||||
database_id, filter_obj=filter_obj, sorts=sorts, page_size=limit
|
||||
)
|
||||
|
||||
# Parse the entries
|
||||
entries = []
|
||||
for page in result.get("results", []):
|
||||
entry = {}
|
||||
properties = page.get("properties", {})
|
||||
|
||||
for prop_name, prop_value in properties.items():
|
||||
entry[prop_name] = NotionReadDatabaseBlock._parse_property_value(
|
||||
prop_value
|
||||
)
|
||||
|
||||
# Add metadata
|
||||
entry["_id"] = page.get("id")
|
||||
entry["_url"] = page.get("url")
|
||||
entry["_created_time"] = page.get("created_time")
|
||||
entry["_last_edited_time"] = page.get("last_edited_time")
|
||||
|
||||
entries.append(entry)
|
||||
|
||||
# Get database title (we need to make a separate call for this)
|
||||
try:
|
||||
database_url = f"https://api.notion.com/v1/databases/{database_id}"
|
||||
db_response = await client.requests.get(
|
||||
database_url, headers=client.headers
|
||||
)
|
||||
if db_response.ok:
|
||||
db_data = db_response.json()
|
||||
db_title = parse_rich_text(db_data.get("title", []))
|
||||
else:
|
||||
db_title = "Unknown Database"
|
||||
except Exception:
|
||||
db_title = "Unknown Database"
|
||||
|
||||
return entries, len(entries), db_title
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
entries, count, db_title = await self.query_database(
|
||||
credentials,
|
||||
input_data.database_id,
|
||||
input_data.filter_property,
|
||||
input_data.filter_value,
|
||||
input_data.sort_property,
|
||||
input_data.sort_direction or "ascending",
|
||||
input_data.limit,
|
||||
)
|
||||
# Yield the complete list for batch operations
|
||||
yield "entries", entries
|
||||
|
||||
# Extract and yield IDs as a list for batch operations
|
||||
entry_ids = [entry["_id"] for entry in entries if "_id" in entry]
|
||||
yield "entry_ids", entry_ids
|
||||
|
||||
# Yield each individual entry and its ID for single connections
|
||||
for entry in entries:
|
||||
yield "entry", entry
|
||||
if "_id" in entry:
|
||||
yield "entry_id", entry["_id"]
|
||||
|
||||
yield "count", count
|
||||
yield "database_title", db_title
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
@@ -1,64 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionReadPageBlock(Block):
|
||||
"""Read a Notion page by ID and return its raw JSON."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
page_id: str = SchemaField(
|
||||
description="Notion page ID. Must be accessible by the connected integration. You can get this from the page URL notion.so/A-Page-586edd711467478da59fe3ce29a1ffab would be 586edd711467478da59fe35e29a1ffab",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
page: dict = SchemaField(description="Raw Notion page JSON.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5246cc1d-34b7-452b-8fc5-3fb25fd8f542",
|
||||
description="Read a Notion page by its ID and return its raw JSON.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=NotionReadPageBlock.Input,
|
||||
output_schema=NotionReadPageBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"page_id": "00000000-0000-0000-0000-000000000000",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[("page", dict)],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"get_page": lambda *args, **kwargs: {"object": "page", "id": "mocked"}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_page(credentials: OAuth2Credentials, page_id: str) -> dict:
|
||||
client = NotionClient(credentials)
|
||||
return await client.get_page(page_id)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
page = await self.get_page(credentials, input_data.page_id)
|
||||
yield "page", page
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
@@ -1,109 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient, blocks_to_markdown, extract_page_title
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionReadPageMarkdownBlock(Block):
|
||||
"""Read a Notion page and convert it to clean Markdown format."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
page_id: str = SchemaField(
|
||||
description="Notion page ID. Must be accessible by the connected integration. You can get this from the page URL notion.so/A-Page-586edd711467478da59fe35e29a1ffab would be 586edd711467478da59fe35e29a1ffab",
|
||||
)
|
||||
include_title: bool = SchemaField(
|
||||
description="Whether to include the page title as a header in the markdown",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
markdown: str = SchemaField(description="Page content in Markdown format.")
|
||||
title: str = SchemaField(description="Page title.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d1312c4d-fae2-4e70-893d-f4d07cce1d4e",
|
||||
description="Read a Notion page and convert it to Markdown format with proper formatting for headings, lists, links, and rich text.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=NotionReadPageMarkdownBlock.Input,
|
||||
output_schema=NotionReadPageMarkdownBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"page_id": "00000000-0000-0000-0000-000000000000",
|
||||
"include_title": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("markdown", "# Test Page\n\nThis is test content."),
|
||||
("title", "Test Page"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"get_page_markdown": lambda *args, **kwargs: (
|
||||
"# Test Page\n\nThis is test content.",
|
||||
"Test Page",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_page_markdown(
|
||||
credentials: OAuth2Credentials, page_id: str, include_title: bool = True
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Get a Notion page and convert it to markdown.
|
||||
|
||||
Args:
|
||||
credentials: OAuth2 credentials for Notion.
|
||||
page_id: The ID of the page to fetch.
|
||||
include_title: Whether to include the page title in the markdown.
|
||||
|
||||
Returns:
|
||||
Tuple of (markdown_content, title)
|
||||
"""
|
||||
client = NotionClient(credentials)
|
||||
|
||||
# Get page metadata
|
||||
page = await client.get_page(page_id)
|
||||
title = extract_page_title(page)
|
||||
|
||||
# Get all blocks from the page
|
||||
blocks = await client.get_blocks(page_id, recursive=True)
|
||||
|
||||
# Convert blocks to markdown
|
||||
content_markdown = blocks_to_markdown(blocks)
|
||||
|
||||
# Combine title and content if requested
|
||||
if include_title and title:
|
||||
full_markdown = f"# {title}\n\n{content_markdown}"
|
||||
else:
|
||||
full_markdown = content_markdown
|
||||
|
||||
return full_markdown, title
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
markdown, title = await self.get_page_markdown(
|
||||
credentials, input_data.page_id, input_data.include_title
|
||||
)
|
||||
yield "markdown", markdown
|
||||
yield "title", title
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
@@ -1,225 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient, extract_page_title, parse_rich_text
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionSearchResult(BaseModel):
|
||||
"""Typed model for Notion search results."""
|
||||
|
||||
id: str
|
||||
type: str # 'page' or 'database'
|
||||
title: str
|
||||
url: str
|
||||
created_time: Optional[str] = None
|
||||
last_edited_time: Optional[str] = None
|
||||
parent_type: Optional[str] = None # 'page', 'database', or 'workspace'
|
||||
parent_id: Optional[str] = None
|
||||
icon: Optional[str] = None # emoji icon if present
|
||||
is_inline: Optional[bool] = None # for databases only
|
||||
|
||||
|
||||
class NotionSearchBlock(Block):
|
||||
"""Search across your Notion workspace for pages and databases."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
query: str = SchemaField(
|
||||
description="Search query text. Leave empty to get all accessible pages/databases.",
|
||||
default="",
|
||||
)
|
||||
filter_type: Optional[str] = SchemaField(
|
||||
description="Filter results by type: 'page' or 'database'. Leave empty for both.",
|
||||
default=None,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results to return", default=20, ge=1, le=100
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: List[NotionSearchResult] = SchemaField(
|
||||
description="List of search results with title, type, URL, and metadata."
|
||||
)
|
||||
result: NotionSearchResult = SchemaField(
|
||||
description="Individual search result (yields one per result found)."
|
||||
)
|
||||
result_ids: List[str] = SchemaField(
|
||||
description="List of IDs from search results for batch operations."
|
||||
)
|
||||
count: int = SchemaField(description="Number of results found.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="313515dd-9848-46ea-9cd6-3c627c892c56",
|
||||
description="Search your Notion workspace for pages and databases by text query.",
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.SEARCH},
|
||||
input_schema=NotionSearchBlock.Input,
|
||||
output_schema=NotionSearchBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"query": "project",
|
||||
"limit": 5,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"results",
|
||||
[
|
||||
NotionSearchResult(
|
||||
id="123",
|
||||
type="page",
|
||||
title="Project Plan",
|
||||
url="https://notion.so/Project-Plan-123",
|
||||
)
|
||||
],
|
||||
),
|
||||
("result_ids", ["123"]),
|
||||
(
|
||||
"result",
|
||||
NotionSearchResult(
|
||||
id="123",
|
||||
type="page",
|
||||
title="Project Plan",
|
||||
url="https://notion.so/Project-Plan-123",
|
||||
),
|
||||
),
|
||||
("count", 1),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"search_workspace": lambda *args, **kwargs: (
|
||||
[
|
||||
NotionSearchResult(
|
||||
id="123",
|
||||
type="page",
|
||||
title="Project Plan",
|
||||
url="https://notion.so/Project-Plan-123",
|
||||
)
|
||||
],
|
||||
1,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def search_workspace(
|
||||
credentials: OAuth2Credentials,
|
||||
query: str = "",
|
||||
filter_type: Optional[str] = None,
|
||||
limit: int = 20,
|
||||
) -> tuple[List[NotionSearchResult], int]:
|
||||
"""
|
||||
Search the Notion workspace.
|
||||
|
||||
Returns:
|
||||
Tuple of (results_list, count)
|
||||
"""
|
||||
client = NotionClient(credentials)
|
||||
|
||||
# Build filter if type is specified
|
||||
filter_obj = None
|
||||
if filter_type:
|
||||
filter_obj = {"property": "object", "value": filter_type}
|
||||
|
||||
# Execute search
|
||||
response = await client.search(
|
||||
query=query, filter_obj=filter_obj, page_size=limit
|
||||
)
|
||||
|
||||
# Parse results
|
||||
results = []
|
||||
for item in response.get("results", []):
|
||||
result_data = {
|
||||
"id": item.get("id", ""),
|
||||
"type": item.get("object", ""),
|
||||
"url": item.get("url", ""),
|
||||
"created_time": item.get("created_time"),
|
||||
"last_edited_time": item.get("last_edited_time"),
|
||||
"title": "", # Will be set below
|
||||
}
|
||||
|
||||
# Extract title based on type
|
||||
if item.get("object") == "page":
|
||||
# For pages, get the title from properties
|
||||
result_data["title"] = extract_page_title(item)
|
||||
|
||||
# Add parent info
|
||||
parent = item.get("parent", {})
|
||||
if parent.get("type") == "page_id":
|
||||
result_data["parent_type"] = "page"
|
||||
result_data["parent_id"] = parent.get("page_id")
|
||||
elif parent.get("type") == "database_id":
|
||||
result_data["parent_type"] = "database"
|
||||
result_data["parent_id"] = parent.get("database_id")
|
||||
elif parent.get("type") == "workspace":
|
||||
result_data["parent_type"] = "workspace"
|
||||
|
||||
# Add icon if present
|
||||
icon = item.get("icon")
|
||||
if icon and icon.get("type") == "emoji":
|
||||
result_data["icon"] = icon.get("emoji")
|
||||
|
||||
elif item.get("object") == "database":
|
||||
# For databases, get title from the title array
|
||||
result_data["title"] = parse_rich_text(item.get("title", []))
|
||||
|
||||
# Add database-specific metadata
|
||||
result_data["is_inline"] = item.get("is_inline", False)
|
||||
|
||||
# Add parent info
|
||||
parent = item.get("parent", {})
|
||||
if parent.get("type") == "page_id":
|
||||
result_data["parent_type"] = "page"
|
||||
result_data["parent_id"] = parent.get("page_id")
|
||||
elif parent.get("type") == "workspace":
|
||||
result_data["parent_type"] = "workspace"
|
||||
|
||||
# Add icon if present
|
||||
icon = item.get("icon")
|
||||
if icon and icon.get("type") == "emoji":
|
||||
result_data["icon"] = icon.get("emoji")
|
||||
|
||||
results.append(NotionSearchResult(**result_data))
|
||||
|
||||
return results, len(results)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, count = await self.search_workspace(
|
||||
credentials, input_data.query, input_data.filter_type, input_data.limit
|
||||
)
|
||||
|
||||
# Yield the complete list for batch operations
|
||||
yield "results", results
|
||||
|
||||
# Extract and yield IDs as a list for batch operations
|
||||
result_ids = [r.id for r in results]
|
||||
yield "result_ids", result_ids
|
||||
|
||||
# Yield each individual result for single connections
|
||||
for result in results:
|
||||
yield "result", result
|
||||
|
||||
yield "count", count
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
@@ -523,6 +523,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
json_format=False,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
|
||||
@@ -30,6 +30,7 @@ class TestLLMStatsTracking:
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
llm_model=llm.LlmModel.GPT4O,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
json_format=False,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
@@ -41,8 +42,6 @@ class TestLLMStatsTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_structured_response_block_tracks_stats(self):
|
||||
"""Test that AIStructuredResponseGeneratorBlock correctly tracks stats."""
|
||||
from unittest.mock import patch
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
@@ -52,7 +51,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
response='{"key1": "value1", "key2": "value2"}',
|
||||
tool_calls=None,
|
||||
prompt_tokens=15,
|
||||
completion_tokens=25,
|
||||
@@ -70,12 +69,10 @@ class TestLLMStatsTracking:
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Check stats
|
||||
assert block.execution_stats.input_token_count == 15
|
||||
@@ -146,7 +143,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"wrong": "format"}</json_output>',
|
||||
response='{"wrong": "format"}',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=15,
|
||||
@@ -157,7 +154,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
response='{"key1": "value1", "key2": "value2"}',
|
||||
tool_calls=None,
|
||||
prompt_tokens=20,
|
||||
completion_tokens=25,
|
||||
@@ -176,12 +173,10 @@ class TestLLMStatsTracking:
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Check stats - should accumulate both calls
|
||||
# For 2 attempts: attempt 1 (failed) + attempt 2 (success) = 2 total
|
||||
@@ -274,8 +269,7 @@ class TestLLMStatsTracking:
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
content='{"summary": "Test chunk summary"}', tool_calls=None
|
||||
)
|
||||
)
|
||||
]
|
||||
@@ -283,7 +277,7 @@ class TestLLMStatsTracking:
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>',
|
||||
content='{"final_summary": "Test final summary"}',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
@@ -304,13 +298,11 @@ class TestLLMStatsTracking:
|
||||
max_tokens=1000, # Large enough to avoid chunking
|
||||
)
|
||||
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
print(f"Actual calls made: {call_count}")
|
||||
print(f"Block stats: {block.execution_stats}")
|
||||
@@ -465,7 +457,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"result": "test"}</json_output>',
|
||||
response='{"result": "test"}',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
@@ -484,12 +476,10 @@ class TestLLMStatsTracking:
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Block finished - now grab and assert stats
|
||||
assert block.execution_stats is not None
|
||||
|
||||
@@ -35,19 +35,20 @@ async def execute_graph(
|
||||
logger.info("Input data: %s", input_data)
|
||||
|
||||
# --- Test adding new executions --- #
|
||||
graph_exec = await agent_server.test_execute_graph(
|
||||
response = await agent_server.test_execute_graph(
|
||||
user_id=test_user.id,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
node_input=input_data,
|
||||
)
|
||||
logger.info("Created execution with ID: %s", graph_exec.id)
|
||||
graph_exec_id = response.graph_exec_id
|
||||
logger.info("Created execution with ID: %s", graph_exec_id)
|
||||
|
||||
# Execution queue should be empty
|
||||
logger.info("Waiting for execution to complete...")
|
||||
result = await wait_execution(test_user.id, graph_exec.id, 30)
|
||||
result = await wait_execution(test_user.id, graph_exec_id, 30)
|
||||
logger.info("Execution completed with %d results", len(result))
|
||||
return graph_exec.id
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -172,11 +172,6 @@ class FillTextTemplateBlock(Block):
|
||||
format: str = SchemaField(
|
||||
description="Template to format the text using `values`. Use Jinja2 syntax."
|
||||
)
|
||||
escape_html: bool = SchemaField(
|
||||
default=False,
|
||||
advanced=True,
|
||||
description="Whether to escape special characters in the inserted values to be HTML-safe. Enable for HTML output, disable for plain text.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: str = SchemaField(description="Formatted text")
|
||||
@@ -210,7 +205,6 @@ class FillTextTemplateBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
formatter = text.TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(input_data.format, input_data.values)
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .graph import NodeModel
|
||||
from .integrations import Webhook # noqa: F401
|
||||
|
||||
# Resolve Webhook forward references
|
||||
# Resolve Webhook <- NodeModel forward reference
|
||||
NodeModel.model_rebuild()
|
||||
LibraryAgentPreset.model_rebuild()
|
||||
|
||||
@@ -1,31 +1,57 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from autogpt_libs.api_key.key_manager import APIKeyManager
|
||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||
from prisma.errors import PrismaError
|
||||
from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import BaseModel, Field
|
||||
from prisma.types import (
|
||||
APIKeyCreateInput,
|
||||
APIKeyUpdateInput,
|
||||
APIKeyWhereInput,
|
||||
APIKeyWhereUniqueInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.data.db import BaseDbModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
keysmith = APIKeySmith()
|
||||
|
||||
|
||||
class APIKeyInfo(BaseModel):
|
||||
id: str
|
||||
# Some basic exceptions
|
||||
class APIKeyError(Exception):
|
||||
"""Base exception for API key operations"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyNotFoundError(APIKeyError):
|
||||
"""Raised when an API key is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyPermissionError(APIKeyError):
|
||||
"""Raised when there are permission issues with API key operations"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyValidationError(APIKeyError):
|
||||
"""Raised when API key validation fails"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKey(BaseDbModel):
|
||||
name: str
|
||||
head: str = Field(
|
||||
description=f"The first {APIKeySmith.HEAD_LENGTH} characters of the key"
|
||||
)
|
||||
tail: str = Field(
|
||||
description=f"The last {APIKeySmith.TAIL_LENGTH} characters of the key"
|
||||
)
|
||||
status: APIKeyStatus
|
||||
permissions: list[APIKeyPermission]
|
||||
prefix: str
|
||||
key: str
|
||||
status: APIKeyStatus = APIKeyStatus.ACTIVE
|
||||
permissions: List[APIKeyPermission]
|
||||
postfix: str
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
@@ -34,211 +60,266 @@ class APIKeyInfo(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
return APIKeyInfo(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
head=api_key.head,
|
||||
tail=api_key.tail,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
try:
|
||||
return APIKey(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
prefix=api_key.prefix,
|
||||
postfix=api_key.postfix,
|
||||
key=api_key.key,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating APIKey from db: {str(e)}")
|
||||
raise APIKeyError(f"Failed to create API key object: {str(e)}")
|
||||
|
||||
|
||||
class APIKeyInfoWithHash(APIKeyInfo):
|
||||
hash: str
|
||||
salt: str | None = None # None for legacy keys
|
||||
|
||||
def match(self, plaintext_key: str) -> bool:
|
||||
"""Returns whether the given key matches this API key object."""
|
||||
return keysmith.verify_key(plaintext_key, self.hash, self.salt)
|
||||
class APIKeyWithoutHash(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
status: APIKeyStatus
|
||||
permissions: List[APIKeyPermission]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime]
|
||||
revoked_at: Optional[datetime]
|
||||
description: Optional[str]
|
||||
user_id: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
return APIKeyInfoWithHash(
|
||||
**APIKeyInfo.from_db(api_key).model_dump(),
|
||||
hash=api_key.hash,
|
||||
salt=api_key.salt,
|
||||
)
|
||||
|
||||
def without_hash(self) -> APIKeyInfo:
|
||||
return APIKeyInfo(**self.model_dump(exclude={"hash", "salt"}))
|
||||
try:
|
||||
return APIKeyWithoutHash(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
prefix=api_key.prefix,
|
||||
postfix=api_key.postfix,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating APIKeyWithoutHash from db: {str(e)}")
|
||||
raise APIKeyError(f"Failed to create API key object: {str(e)}")
|
||||
|
||||
|
||||
async def create_api_key(
|
||||
async def generate_api_key(
|
||||
name: str,
|
||||
user_id: str,
|
||||
permissions: list[APIKeyPermission],
|
||||
permissions: List[APIKeyPermission],
|
||||
description: Optional[str] = None,
|
||||
) -> tuple[APIKeyInfo, str]:
|
||||
) -> tuple[APIKeyWithoutHash, str]:
|
||||
"""
|
||||
Generate a new API key and store it in the database.
|
||||
Returns the API key object (without hash) and the plain text key.
|
||||
"""
|
||||
generated_key = keysmith.generate_key()
|
||||
try:
|
||||
api_manager = APIKeyManager()
|
||||
key = api_manager.generate_api_key()
|
||||
|
||||
saved_key_obj = await PrismaAPIKey.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
}
|
||||
)
|
||||
api_key = await PrismaAPIKey.prisma().create(
|
||||
data=APIKeyCreateInput(
|
||||
id=str(uuid.uuid4()),
|
||||
name=name,
|
||||
prefix=key.prefix,
|
||||
postfix=key.postfix,
|
||||
key=key.hash,
|
||||
permissions=[p for p in permissions],
|
||||
description=description,
|
||||
userId=user_id,
|
||||
)
|
||||
)
|
||||
|
||||
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
||||
api_key_without_hash = APIKeyWithoutHash.from_db(api_key)
|
||||
return api_key_without_hash, key.raw
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while generating API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to generate API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while generating API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to generate API key: {str(e)}")
|
||||
|
||||
|
||||
async def get_active_api_keys_by_head(head: str) -> list[APIKeyInfoWithHash]:
|
||||
results = await PrismaAPIKey.prisma().find_many(
|
||||
where={"head": head, "status": APIKeyStatus.ACTIVE}
|
||||
)
|
||||
return [APIKeyInfoWithHash.from_db(key) for key in results]
|
||||
|
||||
|
||||
async def validate_api_key(plaintext_key: str) -> Optional[APIKeyInfo]:
|
||||
async def validate_api_key(plain_text_key: str) -> Optional[APIKey]:
|
||||
"""
|
||||
Validate an API key and return the API key object if valid and active.
|
||||
Validate an API key and return the API key object if valid.
|
||||
"""
|
||||
try:
|
||||
if not plaintext_key.startswith(APIKeySmith.PREFIX):
|
||||
if not plain_text_key.startswith(APIKeyManager.PREFIX):
|
||||
logger.warning("Invalid API key format")
|
||||
return None
|
||||
|
||||
head = plaintext_key[: APIKeySmith.HEAD_LENGTH]
|
||||
potential_matches = await get_active_api_keys_by_head(head)
|
||||
prefix = plain_text_key[: APIKeyManager.PREFIX_LENGTH]
|
||||
api_manager = APIKeyManager()
|
||||
|
||||
matched_api_key = next(
|
||||
(pm for pm in potential_matches if pm.match(plaintext_key)),
|
||||
None,
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where=APIKeyWhereInput(prefix=prefix, status=(APIKeyStatus.ACTIVE))
|
||||
)
|
||||
if not matched_api_key:
|
||||
# API key not found or invalid
|
||||
|
||||
if not api_key:
|
||||
logger.warning(f"No active API key found with prefix {prefix}")
|
||||
return None
|
||||
|
||||
# Migrate legacy keys to secure format on successful validation
|
||||
if matched_api_key.salt is None:
|
||||
matched_api_key = await _migrate_key_to_secure_hash(
|
||||
plaintext_key, matched_api_key
|
||||
is_valid = api_manager.verify_api_key(plain_text_key, api_key.key)
|
||||
if not is_valid:
|
||||
logger.warning("API key verification failed")
|
||||
return None
|
||||
|
||||
return APIKey.from_db(api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating API key: {str(e)}")
|
||||
raise APIKeyValidationError(f"Failed to validate API key: {str(e)}")
|
||||
|
||||
|
||||
async def revoke_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise APIKeyNotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to revoke this API key."
|
||||
)
|
||||
|
||||
return matched_api_key.without_hash()
|
||||
except Exception as e:
|
||||
logger.error(f"Error while validating API key: {e}")
|
||||
raise RuntimeError("Failed to validate API key") from e
|
||||
|
||||
|
||||
async def _migrate_key_to_secure_hash(
|
||||
plaintext_key: str, key_obj: APIKeyInfoWithHash
|
||||
) -> APIKeyInfoWithHash:
|
||||
"""Replace the SHA256 hash of a legacy API key with a salted Scrypt hash."""
|
||||
try:
|
||||
new_hash, new_salt = keysmith.hash_key(plaintext_key)
|
||||
await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_obj.id}, data={"hash": new_hash, "salt": new_salt}
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(
|
||||
status=APIKeyStatus.REVOKED, revokedAt=datetime.now(timezone.utc)
|
||||
),
|
||||
)
|
||||
logger.info(f"Migrated legacy API key #{key_obj.id} to secure format")
|
||||
# Update the API key object with new values for return
|
||||
key_obj.hash = new_hash
|
||||
key_obj.salt = new_salt
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate legacy API key #{key_obj.id}: {e}")
|
||||
|
||||
return key_obj
|
||||
|
||||
|
||||
async def revoke_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise NotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to revoke this API key.")
|
||||
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_id},
|
||||
data={
|
||||
"status": APIKeyStatus.REVOKED,
|
||||
"revokedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to revoke.")
|
||||
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> list[APIKeyInfo]:
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
return [APIKeyInfo.from_db(key) for key in api_keys]
|
||||
|
||||
|
||||
async def suspend_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
selector: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where=selector)
|
||||
|
||||
if not api_key:
|
||||
raise NotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to suspend this API key.")
|
||||
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=selector, data={"status": APIKeyStatus.SUSPENDED}
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to suspend.")
|
||||
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
def has_permission(api_key: APIKeyInfo, required_permission: APIKeyPermission) -> bool:
|
||||
return required_permission in api_key.permissions
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyInfo]:
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where={"id": key_id, "userId": user_id}
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while revoking API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to revoke API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while revoking API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to revoke API key: {str(e)}")
|
||||
|
||||
return APIKeyInfo.from_db(api_key)
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> List[APIKeyWithoutHash]:
|
||||
try:
|
||||
where_clause: APIKeyWhereInput = {"userId": user_id}
|
||||
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where=where_clause, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
return [APIKeyWithoutHash.from_db(key) for key in api_keys]
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while listing API keys: {str(e)}")
|
||||
raise APIKeyError(f"Failed to list API keys: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing API keys: {str(e)}")
|
||||
raise APIKeyError(f"Failed to list API keys: {str(e)}")
|
||||
|
||||
|
||||
async def suspend_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise APIKeyNotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to suspend this API key."
|
||||
)
|
||||
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(status=APIKeyStatus.SUSPENDED),
|
||||
)
|
||||
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while suspending API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to suspend API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while suspending API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to suspend API key: {str(e)}")
|
||||
|
||||
|
||||
def has_permission(api_key: APIKey, required_permission: APIKeyPermission) -> bool:
|
||||
try:
|
||||
return required_permission in api_key.permissions
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking API key permissions: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where=APIKeyWhereInput(id=key_id, userId=user_id)
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
return APIKeyWithoutHash.from_db(api_key)
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while getting API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to get API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to get API key: {str(e)}")
|
||||
|
||||
|
||||
async def update_api_key_permissions(
|
||||
key_id: str, user_id: str, permissions: list[APIKeyPermission]
|
||||
) -> APIKeyInfo:
|
||||
key_id: str, user_id: str, permissions: List[APIKeyPermission]
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
"""
|
||||
Update the permissions of an API key.
|
||||
"""
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if api_key is None:
|
||||
raise NotFoundError("No such API key found.")
|
||||
if api_key is None:
|
||||
raise APIKeyNotFoundError("No such API key found.")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to update this API key.")
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to update this API key."
|
||||
)
|
||||
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_id},
|
||||
data={"permissions": permissions},
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to update.")
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(permissions=permissions),
|
||||
)
|
||||
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while updating API key permissions: {str(e)}")
|
||||
raise APIKeyError(f"Failed to update API key permissions: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while updating API key permissions: {str(e)}")
|
||||
raise APIKeyError(f"Failed to update API key permissions: {str(e)}")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
@@ -7,7 +8,6 @@ from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Generic,
|
||||
Optional,
|
||||
@@ -20,7 +20,6 @@ from typing import (
|
||||
|
||||
import jsonref
|
||||
import jsonschema
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from prisma.models import AgentBlock
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
@@ -45,10 +44,9 @@ if TYPE_CHECKING:
|
||||
|
||||
app_config = Config()
|
||||
|
||||
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
|
||||
BlockInput = dict[str, Any] # Input: 1 input pin consumes 1 data.
|
||||
BlockOutputEntry = tuple[str, Any] # Output data should be a tuple of (name, value).
|
||||
BlockOutput = AsyncGen[BlockOutputEntry, None] # Output: 1 output pin produces n data.
|
||||
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
|
||||
BlockOutput = AsyncGen[BlockData, None] # Output: 1 output pin produces n data.
|
||||
CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a dict.
|
||||
|
||||
|
||||
@@ -91,45 +89,6 @@ class BlockCategory(Enum):
|
||||
return {"category": self.name, "description": self.value}
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
**data,
|
||||
)
|
||||
|
||||
|
||||
class BlockInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
inputSchema: dict[str, Any]
|
||||
outputSchema: dict[str, Any]
|
||||
costs: list[BlockCost]
|
||||
description: str
|
||||
categories: list[dict[str, str]]
|
||||
contributors: list[dict[str, Any]]
|
||||
staticOutput: bool
|
||||
uiType: str
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||
|
||||
@@ -347,7 +306,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
input_schema: Type[BlockSchemaInputType] = EmptySchema,
|
||||
output_schema: Type[BlockSchemaOutputType] = EmptySchema,
|
||||
test_input: BlockInput | list[BlockInput] | None = None,
|
||||
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
|
||||
test_output: BlockData | list[BlockData] | None = None,
|
||||
test_mock: dict[str, Any] | None = None,
|
||||
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
|
||||
disabled: bool = False,
|
||||
@@ -493,24 +452,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
"uiType": self.block_type.value,
|
||||
}
|
||||
|
||||
def get_info(self) -> BlockInfo:
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
return BlockInfo(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
inputSchema=self.input_schema.jsonschema(),
|
||||
outputSchema=self.output_schema.jsonschema(),
|
||||
costs=get_block_cost(self),
|
||||
description=self.description,
|
||||
categories=[category.dict() for category in self.categories],
|
||||
contributors=[
|
||||
contributor.model_dump() for contributor in self.contributors
|
||||
],
|
||||
staticOutput=self.static_output,
|
||||
uiType=self.block_type.value,
|
||||
)
|
||||
|
||||
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise ValueError(
|
||||
@@ -722,7 +663,7 @@ def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
|
||||
return cls() if cls else None
|
||||
|
||||
|
||||
@cached()
|
||||
@functools.cache
|
||||
def get_webhook_block_ids() -> Sequence[str]:
|
||||
return [
|
||||
id
|
||||
@@ -731,7 +672,7 @@ def get_webhook_block_ids() -> Sequence[str]:
|
||||
]
|
||||
|
||||
|
||||
@cached()
|
||||
@functools.cache
|
||||
def get_io_block_ids() -> Sequence[str]:
|
||||
return [
|
||||
id
|
||||
|
||||
@@ -29,7 +29,8 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.data.block import Block, BlockCost, BlockCostType
|
||||
from backend.data.block import Block
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
anthropic_credentials,
|
||||
|
||||
32
autogpt_platform/backend/backend/data/cost.py
Normal file
32
autogpt_platform/backend/backend/data/cost.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
**data,
|
||||
)
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import stripe
|
||||
from prisma import Json
|
||||
@@ -23,6 +23,7 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -40,9 +41,6 @@ from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block, BlockCost
|
||||
|
||||
settings = Settings()
|
||||
stripe.api_key = settings.secrets.stripe_api_key
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -999,14 +997,10 @@ def get_user_credit_model() -> UserCreditBase:
|
||||
return UserCredit()
|
||||
|
||||
|
||||
def get_block_costs() -> dict[str, list["BlockCost"]]:
|
||||
def get_block_costs() -> dict[str, list[BlockCost]]:
|
||||
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
|
||||
|
||||
|
||||
def get_block_cost(block: "Block") -> list["BlockCost"]:
|
||||
return BLOCK_COSTS.get(block.__class__, [])
|
||||
|
||||
|
||||
async def get_stripe_customer_id(user_id: str) -> str:
|
||||
user = await get_user_by_id(user_id)
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ async def disconnect():
|
||||
|
||||
|
||||
# Transaction timeout constant (in milliseconds)
|
||||
TRANSACTION_TIMEOUT = 30000 # 30 seconds - Increased from 15s to prevent timeout errors during graph creation under load
|
||||
TRANSACTION_TIMEOUT = 15000 # 15 seconds - Increased from 5s to prevent timeout errors
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -11,14 +11,11 @@ from typing import (
|
||||
Generator,
|
||||
Generic,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
@@ -27,6 +24,7 @@ from prisma.models import (
|
||||
AgentNodeExecutionKeyValueData,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
AgentGraphExecutionUpdateManyMutationInput,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
@@ -62,7 +60,7 @@ from .includes import (
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
graph_execution_include,
|
||||
)
|
||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -89,33 +87,6 @@ class BlockErrorStats(BaseModel):
|
||||
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
NodeInputMask = Mapping[str, JsonValue]
|
||||
NodesInputMasks = Mapping[str, NodeInputMask]
|
||||
|
||||
# dest: source
|
||||
VALID_STATUS_TRANSITIONS = {
|
||||
ExecutionStatus.QUEUED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
ExecutionStatus.RUNNING: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.TERMINATED, # For resuming halted execution
|
||||
],
|
||||
ExecutionStatus.COMPLETED: [
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
ExecutionStatus.FAILED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
ExecutionStatus.TERMINATED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class GraphExecutionMeta(BaseDbModel):
|
||||
@@ -123,15 +94,10 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
user_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
inputs: Optional[BlockInput] # no default -> required in the OpenAPI spec
|
||||
credential_inputs: Optional[dict[str, CredentialsMetaInput]]
|
||||
nodes_input_masks: Optional[dict[str, BlockInput]]
|
||||
preset_id: Optional[str]
|
||||
preset_id: Optional[str] = None
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
is_shared: bool = False
|
||||
share_token: Optional[str] = None
|
||||
|
||||
class Stats(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
@@ -213,18 +179,6 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
user_id=_graph_exec.userId,
|
||||
graph_id=_graph_exec.agentGraphId,
|
||||
graph_version=_graph_exec.agentGraphVersion,
|
||||
inputs=cast(BlockInput | None, _graph_exec.inputs),
|
||||
credential_inputs=(
|
||||
{
|
||||
name: CredentialsMetaInput.model_validate(cmi)
|
||||
for name, cmi in cast(dict, _graph_exec.credentialInputs).items()
|
||||
}
|
||||
if _graph_exec.credentialInputs
|
||||
else None
|
||||
),
|
||||
nodes_input_masks=cast(
|
||||
dict[str, BlockInput] | None, _graph_exec.nodesInputMasks
|
||||
),
|
||||
preset_id=_graph_exec.agentPresetId,
|
||||
status=ExecutionStatus(_graph_exec.executionStatus),
|
||||
started_at=start_time,
|
||||
@@ -248,13 +202,11 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
if stats
|
||||
else None
|
||||
),
|
||||
is_shared=_graph_exec.isShared,
|
||||
share_token=_graph_exec.shareToken,
|
||||
)
|
||||
|
||||
|
||||
class GraphExecution(GraphExecutionMeta):
|
||||
inputs: BlockInput # type: ignore - incompatible override is intentional
|
||||
inputs: BlockInput
|
||||
outputs: CompletedBlockOutput
|
||||
|
||||
@staticmethod
|
||||
@@ -274,18 +226,15 @@ class GraphExecution(GraphExecutionMeta):
|
||||
)
|
||||
|
||||
inputs = {
|
||||
**(
|
||||
graph_exec.inputs
|
||||
or {
|
||||
# fallback: extract inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
)
|
||||
}
|
||||
),
|
||||
**{
|
||||
# inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
)
|
||||
},
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
@@ -303,13 +252,14 @@ class GraphExecution(GraphExecutionMeta):
|
||||
if (
|
||||
block := get_block(exec.block_id)
|
||||
) and block.block_type == BlockType.OUTPUT:
|
||||
outputs[exec.input_data["name"]].append(exec.input_data.get("value"))
|
||||
outputs[exec.input_data["name"]].append(
|
||||
exec.input_data.get("value", None)
|
||||
)
|
||||
|
||||
return GraphExecution(
|
||||
**{
|
||||
field_name: getattr(graph_exec, field_name)
|
||||
for field_name in GraphExecutionMeta.model_fields
|
||||
if field_name != "inputs"
|
||||
},
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
@@ -342,17 +292,13 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
def to_graph_execution_entry(
|
||||
self,
|
||||
user_context: "UserContext",
|
||||
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
):
|
||||
def to_graph_execution_entry(self, user_context: "UserContext"):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
|
||||
user_context=user_context,
|
||||
)
|
||||
|
||||
@@ -369,9 +315,10 @@ class NodeExecutionResult(BaseModel):
|
||||
input_data: BlockInput
|
||||
output_data: CompletedBlockOutput
|
||||
add_time: datetime
|
||||
queue_time: datetime | None
|
||||
start_time: datetime | None
|
||||
end_time: datetime | None
|
||||
queue_time: datetime | None = None
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
stats: NodeExecutionStats | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
@@ -389,7 +336,7 @@ class NodeExecutionResult(BaseModel):
|
||||
else:
|
||||
input_data: BlockInput = defaultdict()
|
||||
for data in _node_exec.Input or []:
|
||||
input_data[data.name] = type_utils.convert(data.data, JsonValue)
|
||||
input_data[data.name] = type_utils.convert(data.data, type[Any])
|
||||
|
||||
output_data: CompletedBlockOutput = defaultdict(list)
|
||||
|
||||
@@ -398,7 +345,7 @@ class NodeExecutionResult(BaseModel):
|
||||
output_data[name].extend(messages)
|
||||
else:
|
||||
for data in _node_exec.Output or []:
|
||||
output_data[data.name].append(type_utils.convert(data.data, JsonValue))
|
||||
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
|
||||
|
||||
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
|
||||
if graph_execution:
|
||||
@@ -423,6 +370,7 @@ class NodeExecutionResult(BaseModel):
|
||||
queue_time=_node_exec.queuedTime,
|
||||
start_time=_node_exec.startedTime,
|
||||
end_time=_node_exec.endedTime,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
def to_node_execution_entry(
|
||||
@@ -593,12 +541,9 @@ async def get_graph_execution(
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
starting_nodes_input: list[tuple[str, BlockInput]], # list[(node_id, BlockInput)]
|
||||
inputs: Mapping[str, JsonValue],
|
||||
starting_nodes_input: list[tuple[str, BlockInput]],
|
||||
user_id: str,
|
||||
preset_id: Optional[str] = None,
|
||||
credential_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
@@ -606,18 +551,11 @@ async def create_graph_execution(
|
||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||
"""
|
||||
result = await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
),
|
||||
"nodesInputMasks": (
|
||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||
),
|
||||
"NodeExecutions": {
|
||||
data=AgentGraphExecutionCreateInput(
|
||||
agentGraphId=graph_id,
|
||||
agentGraphVersion=graph_version,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
NodeExecutions={
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
@@ -633,9 +571,9 @@ async def create_graph_execution(
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
},
|
||||
userId=user_id,
|
||||
agentPresetId=preset_id,
|
||||
),
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -646,7 +584,7 @@ async def upsert_execution_input(
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: JsonValue,
|
||||
input_data: Any,
|
||||
node_exec_id: str | None = None,
|
||||
) -> tuple[str, BlockInput]:
|
||||
"""
|
||||
@@ -695,7 +633,7 @@ async def upsert_execution_input(
|
||||
)
|
||||
return existing_execution.id, {
|
||||
**{
|
||||
input_data.name: type_utils.convert(input_data.data, JsonValue)
|
||||
input_data.name: type_utils.convert(input_data.data, type[Any])
|
||||
for input_data in existing_execution.Input or []
|
||||
},
|
||||
input_name: input_data,
|
||||
@@ -718,6 +656,42 @@ async def upsert_execution_input(
|
||||
)
|
||||
|
||||
|
||||
async def create_node_execution(
|
||||
node_exec_id: str,
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
) -> None:
|
||||
"""Create a new node execution with the first input."""
|
||||
json_input_data = SafeJson(input_data)
|
||||
await AgentNodeExecution.prisma().create(
|
||||
data=AgentNodeExecutionCreateInput(
|
||||
id=node_exec_id,
|
||||
agentNodeId=node_id,
|
||||
agentGraphExecutionId=graph_exec_id,
|
||||
executionStatus=ExecutionStatus.INCOMPLETE,
|
||||
Input={"create": {"name": input_name, "data": json_input_data}},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def add_input_to_node_execution(
|
||||
node_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
) -> None:
|
||||
"""Add an input to an existing node execution."""
|
||||
json_input_data = SafeJson(input_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data=AgentNodeExecutionInputOutputCreateInput(
|
||||
name=input_name,
|
||||
data=json_input_data,
|
||||
referencedByInputExecId=node_exec_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def upsert_execution_output(
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
@@ -756,11 +730,6 @@ async def update_graph_execution_stats(
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
if not status and not stats:
|
||||
raise ValueError(
|
||||
f"Must provide either status or stats to update for execution {graph_exec_id}"
|
||||
)
|
||||
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
|
||||
|
||||
if stats:
|
||||
@@ -772,25 +741,20 @@ async def update_graph_execution_stats(
|
||||
if status:
|
||||
update_data["executionStatus"] = status
|
||||
|
||||
where_clause: AgentGraphExecutionWhereInput = {"id": graph_exec_id}
|
||||
|
||||
if status:
|
||||
if allowed_from := VALID_STATUS_TRANSITIONS.get(status, []):
|
||||
# Add OR clause to check if current status is one of the allowed source statuses
|
||||
where_clause["AND"] = [
|
||||
{"id": graph_exec_id},
|
||||
{"OR": [{"executionStatus": s} for s in allowed_from]},
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Status {status} cannot be set via update for execution {graph_exec_id}. "
|
||||
f"This status can only be set at creation or is not a valid target status."
|
||||
)
|
||||
|
||||
await AgentGraphExecution.prisma().update_many(
|
||||
where=where_clause,
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
{"executionStatus": ExecutionStatus.RUNNING},
|
||||
{"executionStatus": ExecutionStatus.QUEUED},
|
||||
# Terminated graph can be resumed.
|
||||
{"executionStatus": ExecutionStatus.TERMINATED},
|
||||
],
|
||||
},
|
||||
data=update_data,
|
||||
)
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
|
||||
where={"id": graph_exec_id},
|
||||
@@ -798,7 +762,6 @@ async def update_graph_execution_stats(
|
||||
[*get_io_block_ids(), *get_webhook_block_ids()]
|
||||
),
|
||||
)
|
||||
|
||||
return GraphExecution.from_db(graph_exec)
|
||||
|
||||
|
||||
@@ -963,7 +926,7 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
|
||||
user_context: UserContext
|
||||
|
||||
|
||||
@@ -1025,18 +988,6 @@ class NodeExecutionEvent(NodeExecutionResult):
|
||||
)
|
||||
|
||||
|
||||
class SharedExecutionResponse(BaseModel):
|
||||
"""Public-safe response for shared executions"""
|
||||
|
||||
id: str
|
||||
graph_name: str
|
||||
graph_description: Optional[str]
|
||||
status: ExecutionStatus
|
||||
created_at: datetime
|
||||
outputs: CompletedBlockOutput # Only the final outputs, no intermediate data
|
||||
# Deliberately exclude: user_id, inputs, credentials, node details
|
||||
|
||||
|
||||
ExecutionEvent = Annotated[
|
||||
GraphExecutionEvent | NodeExecutionEvent, Field(discriminator="event_type")
|
||||
]
|
||||
@@ -1214,98 +1165,3 @@ async def get_block_error_stats(
|
||||
)
|
||||
for row in result
|
||||
]
|
||||
|
||||
|
||||
async def update_graph_execution_share_status(
|
||||
execution_id: str,
|
||||
user_id: str,
|
||||
is_shared: bool,
|
||||
share_token: str | None,
|
||||
shared_at: datetime | None,
|
||||
) -> None:
|
||||
"""Update the sharing status of a graph execution."""
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": execution_id},
|
||||
data={
|
||||
"isShared": is_shared,
|
||||
"shareToken": share_token,
|
||||
"sharedAt": shared_at,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_execution_by_share_token(
|
||||
share_token: str,
|
||||
) -> SharedExecutionResponse | None:
|
||||
"""Get a shared execution with limited public-safe data."""
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={
|
||||
"shareToken": share_token,
|
||||
"isShared": True,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={
|
||||
"AgentGraph": True,
|
||||
"NodeExecutions": {
|
||||
"include": {
|
||||
"Output": True,
|
||||
"Node": {
|
||||
"include": {
|
||||
"AgentBlock": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if not execution:
|
||||
return None
|
||||
|
||||
# Extract outputs from OUTPUT blocks only (consistent with GraphExecution.from_db)
|
||||
outputs: CompletedBlockOutput = defaultdict(list)
|
||||
if execution.NodeExecutions:
|
||||
for node_exec in execution.NodeExecutions:
|
||||
if node_exec.Node and node_exec.Node.agentBlockId:
|
||||
# Get the block definition to check its type
|
||||
block = get_block(node_exec.Node.agentBlockId)
|
||||
|
||||
if block and block.block_type == BlockType.OUTPUT:
|
||||
# For OUTPUT blocks, the data is stored in executionData or Input
|
||||
# The executionData contains the structured input with 'name' and 'value' fields
|
||||
if hasattr(node_exec, "executionData") and node_exec.executionData:
|
||||
exec_data = type_utils.convert(
|
||||
node_exec.executionData, dict[str, Any]
|
||||
)
|
||||
if "name" in exec_data:
|
||||
name = exec_data["name"]
|
||||
value = exec_data.get("value")
|
||||
outputs[name].append(value)
|
||||
elif node_exec.Input:
|
||||
# Build input_data from Input relation
|
||||
input_data = {}
|
||||
for data in node_exec.Input:
|
||||
if data.name and data.data is not None:
|
||||
input_data[data.name] = type_utils.convert(
|
||||
data.data, JsonValue
|
||||
)
|
||||
|
||||
if "name" in input_data:
|
||||
name = input_data["name"]
|
||||
value = input_data.get("value")
|
||||
outputs[name].append(value)
|
||||
|
||||
return SharedExecutionResponse(
|
||||
id=execution.id,
|
||||
graph_name=(
|
||||
execution.AgentGraph.name
|
||||
if (execution.AgentGraph and execution.AgentGraph.name)
|
||||
else "Untitled Agent"
|
||||
),
|
||||
graph_description=(
|
||||
execution.AgentGraph.description if execution.AgentGraph else None
|
||||
),
|
||||
status=ExecutionStatus(execution.executionStatus),
|
||||
created_at=execution.createdAt,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from prisma.enums import SubmissionStatus
|
||||
@@ -13,7 +12,7 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic import Field, JsonValue, create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -29,14 +28,12 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, query_raw_with_schema, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .execution import NodesInputMasks
|
||||
from .integrations import Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -162,8 +159,6 @@ class BaseGraph(BaseDbModel):
|
||||
is_active: bool = True
|
||||
name: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
recommended_schedule_cron: str | None = None
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
forked_from_id: str | None = None
|
||||
@@ -210,35 +205,6 @@ class BaseGraph(BaseDbModel):
|
||||
None,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def trigger_setup_info(self) -> "GraphTriggerInfo | None":
|
||||
if not (
|
||||
self.webhook_input_node
|
||||
and (trigger_block := self.webhook_input_node.block).webhook_config
|
||||
):
|
||||
return None
|
||||
|
||||
return GraphTriggerInfo(
|
||||
provider=trigger_block.webhook_config.provider,
|
||||
config_schema={
|
||||
**(json_schema := trigger_block.input_schema.jsonschema()),
|
||||
"properties": {
|
||||
pn: sub_schema
|
||||
for pn, sub_schema in json_schema["properties"].items()
|
||||
if not is_credentials_field_name(pn)
|
||||
},
|
||||
"required": [
|
||||
pn
|
||||
for pn in json_schema.get("required", [])
|
||||
if not is_credentials_field_name(pn)
|
||||
],
|
||||
},
|
||||
credentials_input_name=next(
|
||||
iter(trigger_block.input_schema.get_credentials_fields()), None
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
@@ -272,14 +238,6 @@ class BaseGraph(BaseDbModel):
|
||||
}
|
||||
|
||||
|
||||
class GraphTriggerInfo(BaseModel):
|
||||
provider: ProviderName
|
||||
config_schema: dict[str, Any] = Field(
|
||||
description="Input schema for the trigger block"
|
||||
)
|
||||
credentials_input_name: Optional[str]
|
||||
|
||||
|
||||
class Graph(BaseGraph):
|
||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
||||
|
||||
@@ -384,8 +342,6 @@ class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
|
||||
created_at: datetime
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[NodeModel]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
@@ -398,10 +354,6 @@ class GraphModel(Graph):
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||
return cast(NodeModel, super().webhook_input_node)
|
||||
|
||||
def meta(self) -> "GraphMeta":
|
||||
"""
|
||||
Returns a GraphMeta object with metadata about the graph.
|
||||
@@ -462,7 +414,7 @@ class GraphModel(Graph):
|
||||
def validate_graph(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
"""
|
||||
Validate graph structure and raise `ValueError` on issues.
|
||||
@@ -476,7 +428,7 @@ class GraphModel(Graph):
|
||||
def _validate_graph(
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> None:
|
||||
errors = GraphModel._validate_graph_get_errors(
|
||||
graph, for_run, nodes_input_masks
|
||||
@@ -490,7 +442,7 @@ class GraphModel(Graph):
|
||||
def validate_graph_get_errors(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph and return structured errors per node.
|
||||
@@ -512,7 +464,7 @@ class GraphModel(Graph):
|
||||
def _validate_graph_get_errors(
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph and return structured errors per node.
|
||||
@@ -703,12 +655,9 @@ class GraphModel(Graph):
|
||||
version=graph.version,
|
||||
forked_from_id=graph.forkedFromId,
|
||||
forked_from_version=graph.forkedFromVersion,
|
||||
created_at=graph.createdAt,
|
||||
is_active=graph.isActive,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
instructions=graph.instructions,
|
||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
||||
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
|
||||
links=list(
|
||||
{
|
||||
@@ -747,13 +696,6 @@ class GraphMeta(Graph):
|
||||
return GraphMeta(**graph.model_dump())
|
||||
|
||||
|
||||
class GraphsPaginated(BaseModel):
|
||||
"""Response schema for paginated graphs."""
|
||||
|
||||
graphs: list[GraphMeta]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
# --------------------- CRUD functions --------------------- #
|
||||
|
||||
|
||||
@@ -782,42 +724,31 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
return NodeModel.from_db(node)
|
||||
|
||||
|
||||
async def list_graphs_paginated(
|
||||
async def list_graphs(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
filter_by: Literal["active"] | None = "active",
|
||||
) -> GraphsPaginated:
|
||||
) -> list[GraphMeta]:
|
||||
"""
|
||||
Retrieves paginated graph metadata objects.
|
||||
Retrieves graph metadata objects.
|
||||
Default behaviour is to get all currently active graphs.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user that owns the graphs.
|
||||
page: Page number (1-based).
|
||||
page_size: Number of graphs per page.
|
||||
filter_by: An optional filter to either select graphs.
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
GraphsPaginated: Paginated list of graph metadata.
|
||||
list[GraphMeta]: A list of objects representing the retrieved graphs.
|
||||
"""
|
||||
where_clause: AgentGraphWhereInput = {"userId": user_id}
|
||||
|
||||
if filter_by == "active":
|
||||
where_clause["isActive"] = True
|
||||
|
||||
# Get total count
|
||||
total_count = await AgentGraph.prisma().count(where=where_clause)
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
# Get paginated results
|
||||
offset = (page - 1) * page_size
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where=where_clause,
|
||||
distinct=["id"],
|
||||
order={"version": "desc"},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
skip=offset,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
graph_models: list[GraphMeta] = []
|
||||
@@ -831,15 +762,7 @@ async def list_graphs_paginated(
|
||||
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||
continue
|
||||
|
||||
return GraphsPaginated(
|
||||
graphs=graph_models,
|
||||
pagination=Pagination(
|
||||
total_items=total_count,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
return graph_models
|
||||
|
||||
|
||||
async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph | None:
|
||||
@@ -1122,7 +1045,6 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
version=graph.version,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
recommendedScheduleCron=graph.recommended_schedule_cron,
|
||||
isActive=graph.is_active,
|
||||
userId=user_id,
|
||||
forkedFromId=graph.forked_from_id,
|
||||
@@ -1181,7 +1103,6 @@ def make_graph_model(creatable_graph: Graph, user_id: str) -> GraphModel:
|
||||
return GraphModel(
|
||||
**creatable_graph.model_dump(exclude={"nodes"}),
|
||||
user_id=user_id,
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
nodes=[
|
||||
NodeModel(
|
||||
**creatable_node.model_dump(),
|
||||
|
||||
@@ -59,15 +59,9 @@ def graph_execution_include(
|
||||
}
|
||||
|
||||
|
||||
AGENT_PRESET_INCLUDE: prisma.types.AgentPresetInclude = {
|
||||
"InputPresets": True,
|
||||
"Webhook": True,
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE},
|
||||
"AgentPresets": {"include": AGENT_PRESET_INCLUDE},
|
||||
"AgentPresets": {"include": {"InputPresets": True}},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import cache
|
||||
|
||||
from autogpt_libs.utils.cache import cached, thread_cached
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
@@ -12,7 +13,7 @@ load_dotenv()
|
||||
|
||||
HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,7 +35,7 @@ def disconnect():
|
||||
get_redis().close()
|
||||
|
||||
|
||||
@cached()
|
||||
@cache
|
||||
def get_redis() -> Redis:
|
||||
return connect()
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Optional, cast
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import HTTPException
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import User as PrismaUser
|
||||
@@ -24,11 +23,7 @@ from backend.util.settings import Settings
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# Cache decorator alias for consistent user lookup caching
|
||||
cache_user_lookup = cached(maxsize=1000, ttl_seconds=300)
|
||||
|
||||
|
||||
@cache_user_lookup
|
||||
async def get_or_create_user(user_data: dict) -> User:
|
||||
try:
|
||||
user_id = user_data.get("sub")
|
||||
@@ -54,7 +49,6 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
|
||||
|
||||
|
||||
@cache_user_lookup
|
||||
async def get_user_by_id(user_id: str) -> User:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
@@ -70,7 +64,6 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
raise DatabaseError(f"Failed to get user email for user {user_id}: {e}") from e
|
||||
|
||||
|
||||
@cache_user_lookup
|
||||
async def get_user_by_email(email: str) -> Optional[User]:
|
||||
try:
|
||||
user = await prisma.user.find_unique(where={"email": email})
|
||||
@@ -81,17 +74,7 @@ async def get_user_by_email(email: str) -> Optional[User]:
|
||||
|
||||
async def update_user_email(user_id: str, email: str):
|
||||
try:
|
||||
# Get old email first for cache invalidation
|
||||
old_user = await prisma.user.find_unique(where={"id": user_id})
|
||||
old_email = old_user.email if old_user else None
|
||||
|
||||
await prisma.user.update(where={"id": user_id}, data={"email": email})
|
||||
|
||||
# Selectively invalidate only the specific user entries
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
if old_email:
|
||||
get_user_by_email.cache_delete(old_email)
|
||||
get_user_by_email.cache_delete(email)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to update user email for user {user_id}: {e}"
|
||||
@@ -131,8 +114,6 @@ async def update_user_integrations(user_id: str, data: UserIntegrations):
|
||||
where={"id": user_id},
|
||||
data={"integrations": encrypted_data},
|
||||
)
|
||||
# Invalidate cache for this user
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
|
||||
|
||||
async def migrate_and_encrypt_user_integrations():
|
||||
@@ -304,10 +285,6 @@ async def update_user_notification_preference(
|
||||
)
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
|
||||
# Invalidate cache for this user since notification preferences are part of user data
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
|
||||
preferences: dict[NotificationType, bool] = {
|
||||
NotificationType.AGENT_RUN: user.notifyOnAgentRun or True,
|
||||
NotificationType.ZERO_BALANCE: user.notifyOnZeroBalance or True,
|
||||
@@ -346,8 +323,6 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||
where={"id": user_id},
|
||||
data={"emailVerified": verified},
|
||||
)
|
||||
# Invalidate cache for this user
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to set email verification status for user {user_id}: {e}"
|
||||
@@ -432,10 +407,6 @@ async def update_user_timezone(user_id: str, timezone: str) -> User:
|
||||
)
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
|
||||
# Invalidate cache for this user
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
|
||||
return User.from_db(user)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to update timezone for user {user_id}: {e}") from e
|
||||
|
||||
@@ -107,7 +107,7 @@ async def generate_activity_status_for_execution(
|
||||
# Check if we have OpenAI API key
|
||||
try:
|
||||
settings = Settings()
|
||||
if not settings.secrets.openai_internal_api_key:
|
||||
if not settings.secrets.openai_api_key:
|
||||
logger.debug(
|
||||
"OpenAI API key not configured, skipping activity status generation"
|
||||
)
|
||||
@@ -115,7 +115,7 @@ async def generate_activity_status_for_execution(
|
||||
|
||||
# Get all node executions for this graph execution
|
||||
node_executions = await db_client.get_node_executions(
|
||||
graph_exec_id, include_exec_data=True
|
||||
graph_exec_id=graph_exec_id, include_exec_data=True
|
||||
)
|
||||
|
||||
# Get graph metadata and full graph structure for name, description, and links
|
||||
@@ -187,7 +187,7 @@ async def generate_activity_status_for_execution(
|
||||
credentials = APIKeyCredentials(
|
||||
id="openai",
|
||||
provider="openai",
|
||||
api_key=SecretStr(settings.secrets.openai_internal_api_key),
|
||||
api_key=SecretStr(settings.secrets.openai_api_key),
|
||||
title="System OpenAI",
|
||||
)
|
||||
|
||||
@@ -423,6 +423,7 @@ async def _call_llm_direct(
|
||||
credentials=credentials,
|
||||
llm_model=LlmModel.GPT4O_MINI,
|
||||
prompt=prompt,
|
||||
json_format=False,
|
||||
max_tokens=150,
|
||||
compress_prompt_to_fit=True,
|
||||
)
|
||||
|
||||
@@ -468,7 +468,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_llm.return_value = (
|
||||
"I analyzed your data and provided the requested insights."
|
||||
)
|
||||
@@ -520,7 +520,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = ""
|
||||
mock_settings.return_value.secrets.openai_api_key = ""
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -546,7 +546,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -581,7 +581,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_llm.return_value = "Agent completed execution."
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
@@ -633,7 +633,7 @@ class TestIntegration:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
|
||||
mock_response = LLMResponse(
|
||||
raw_response={},
|
||||
|
||||
@@ -4,12 +4,13 @@ from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
add_input_to_node_execution,
|
||||
create_graph_execution,
|
||||
create_node_execution,
|
||||
get_block_error_stats,
|
||||
get_execution_kv_data,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_latest_node_execution,
|
||||
get_node_execution,
|
||||
get_node_executions,
|
||||
set_execution_kv_data,
|
||||
@@ -17,7 +18,6 @@ from backend.data.execution import (
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_status,
|
||||
update_node_execution_status_batch,
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.generate_data import get_user_execution_summary_data
|
||||
@@ -85,16 +85,6 @@ class DatabaseManager(AppService):
|
||||
async def health_check(self) -> str:
|
||||
if not db.is_connected():
|
||||
raise UnhealthyServiceError("Database is not connected")
|
||||
|
||||
try:
|
||||
# Test actual database connectivity by executing a simple query
|
||||
# This will fail if Prisma query engine is not responding
|
||||
result = await db.query_raw_with_schema("SELECT 1 as health_check")
|
||||
if not result or result[0].get("health_check") != 1:
|
||||
raise UnhealthyServiceError("Database query test failed")
|
||||
except Exception as e:
|
||||
raise UnhealthyServiceError(f"Database health check failed: {e}")
|
||||
|
||||
return await super().health_check()
|
||||
|
||||
@classmethod
|
||||
@@ -115,13 +105,13 @@ class DatabaseManager(AppService):
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_execution = _(get_node_execution)
|
||||
get_node_executions = _(get_node_executions)
|
||||
get_latest_node_execution = _(get_latest_node_execution)
|
||||
update_node_execution_status = _(update_node_execution_status)
|
||||
update_node_execution_status_batch = _(update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
create_node_execution = _(create_node_execution)
|
||||
add_input_to_node_execution = _(add_input_to_node_execution)
|
||||
get_execution_kv_data = _(get_execution_kv_data)
|
||||
set_execution_kv_data = _(set_execution_kv_data)
|
||||
get_block_error_stats = _(get_block_error_stats)
|
||||
@@ -181,10 +171,12 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
get_graph_executions = _(d.get_graph_executions)
|
||||
get_graph_execution_meta = _(d.get_graph_execution_meta)
|
||||
get_node_executions = _(d.get_node_executions)
|
||||
create_node_execution = _(d.create_node_execution)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(d.update_graph_execution_stats)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
add_input_to_node_execution = _(d.add_input_to_node_execution)
|
||||
|
||||
# Graphs
|
||||
get_graph_metadata = _(d.get_graph_metadata)
|
||||
@@ -199,14 +191,6 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
# User Emails
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
|
||||
# Library
|
||||
list_library_agents = _(d.list_library_agents)
|
||||
add_store_agent_to_library = _(d.add_store_agent_to_library)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(d.get_store_agents)
|
||||
get_store_agent_details = _(d.get_store_agent_details)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -217,16 +201,12 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
create_graph_execution = d.create_graph_execution
|
||||
get_connected_output_nodes = d.get_connected_output_nodes
|
||||
get_latest_node_execution = d.get_latest_node_execution
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
update_graph_execution_stats = d.update_graph_execution_stats
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
|
||||
154
autogpt_platform/backend/backend/executor/execution_cache.py
Normal file
154
autogpt_platform/backend/backend/executor/execution_cache.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def with_lock(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self._lock:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ExecutionCache:
|
||||
def __init__(self, graph_exec_id: str, db_client: "DatabaseManagerClient"):
|
||||
self._lock = threading.RLock()
|
||||
self._graph_exec_id = graph_exec_id
|
||||
self._graph_stats: GraphExecutionStats = GraphExecutionStats()
|
||||
self._node_executions: OrderedDict[str, NodeExecutionResult] = OrderedDict()
|
||||
|
||||
for execution in db_client.get_node_executions(self._graph_exec_id):
|
||||
self._node_executions[execution.node_exec_id] = execution
|
||||
|
||||
@with_lock
|
||||
def get_node_execution(self, node_exec_id: str) -> Optional[NodeExecutionResult]:
|
||||
execution = self._node_executions.get(node_exec_id)
|
||||
return execution.model_copy(deep=True) if execution else None
|
||||
|
||||
@with_lock
|
||||
def get_latest_node_execution(self, node_id: str) -> Optional[NodeExecutionResult]:
|
||||
for execution in reversed(self._node_executions.values()):
|
||||
if (
|
||||
execution.node_id == node_id
|
||||
and execution.status != ExecutionStatus.INCOMPLETE
|
||||
):
|
||||
return execution.model_copy(deep=True)
|
||||
return None
|
||||
|
||||
@with_lock
|
||||
def get_node_executions(
|
||||
self,
|
||||
*,
|
||||
statuses: Optional[list] = None,
|
||||
block_ids: Optional[list] = None,
|
||||
node_id: Optional[str] = None,
|
||||
):
|
||||
results = []
|
||||
for execution in self._node_executions.values():
|
||||
if statuses and execution.status not in statuses:
|
||||
continue
|
||||
if block_ids and execution.block_id not in block_ids:
|
||||
continue
|
||||
if node_id and execution.node_id != node_id:
|
||||
continue
|
||||
results.append(execution.model_copy(deep=True))
|
||||
return results
|
||||
|
||||
@with_lock
|
||||
def update_node_execution_status(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: Optional[dict] = None,
|
||||
stats: Optional[dict] = None,
|
||||
):
|
||||
if exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {exec_id} not found in cache")
|
||||
|
||||
execution = self._node_executions[exec_id]
|
||||
execution.status = status
|
||||
|
||||
if execution_data:
|
||||
execution.input_data.update(execution_data)
|
||||
|
||||
if stats:
|
||||
execution.stats = execution.stats or NodeExecutionStats()
|
||||
current_stats = execution.stats.model_dump()
|
||||
current_stats.update(stats)
|
||||
execution.stats = NodeExecutionStats.model_validate(current_stats)
|
||||
|
||||
@with_lock
|
||||
def upsert_execution_output(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
) -> NodeExecutionResult:
|
||||
if node_exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {node_exec_id} not found in cache")
|
||||
|
||||
execution = self._node_executions[node_exec_id]
|
||||
if output_name not in execution.output_data:
|
||||
execution.output_data[output_name] = []
|
||||
execution.output_data[output_name].append(output_data)
|
||||
|
||||
return execution
|
||||
|
||||
@with_lock
|
||||
def update_graph_stats(
|
||||
self, status: Optional[ExecutionStatus] = None, stats: Optional[dict] = None
|
||||
):
|
||||
if status is not None:
|
||||
pass
|
||||
if stats is not None:
|
||||
current_stats = self._graph_stats.model_dump()
|
||||
current_stats.update(stats)
|
||||
self._graph_stats = GraphExecutionStats.model_validate(current_stats)
|
||||
|
||||
@with_lock
|
||||
def update_graph_start_time(self):
|
||||
"""Update graph start time (handled by database persistence)."""
|
||||
pass
|
||||
|
||||
@with_lock
|
||||
def find_incomplete_execution_for_input(
|
||||
self, node_id: str, input_name: str
|
||||
) -> tuple[str, NodeExecutionResult] | None:
|
||||
for exec_id, execution in self._node_executions.items():
|
||||
if (
|
||||
execution.node_id == node_id
|
||||
and execution.status == ExecutionStatus.INCOMPLETE
|
||||
and input_name not in execution.input_data
|
||||
):
|
||||
return exec_id, execution
|
||||
return None
|
||||
|
||||
@with_lock
|
||||
def add_node_execution(
|
||||
self, node_exec_id: str, execution: NodeExecutionResult
|
||||
) -> None:
|
||||
self._node_executions[node_exec_id] = execution
|
||||
|
||||
@with_lock
|
||||
def update_execution_input(
|
||||
self, exec_id: str, input_name: str, input_data: Any
|
||||
) -> dict:
|
||||
if exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {exec_id} not found in cache")
|
||||
execution = self._node_executions[exec_id]
|
||||
execution.input_data[input_name] = input_data
|
||||
return execution.input_data.copy()
|
||||
|
||||
def finalize(self) -> None:
|
||||
with self._lock:
|
||||
self._node_executions.clear()
|
||||
self._graph_stats = GraphExecutionStats()
|
||||
@@ -0,0 +1,355 @@
|
||||
"""Test execution creation with proper ID generation and persistence."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def execution_client_with_mock_db(event_loop):
|
||||
"""Create an ExecutionDataClient with proper database records."""
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, User
|
||||
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
|
||||
# Create test database records to satisfy foreign key constraints
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": "test_user_123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
)
|
||||
|
||||
await AgentGraph.prisma().create(
|
||||
data={
|
||||
"id": "test_graph_456",
|
||||
"version": 1,
|
||||
"userId": "test_user_123",
|
||||
"name": "Test Graph",
|
||||
"description": "Test graph for execution tests",
|
||||
}
|
||||
)
|
||||
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
|
||||
await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"id": "test_graph_exec_id",
|
||||
"userId": "test_user_123",
|
||||
"agentGraphId": "test_graph_456",
|
||||
"agentGraphVersion": 1,
|
||||
"executionStatus": AgentExecutionStatus.RUNNING,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Records might already exist, that's fine
|
||||
pass
|
||||
|
||||
# Mock the graph execution metadata - align with assertions below
|
||||
mock_graph_meta = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_123",
|
||||
graph_id="test_graph_456",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# Create client with ThreadPoolExecutor and graph metadata (constructed inside patch)
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# Storage for tracking created executions
|
||||
created_executions = []
|
||||
|
||||
async def mock_create_node_execution(
|
||||
node_exec_id, node_id, graph_exec_id, input_name, input_data
|
||||
):
|
||||
"""Mock execution creation that records what was created."""
|
||||
created_executions.append(
|
||||
{
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": node_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"input_name": input_name,
|
||||
"input_data": input_data,
|
||||
}
|
||||
)
|
||||
return node_exec_id
|
||||
|
||||
def sync_mock_create_node_execution(
|
||||
node_exec_id, node_id, graph_exec_id, input_name, input_data
|
||||
):
|
||||
"""Mock sync execution creation that records what was created."""
|
||||
created_executions.append(
|
||||
{
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": node_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"input_name": input_name,
|
||||
"input_data": input_data,
|
||||
}
|
||||
)
|
||||
return node_exec_id
|
||||
|
||||
# Prepare mock async and sync DB clients
|
||||
async_mock_client = AsyncMock()
|
||||
async_mock_client.create_node_execution = mock_create_node_execution
|
||||
|
||||
sync_mock_client = MagicMock()
|
||||
sync_mock_client.create_node_execution = sync_mock_create_node_execution
|
||||
# Mock graph execution for return values
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
|
||||
mock_graph_update = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_123",
|
||||
graph_id="test_graph_456",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# No-ops for other sync methods used by the client during tests
|
||||
sync_mock_client.add_input_to_node_execution.side_effect = lambda **kwargs: None
|
||||
sync_mock_client.update_node_execution_status.side_effect = (
|
||||
lambda *args, **kwargs: None
|
||||
)
|
||||
sync_mock_client.upsert_execution_output.side_effect = lambda **kwargs: None
|
||||
sync_mock_client.update_graph_execution_stats.side_effect = (
|
||||
lambda *args, **kwargs: mock_graph_update
|
||||
)
|
||||
sync_mock_client.update_graph_execution_start_time.side_effect = (
|
||||
lambda *args, **kwargs: mock_graph_update
|
||||
)
|
||||
|
||||
thread = threading.Thread(target=event_loop.run_forever, daemon=True)
|
||||
thread.start()
|
||||
with patch(
|
||||
"backend.executor.execution_data.get_database_manager_async_client",
|
||||
return_value=async_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_database_manager_client",
|
||||
return_value=sync_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_execution_event_bus"
|
||||
), patch(
|
||||
"backend.executor.execution_data.non_blocking_persist", lambda func: func
|
||||
):
|
||||
# Now construct the client under the patch so it captures the mocked clients
|
||||
client = ExecutionDataClient(executor, "test_graph_exec_id", mock_graph_meta)
|
||||
# Store the mocks for the test to access if needed
|
||||
setattr(client, "_test_async_client", async_mock_client)
|
||||
setattr(client, "_test_sync_client", sync_mock_client)
|
||||
setattr(client, "_created_executions", created_executions)
|
||||
yield client
|
||||
|
||||
# Cleanup test database records
|
||||
try:
|
||||
await AgentGraphExecution.prisma().delete_many(
|
||||
where={"id": "test_graph_exec_id"}
|
||||
)
|
||||
await AgentGraph.prisma().delete_many(where={"id": "test_graph_456"})
|
||||
await User.prisma().delete_many(where={"id": "test_user_123"})
|
||||
except Exception:
|
||||
# Cleanup may fail if records don't exist
|
||||
pass
|
||||
|
||||
# Cleanup
|
||||
event_loop.call_soon_threadsafe(event_loop.stop)
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestExecutionCreation:
|
||||
"""Test execution creation with proper ID generation and persistence."""
|
||||
|
||||
async def test_execution_creation_with_valid_ids(
|
||||
self, execution_client_with_mock_db
|
||||
):
|
||||
"""Test that execution creation generates and persists valid IDs."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
node_id = "test_node_789"
|
||||
input_name = "test_input"
|
||||
input_data = "test_value"
|
||||
block_id = "test_block_abc"
|
||||
|
||||
# This should trigger execution creation since cache is empty
|
||||
exec_id, input_dict = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Verify execution ID is valid UUID
|
||||
try:
|
||||
uuid.UUID(exec_id)
|
||||
except ValueError:
|
||||
pytest.fail(f"Generated execution ID '{exec_id}' is not a valid UUID")
|
||||
|
||||
# Verify execution was created in cache with complete data
|
||||
assert exec_id in client._cache._node_executions
|
||||
cached_execution = client._cache._node_executions[exec_id]
|
||||
|
||||
# Check all required fields have valid values
|
||||
assert cached_execution.user_id == "test_user_123"
|
||||
assert cached_execution.graph_id == "test_graph_456"
|
||||
assert cached_execution.graph_version == 1
|
||||
assert cached_execution.graph_exec_id == "test_graph_exec_id"
|
||||
assert cached_execution.node_exec_id == exec_id
|
||||
assert cached_execution.node_id == node_id
|
||||
assert cached_execution.block_id == block_id
|
||||
assert cached_execution.status == ExecutionStatus.INCOMPLETE
|
||||
assert cached_execution.input_data == {input_name: input_data}
|
||||
assert isinstance(cached_execution.add_time, datetime)
|
||||
|
||||
# Verify execution was persisted to database with our generated ID
|
||||
created_executions = getattr(client, "_created_executions", [])
|
||||
assert len(created_executions) == 1
|
||||
created = created_executions[0]
|
||||
assert created["node_exec_id"] == exec_id # Our generated ID was used
|
||||
assert created["node_id"] == node_id
|
||||
assert created["graph_exec_id"] == "test_graph_exec_id"
|
||||
assert created["input_name"] == input_name
|
||||
assert created["input_data"] == input_data
|
||||
|
||||
# Verify input dict returned correctly
|
||||
assert input_dict == {input_name: input_data}
|
||||
|
||||
async def test_execution_reuse_vs_creation(self, execution_client_with_mock_db):
|
||||
"""Test that execution reuse works and creation only happens when needed."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
node_id = "reuse_test_node"
|
||||
block_id = "reuse_test_block"
|
||||
|
||||
# Create first execution
|
||||
exec_id_1, input_dict_1 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_1",
|
||||
input_data="value_1",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# This should reuse the existing INCOMPLETE execution
|
||||
exec_id_2, input_dict_2 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_2",
|
||||
input_data="value_2",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Should reuse the same execution
|
||||
assert exec_id_1 == exec_id_2
|
||||
assert input_dict_2 == {"input_1": "value_1", "input_2": "value_2"}
|
||||
|
||||
# Only one execution should be created in database
|
||||
created_executions = getattr(client, "_created_executions", [])
|
||||
assert len(created_executions) == 1
|
||||
|
||||
# Verify cache has the merged inputs
|
||||
cached_execution = client._cache._node_executions[exec_id_1]
|
||||
assert cached_execution.input_data == {
|
||||
"input_1": "value_1",
|
||||
"input_2": "value_2",
|
||||
}
|
||||
|
||||
# Now complete the execution and try to add another input
|
||||
client.update_node_status_and_publish(
|
||||
exec_id=exec_id_1, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
# Verify the execution status was actually updated in the cache
|
||||
updated_execution = client._cache._node_executions[exec_id_1]
|
||||
assert (
|
||||
updated_execution.status == ExecutionStatus.COMPLETED
|
||||
), f"Expected COMPLETED but got {updated_execution.status}"
|
||||
|
||||
# This should create a NEW execution since the first is no longer INCOMPLETE
|
||||
exec_id_3, input_dict_3 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_3",
|
||||
input_data="value_3",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Should be a different execution
|
||||
assert exec_id_3 != exec_id_1
|
||||
assert input_dict_3 == {"input_3": "value_3"}
|
||||
|
||||
# Verify cache behavior: should have two different executions in cache now
|
||||
cached_executions = client._cache._node_executions
|
||||
assert len(cached_executions) == 2
|
||||
assert exec_id_1 in cached_executions
|
||||
assert exec_id_3 in cached_executions
|
||||
|
||||
# First execution should be COMPLETED
|
||||
assert cached_executions[exec_id_1].status == ExecutionStatus.COMPLETED
|
||||
# Third execution should be INCOMPLETE (newly created)
|
||||
assert cached_executions[exec_id_3].status == ExecutionStatus.INCOMPLETE
|
||||
|
||||
async def test_multiple_nodes_get_different_execution_ids(
|
||||
self, execution_client_with_mock_db
|
||||
):
|
||||
"""Test that different nodes get different execution IDs."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
# Create executions for different nodes
|
||||
exec_id_a, _ = client.upsert_execution_input(
|
||||
node_id="node_a",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="block_a",
|
||||
)
|
||||
|
||||
exec_id_b, _ = client.upsert_execution_input(
|
||||
node_id="node_b",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="block_b",
|
||||
)
|
||||
|
||||
# Should be different executions with different IDs
|
||||
assert exec_id_a != exec_id_b
|
||||
|
||||
# Both should be valid UUIDs
|
||||
uuid.UUID(exec_id_a)
|
||||
uuid.UUID(exec_id_b)
|
||||
|
||||
# Both should be in cache
|
||||
cached_executions = client._cache._node_executions
|
||||
assert len(cached_executions) == 2
|
||||
assert exec_id_a in cached_executions
|
||||
assert exec_id_b in cached_executions
|
||||
|
||||
# Both should have correct node IDs
|
||||
assert cached_executions[exec_id_a].node_id == "node_a"
|
||||
assert cached_executions[exec_id_b].node_id == "node_b"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
338
autogpt_platform/backend/backend/executor/execution_data.py
Normal file
338
autogpt_platform/backend/backend/executor/execution_data.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionMeta,
|
||||
NodeExecutionResult,
|
||||
)
|
||||
from backend.data.graph import Node
|
||||
from backend.data.model import GraphExecutionStats
|
||||
from backend.executor.execution_cache import ExecutionCache
|
||||
from backend.util.clients import (
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def non_blocking_persist(func: Callable[P, T]) -> Callable[P, None]:
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
# First argument is always self for methods - access through cast for typing
|
||||
self = cast("ExecutionDataClient", args[0])
|
||||
future = self._executor.submit(func, *args, **kwargs)
|
||||
self._pending_tasks.add(future)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ExecutionDataClient:
|
||||
def __init__(
|
||||
self, executor: Executor, graph_exec_id: str, graph_metadata: GraphExecutionMeta
|
||||
):
|
||||
self._executor = executor
|
||||
self._graph_exec_id = graph_exec_id
|
||||
self._cache = ExecutionCache(graph_exec_id, self.db_client_sync)
|
||||
self._pending_tasks = set()
|
||||
self._graph_metadata = graph_metadata
|
||||
self.graph_lock = threading.RLock()
|
||||
|
||||
def finalize_execution(self, timeout: float = 30.0):
|
||||
logger.info(f"Flushing db writes for execution {self._graph_exec_id}")
|
||||
exceptions = []
|
||||
|
||||
# Wait for all pending database operations to complete
|
||||
logger.debug(
|
||||
f"Waiting for {len(self._pending_tasks)} pending database operations"
|
||||
)
|
||||
for future in list(self._pending_tasks):
|
||||
try:
|
||||
future.result(timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"Background database operation failed: {e}")
|
||||
exceptions.append(e)
|
||||
finally:
|
||||
self._pending_tasks.discard(future)
|
||||
|
||||
self._cache.finalize()
|
||||
|
||||
if exceptions:
|
||||
logger.error(f"Background persistence failed with {len(exceptions)} errors")
|
||||
raise RuntimeError(
|
||||
f"Background persistence failed with {len(exceptions)} errors: {exceptions}"
|
||||
)
|
||||
|
||||
@property
|
||||
def db_client_async(self) -> "DatabaseManagerAsyncClient":
|
||||
return get_database_manager_async_client()
|
||||
|
||||
@property
|
||||
def db_client_sync(self) -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
@property
|
||||
def event_bus(self):
|
||||
return get_execution_event_bus()
|
||||
|
||||
async def get_node(self, node_id: str) -> Node:
|
||||
return await self.db_client_async.get_node(node_id)
|
||||
|
||||
def spend_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
cost: int,
|
||||
metadata: UsageTransactionMetadata,
|
||||
) -> int:
|
||||
return self.db_client_sync.spend_credits(
|
||||
user_id=user_id, cost=cost, metadata=metadata
|
||||
)
|
||||
|
||||
def get_graph_execution_meta(
|
||||
self, user_id: str, execution_id: str
|
||||
) -> GraphExecutionMeta | None:
|
||||
return self.db_client_sync.get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=execution_id
|
||||
)
|
||||
|
||||
def get_graph_metadata(
|
||||
self, graph_id: str, graph_version: int | None = None
|
||||
) -> Any:
|
||||
return self.db_client_sync.get_graph_metadata(graph_id, graph_version)
|
||||
|
||||
def get_credits(self, user_id: str) -> int:
|
||||
return self.db_client_sync.get_credits(user_id)
|
||||
|
||||
def get_user_email_by_id(self, user_id: str) -> str | None:
|
||||
return self.db_client_sync.get_user_email_by_id(user_id)
|
||||
|
||||
def get_latest_node_execution(self, node_id: str) -> NodeExecutionResult | None:
|
||||
return self._cache.get_latest_node_execution(node_id)
|
||||
|
||||
def get_node_execution(self, node_exec_id: str) -> NodeExecutionResult | None:
|
||||
return self._cache.get_node_execution(node_exec_id)
|
||||
|
||||
def get_node_executions(
|
||||
self,
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
block_ids: list[str] | None = None,
|
||||
) -> list[NodeExecutionResult]:
|
||||
return self._cache.get_node_executions(
|
||||
statuses=statuses, block_ids=block_ids, node_id=node_id
|
||||
)
|
||||
|
||||
def update_node_status_and_publish(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: dict | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
self._cache.update_node_execution_status(exec_id, status, execution_data, stats)
|
||||
self._persist_node_status_to_db(exec_id, status, execution_data, stats)
|
||||
|
||||
def upsert_execution_input(
|
||||
self, node_id: str, input_name: str, input_data: Any, block_id: str
|
||||
) -> tuple[str, dict]:
|
||||
# Validate input parameters to prevent foreign key constraint errors
|
||||
if not node_id or not isinstance(node_id, str):
|
||||
raise ValueError(f"Invalid node_id: {node_id}")
|
||||
if not self._graph_exec_id or not isinstance(self._graph_exec_id, str):
|
||||
raise ValueError(f"Invalid graph_exec_id: {self._graph_exec_id}")
|
||||
if not block_id or not isinstance(block_id, str):
|
||||
raise ValueError(f"Invalid block_id: {block_id}")
|
||||
|
||||
# UPDATE: Try to find an existing incomplete execution for this node and input
|
||||
if result := self._cache.find_incomplete_execution_for_input(
|
||||
node_id, input_name
|
||||
):
|
||||
exec_id, _ = result
|
||||
updated_input_data = self._cache.update_execution_input(
|
||||
exec_id, input_name, input_data
|
||||
)
|
||||
self._persist_add_input_to_db(exec_id, input_name, input_data)
|
||||
return exec_id, updated_input_data
|
||||
|
||||
# CREATE: No suitable execution found, create new one
|
||||
node_exec_id = str(uuid.uuid4())
|
||||
logger.debug(
|
||||
f"Creating new execution {node_exec_id} for node {node_id} "
|
||||
f"in graph execution {self._graph_exec_id}"
|
||||
)
|
||||
|
||||
new_execution = NodeExecutionResult(
|
||||
user_id=self._graph_metadata.user_id,
|
||||
graph_id=self._graph_metadata.graph_id,
|
||||
graph_version=self._graph_metadata.graph_version,
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
node_exec_id=node_exec_id,
|
||||
node_id=node_id,
|
||||
block_id=block_id,
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
input_data={input_name: input_data},
|
||||
output_data={},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
)
|
||||
self._cache.add_node_execution(node_exec_id, new_execution)
|
||||
self._persist_new_node_execution_to_db(
|
||||
node_exec_id, node_id, input_name, input_data
|
||||
)
|
||||
|
||||
return node_exec_id, {input_name: input_data}
|
||||
|
||||
def upsert_execution_output(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
):
|
||||
self._cache.upsert_execution_output(node_exec_id, output_name, output_data)
|
||||
self._persist_execution_output_to_db(node_exec_id, output_name, output_data)
|
||||
|
||||
def update_graph_stats_and_publish(
|
||||
self,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> None:
|
||||
stats_dict = stats.model_dump() if stats else None
|
||||
self._cache.update_graph_stats(status=status, stats=stats_dict)
|
||||
self._persist_graph_stats_to_db(status=status, stats=stats)
|
||||
|
||||
def update_graph_start_time_and_publish(self) -> None:
|
||||
self._cache.update_graph_start_time()
|
||||
self._persist_graph_start_time_to_db()
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_node_status_to_db(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: dict | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
exec_update = self.db_client_sync.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
self.event_bus.publish(exec_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_add_input_to_db(
|
||||
self, node_exec_id: str, input_name: str, input_data: Any
|
||||
):
|
||||
self.db_client_sync.add_input_to_node_execution(
|
||||
node_exec_id=node_exec_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_execution_output_to_db(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
):
|
||||
self.db_client_sync.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
if exec_update := self.get_node_execution(node_exec_id):
|
||||
self.event_bus.publish(exec_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_graph_stats_to_db(
|
||||
self,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
):
|
||||
graph_update = self.db_client_sync.update_graph_execution_stats(
|
||||
self._graph_exec_id, status, stats
|
||||
)
|
||||
if not graph_update:
|
||||
raise RuntimeError(
|
||||
f"Failed to update graph execution stats for {self._graph_exec_id}"
|
||||
)
|
||||
self.event_bus.publish(graph_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_graph_start_time_to_db(self):
|
||||
graph_update = self.db_client_sync.update_graph_execution_start_time(
|
||||
self._graph_exec_id
|
||||
)
|
||||
if not graph_update:
|
||||
raise RuntimeError(
|
||||
f"Failed to update graph execution start time for {self._graph_exec_id}"
|
||||
)
|
||||
self.event_bus.publish(graph_update)
|
||||
|
||||
async def generate_activity_status(
|
||||
self,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_stats: GraphExecutionStats,
|
||||
user_id: str,
|
||||
execution_status: ExecutionStatus,
|
||||
) -> str | None:
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
|
||||
return await generate_activity_status_for_execution(
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_stats=execution_stats,
|
||||
db_client=self.db_client_async,
|
||||
user_id=user_id,
|
||||
execution_status=execution_status,
|
||||
)
|
||||
|
||||
@non_blocking_persist
|
||||
def _send_execution_update(self, execution: NodeExecutionResult):
|
||||
"""Send execution update to event bus."""
|
||||
try:
|
||||
self.event_bus.publish(execution)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send execution update: {e}")
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_new_node_execution_to_db(
|
||||
self, node_exec_id: str, node_id: str, input_name: str, input_data: Any
|
||||
):
|
||||
try:
|
||||
self.db_client_sync.create_node_execution(
|
||||
node_exec_id=node_exec_id,
|
||||
node_id=node_id,
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create node execution {node_exec_id} for node {node_id} "
|
||||
f"in graph execution {self._graph_exec_id}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def increment_execution_count(self, user_id: str) -> int:
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}"
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
668
autogpt_platform/backend/backend/executor/execution_data_test.py
Normal file
668
autogpt_platform/backend/backend/executor/execution_data_test.py
Normal file
@@ -0,0 +1,668 @@
|
||||
"""Test suite for ExecutionDataClient."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def execution_client(event_loop):
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
|
||||
mock_graph_meta = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_id",
|
||||
graph_id="test_graph_id",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# Mock all database operations to prevent connection attempts
|
||||
async_mock_client = AsyncMock()
|
||||
sync_mock_client = MagicMock()
|
||||
|
||||
# Mock all database methods to return None or empty results
|
||||
sync_mock_client.get_node_executions.return_value = []
|
||||
sync_mock_client.create_node_execution.return_value = None
|
||||
sync_mock_client.add_input_to_node_execution.return_value = None
|
||||
sync_mock_client.update_node_execution_status.return_value = None
|
||||
sync_mock_client.upsert_execution_output.return_value = None
|
||||
sync_mock_client.update_graph_execution_stats.return_value = mock_graph_meta
|
||||
sync_mock_client.update_graph_execution_start_time.return_value = mock_graph_meta
|
||||
|
||||
# Mock event bus to prevent connection attempts
|
||||
mock_event_bus = MagicMock()
|
||||
mock_event_bus.publish.return_value = None
|
||||
|
||||
thread = threading.Thread(target=event_loop.run_forever, daemon=True)
|
||||
thread.start()
|
||||
|
||||
with patch(
|
||||
"backend.executor.execution_data.get_database_manager_async_client",
|
||||
return_value=async_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_database_manager_client",
|
||||
return_value=sync_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_execution_event_bus",
|
||||
return_value=mock_event_bus,
|
||||
), patch(
|
||||
"backend.executor.execution_data.non_blocking_persist", lambda func: func
|
||||
):
|
||||
|
||||
client = ExecutionDataClient(executor, "test_graph_exec_id", mock_graph_meta)
|
||||
yield client
|
||||
|
||||
event_loop.call_soon_threadsafe(event_loop.stop)
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestExecutionDataClient:
|
||||
|
||||
async def test_update_node_status_writes_to_cache_immediately(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that node status updates are immediately visible in cache."""
|
||||
# First create an execution to update
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
status = ExecutionStatus.RUNNING
|
||||
execution_data = {"step": "processing"}
|
||||
stats = {"duration": 5.2}
|
||||
|
||||
# Update status of existing execution
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=status,
|
||||
execution_data=execution_data,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
# Verify immediate visibility in cache
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert cached_exec.status == status
|
||||
# execution_data should be merged with existing input_data, not replace it
|
||||
expected_input_data = {"test_input": "test_value", "step": "processing"}
|
||||
assert cached_exec.input_data == expected_input_data
|
||||
|
||||
def test_update_node_status_execution_not_found_raises_error(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that updating non-existent execution raises error instead of creating it."""
|
||||
non_existent_id = "does-not-exist"
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Execution does-not-exist not found in cache"
|
||||
):
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=non_existent_id, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
async def test_upsert_execution_output_writes_to_cache_immediately(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that output updates are immediately visible in cache."""
|
||||
# First create an execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
output_name = "result"
|
||||
output_data = {"answer": 42, "confidence": 0.95}
|
||||
|
||||
# Update to RUNNING status first
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id, output_name=output_name, output_data=output_data
|
||||
)
|
||||
# Check output through the node execution
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert output_name in cached_exec.output_data
|
||||
assert cached_exec.output_data[output_name] == [output_data]
|
||||
|
||||
async def test_get_node_execution_reads_from_cache(self, execution_client):
|
||||
"""Test that get_node_execution returns cached data immediately."""
|
||||
# First create an execution to work with
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Then update its status
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
execution_data={"result": "success"},
|
||||
)
|
||||
|
||||
result = execution_client.get_node_execution(node_exec_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == ExecutionStatus.COMPLETED
|
||||
# execution_data gets merged with existing input_data
|
||||
expected_input_data = {"test_input": "test_value", "result": "success"}
|
||||
assert result.input_data == expected_input_data
|
||||
|
||||
async def test_get_latest_node_execution_reads_from_cache(self, execution_client):
|
||||
"""Test that get_latest_node_execution returns cached data."""
|
||||
node_id = "node-1"
|
||||
|
||||
# First create an execution for this node
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Then update its status to make it non-INCOMPLETE (so it's returned by get_latest)
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"from": "cache"},
|
||||
)
|
||||
|
||||
result = execution_client.get_latest_node_execution(node_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == ExecutionStatus.RUNNING
|
||||
# execution_data gets merged with existing input_data
|
||||
expected_input_data = {"test_input": "test_value", "from": "cache"}
|
||||
assert result.input_data == expected_input_data
|
||||
|
||||
async def test_get_node_executions_sync_filters_correctly(self, execution_client):
|
||||
# Create executions with different statuses
|
||||
executions = [
|
||||
(ExecutionStatus.RUNNING, "block-a"),
|
||||
(ExecutionStatus.COMPLETED, "block-a"),
|
||||
(ExecutionStatus.FAILED, "block-b"),
|
||||
(ExecutionStatus.RUNNING, "block-b"),
|
||||
]
|
||||
|
||||
exec_ids = []
|
||||
for i, (status, block_id) in enumerate(executions):
|
||||
# First create the execution
|
||||
exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=f"node-{i}",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id=block_id,
|
||||
)
|
||||
exec_ids.append(exec_id)
|
||||
|
||||
# Then update its status and metadata
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id, status=status, execution_data={"block": block_id}
|
||||
)
|
||||
# Update cached execution with graph_exec_id and block_id for filtering
|
||||
# Note: In real implementation, these would be set during creation
|
||||
# For test purposes, we'll skip this manual update since the filtering
|
||||
# logic should work with the data as created
|
||||
|
||||
# Test status filtering
|
||||
running_execs = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING]
|
||||
)
|
||||
assert len(running_execs) == 2
|
||||
assert all(e.status == ExecutionStatus.RUNNING for e in running_execs)
|
||||
|
||||
# Test block_id filtering
|
||||
block_a_execs = execution_client.get_node_executions(block_ids=["block-a"])
|
||||
assert len(block_a_execs) == 2
|
||||
assert all(e.block_id == "block-a" for e in block_a_execs)
|
||||
|
||||
# Test combined filtering
|
||||
running_block_b = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING], block_ids=["block-b"]
|
||||
)
|
||||
assert len(running_block_b) == 1
|
||||
assert running_block_b[0].status == ExecutionStatus.RUNNING
|
||||
assert running_block_b[0].block_id == "block-b"
|
||||
|
||||
async def test_write_then_read_consistency(self, execution_client):
|
||||
"""Test critical race condition scenario: immediate read after write."""
|
||||
# First create an execution to work with
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="consistency-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Write status
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"step": 1},
|
||||
)
|
||||
|
||||
# Write output
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="intermediate",
|
||||
output_data={"progress": 50},
|
||||
)
|
||||
|
||||
# Update status again
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
execution_data={"step": 2},
|
||||
)
|
||||
|
||||
# All changes should be immediately visible
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert cached_exec.status == ExecutionStatus.COMPLETED
|
||||
# execution_data gets merged with existing input_data - step 2 overwrites step 1
|
||||
expected_input_data = {"test_input": "test_value", "step": 2}
|
||||
assert cached_exec.input_data == expected_input_data
|
||||
|
||||
# Output should be visible in execution record
|
||||
assert cached_exec.output_data["intermediate"] == [{"progress": 50}]
|
||||
|
||||
async def test_concurrent_operations_are_thread_safe(self, execution_client):
|
||||
"""Test that concurrent operations don't corrupt cache."""
|
||||
num_threads = 3 # Reduced for simpler test
|
||||
operations_per_thread = 5 # Reduced for simpler test
|
||||
|
||||
# Create all executions upfront
|
||||
created_exec_ids = []
|
||||
for thread_id in range(num_threads):
|
||||
for i in range(operations_per_thread):
|
||||
exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=f"node-{thread_id}-{i}",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id=f"block-{thread_id}-{i}",
|
||||
)
|
||||
created_exec_ids.append((exec_id, thread_id, i))
|
||||
|
||||
def worker(thread_data):
|
||||
"""Perform multiple operations from a thread."""
|
||||
thread_id, ops = thread_data
|
||||
for i, (exec_id, _, _) in enumerate(ops):
|
||||
# Status updates
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"thread": thread_id, "op": i},
|
||||
)
|
||||
|
||||
# Output updates (use just one exec_id per thread for outputs)
|
||||
if i == 0: # Only add outputs to first execution of each thread
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=exec_id,
|
||||
output_name=f"output_{i}",
|
||||
output_data={"thread": thread_id, "value": i},
|
||||
)
|
||||
|
||||
# Organize executions by thread
|
||||
thread_data = []
|
||||
for tid in range(num_threads):
|
||||
thread_ops = [
|
||||
exec_data for exec_data in created_exec_ids if exec_data[1] == tid
|
||||
]
|
||||
thread_data.append((tid, thread_ops))
|
||||
|
||||
# Start multiple threads
|
||||
threads = []
|
||||
for data in thread_data:
|
||||
thread = threading.Thread(target=worker, args=(data,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for completion
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify data integrity
|
||||
expected_executions = num_threads * operations_per_thread
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == expected_executions
|
||||
|
||||
# Verify outputs - only first execution of each thread should have outputs
|
||||
output_count = 0
|
||||
for execution in all_executions:
|
||||
if execution.output_data:
|
||||
output_count += 1
|
||||
assert output_count == num_threads # One output per thread
|
||||
|
||||
async def test_sync_and_async_versions_consistent(self, execution_client):
|
||||
"""Test that sync and async versions of output operations behave the same."""
|
||||
# First create the execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="sync-async-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="sync_result",
|
||||
output_data={"method": "sync"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="async_result",
|
||||
output_data={"method": "async"},
|
||||
)
|
||||
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert "sync_result" in cached_exec.output_data
|
||||
assert "async_result" in cached_exec.output_data
|
||||
assert cached_exec.output_data["sync_result"] == [{"method": "sync"}]
|
||||
assert cached_exec.output_data["async_result"] == [{"method": "async"}]
|
||||
|
||||
async def test_finalize_execution_completes_and_clears_cache(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that finalize_execution waits for background tasks and clears cache."""
|
||||
# First create the execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="pending-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Trigger some background operations
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id, status=ExecutionStatus.RUNNING
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id, output_name="test", output_data={"value": 1}
|
||||
)
|
||||
|
||||
# Wait for background tasks - may fail in test environment due to DB issues
|
||||
try:
|
||||
execution_client.finalize_execution(timeout=5.0)
|
||||
except RuntimeError as e:
|
||||
# In test environment, background DB operations may fail, but cache should still be cleared
|
||||
assert "Background persistence failed" in str(e)
|
||||
|
||||
# Cache should be cleared regardless of background task failures
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == 0 # Cache should be cleared
|
||||
|
||||
async def test_manager_usage_pattern(self, execution_client):
|
||||
# Create executions first
|
||||
node_exec_id_1, _ = execution_client.upsert_execution_input(
|
||||
node_id="node-1",
|
||||
input_name="input1",
|
||||
input_data="data1",
|
||||
block_id="block-1",
|
||||
)
|
||||
|
||||
node_exec_id_2, _ = execution_client.upsert_execution_input(
|
||||
node_id="node-2",
|
||||
input_name="input_from_node1",
|
||||
input_data="value1",
|
||||
block_id="block-2",
|
||||
)
|
||||
|
||||
# Simulate manager.py workflow
|
||||
|
||||
# 1. Start execution
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "data1"},
|
||||
)
|
||||
|
||||
# 2. Node produces output
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id_1,
|
||||
output_name="result",
|
||||
output_data={"output": "value1"},
|
||||
)
|
||||
|
||||
# 3. Complete first node
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_1, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
# 4. Start second node (would read output from first)
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_2,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input_from_node1": "value1"},
|
||||
)
|
||||
|
||||
# 5. Manager queries for executions
|
||||
|
||||
all_executions = execution_client.get_node_executions()
|
||||
running_executions = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING]
|
||||
)
|
||||
completed_executions = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.COMPLETED]
|
||||
)
|
||||
|
||||
# Verify manager can see all data immediately
|
||||
assert len(all_executions) == 2
|
||||
assert len(running_executions) == 1
|
||||
assert len(completed_executions) == 1
|
||||
|
||||
# Verify output is accessible
|
||||
exec_1 = execution_client.get_node_execution(node_exec_id_1)
|
||||
assert exec_1 is not None
|
||||
assert exec_1.output_data["result"] == [{"output": "value1"}]
|
||||
|
||||
def test_stats_handling_in_update_node_status(self, execution_client):
|
||||
"""Test that stats parameter is properly handled in update_node_status_and_publish."""
|
||||
# Create a fake execution directly in cache to avoid database issues
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
|
||||
node_exec_id = "test-stats-exec-id"
|
||||
fake_execution = NodeExecutionResult(
|
||||
user_id="test-user",
|
||||
graph_id="test-graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test-graph-exec",
|
||||
node_exec_id=node_exec_id,
|
||||
node_id="stats-test-node",
|
||||
block_id="test-block",
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
input_data={"test_input": "test_value"},
|
||||
output_data={},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# Add directly to cache
|
||||
execution_client._cache.add_node_execution(node_exec_id, fake_execution)
|
||||
|
||||
stats = {"token_count": 150, "processing_time": 2.5}
|
||||
|
||||
# Update status with stats
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
# Verify execution was updated and stats are stored
|
||||
execution = execution_client.get_node_execution(node_exec_id)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.RUNNING
|
||||
|
||||
# Stats should be stored in proper stats field
|
||||
assert execution.stats is not None
|
||||
stats_dict = execution.stats.model_dump()
|
||||
# Only check the fields we set, ignore defaults
|
||||
assert stats_dict["token_count"] == 150
|
||||
assert stats_dict["processing_time"] == 2.5
|
||||
|
||||
# Update with additional stats
|
||||
additional_stats = {"error_count": 0}
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
stats=additional_stats,
|
||||
)
|
||||
|
||||
# Stats should be merged
|
||||
execution = execution_client.get_node_execution(node_exec_id)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.COMPLETED
|
||||
stats_dict = execution.stats.model_dump()
|
||||
# Check the merged stats
|
||||
assert stats_dict["token_count"] == 150
|
||||
assert stats_dict["processing_time"] == 2.5
|
||||
assert stats_dict["error_count"] == 0
|
||||
|
||||
async def test_upsert_execution_input_scenarios(self, execution_client):
|
||||
"""Test different scenarios of upsert_execution_input - create vs update."""
|
||||
node_id = "test-node"
|
||||
graph_exec_id = (
|
||||
"test_graph_exec_id" # Must match the ExecutionDataClient's scope
|
||||
)
|
||||
|
||||
# Scenario 1: Create new execution when none exists
|
||||
exec_id_1, input_data_1 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="first_input",
|
||||
input_data="value1",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should create new execution
|
||||
execution = execution_client.get_node_execution(exec_id_1)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.INCOMPLETE
|
||||
assert execution.node_id == node_id
|
||||
assert execution.graph_exec_id == graph_exec_id
|
||||
assert input_data_1 == {"first_input": "value1"}
|
||||
|
||||
# Scenario 2: Add input to existing INCOMPLETE execution
|
||||
exec_id_2, input_data_2 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="second_input",
|
||||
input_data="value2",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should use same execution
|
||||
assert exec_id_2 == exec_id_1
|
||||
assert input_data_2 == {"first_input": "value1", "second_input": "value2"}
|
||||
|
||||
# Verify execution has both inputs
|
||||
execution = execution_client.get_node_execution(exec_id_1)
|
||||
assert execution is not None
|
||||
assert execution.input_data == {
|
||||
"first_input": "value1",
|
||||
"second_input": "value2",
|
||||
}
|
||||
|
||||
# Scenario 3: Create new execution when existing is not INCOMPLETE
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id_1, status=ExecutionStatus.RUNNING
|
||||
)
|
||||
|
||||
exec_id_3, input_data_3 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="third_input",
|
||||
input_data="value3",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should create new execution
|
||||
assert exec_id_3 != exec_id_1
|
||||
execution_3 = execution_client.get_node_execution(exec_id_3)
|
||||
assert execution_3 is not None
|
||||
assert input_data_3 == {"third_input": "value3"}
|
||||
|
||||
# Verify we now have 2 executions
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == 2
|
||||
|
||||
def test_graph_stats_operations(self, execution_client):
|
||||
"""Test graph-level stats and start time operations."""
|
||||
|
||||
# Test update_graph_stats_and_publish
|
||||
from backend.data.model import GraphExecutionStats
|
||||
|
||||
stats = GraphExecutionStats(
|
||||
walltime=10.5, cputime=8.2, node_count=5, node_error_count=1
|
||||
)
|
||||
|
||||
execution_client.update_graph_stats_and_publish(
|
||||
status=ExecutionStatus.RUNNING, stats=stats
|
||||
)
|
||||
|
||||
# Verify stats are stored in cache
|
||||
cached_stats = execution_client._cache._graph_stats
|
||||
assert cached_stats.walltime == 10.5
|
||||
|
||||
execution_client.update_graph_start_time_and_publish()
|
||||
cached_stats = execution_client._cache._graph_stats
|
||||
assert cached_stats.walltime == 10.5
|
||||
|
||||
def test_public_methods_accessible(self, execution_client):
|
||||
"""Test that public methods are accessible."""
|
||||
assert hasattr(execution_client._cache, "update_node_execution_status")
|
||||
assert hasattr(execution_client._cache, "upsert_execution_output")
|
||||
assert hasattr(execution_client._cache, "add_node_execution")
|
||||
assert hasattr(execution_client._cache, "find_incomplete_execution_for_input")
|
||||
assert hasattr(execution_client._cache, "update_execution_input")
|
||||
assert hasattr(execution_client, "upsert_execution_input")
|
||||
assert hasattr(execution_client, "update_node_status_and_publish")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -5,14 +5,31 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
from typing import Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.block import (
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
@@ -22,45 +39,14 @@ from backend.data.notifications import (
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import LogMetadata
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
|
||||
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockOutputEntry,
|
||||
BlockSchema,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
@@ -69,21 +55,17 @@ from backend.executor.utils import (
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
from backend.util.decorator import (
|
||||
async_error_logged,
|
||||
async_time_measured,
|
||||
error_logged,
|
||||
time_measured,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
@@ -131,14 +113,13 @@ async def execute_node(
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
persist the execution result, and return the subsequent node to be executed.
|
||||
|
||||
Args:
|
||||
db_client: The client to send execution updates to the server.
|
||||
creds_manager: The manager to acquire and release credentials.
|
||||
data: The execution data for executing the current node.
|
||||
execution_stats: The execution statistics to be updated.
|
||||
@@ -235,21 +216,20 @@ async def execute_node(
|
||||
|
||||
|
||||
async def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
execution_data_client: ExecutionDataClient,
|
||||
node: Node,
|
||||
output: BlockOutputEntry,
|
||||
output: BlockData,
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
user_context: UserContext,
|
||||
) -> list[NodeExecutionEntry]:
|
||||
async def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
) -> NodeExecutionEntry:
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
execution_data_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.QUEUED,
|
||||
execution_data=data,
|
||||
@@ -282,21 +262,22 @@ async def _enqueue_next_nodes(
|
||||
next_data = parse_execution_output(output, next_output_name)
|
||||
if next_data is None and output_name != next_output_name:
|
||||
return enqueued_executions
|
||||
next_node = await db_client.get_node(next_node_id)
|
||||
next_node = await execution_data_client.get_node(next_node_id)
|
||||
|
||||
# Multiple node can register the same next node, we need this to be atomic
|
||||
# To avoid same execution to be enqueued multiple times,
|
||||
# Or the same input to be consumed multiple times.
|
||||
async with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
|
||||
with execution_data_client.graph_lock:
|
||||
# Add output data to the earliest incomplete execution, or create a new one.
|
||||
next_node_exec_id, next_node_input = await db_client.upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
next_node_exec_id, next_node_input = (
|
||||
execution_data_client.upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
block_id=next_node.block_id,
|
||||
)
|
||||
)
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
execution_data_client.update_node_status_and_publish(
|
||||
exec_id=next_node_exec_id,
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
)
|
||||
@@ -308,8 +289,8 @@ async def _enqueue_next_nodes(
|
||||
if link.is_static and link.sink_name not in next_node_input
|
||||
}
|
||||
if static_link_names and (
|
||||
latest_execution := await db_client.get_latest_node_execution(
|
||||
next_node_id, graph_exec_id
|
||||
latest_execution := execution_data_client.get_latest_node_execution(
|
||||
next_node_id
|
||||
)
|
||||
):
|
||||
for name in static_link_names:
|
||||
@@ -348,9 +329,8 @@ async def _enqueue_next_nodes(
|
||||
|
||||
# If link is static, there could be some incomplete executions waiting for it.
|
||||
# Load and complete the input missing input data, and try to re-enqueue them.
|
||||
for iexec in await db_client.get_node_executions(
|
||||
for iexec in execution_data_client.get_node_executions(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.INCOMPLETE],
|
||||
):
|
||||
idata = iexec.input_data
|
||||
@@ -414,12 +394,15 @@ class ExecutionProcessor:
|
||||
9. Node executor enqueues the next executed nodes to the node execution queue.
|
||||
"""
|
||||
|
||||
# Current execution data client (scoped to current graph execution)
|
||||
execution_data: ExecutionDataClient
|
||||
|
||||
@async_error_logged(swallow=True)
|
||||
async def on_node_execution(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
@@ -431,8 +414,7 @@ class ExecutionProcessor:
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
)
|
||||
db_client = get_db_async_client()
|
||||
node = await db_client.get_node(node_exec.node_id)
|
||||
node = await self.execution_data.get_node(node_exec.node_id)
|
||||
execution_stats = NodeExecutionStats()
|
||||
|
||||
timing_info, status = await self._on_node_execution(
|
||||
@@ -440,7 +422,6 @@ class ExecutionProcessor:
|
||||
node_exec=node_exec,
|
||||
node_exec_progress=node_exec_progress,
|
||||
stats=execution_stats,
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
@@ -464,15 +445,12 @@ class ExecutionProcessor:
|
||||
if node_error and not isinstance(node_error, str):
|
||||
node_stats["error"] = str(node_error) or node_stats.__class__.__name__
|
||||
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec.node_exec_id,
|
||||
status=status,
|
||||
stats=node_stats,
|
||||
)
|
||||
await async_update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
stats=graph_stats,
|
||||
)
|
||||
|
||||
@@ -485,22 +463,17 @@ class ExecutionProcessor:
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
stats: NodeExecutionStats,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
async def persist_output(output_name: str, output_data: Any) -> None:
|
||||
await db_client.upsert_execution_output(
|
||||
self.execution_data.upsert_execution_output(
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
if exec_update := await db_client.get_node_execution(
|
||||
node_exec.node_exec_id
|
||||
):
|
||||
await send_async_execution_update(exec_update)
|
||||
|
||||
node_exec_progress.add_output(
|
||||
ExecutionOutputEntry(
|
||||
@@ -512,8 +485,7 @@ class ExecutionProcessor:
|
||||
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec.node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
)
|
||||
@@ -574,6 +546,8 @@ class ExecutionProcessor:
|
||||
self.node_evaluation_thread = threading.Thread(
|
||||
target=self.node_evaluation_loop.run_forever, daemon=True
|
||||
)
|
||||
# single thread executor
|
||||
self.execution_data_executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.node_execution_thread.start()
|
||||
self.node_evaluation_thread.start()
|
||||
logger.info(f"[GraphExecutor] {self.tid} started")
|
||||
@@ -593,9 +567,13 @@ class ExecutionProcessor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
db_client = get_db_client()
|
||||
|
||||
exec_meta = db_client.get_graph_execution_meta(
|
||||
# Get graph execution metadata first via sync client
|
||||
from backend.util.clients import get_database_manager_client
|
||||
|
||||
db_client_sync = get_database_manager_client()
|
||||
|
||||
exec_meta = db_client_sync.get_graph_execution_meta(
|
||||
user_id=graph_exec.user_id,
|
||||
execution_id=graph_exec.graph_exec_id,
|
||||
)
|
||||
@@ -605,12 +583,15 @@ class ExecutionProcessor:
|
||||
)
|
||||
return
|
||||
|
||||
if exec_meta.status in [ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE]:
|
||||
# Create scoped ExecutionDataClient for this graph execution with metadata
|
||||
self.execution_data = ExecutionDataClient(
|
||||
self.execution_data_executor, graph_exec.graph_exec_id, exec_meta
|
||||
)
|
||||
|
||||
if exec_meta.status == ExecutionStatus.QUEUED:
|
||||
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
|
||||
exec_meta.status = ExecutionStatus.RUNNING
|
||||
send_execution_update(
|
||||
db_client.update_graph_execution_start_time(graph_exec.graph_exec_id)
|
||||
)
|
||||
self.execution_data.update_graph_start_time_and_publish()
|
||||
elif exec_meta.status == ExecutionStatus.RUNNING:
|
||||
log_metadata.info(
|
||||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
|
||||
@@ -620,9 +601,7 @@ class ExecutionProcessor:
|
||||
log_metadata.info(
|
||||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} was disturbed, continuing where it left off."
|
||||
)
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
status=ExecutionStatus.RUNNING,
|
||||
)
|
||||
else:
|
||||
@@ -653,12 +632,10 @@ class ExecutionProcessor:
|
||||
|
||||
# Activity status handling
|
||||
activity_status = asyncio.run_coroutine_threadsafe(
|
||||
generate_activity_status_for_execution(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.generate_activity_status(
|
||||
graph_id=graph_exec.graph_id,
|
||||
graph_version=graph_exec.graph_version,
|
||||
execution_stats=exec_stats,
|
||||
db_client=get_db_async_client(),
|
||||
user_id=graph_exec.user_id,
|
||||
execution_status=status,
|
||||
),
|
||||
@@ -673,15 +650,14 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
# Communication handling
|
||||
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
self._handle_agent_run_notif(graph_exec, exec_stats)
|
||||
|
||||
finally:
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
status=exec_meta.status,
|
||||
stats=exec_stats,
|
||||
)
|
||||
self.execution_data.finalize_execution()
|
||||
|
||||
def _charge_usage(
|
||||
self,
|
||||
@@ -690,7 +666,6 @@ class ExecutionProcessor:
|
||||
) -> tuple[int, int]:
|
||||
total_cost = 0
|
||||
remaining_balance = 0
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
@@ -700,7 +675,7 @@ class ExecutionProcessor:
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
remaining_balance = self.execution_data.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -718,7 +693,7 @@ class ExecutionProcessor:
|
||||
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
remaining_balance = self.execution_data.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -751,7 +726,6 @@ class ExecutionProcessor:
|
||||
"""
|
||||
execution_status: ExecutionStatus = ExecutionStatus.RUNNING
|
||||
error: Exception | None = None
|
||||
db_client = get_db_client()
|
||||
execution_stats_lock = threading.Lock()
|
||||
|
||||
# State holders ----------------------------------------------------
|
||||
@@ -762,7 +736,7 @@ class ExecutionProcessor:
|
||||
execution_queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
|
||||
try:
|
||||
if db_client.get_credits(graph_exec.user_id) <= 0:
|
||||
if self.execution_data.get_credits(graph_exec.user_id) <= 0:
|
||||
raise InsufficientBalanceError(
|
||||
user_id=graph_exec.user_id,
|
||||
message="You have no credits left to run an agent.",
|
||||
@@ -774,7 +748,7 @@ class ExecutionProcessor:
|
||||
try:
|
||||
if moderation_error := asyncio.run_coroutine_threadsafe(
|
||||
automod_manager.moderate_graph_execution_inputs(
|
||||
db_client=get_db_async_client(),
|
||||
db_client=self.execution_data.db_client_async,
|
||||
graph_exec=graph_exec,
|
||||
),
|
||||
self.node_evaluation_loop,
|
||||
@@ -789,16 +763,34 @@ class ExecutionProcessor:
|
||||
# ------------------------------------------------------------
|
||||
# Pre‑populate queue ---------------------------------------
|
||||
# ------------------------------------------------------------
|
||||
for node_exec in db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
|
||||
queued_executions = self.execution_data.get_node_executions(
|
||||
statuses=[
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
],
|
||||
):
|
||||
node_entry = node_exec.to_node_execution_entry(graph_exec.user_context)
|
||||
execution_queue.add(node_entry)
|
||||
)
|
||||
log_metadata.info(
|
||||
f"Pre-populating queue with {len(queued_executions)} executions from cache"
|
||||
)
|
||||
|
||||
for i, node_exec in enumerate(queued_executions):
|
||||
log_metadata.info(
|
||||
f" [{i}] {node_exec.node_exec_id}: status={node_exec.status}, node={node_exec.node_id}"
|
||||
)
|
||||
try:
|
||||
node_entry = node_exec.to_node_execution_entry(
|
||||
graph_exec.user_context
|
||||
)
|
||||
execution_queue.add(node_entry)
|
||||
log_metadata.info(" Added to execution queue successfully")
|
||||
except Exception as e:
|
||||
log_metadata.error(f" Failed to add to execution queue: {e}")
|
||||
|
||||
log_metadata.info(
|
||||
f"Execution queue populated with {len(queued_executions)} executions"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Main dispatch / polling loop -----------------------------
|
||||
@@ -818,13 +810,14 @@ class ExecutionProcessor:
|
||||
try:
|
||||
cost, remaining_balance = self._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
execution_count=self.execution_data.increment_execution_count(
|
||||
graph_exec.user_id
|
||||
),
|
||||
)
|
||||
with execution_stats_lock:
|
||||
execution_stats.cost += cost
|
||||
# Check if we crossed the low balance threshold
|
||||
self._handle_low_balance(
|
||||
db_client=db_client,
|
||||
user_id=graph_exec.user_id,
|
||||
current_balance=remaining_balance,
|
||||
transaction_cost=cost,
|
||||
@@ -832,19 +825,17 @@ class ExecutionProcessor:
|
||||
except InsufficientBalanceError as balance_error:
|
||||
error = balance_error # Set error to trigger FAILED status
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
db_client.upsert_execution_output(
|
||||
self.execution_data.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="error",
|
||||
output_data=str(error),
|
||||
)
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
self._handle_insufficient_funds_notif(
|
||||
db_client,
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
error,
|
||||
@@ -931,12 +922,13 @@ class ExecutionProcessor:
|
||||
time.sleep(0.1)
|
||||
|
||||
# loop done --------------------------------------------------
|
||||
# Background task finalization moved to finally block
|
||||
|
||||
# Output moderation
|
||||
try:
|
||||
if moderation_error := asyncio.run_coroutine_threadsafe(
|
||||
automod_manager.moderate_graph_execution_outputs(
|
||||
db_client=get_db_async_client(),
|
||||
db_client=self.execution_data.db_client_async,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
user_id=graph_exec.user_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
@@ -990,7 +982,6 @@ class ExecutionProcessor:
|
||||
error=error,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
log_metadata=log_metadata,
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
@error_logged(swallow=True)
|
||||
@@ -1003,7 +994,6 @@ class ExecutionProcessor:
|
||||
error: Exception | None,
|
||||
graph_exec_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
db_client: "DatabaseManagerClient",
|
||||
) -> None:
|
||||
"""
|
||||
Clean up running node executions and evaluations when graph execution ends.
|
||||
@@ -1037,8 +1027,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
while queued_execution := execution_queue.get_or_none():
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=queued_execution.node_exec_id,
|
||||
status=execution_status,
|
||||
stats={"error": str(error)} if error else None,
|
||||
@@ -1053,7 +1042,7 @@ class ExecutionProcessor:
|
||||
node_id: str,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
execution_queue: ExecutionQueue[NodeExecutionEntry],
|
||||
) -> None:
|
||||
"""Process a node's output, update its status, and enqueue next nodes.
|
||||
@@ -1066,12 +1055,10 @@ class ExecutionProcessor:
|
||||
nodes_input_masks: Optional map of node input overrides
|
||||
execution_queue: Queue to add next executions to
|
||||
"""
|
||||
db_client = get_db_async_client()
|
||||
|
||||
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
|
||||
|
||||
for next_execution in await _enqueue_next_nodes(
|
||||
db_client=db_client,
|
||||
execution_data_client=self.execution_data,
|
||||
node=output.node,
|
||||
output=output.data,
|
||||
user_id=graph_exec.user_id,
|
||||
@@ -1085,15 +1072,13 @@ class ExecutionProcessor:
|
||||
|
||||
def _handle_agent_run_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec: GraphExecutionEntry,
|
||||
exec_stats: GraphExecutionStats,
|
||||
):
|
||||
metadata = db_client.get_graph_metadata(
|
||||
metadata = self.execution_data.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
outputs = self.execution_data.get_node_executions(
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
|
||||
@@ -1122,13 +1107,12 @@ class ExecutionProcessor:
|
||||
|
||||
def _handle_insufficient_funds_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
metadata = self.execution_data.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
@@ -1147,7 +1131,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
user_email = self.execution_data.get_user_email_by_id(user_id)
|
||||
|
||||
alert_message = (
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
@@ -1169,7 +1153,6 @@ class ExecutionProcessor:
|
||||
|
||||
def _handle_low_balance(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
current_balance: int,
|
||||
transaction_cost: int,
|
||||
@@ -1198,7 +1181,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
user_email = self.execution_data.get_user_email_by_id(user_id)
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
@@ -1576,117 +1559,3 @@ class ExecutionManager(AppProcess):
|
||||
)
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
|
||||
def get_db_async_client() -> "DatabaseManagerAsyncClient":
|
||||
return get_database_manager_async_client()
|
||||
|
||||
|
||||
@func_retry
|
||||
async def send_async_execution_update(
|
||||
entry: GraphExecution | NodeExecutionResult | None,
|
||||
) -> None:
|
||||
if entry is None:
|
||||
return
|
||||
await get_async_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
@func_retry
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
|
||||
if entry is None:
|
||||
return
|
||||
return get_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
async def async_update_node_execution_status(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = await db_client.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
await send_async_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
|
||||
def update_node_execution_status(
|
||||
db_client: "DatabaseManagerClient",
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = db_client.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
send_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
|
||||
async def async_update_graph_execution_state(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||||
graph_update = await db_client.update_graph_execution_stats(
|
||||
graph_exec_id, status, stats
|
||||
)
|
||||
if graph_update:
|
||||
await send_async_execution_update(graph_update)
|
||||
else:
|
||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||
return graph_update
|
||||
|
||||
|
||||
def update_graph_execution_state(
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||||
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
|
||||
if graph_update:
|
||||
send_execution_update(graph_update)
|
||||
else:
|
||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||
return graph_update
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def synchronized(key: str, timeout: int = 60):
|
||||
r = await redis.get_redis_async()
|
||||
lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
await lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if await lock.locked() and await lock.owned():
|
||||
await lock.release()
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
"""
|
||||
Increment the execution count for a given user,
|
||||
this will be used to charge the user for the execution cost.
|
||||
"""
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}" # User Execution Count global key
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
|
||||
@@ -32,13 +32,17 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
@@ -62,6 +66,19 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
assert "$4.00" in discord_message
|
||||
assert "$6.00" in discord_message
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
@@ -90,12 +107,17 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
@@ -105,6 +127,19 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
@@ -133,12 +168,17 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
@@ -147,3 +187,16 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
# Verify no notification was sent (user was already below threshold)
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
@@ -35,20 +35,21 @@ async def execute_graph(
|
||||
logger.info(f"Input data: {input_data}")
|
||||
|
||||
# --- Test adding new executions --- #
|
||||
graph_exec = await agent_server.test_execute_graph(
|
||||
response = await agent_server.test_execute_graph(
|
||||
user_id=test_user.id,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
node_input=input_data,
|
||||
)
|
||||
logger.info(f"Created execution with ID: {graph_exec.id}")
|
||||
graph_exec_id = response.graph_exec_id
|
||||
logger.info(f"Created execution with ID: {graph_exec_id}")
|
||||
|
||||
# Execution queue should be empty
|
||||
logger.info("Waiting for execution to complete...")
|
||||
result = await wait_execution(test_user.id, graph_exec.id, 30)
|
||||
result = await wait_execution(test_user.id, graph_exec_id, 30)
|
||||
logger.info(f"Execution completed with {len(result)} results")
|
||||
assert len(result) == num_execs
|
||||
return graph_exec.id
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
async def assert_sample_graph_executions(
|
||||
@@ -378,7 +379,7 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
|
||||
# Verify execution
|
||||
assert result is not None
|
||||
graph_exec_id = result.id
|
||||
graph_exec_id = result["id"]
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, graph_exec_id)
|
||||
@@ -467,7 +468,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
|
||||
# Verify execution
|
||||
assert result is not None, "Result must not be None"
|
||||
graph_exec_id = result.id
|
||||
graph_exec_id = result["id"]
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, graph_exec_id)
|
||||
|
||||
@@ -191,22 +191,15 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
id: str
|
||||
name: str
|
||||
next_run_time: str
|
||||
timezone: str = Field(default="UTC", description="Timezone used for scheduling")
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
job_args: GraphExecutionJobArgs, job_obj: JobObj
|
||||
) -> "GraphExecutionJobInfo":
|
||||
# Extract timezone from the trigger if it's a CronTrigger
|
||||
timezone_str = "UTC"
|
||||
if hasattr(job_obj.trigger, "timezone"):
|
||||
timezone_str = str(job_obj.trigger.timezone)
|
||||
|
||||
return GraphExecutionJobInfo(
|
||||
id=job_obj.id,
|
||||
name=job_obj.name,
|
||||
next_run_time=job_obj.next_run_time.isoformat(),
|
||||
timezone=timezone_str,
|
||||
**job_args.model_dump(),
|
||||
)
|
||||
|
||||
@@ -402,7 +395,6 @@ class Scheduler(AppService):
|
||||
input_data: BlockInput,
|
||||
input_credentials: dict[str, CredentialsMetaInput],
|
||||
name: Optional[str] = None,
|
||||
user_timezone: str | None = None,
|
||||
) -> GraphExecutionJobInfo:
|
||||
# Validate the graph before scheduling to prevent runtime failures
|
||||
# We don't need the return value, just want the validation to run
|
||||
@@ -416,18 +408,7 @@ class Scheduler(AppService):
|
||||
)
|
||||
)
|
||||
|
||||
# Use provided timezone or default to UTC
|
||||
# Note: Timezone should be passed from the client to avoid database lookups
|
||||
if not user_timezone:
|
||||
user_timezone = "UTC"
|
||||
logger.warning(
|
||||
f"No timezone provided for user {user_id}, using UTC for scheduling. "
|
||||
f"Client should pass user's timezone for correct scheduling."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduling job for user {user_id} with timezone {user_timezone} (cron: {cron})"
|
||||
)
|
||||
logger.info(f"Scheduling job for user {user_id} in UTC (cron: {cron})")
|
||||
|
||||
job_args = GraphExecutionJobArgs(
|
||||
user_id=user_id,
|
||||
@@ -441,12 +422,12 @@ class Scheduler(AppService):
|
||||
execute_graph,
|
||||
kwargs=job_args.model_dump(),
|
||||
name=name,
|
||||
trigger=CronTrigger.from_crontab(cron, timezone=user_timezone),
|
||||
trigger=CronTrigger.from_crontab(cron, timezone="UTC"),
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Added job {job.id} with cron schedule '{cron}' in timezone {user_timezone}, input data: {input_data}"
|
||||
f"Added job {job.id} with cron schedule '{cron}' in UTC, input data: {input_data}"
|
||||
)
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from backend.executor.utils import merge_execution_input, parse_execution_output
|
||||
from backend.util.mock import MockObject
|
||||
@@ -277,147 +276,3 @@ def test_merge_execution_input():
|
||||
result = merge_execution_input(data)
|
||||
assert "mixed" in result
|
||||
assert result["mixed"].attr[0]["key"] == "value3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
"""
|
||||
Verify that calling the function with its own output creates the same execution again.
|
||||
"""
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Mock data
|
||||
graph_id = "test-graph-id"
|
||||
user_id = "test-user-id"
|
||||
inputs = {"test_input": "test_value"}
|
||||
preset_id = "test-preset-id"
|
||||
graph_version = 1
|
||||
graph_credentials_inputs = {
|
||||
"cred_key": CredentialsMetaInput(
|
||||
id="cred-id", provider=ProviderName("test_provider"), type="oauth2"
|
||||
)
|
||||
}
|
||||
nodes_input_masks = {"node1": {"input1": "masked_value"}}
|
||||
|
||||
# Mock the graph object returned by validate_and_construct_node_execution_input
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.version = graph_version
|
||||
|
||||
# Mock the starting nodes input and compiled nodes input masks
|
||||
starting_nodes_input = [
|
||||
("node1", {"input1": "value1"}),
|
||||
("node2", {"input1": "value2"}),
|
||||
]
|
||||
compiled_nodes_input_masks = {"node1": {"input1": "compiled_mask"}}
|
||||
|
||||
# Mock the graph execution object
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Mock user context
|
||||
mock_user_context = {"user_id": user_id, "context": "test_context"}
|
||||
|
||||
# Mock the queue and event bus
|
||||
mock_queue = mocker.AsyncMock()
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils.validate_and_construct_node_execution_input"
|
||||
)
|
||||
mock_edb = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_get_user_context = mocker.patch("backend.executor.utils.get_user_context")
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
|
||||
# Setup mock returns
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
|
||||
mock_get_user_context.return_value = mock_user_context
|
||||
mock_get_queue.return_value = mock_queue
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
# Call the function - first execution
|
||||
result1 = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
# Store the parameters used in the first call to create_graph_execution
|
||||
first_call_kwargs = mock_edb.create_graph_execution.call_args[1]
|
||||
|
||||
# Verify the create_graph_execution was called with correct parameters
|
||||
mock_edb.create_graph_execution.assert_called_once_with(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=mock_graph.version,
|
||||
inputs=inputs,
|
||||
credential_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
# Set up the graph execution mock to have properties we can extract
|
||||
mock_graph_exec.graph_id = graph_id
|
||||
mock_graph_exec.user_id = user_id
|
||||
mock_graph_exec.graph_version = graph_version
|
||||
mock_graph_exec.inputs = inputs
|
||||
mock_graph_exec.credential_inputs = graph_credentials_inputs
|
||||
mock_graph_exec.nodes_input_masks = nodes_input_masks
|
||||
mock_graph_exec.preset_id = preset_id
|
||||
|
||||
# Create a second mock execution for the sanity check
|
||||
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec_2.id = "execution-id-456"
|
||||
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Reset mocks and set up for second call
|
||||
mock_edb.create_graph_execution.reset_mock()
|
||||
mock_edb.create_graph_execution.return_value = mock_graph_exec_2
|
||||
mock_validate.reset_mock()
|
||||
|
||||
# Sanity check: call add_graph_execution with properties from first result
|
||||
# This should create the same execution parameters
|
||||
result2 = await add_graph_execution(
|
||||
graph_id=mock_graph_exec.graph_id,
|
||||
user_id=mock_graph_exec.user_id,
|
||||
inputs=mock_graph_exec.inputs,
|
||||
preset_id=mock_graph_exec.preset_id,
|
||||
graph_version=mock_graph_exec.graph_version,
|
||||
graph_credentials_inputs=mock_graph_exec.credential_inputs,
|
||||
nodes_input_masks=mock_graph_exec.nodes_input_masks,
|
||||
)
|
||||
|
||||
# Verify that create_graph_execution was called with identical parameters
|
||||
second_call_kwargs = mock_edb.create_graph_execution.call_args[1]
|
||||
|
||||
# The sanity check: both calls should use identical parameters
|
||||
assert first_call_kwargs == second_call_kwargs
|
||||
|
||||
# Both executions should succeed (though they create different objects)
|
||||
assert result1 == mock_graph_exec
|
||||
assert result2 == mock_graph_exec_2
|
||||
@@ -4,27 +4,20 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, Mapping, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCostType,
|
||||
BlockInput,
|
||||
BlockOutputEntry,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.db import prisma
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node
|
||||
@@ -246,7 +239,7 @@ def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockOutputEntry, name: str) -> JsonValue | None:
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
|
||||
@@ -270,7 +263,7 @@ def parse_execution_output(output: BlockOutputEntry, name: str) -> JsonValue | N
|
||||
if tokens is None:
|
||||
return None
|
||||
|
||||
cur: JsonValue = data
|
||||
cur: Any = data
|
||||
for delim, ident in tokens:
|
||||
if delim == LIST_SPLIT:
|
||||
# list[index]
|
||||
@@ -435,7 +428,7 @@ def validate_exec(
|
||||
async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Checks all credentials for all nodes of the graph and returns structured errors.
|
||||
@@ -515,8 +508,8 @@ async def _validate_node_input_credentials(
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
graph: GraphModel,
|
||||
graph_credentials_input: Mapping[str, CredentialsMetaInput],
|
||||
) -> NodesInputMasks:
|
||||
graph_credentials_input: dict[str, CredentialsMetaInput],
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
"""
|
||||
Maps credentials for an execution to the correct nodes.
|
||||
|
||||
@@ -551,8 +544,8 @@ def make_node_credentials_input_map(
|
||||
async def validate_graph_with_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> Mapping[str, Mapping[str, str]]:
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph including credentials and return structured errors per node.
|
||||
|
||||
@@ -582,7 +575,7 @@ async def _construct_starting_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
@@ -623,7 +616,7 @@ async def _construct_starting_node_execution_input(
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = cast(str | None, node.input_default.get("name"))
|
||||
input_name = node.input_default.get("name")
|
||||
if input_name and input_name in graph_inputs:
|
||||
input_data = {"value": graph_inputs[input_name]}
|
||||
|
||||
@@ -650,9 +643,9 @@ async def validate_and_construct_node_execution_input(
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], dict[str, dict[str, JsonValue]]]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
This centralizes the logic used by both scheduler validation and actual execution.
|
||||
@@ -666,9 +659,7 @@ async def validate_and_construct_node_execution_input(
|
||||
nodes_input_masks: Node inputs to use.
|
||||
|
||||
Returns:
|
||||
GraphModel: Full graph object for the given `graph_id`.
|
||||
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
||||
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
||||
tuple[GraphModel, list[tuple[str, BlockInput]]]: Graph model and list of tuples for node execution input.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the graph is not found.
|
||||
@@ -709,11 +700,11 @@ async def validate_and_construct_node_execution_input(
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
overrides_map_1: NodesInputMasks,
|
||||
overrides_map_2: NodesInputMasks,
|
||||
) -> NodesInputMasks:
|
||||
overrides_map_1: dict[str, dict[str, JsonValue]],
|
||||
overrides_map_2: dict[str, dict[str, JsonValue]],
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
"""Perform a per-node merge of input overrides"""
|
||||
result = dict(overrides_map_1).copy()
|
||||
result = overrides_map_1.copy()
|
||||
for node_id, overrides2 in overrides_map_2.items():
|
||||
if node_id in result:
|
||||
result[node_id] = {**result[node_id], **overrides2}
|
||||
@@ -863,8 +854,8 @@ async def add_graph_execution(
|
||||
inputs: Optional[BlockInput] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
@@ -888,7 +879,7 @@ async def add_graph_execution(
|
||||
else:
|
||||
edb = get_database_manager_async_client()
|
||||
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks = (
|
||||
graph, starting_nodes_input, nodes_input_masks = (
|
||||
await validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
@@ -901,43 +892,37 @@ async def add_graph_execution(
|
||||
graph_exec = None
|
||||
|
||||
try:
|
||||
# Sanity check: running add_graph_execution with the properties of
|
||||
# the graph_exec created here should create the same execution again.
|
||||
graph_exec = await edb.create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
inputs=inputs or {},
|
||||
credential_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
||||
user_context=await get_user_context(user_id),
|
||||
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
||||
)
|
||||
# Fetch user context for the graph execution
|
||||
user_context = await get_user_context(user_id)
|
||||
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(user_context)
|
||||
if nodes_input_masks:
|
||||
graph_exec_entry.nodes_input_masks = nodes_input_masks
|
||||
|
||||
logger.info(
|
||||
f"Created graph execution #{graph_exec.id} for graph "
|
||||
f"#{graph_id} with {len(starting_nodes_input)} starting nodes. "
|
||||
f"Now publishing to execution queue."
|
||||
)
|
||||
|
||||
exec_queue = await get_async_execution_queue()
|
||||
await exec_queue.publish_message(
|
||||
await queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||
|
||||
graph_exec.status = ExecutionStatus.QUEUED
|
||||
await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=graph_exec.status,
|
||||
)
|
||||
await get_async_execution_event_bus().publish(graph_exec)
|
||||
bus = get_async_execution_event_bus()
|
||||
await bus.publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except BaseException as e:
|
||||
@@ -967,7 +952,7 @@ async def add_graph_execution(
|
||||
class ExecutionOutputEntry(BaseModel):
|
||||
node: Node
|
||||
node_exec_id: str
|
||||
data: BlockOutputEntry
|
||||
data: BlockData
|
||||
|
||||
|
||||
class NodeExecutionProgress:
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import functools
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from ._base import BaseWebhooksManager
|
||||
|
||||
|
||||
# --8<-- [start:load_webhook_managers]
|
||||
@cached()
|
||||
@functools.cache
|
||||
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
|
||||
webhook_managers = {}
|
||||
|
||||
|
||||
@@ -7,9 +7,10 @@ from backend.data.graph import set_node_webhook
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
from .utils import setup_webhook_for_block
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import BaseGraph, GraphModel, NodeModel
|
||||
from backend.data.graph import BaseGraph, GraphModel, Node, NodeModel
|
||||
from backend.data.model import Credentials
|
||||
|
||||
from ._base import BaseWebhooksManager
|
||||
@@ -42,19 +43,32 @@ async def _on_graph_activate(graph: "BaseGraph", user_id: str) -> "BaseGraph": .
|
||||
|
||||
async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
|
||||
get_credentials = credentials_manager.cached_getter(user_id)
|
||||
updated_nodes = []
|
||||
for new_node in graph.nodes:
|
||||
block_input_schema = cast(BlockSchema, new_node.block.input_schema)
|
||||
|
||||
for creds_field_name in block_input_schema.get_credentials_fields().keys():
|
||||
# Prevent saving graph with non-existent credentials
|
||||
if (
|
||||
creds_meta := new_node.input_default.get(creds_field_name)
|
||||
) and not await get_credentials(creds_meta["id"]):
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
node_credentials = None
|
||||
if (
|
||||
# Webhook-triggered blocks are only allowed to have 1 credentials input
|
||||
(
|
||||
creds_field_name := next(
|
||||
iter(block_input_schema.get_credentials_fields()), None
|
||||
)
|
||||
)
|
||||
and (creds_meta := new_node.input_default.get(creds_field_name))
|
||||
and not (node_credentials := await get_credentials(creds_meta["id"]))
|
||||
):
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_activate(
|
||||
user_id, graph.id, new_node, credentials=node_credentials
|
||||
)
|
||||
updated_nodes.append(updated_node)
|
||||
|
||||
graph.nodes = updated_nodes
|
||||
return graph
|
||||
|
||||
|
||||
@@ -71,14 +85,20 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
block_input_schema = cast(BlockSchema, node.block.input_schema)
|
||||
|
||||
node_credentials = None
|
||||
for creds_field_name in block_input_schema.get_credentials_fields().keys():
|
||||
if (creds_meta := node.input_default.get(creds_field_name)) and not (
|
||||
node_credentials := await get_credentials(creds_meta["id"])
|
||||
):
|
||||
logger.warning(
|
||||
f"Node #{node.id} input '{creds_field_name}' referenced "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
if (
|
||||
# Webhook-triggered blocks are only allowed to have 1 credentials input
|
||||
(
|
||||
creds_field_name := next(
|
||||
iter(block_input_schema.get_credentials_fields()), None
|
||||
)
|
||||
)
|
||||
and (creds_meta := node.input_default.get(creds_field_name))
|
||||
and not (node_credentials := await get_credentials(creds_meta["id"]))
|
||||
):
|
||||
logger.error(
|
||||
f"Node #{node.id} input '{creds_field_name}' referenced non-existent "
|
||||
f"credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_deactivate(
|
||||
user_id, node, credentials=node_credentials
|
||||
@@ -89,6 +109,32 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
return graph
|
||||
|
||||
|
||||
async def on_node_activate(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
node: "Node",
|
||||
*,
|
||||
credentials: Optional["Credentials"] = None,
|
||||
) -> "Node":
|
||||
"""Hook to be called when the node is activated/created"""
|
||||
|
||||
if node.block.webhook_config:
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=node.block,
|
||||
trigger_config=node.input_default,
|
||||
for_graph_id=graph_id,
|
||||
)
|
||||
if new_webhook:
|
||||
node = await set_node_webhook(node.id, new_webhook.id)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Node #{node.id} does not have everything for a webhook: {feedback}"
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
async def on_node_deactivate(
|
||||
user_id: str,
|
||||
node: "NodeModel",
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional, cast
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Config
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
@@ -12,7 +13,6 @@ if TYPE_CHECKING:
|
||||
from backend.data.block import Block, BlockSchema
|
||||
from backend.data.integrations import Webhook
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app_config = Config()
|
||||
@@ -20,7 +20,7 @@ credentials_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
# TODO: add test to assert this matches the actual API route
|
||||
def webhook_ingress_url(provider_name: "ProviderName", webhook_id: str) -> str:
|
||||
def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
|
||||
return (
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
|
||||
f"/webhooks/{webhook_id}/ingress"
|
||||
@@ -144,69 +144,3 @@ async def setup_webhook_for_block(
|
||||
)
|
||||
logger.debug(f"Acquired webhook: {webhook}")
|
||||
return webhook, None
|
||||
|
||||
|
||||
async def migrate_legacy_triggered_graphs():
|
||||
from prisma.models import AgentGraph
|
||||
|
||||
from backend.data.graph import AGENT_GRAPH_INCLUDE, GraphModel, set_node_webhook
|
||||
from backend.data.model import is_credentials_field_name
|
||||
from backend.server.v2.library.db import create_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPresetCreatable
|
||||
|
||||
triggered_graphs = [
|
||||
GraphModel.from_db(_graph)
|
||||
for _graph in await AgentGraph.prisma().find_many(
|
||||
where={
|
||||
"isActive": True,
|
||||
"Nodes": {"some": {"NOT": [{"webhookId": None}]}},
|
||||
},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
]
|
||||
|
||||
n_migrated_webhooks = 0
|
||||
|
||||
for graph in triggered_graphs:
|
||||
try:
|
||||
if not (
|
||||
(trigger_node := graph.webhook_input_node) and trigger_node.webhook_id
|
||||
):
|
||||
continue
|
||||
|
||||
# Use trigger node's inputs for the preset
|
||||
preset_credentials = {
|
||||
field_name: creds_meta
|
||||
for field_name, creds_meta in trigger_node.input_default.items()
|
||||
if is_credentials_field_name(field_name)
|
||||
}
|
||||
preset_inputs = {
|
||||
field_name: value
|
||||
for field_name, value in trigger_node.input_default.items()
|
||||
if not is_credentials_field_name(field_name)
|
||||
}
|
||||
|
||||
# Create a triggered preset for the graph
|
||||
await create_preset(
|
||||
graph.user_id,
|
||||
LibraryAgentPresetCreatable(
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
inputs=preset_inputs,
|
||||
credentials=preset_credentials,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
webhook_id=trigger_node.webhook_id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Detach webhook from the graph node
|
||||
await set_node_webhook(trigger_node.id, None)
|
||||
|
||||
n_migrated_webhooks += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate graph #{graph.id} trigger to preset: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Migrated {n_migrated_webhooks} node triggers to triggered presets")
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
"""
|
||||
Prometheus instrumentation for FastAPI services.
|
||||
|
||||
This module provides centralized metrics collection and instrumentation
|
||||
for all FastAPI services in the AutoGPT platform.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from prometheus_client import Counter, Gauge, Histogram, Info
|
||||
from prometheus_fastapi_instrumentator import Instrumentator, metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Custom business metrics with controlled cardinality
|
||||
GRAPH_EXECUTIONS = Counter(
|
||||
"autogpt_graph_executions_total",
|
||||
"Total number of graph executions",
|
||||
labelnames=[
|
||||
"status"
|
||||
], # Removed graph_id and user_id to prevent cardinality explosion
|
||||
)
|
||||
|
||||
GRAPH_EXECUTIONS_BY_USER = Counter(
|
||||
"autogpt_graph_executions_by_user_total",
|
||||
"Total number of graph executions by user (sampled)",
|
||||
labelnames=["status"], # Only status, user_id tracked separately when needed
|
||||
)
|
||||
|
||||
BLOCK_EXECUTIONS = Counter(
|
||||
"autogpt_block_executions_total",
|
||||
"Total number of block executions",
|
||||
labelnames=["block_type", "status"], # block_type is bounded
|
||||
)
|
||||
|
||||
BLOCK_DURATION = Histogram(
|
||||
"autogpt_block_duration_seconds",
|
||||
"Duration of block executions in seconds",
|
||||
labelnames=["block_type"],
|
||||
buckets=[0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60],
|
||||
)
|
||||
|
||||
WEBSOCKET_CONNECTIONS = Gauge(
|
||||
"autogpt_websocket_connections_total",
|
||||
"Total number of active WebSocket connections",
|
||||
# Removed user_id label - track total only to prevent cardinality explosion
|
||||
)
|
||||
|
||||
SCHEDULER_JOBS = Gauge(
|
||||
"autogpt_scheduler_jobs",
|
||||
"Current number of scheduled jobs",
|
||||
labelnames=["job_type", "status"],
|
||||
)
|
||||
|
||||
DATABASE_QUERIES = Histogram(
|
||||
"autogpt_database_query_duration_seconds",
|
||||
"Duration of database queries in seconds",
|
||||
labelnames=["operation", "table"],
|
||||
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5],
|
||||
)
|
||||
|
||||
RABBITMQ_MESSAGES = Counter(
|
||||
"autogpt_rabbitmq_messages_total",
|
||||
"Total number of RabbitMQ messages",
|
||||
labelnames=["queue", "status"],
|
||||
)
|
||||
|
||||
AUTHENTICATION_ATTEMPTS = Counter(
|
||||
"autogpt_auth_attempts_total",
|
||||
"Total number of authentication attempts",
|
||||
labelnames=["method", "status"],
|
||||
)
|
||||
|
||||
API_KEY_USAGE = Counter(
|
||||
"autogpt_api_key_usage_total",
|
||||
"API key usage by provider",
|
||||
labelnames=["provider", "block_type", "status"],
|
||||
)
|
||||
|
||||
# Function/operation level metrics with controlled cardinality
|
||||
GRAPH_OPERATIONS = Counter(
|
||||
"autogpt_graph_operations_total",
|
||||
"Graph operations by type",
|
||||
labelnames=["operation", "status"], # create, update, delete, execute, etc.
|
||||
)
|
||||
|
||||
USER_OPERATIONS = Counter(
|
||||
"autogpt_user_operations_total",
|
||||
"User operations by type",
|
||||
labelnames=["operation", "status"], # login, register, update_profile, etc.
|
||||
)
|
||||
|
||||
RATE_LIMIT_HITS = Counter(
|
||||
"autogpt_rate_limit_hits_total",
|
||||
"Number of rate limit hits",
|
||||
labelnames=["endpoint"], # Removed user_id to prevent cardinality explosion
|
||||
)
|
||||
|
||||
SERVICE_INFO = Info(
|
||||
"autogpt_service",
|
||||
"Service information",
|
||||
)
|
||||
|
||||
|
||||
def instrument_fastapi(
|
||||
app: FastAPI,
|
||||
service_name: str,
|
||||
expose_endpoint: bool = True,
|
||||
endpoint: str = "/metrics",
|
||||
include_in_schema: bool = False,
|
||||
excluded_handlers: Optional[list] = None,
|
||||
) -> Instrumentator:
|
||||
"""
|
||||
Instrument a FastAPI application with Prometheus metrics.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
service_name: Name of the service for metrics labeling
|
||||
expose_endpoint: Whether to expose /metrics endpoint
|
||||
endpoint: Path for metrics endpoint
|
||||
include_in_schema: Whether to include metrics endpoint in OpenAPI schema
|
||||
excluded_handlers: List of paths to exclude from metrics
|
||||
|
||||
Returns:
|
||||
Configured Instrumentator instance
|
||||
"""
|
||||
|
||||
# Set service info
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
service_version = version("autogpt-platform-backend")
|
||||
except Exception:
|
||||
service_version = "unknown"
|
||||
|
||||
SERVICE_INFO.info(
|
||||
{
|
||||
"service": service_name,
|
||||
"version": service_version,
|
||||
}
|
||||
)
|
||||
|
||||
# Create instrumentator with default metrics
|
||||
instrumentator = Instrumentator(
|
||||
should_group_status_codes=True,
|
||||
should_ignore_untemplated=True,
|
||||
should_respect_env_var=True,
|
||||
should_instrument_requests_inprogress=True,
|
||||
excluded_handlers=excluded_handlers or ["/health", "/readiness"],
|
||||
env_var_name="ENABLE_METRICS",
|
||||
inprogress_name="autogpt_http_requests_inprogress",
|
||||
inprogress_labels=True,
|
||||
)
|
||||
|
||||
# Add default HTTP metrics
|
||||
instrumentator.add(
|
||||
metrics.default(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Add request size metrics
|
||||
instrumentator.add(
|
||||
metrics.request_size(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Add response size metrics
|
||||
instrumentator.add(
|
||||
metrics.response_size(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Add latency metrics with custom buckets for better granularity
|
||||
instrumentator.add(
|
||||
metrics.latency(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60],
|
||||
)
|
||||
)
|
||||
|
||||
# Add combined metrics (requests by method and status)
|
||||
instrumentator.add(
|
||||
metrics.combined_size(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Instrument the app
|
||||
instrumentator.instrument(app)
|
||||
|
||||
# Expose metrics endpoint if requested
|
||||
if expose_endpoint:
|
||||
instrumentator.expose(
|
||||
app,
|
||||
endpoint=endpoint,
|
||||
include_in_schema=include_in_schema,
|
||||
tags=["monitoring"] if include_in_schema else None,
|
||||
)
|
||||
logger.info(f"Metrics endpoint exposed at {endpoint} for {service_name}")
|
||||
|
||||
return instrumentator
|
||||
|
||||
|
||||
def record_graph_execution(graph_id: str, status: str, user_id: str):
|
||||
"""Record a graph execution event.
|
||||
|
||||
Args:
|
||||
graph_id: Graph identifier (kept for future sampling/debugging)
|
||||
status: Execution status (success/error/validation_error)
|
||||
user_id: User identifier (kept for future sampling/debugging)
|
||||
"""
|
||||
# Track overall executions without high-cardinality labels
|
||||
GRAPH_EXECUTIONS.labels(status=status).inc()
|
||||
|
||||
# Optionally track per-user executions (implement sampling if needed)
|
||||
# For now, just track status to avoid cardinality explosion
|
||||
GRAPH_EXECUTIONS_BY_USER.labels(status=status).inc()
|
||||
|
||||
|
||||
def record_block_execution(block_type: str, status: str, duration: float):
|
||||
"""Record a block execution event with duration."""
|
||||
BLOCK_EXECUTIONS.labels(block_type=block_type, status=status).inc()
|
||||
BLOCK_DURATION.labels(block_type=block_type).observe(duration)
|
||||
|
||||
|
||||
def update_websocket_connections(user_id: str, delta: int):
|
||||
"""Update the number of active WebSocket connections.
|
||||
|
||||
Args:
|
||||
user_id: User identifier (kept for future sampling/debugging)
|
||||
delta: Change in connection count (+1 for connect, -1 for disconnect)
|
||||
"""
|
||||
# Track total connections without user_id to prevent cardinality explosion
|
||||
if delta > 0:
|
||||
WEBSOCKET_CONNECTIONS.inc(delta)
|
||||
else:
|
||||
WEBSOCKET_CONNECTIONS.dec(abs(delta))
|
||||
|
||||
|
||||
def record_database_query(operation: str, table: str, duration: float):
|
||||
"""Record a database query with duration."""
|
||||
DATABASE_QUERIES.labels(operation=operation, table=table).observe(duration)
|
||||
|
||||
|
||||
def record_rabbitmq_message(queue: str, status: str):
|
||||
"""Record a RabbitMQ message event."""
|
||||
RABBITMQ_MESSAGES.labels(queue=queue, status=status).inc()
|
||||
|
||||
|
||||
def record_authentication_attempt(method: str, status: str):
|
||||
"""Record an authentication attempt."""
|
||||
AUTHENTICATION_ATTEMPTS.labels(method=method, status=status).inc()
|
||||
|
||||
|
||||
def record_api_key_usage(provider: str, block_type: str, status: str):
|
||||
"""Record API key usage by provider and block."""
|
||||
API_KEY_USAGE.labels(provider=provider, block_type=block_type, status=status).inc()
|
||||
|
||||
|
||||
def record_rate_limit_hit(endpoint: str, user_id: str):
|
||||
"""Record a rate limit hit.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint that was rate limited
|
||||
user_id: User identifier (kept for future sampling/debugging)
|
||||
"""
|
||||
RATE_LIMIT_HITS.labels(endpoint=endpoint).inc()
|
||||
|
||||
|
||||
def record_graph_operation(operation: str, status: str):
|
||||
"""Record a graph operation (create, update, delete, execute, etc.)."""
|
||||
GRAPH_OPERATIONS.labels(operation=operation, status=status).inc()
|
||||
|
||||
|
||||
def record_user_operation(operation: str, status: str):
|
||||
"""Record a user operation (login, register, etc.)."""
|
||||
USER_OPERATIONS.labels(operation=operation, status=status).inc()
|
||||
@@ -63,7 +63,7 @@ except ImportError:
|
||||
|
||||
# Cost System
|
||||
try:
|
||||
from backend.data.block import BlockCost, BlockCostType
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
except ImportError:
|
||||
from backend.data.block_cost_config import BlockCost, BlockCostType
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Callable, List, Optional, Type
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import BlockCost, BlockCostType
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
|
||||
@@ -8,8 +8,9 @@ BLOCK_COSTS configuration used by the execution system.
|
||||
import logging
|
||||
from typing import List, Type
|
||||
|
||||
from backend.data.block import Block, BlockCost
|
||||
from backend.data.block import Block
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Callable, List, Optional, Set, Type
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.block import BlockCost
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
|
||||
@@ -6,10 +6,10 @@ import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.blocks.basic import Block
|
||||
from backend.data.model import Credentials
|
||||
from backend.data.model import APIKeyCredentials, Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
@@ -17,8 +17,6 @@ from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
if TYPE_CHECKING:
|
||||
from backend.sdk.provider import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SDKOAuthCredentials(BaseModel):
|
||||
"""OAuth credentials configuration for SDK providers."""
|
||||
@@ -104,8 +102,21 @@ class AutoRegistry:
|
||||
"""Register an environment variable as an API key for a provider."""
|
||||
with cls._lock:
|
||||
cls._api_key_mappings[provider] = env_var_name
|
||||
# Note: The credential itself is created by ProviderBuilder.with_api_key()
|
||||
# We only store the mapping here to avoid duplication
|
||||
|
||||
# Dynamically check if the env var exists and create credential
|
||||
import os
|
||||
|
||||
api_key = os.getenv(env_var_name)
|
||||
if api_key:
|
||||
credential = APIKeyCredentials(
|
||||
id=f"{provider}-default",
|
||||
provider=provider,
|
||||
api_key=SecretStr(api_key),
|
||||
title=f"Default {provider} credentials",
|
||||
)
|
||||
# Check if credential already exists to avoid duplicates
|
||||
if not any(c.id == credential.id for c in cls._default_credentials):
|
||||
cls._default_credentials.append(credential)
|
||||
|
||||
@classmethod
|
||||
def get_all_credentials(cls) -> List[Credentials]:
|
||||
@@ -199,43 +210,3 @@ class AutoRegistry:
|
||||
webhooks.load_webhook_managers = patched_load
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch webhook managers: {e}")
|
||||
|
||||
# Patch credentials store to include SDK-registered credentials
|
||||
try:
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
# Get the module from sys.modules to respect mocking
|
||||
if "backend.integrations.credentials_store" in sys.modules:
|
||||
creds_store: Any = sys.modules["backend.integrations.credentials_store"]
|
||||
else:
|
||||
import backend.integrations.credentials_store
|
||||
|
||||
creds_store: Any = backend.integrations.credentials_store
|
||||
|
||||
if hasattr(creds_store, "IntegrationCredentialsStore"):
|
||||
store_class = creds_store.IntegrationCredentialsStore
|
||||
if hasattr(store_class, "get_all_creds"):
|
||||
original_get_all_creds = store_class.get_all_creds
|
||||
|
||||
async def patched_get_all_creds(self, user_id: str):
|
||||
# Get original credentials
|
||||
original_creds = await original_get_all_creds(self, user_id)
|
||||
|
||||
# Add SDK-registered credentials
|
||||
sdk_creds = cls.get_all_credentials()
|
||||
|
||||
# Combine credentials, avoiding duplicates by ID
|
||||
existing_ids = {c.id for c in original_creds}
|
||||
for cred in sdk_creds:
|
||||
if cred.id not in existing_ids:
|
||||
original_creds.append(cred)
|
||||
|
||||
return original_creds
|
||||
|
||||
store_class.get_all_creds = patched_get_all_creds
|
||||
logger.info(
|
||||
"Successfully patched IntegrationCredentialsStore.get_all_creds"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch credentials store: {e}")
|
||||
|
||||
@@ -88,7 +88,6 @@ async def test_send_graph_execution_result(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
preset_id=None,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
ended_at=datetime.now(tz=timezone.utc),
|
||||
@@ -102,8 +101,6 @@ async def test_send_graph_execution_result(
|
||||
"input_1": "some input value :)",
|
||||
"input_2": "some *other* input value",
|
||||
},
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
outputs={
|
||||
"the_output": ["some output value"],
|
||||
"other_output": ["sike there was another output"],
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
from .routes.v1 import v1_router
|
||||
@@ -14,12 +13,3 @@ external_app = FastAPI(
|
||||
|
||||
external_app.add_middleware(SecurityHeadersMiddleware)
|
||||
external_app.include_router(v1_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_app,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from fastapi.security import APIKeyHeader
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
|
||||
from backend.data.api_key import has_permission, validate_api_key
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
api_key_header = APIKeyHeader(name="X-API-Key")
|
||||
|
||||
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
async def require_api_key(request: Request):
|
||||
"""Base middleware for API key authentication"""
|
||||
api_key = await api_key_header(request)
|
||||
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
@@ -17,19 +19,18 @@ async def require_api_key(api_key: str | None = Security(api_key_header)) -> API
|
||||
if not api_key_obj:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
request.state.api_key = api_key_obj
|
||||
return api_key_obj
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""Dependency function for checking specific permissions"""
|
||||
|
||||
async def check_permission(
|
||||
api_key: APIKeyInfo = Security(require_api_key),
|
||||
) -> APIKeyInfo:
|
||||
async def check_permission(api_key=Depends(require_api_key)):
|
||||
if not has_permission(api_key, permission):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"API key lacks the required permission '{permission}'",
|
||||
detail=f"API key missing required permission: {permission}",
|
||||
)
|
||||
return api_key
|
||||
|
||||
|
||||
@@ -2,14 +2,14 @@ import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.data.block
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.external.middleware import require_permission
|
||||
@@ -47,9 +47,9 @@ class GraphExecutionResult(TypedDict):
|
||||
@v1_router.get(
|
||||
path="/blocks",
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||
dependencies=[Depends(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||
)
|
||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
return [b.to_dict() for b in blocks if not b.disabled]
|
||||
|
||||
@@ -57,12 +57,12 @@ async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
@v1_router.post(
|
||||
path="/blocks/{block_id}/execute",
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.EXECUTE_BLOCK))],
|
||||
dependencies=[Depends(require_permission(APIKeyPermission.EXECUTE_BLOCK))],
|
||||
)
|
||||
async def execute_graph_block(
|
||||
block_id: str,
|
||||
data: BlockInput,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||
) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
if not obj:
|
||||
@@ -82,7 +82,7 @@ async def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = await add_graph_execution(
|
||||
@@ -104,7 +104,7 @@ async def execute_graph(
|
||||
async def get_graph_execution_results(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||
api_key: APIKey = Depends(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||
) -> GraphExecutionResult:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=api_key.user_id)
|
||||
if not graph:
|
||||
|
||||
@@ -81,10 +81,6 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
# Add noindex header for shared execution pages
|
||||
if "/public/shared" in request.url.path:
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow"
|
||||
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(request.url.path):
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Optional
|
||||
|
||||
import pydantic
|
||||
|
||||
from backend.data.api_key import APIKeyInfo, APIKeyPermission
|
||||
from backend.data.api_key import APIKeyPermission, APIKeyWithoutHash
|
||||
from backend.data.graph import Graph
|
||||
from backend.util.timezone_name import TimeZoneName
|
||||
|
||||
@@ -34,6 +34,10 @@ class WSSubscribeGraphExecutionsRequest(pydantic.BaseModel):
|
||||
graph_id: str
|
||||
|
||||
|
||||
class ExecuteGraphResponse(pydantic.BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
class CreateGraph(pydantic.BaseModel):
|
||||
graph: Graph
|
||||
|
||||
@@ -45,7 +49,7 @@ class CreateAPIKeyRequest(pydantic.BaseModel):
|
||||
|
||||
|
||||
class CreateAPIKeyResponse(pydantic.BaseModel):
|
||||
api_key: APIKeyInfo
|
||||
api_key: APIKeyWithoutHash
|
||||
plain_text_key: str
|
||||
|
||||
|
||||
|
||||
@@ -12,13 +12,11 @@ from autogpt_libs.auth import add_auth_responses_to_openapi
|
||||
from autogpt_libs.auth import verify_settings as verify_auth_settings
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
@@ -37,12 +35,10 @@ import backend.util.settings
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.util import json
|
||||
from backend.util.cloud_storage import shutdown_cloud_storage_handler
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
|
||||
@@ -80,8 +76,6 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
@@ -143,16 +137,6 @@ app.add_middleware(SecurityHeadersMiddleware)
|
||||
# Add 401 responses to authenticated endpoints in OpenAPI spec
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
app,
|
||||
service_name="rest-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=settings.config.app_env
|
||||
== backend.util.settings.AppEnvironment.LOCAL,
|
||||
)
|
||||
|
||||
|
||||
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: fastapi.Request, exc: Exception):
|
||||
@@ -211,14 +195,10 @@ async def validation_error_handler(
|
||||
)
|
||||
|
||||
|
||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||
|
||||
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
app.include_router(
|
||||
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
|
||||
@@ -383,7 +363,6 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
preset_id=preset_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs or {},
|
||||
credential_inputs={},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Sequence
|
||||
|
||||
import pydantic
|
||||
import stripe
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
@@ -30,19 +27,30 @@ from typing_extensions import Optional, TypedDict
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.v2.library.db as library_db
|
||||
from backend.data import api_key as api_key_db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import (
|
||||
APIKeyError,
|
||||
APIKeyNotFoundError,
|
||||
APIKeyPermissionError,
|
||||
APIKeyWithoutHash,
|
||||
generate_api_key,
|
||||
get_api_key_by_id,
|
||||
list_user_api_keys,
|
||||
revoke_api_key,
|
||||
suspend_api_key,
|
||||
update_api_key_permissions,
|
||||
)
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
TransactionHistory,
|
||||
get_auto_top_up,
|
||||
get_block_costs,
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.execution import UserContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -66,15 +74,11 @@ from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
on_graph_activate,
|
||||
on_graph_deactivate,
|
||||
)
|
||||
from backend.monitoring.instrumentation import (
|
||||
record_block_execution,
|
||||
record_graph_execution,
|
||||
record_graph_operation,
|
||||
)
|
||||
from backend.server.model import (
|
||||
CreateAPIKeyRequest,
|
||||
CreateAPIKeyResponse,
|
||||
CreateGraph,
|
||||
ExecuteGraphResponse,
|
||||
RequestTopUp,
|
||||
SetGraphActiveVersion,
|
||||
TimezoneResponse,
|
||||
@@ -87,6 +91,7 @@ from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.timezone_utils import (
|
||||
convert_cron_to_utc,
|
||||
convert_utc_time_to_user_timezone,
|
||||
get_user_timezone_or_utc,
|
||||
)
|
||||
@@ -104,7 +109,6 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
# Define the API routes
|
||||
@@ -173,6 +177,7 @@ async def get_user_timezone_route(
|
||||
summary="Update user timezone",
|
||||
tags=["auth"],
|
||||
dependencies=[Security(requires_user)],
|
||||
response_model=TimezoneResponse,
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
@@ -263,37 +268,18 @@ async def is_onboarding_enabled():
|
||||
########################################################
|
||||
|
||||
|
||||
@cached()
|
||||
def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
|
||||
"""
|
||||
Get cached blocks with thundering herd protection.
|
||||
|
||||
Uses sync_cache decorator to prevent multiple concurrent requests
|
||||
from all executing the expensive block loading operation.
|
||||
"""
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
block_classes = get_blocks()
|
||||
result = []
|
||||
|
||||
for block_class in block_classes.values():
|
||||
block_instance = block_class()
|
||||
if not block_instance.disabled:
|
||||
# Get costs for this specific block class without creating another instance
|
||||
costs = get_block_cost(block_instance)
|
||||
result.append({**block_instance.to_dict(), "costs": costs})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/blocks",
|
||||
summary="List available blocks",
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
return _get_cached_blocks()
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
return [
|
||||
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
|
||||
]
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -302,45 +288,15 @@ async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def execute_graph_block(
|
||||
block_id: str, data: BlockInput, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> CompletedBlockOutput:
|
||||
async def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
|
||||
obj = get_block(block_id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
|
||||
# Get user context for block execution
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found.")
|
||||
|
||||
user_context = UserContext(timezone=user.timezone)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
output = defaultdict(list)
|
||||
async for name, data in obj.execute(
|
||||
data,
|
||||
user_context=user_context,
|
||||
user_id=user_id,
|
||||
# Note: graph_exec_id and graph_id are not available for direct block execution
|
||||
):
|
||||
output[name].append(data)
|
||||
|
||||
# Record successful block execution with duration
|
||||
duration = time.time() - start_time
|
||||
block_type = obj.__class__.__name__
|
||||
record_block_execution(
|
||||
block_type=block_type, status="success", duration=duration
|
||||
)
|
||||
|
||||
return output
|
||||
except Exception:
|
||||
# Record failed block execution
|
||||
duration = time.time() - start_time
|
||||
block_type = obj.__class__.__name__
|
||||
record_block_execution(block_type=block_type, status="error", duration=duration)
|
||||
raise
|
||||
output = defaultdict(list)
|
||||
async for name, data in obj.execute(data):
|
||||
output[name].append(data)
|
||||
return output
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -633,13 +589,7 @@ class DeleteGraphResponse(TypedDict):
|
||||
async def list_graphs(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphMeta]:
|
||||
paginated_result = await graph_db.list_graphs_paginated(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=250,
|
||||
filter_by="active",
|
||||
)
|
||||
return paginated_result.graphs
|
||||
return await graph_db.list_graphs(filter_by="active", user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -833,7 +783,7 @@ async def execute_graph(
|
||||
],
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> execution_db.GraphExecutionMeta:
|
||||
) -> ExecuteGraphResponse:
|
||||
current_balance = await _user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
@@ -842,7 +792,7 @@ async def execute_graph(
|
||||
)
|
||||
|
||||
try:
|
||||
result = await execution_utils.add_graph_execution(
|
||||
graph_exec = await execution_utils.add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
@@ -850,16 +800,8 @@ async def execute_graph(
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
)
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
record_graph_operation(operation="execute", status="success")
|
||||
return result
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
|
||||
except GraphValidationError as e:
|
||||
# Record failed graph execution
|
||||
record_graph_execution(
|
||||
graph_id=graph_id, status="validation_error", user_id=user_id
|
||||
)
|
||||
record_graph_operation(operation="execute", status="validation_error")
|
||||
# Return structured validation errors that the frontend can parse
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -870,11 +812,6 @@ async def execute_graph(
|
||||
"node_errors": e.node_errors,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
# Record any other failures
|
||||
record_graph_execution(graph_id=graph_id, status="error", user_id=user_id)
|
||||
record_graph_operation(operation="execute", status="error")
|
||||
raise
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -928,12 +865,7 @@ async def _stop_graph_run(
|
||||
async def list_graphs_executions(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
paginated_result = await execution_db.get_graph_executions_paginated(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=250,
|
||||
)
|
||||
return paginated_result.executions
|
||||
return await execution_db.get_graph_executions(user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -1004,99 +936,6 @@ async def delete_graph_execution(
|
||||
)
|
||||
|
||||
|
||||
class ShareRequest(pydantic.BaseModel):
|
||||
"""Optional request body for share endpoint."""
|
||||
|
||||
pass # Empty body is fine
|
||||
|
||||
|
||||
class ShareResponse(pydantic.BaseModel):
|
||||
"""Response from share endpoints."""
|
||||
|
||||
share_url: str
|
||||
share_token: str
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/graphs/{graph_id}/executions/{graph_exec_id}/share",
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def enable_execution_sharing(
|
||||
graph_id: Annotated[str, Path],
|
||||
graph_exec_id: Annotated[str, Path],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
_body: ShareRequest = Body(default=ShareRequest()),
|
||||
) -> ShareResponse:
|
||||
"""Enable sharing for a graph execution."""
|
||||
# Verify the execution belongs to the user
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
# Generate a unique share token
|
||||
share_token = str(uuid.uuid4())
|
||||
|
||||
# Update the execution with share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
is_shared=True,
|
||||
share_token=share_token,
|
||||
shared_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Return the share URL
|
||||
frontend_url = Settings().config.frontend_base_url or "http://localhost:3000"
|
||||
share_url = f"{frontend_url}/share/{share_token}"
|
||||
|
||||
return ShareResponse(share_url=share_url, share_token=share_token)
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
"/graphs/{graph_id}/executions/{graph_exec_id}/share",
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def disable_execution_sharing(
|
||||
graph_id: Annotated[str, Path],
|
||||
graph_exec_id: Annotated[str, Path],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> None:
|
||||
"""Disable sharing for a graph execution."""
|
||||
# Verify the execution belongs to the user
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
# Remove share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
is_shared=False,
|
||||
share_token=None,
|
||||
shared_at=None,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get("/public/shared/{share_token}")
|
||||
async def get_shared_execution(
|
||||
share_token: Annotated[
|
||||
str,
|
||||
Path(regex=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||
],
|
||||
) -> execution_db.SharedExecutionResponse:
|
||||
"""Get a shared graph execution by share token (no auth required)."""
|
||||
execution = await execution_db.get_graph_execution_by_share_token(share_token)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Shared execution not found")
|
||||
|
||||
return execution
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Schedules ########################
|
||||
########################################################
|
||||
@@ -1108,10 +947,6 @@ class ScheduleCreationRequest(pydantic.BaseModel):
|
||||
cron: str
|
||||
inputs: dict[str, Any]
|
||||
credentials: dict[str, CredentialsMetaInput] = pydantic.Field(default_factory=dict)
|
||||
timezone: Optional[str] = pydantic.Field(
|
||||
default=None,
|
||||
description="User's timezone for scheduling (e.g., 'America/New_York'). If not provided, will use user's saved timezone or UTC.",
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -1136,22 +971,26 @@ async def create_graph_execution_schedule(
|
||||
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
|
||||
)
|
||||
|
||||
# Use timezone from request if provided, otherwise fetch from user profile
|
||||
if schedule_params.timezone:
|
||||
user_timezone = schedule_params.timezone
|
||||
else:
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
# Convert cron expression from user timezone to UTC
|
||||
try:
|
||||
utc_cron = convert_cron_to_utc(schedule_params.cron, user_timezone)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid cron expression for timezone {user_timezone}: {e}",
|
||||
)
|
||||
|
||||
result = await get_scheduler_client().add_execution_schedule(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
name=schedule_params.name,
|
||||
cron=schedule_params.cron,
|
||||
cron=utc_cron, # Send UTC cron to scheduler
|
||||
input_data=schedule_params.inputs,
|
||||
input_credentials=schedule_params.credentials,
|
||||
user_timezone=user_timezone,
|
||||
)
|
||||
|
||||
# Convert the next_run_time back to user timezone for display
|
||||
@@ -1173,11 +1012,24 @@ async def list_graph_execution_schedules(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
graph_id: str = Path(),
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await get_scheduler_client().get_execution_schedules(
|
||||
schedules = await get_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
|
||||
# Get user timezone for conversion
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
# Convert next_run_time to user timezone for display
|
||||
for schedule in schedules:
|
||||
if schedule.next_run_time:
|
||||
schedule.next_run_time = convert_utc_time_to_user_timezone(
|
||||
schedule.next_run_time, user_timezone
|
||||
)
|
||||
|
||||
return schedules
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/schedules",
|
||||
@@ -1188,7 +1040,20 @@ async def list_graph_execution_schedules(
|
||||
async def list_all_graphs_execution_schedules(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await get_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
schedules = await get_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
|
||||
# Get user timezone for conversion
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
# Convert UTC next_run_time to user timezone for display
|
||||
for schedule in schedules:
|
||||
if schedule.next_run_time:
|
||||
schedule.next_run_time = convert_utc_time_to_user_timezone(
|
||||
schedule.next_run_time, user_timezone
|
||||
)
|
||||
|
||||
return schedules
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
@@ -1219,6 +1084,7 @@ async def delete_graph_execution_schedule(
|
||||
@v1_router.post(
|
||||
"/api-keys",
|
||||
summary="Create new API key",
|
||||
response_model=CreateAPIKeyResponse,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
@@ -1226,73 +1092,128 @@ async def create_api_key(
|
||||
request: CreateAPIKeyRequest, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> CreateAPIKeyResponse:
|
||||
"""Create a new API key"""
|
||||
api_key_info, plain_text_key = await api_key_db.create_api_key(
|
||||
name=request.name,
|
||||
user_id=user_id,
|
||||
permissions=request.permissions,
|
||||
description=request.description,
|
||||
)
|
||||
return CreateAPIKeyResponse(api_key=api_key_info, plain_text_key=plain_text_key)
|
||||
try:
|
||||
api_key, plain_text = await generate_api_key(
|
||||
name=request.name,
|
||||
user_id=user_id,
|
||||
permissions=request.permissions,
|
||||
description=request.description,
|
||||
)
|
||||
return CreateAPIKeyResponse(api_key=api_key, plain_text_key=plain_text)
|
||||
except APIKeyError as e:
|
||||
logger.error(
|
||||
"Could not create API key for user %s: %s. Review input and permissions.",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Verify request payload and try again."},
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/api-keys",
|
||||
summary="List user API keys",
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_api_keys(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[api_key_db.APIKeyInfo]:
|
||||
) -> list[APIKeyWithoutHash]:
|
||||
"""List all API keys for the user"""
|
||||
return await api_key_db.list_user_api_keys(user_id)
|
||||
try:
|
||||
return await list_user_api_keys(user_id)
|
||||
except APIKeyError as e:
|
||||
logger.error("Failed to list API keys for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Check API key service availability."},
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/api-keys/{key_id}",
|
||||
summary="Get specific API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_api_key(
|
||||
key_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> api_key_db.APIKeyInfo:
|
||||
) -> APIKeyWithoutHash:
|
||||
"""Get a specific API key"""
|
||||
api_key = await api_key_db.get_api_key_by_id(key_id, user_id)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
return api_key
|
||||
try:
|
||||
api_key = await get_api_key_by_id(key_id, user_id)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
return api_key
|
||||
except APIKeyError as e:
|
||||
logger.error("Error retrieving API key %s for user %s: %s", key_id, user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Ensure the key ID is correct."},
|
||||
)
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
"/api-keys/{key_id}",
|
||||
summary="Revoke API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def delete_api_key(
|
||||
key_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> api_key_db.APIKeyInfo:
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
"""Revoke an API key"""
|
||||
return await api_key_db.revoke_api_key(key_id, user_id)
|
||||
try:
|
||||
return await revoke_api_key(key_id, user_id)
|
||||
except APIKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
except APIKeyPermissionError:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
except APIKeyError as e:
|
||||
logger.error("Failed to revoke API key %s for user %s: %s", key_id, user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Verify permissions or try again later.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/api-keys/{key_id}/suspend",
|
||||
summary="Suspend API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def suspend_key(
|
||||
key_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> api_key_db.APIKeyInfo:
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
"""Suspend an API key"""
|
||||
return await api_key_db.suspend_api_key(key_id, user_id)
|
||||
try:
|
||||
return await suspend_api_key(key_id, user_id)
|
||||
except APIKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
except APIKeyPermissionError:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
except APIKeyError as e:
|
||||
logger.error("Failed to suspend API key %s for user %s: %s", key_id, user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Check user permissions and retry."},
|
||||
)
|
||||
|
||||
|
||||
@v1_router.put(
|
||||
"/api-keys/{key_id}/permissions",
|
||||
summary="Update key permissions",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
@@ -1300,8 +1221,22 @@ async def update_permissions(
|
||||
key_id: str,
|
||||
request: UpdatePermissionsRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> api_key_db.APIKeyInfo:
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
"""Update API key permissions"""
|
||||
return await api_key_db.update_api_key_permissions(
|
||||
key_id, user_id, request.permissions
|
||||
)
|
||||
try:
|
||||
return await update_api_key_permissions(key_id, user_id, request.permissions)
|
||||
except APIKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
except APIKeyPermissionError:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
except APIKeyError as e:
|
||||
logger.error(
|
||||
"Failed to update permissions for API key %s of user %s: %s",
|
||||
key_id,
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Ensure permissions list is valid."},
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
@@ -110,8 +109,8 @@ def test_get_graph_blocks(
|
||||
|
||||
# Mock block costs
|
||||
mocker.patch(
|
||||
"backend.data.credit.get_block_cost",
|
||||
return_value=[{"cost": 10, "type": "credit"}],
|
||||
"backend.server.routers.v1.get_block_costs",
|
||||
return_value={"test-block": [{"cost": 10, "type": "credit"}]},
|
||||
)
|
||||
|
||||
response = client.get("/blocks")
|
||||
@@ -147,15 +146,6 @@ def test_execute_graph_block(
|
||||
return_value=mock_block,
|
||||
)
|
||||
|
||||
# Mock user for user_context
|
||||
mock_user = Mock()
|
||||
mock_user.timezone = "UTC"
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_by_id",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"input_name": "test_input",
|
||||
"input_value": "test_value",
|
||||
@@ -275,12 +265,11 @@ def test_get_graphs(
|
||||
name="Test Graph",
|
||||
description="A test graph",
|
||||
user_id=test_user_id,
|
||||
created_at=datetime(2025, 9, 4, 13, 37),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.graph.list_graphs_paginated",
|
||||
return_value=Mock(graphs=[mock_graph]),
|
||||
"backend.server.routers.v1.graph_db.list_graphs",
|
||||
return_value=[mock_graph],
|
||||
)
|
||||
|
||||
response = client.get("/graphs")
|
||||
@@ -310,7 +299,6 @@ def test_get_graph(
|
||||
name="Test Graph",
|
||||
description="A test graph",
|
||||
user_id=test_user_id,
|
||||
created_at=datetime(2025, 9, 4, 13, 37),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
@@ -360,7 +348,6 @@ def test_delete_graph(
|
||||
name="Test Graph",
|
||||
description="A test graph",
|
||||
user_id=test_user_id,
|
||||
created_at=datetime(2025, 9, 4, 13, 37),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
|
||||
@@ -3,9 +3,8 @@ API Key authentication utilities for FastAPI applications.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.security import APIKeyHeader
|
||||
@@ -13,8 +12,6 @@ from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIKeyAuthenticator(APIKeyHeader):
|
||||
"""
|
||||
@@ -54,8 +51,7 @@ class APIKeyAuthenticator(APIKeyHeader):
|
||||
header_name (str): The name of the header containing the API key
|
||||
expected_token (Optional[str]): The expected API key value for simple token matching
|
||||
validator (Optional[Callable]): Custom validation function that takes an API key
|
||||
string and returns a truthy value if and only if the passed string is a
|
||||
valid API key. Can be async.
|
||||
string and returns a boolean or object. Can be async.
|
||||
status_if_missing (int): HTTP status code to use for validation errors
|
||||
message_if_invalid (str): Error message to return when validation fails
|
||||
"""
|
||||
@@ -64,9 +60,7 @@ class APIKeyAuthenticator(APIKeyHeader):
|
||||
self,
|
||||
header_name: str,
|
||||
expected_token: Optional[str] = None,
|
||||
validator: Optional[
|
||||
Callable[[str], Any] | Callable[[str], Awaitable[Any]]
|
||||
] = None,
|
||||
validator: Optional[Callable[[str], bool]] = None,
|
||||
status_if_missing: int = HTTP_401_UNAUTHORIZED,
|
||||
message_if_invalid: str = "Invalid API key",
|
||||
):
|
||||
@@ -81,7 +75,7 @@ class APIKeyAuthenticator(APIKeyHeader):
|
||||
self.message_if_invalid = message_if_invalid
|
||||
|
||||
async def __call__(self, request: Request) -> Any:
|
||||
api_key = await super().__call__(request)
|
||||
api_key = await super()(request)
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=self.status_if_missing, detail="No API key in request"
|
||||
@@ -112,9 +106,4 @@ class APIKeyAuthenticator(APIKeyHeader):
|
||||
f"{self.__class__.__name__}.expected_token is not set; "
|
||||
"either specify it or provide a custom validator"
|
||||
)
|
||||
try:
|
||||
return secrets.compare_digest(api_key, self.expected_token)
|
||||
except TypeError as e:
|
||||
# If value is not an ASCII string, compare_digest raises a TypeError
|
||||
logger.warning(f"{self.model.name} API key check failed: {e}")
|
||||
return False
|
||||
return secrets.compare_digest(api_key, self.expected_token)
|
||||
|
||||
@@ -1,537 +0,0 @@
|
||||
"""
|
||||
Unit tests for APIKeyAuthenticator class.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
|
||||
from backend.server.utils.api_key_auth import APIKeyAuthenticator
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request object."""
|
||||
request = Mock(spec=Request)
|
||||
request.state = Mock()
|
||||
request.headers = {}
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_auth():
|
||||
"""Create a basic APIKeyAuthenticator instance."""
|
||||
return APIKeyAuthenticator(
|
||||
header_name="X-API-Key", expected_token="test-secret-token"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_auth_custom_validator():
|
||||
"""Create APIKeyAuthenticator with custom validator."""
|
||||
|
||||
def custom_validator(api_key: str) -> bool:
|
||||
return api_key == "custom-valid-key"
|
||||
|
||||
return APIKeyAuthenticator(header_name="X-API-Key", validator=custom_validator)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_auth_async_validator():
|
||||
"""Create APIKeyAuthenticator with async custom validator."""
|
||||
|
||||
async def async_validator(api_key: str) -> bool:
|
||||
return api_key == "async-valid-key"
|
||||
|
||||
return APIKeyAuthenticator(header_name="X-API-Key", validator=async_validator)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_auth_object_validator():
|
||||
"""Create APIKeyAuthenticator that returns objects from validator."""
|
||||
|
||||
async def object_validator(api_key: str):
|
||||
if api_key == "user-key":
|
||||
return {"user_id": "123", "permissions": ["read", "write"]}
|
||||
return None
|
||||
|
||||
return APIKeyAuthenticator(header_name="X-API-Key", validator=object_validator)
|
||||
|
||||
|
||||
# ========== Basic Initialization Tests ========== #
|
||||
|
||||
|
||||
def test_init_with_expected_token():
|
||||
"""Test initialization with expected token."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="test-token")
|
||||
|
||||
assert auth.model.name == "X-API-Key"
|
||||
assert auth.expected_token == "test-token"
|
||||
assert auth.custom_validator is None
|
||||
assert auth.status_if_missing == HTTP_401_UNAUTHORIZED
|
||||
assert auth.message_if_invalid == "Invalid API key"
|
||||
|
||||
|
||||
def test_init_with_custom_validator():
|
||||
"""Test initialization with custom validator."""
|
||||
|
||||
def validator(key: str) -> bool:
|
||||
return True
|
||||
|
||||
auth = APIKeyAuthenticator(header_name="Authorization", validator=validator)
|
||||
|
||||
assert auth.model.name == "Authorization"
|
||||
assert auth.expected_token is None
|
||||
assert auth.custom_validator == validator
|
||||
assert auth.status_if_missing == HTTP_401_UNAUTHORIZED
|
||||
assert auth.message_if_invalid == "Invalid API key"
|
||||
|
||||
|
||||
def test_init_with_custom_parameters():
|
||||
"""Test initialization with custom status and message."""
|
||||
auth = APIKeyAuthenticator(
|
||||
header_name="X-Custom-Key",
|
||||
expected_token="token",
|
||||
status_if_missing=HTTP_403_FORBIDDEN,
|
||||
message_if_invalid="Access denied",
|
||||
)
|
||||
|
||||
assert auth.model.name == "X-Custom-Key"
|
||||
assert auth.status_if_missing == HTTP_403_FORBIDDEN
|
||||
assert auth.message_if_invalid == "Access denied"
|
||||
|
||||
|
||||
def test_scheme_name_generation():
|
||||
"""Test that scheme_name is generated correctly."""
|
||||
auth = APIKeyAuthenticator(header_name="X-Custom-Header", expected_token="token")
|
||||
|
||||
assert auth.scheme_name == "APIKeyAuthenticator-X-Custom-Header"
|
||||
|
||||
|
||||
# ========== Authentication Flow Tests ========== #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_missing(api_key_auth, mock_request):
|
||||
"""Test behavior when API key is missing from request."""
|
||||
# Mock the parent class method to return None (no API key)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "No API key in request"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_valid(api_key_auth, mock_request):
|
||||
"""Test behavior with valid API key."""
|
||||
# Mock the parent class to return the API key
|
||||
with patch.object(
|
||||
api_key_auth.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="test-secret-token",
|
||||
):
|
||||
result = await api_key_auth(mock_request)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_invalid(api_key_auth, mock_request):
|
||||
"""Test behavior with invalid API key."""
|
||||
# Mock the parent class to return an invalid API key
|
||||
with patch.object(
|
||||
api_key_auth.__class__.__bases__[0], "__call__", return_value="invalid-token"
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
# ========== Custom Validator Tests ========== #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_status_and_message(mock_request):
|
||||
"""Test custom status code and message."""
|
||||
auth = APIKeyAuthenticator(
|
||||
header_name="X-API-Key",
|
||||
expected_token="valid-token",
|
||||
status_if_missing=HTTP_403_FORBIDDEN,
|
||||
message_if_invalid="Access forbidden",
|
||||
)
|
||||
|
||||
# Test missing key
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_403_FORBIDDEN
|
||||
assert exc_info.value.detail == "No API key in request"
|
||||
|
||||
# Test invalid key
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value="invalid-token"
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_403_FORBIDDEN
|
||||
assert exc_info.value.detail == "Access forbidden"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_sync_validator(api_key_auth_custom_validator, mock_request):
|
||||
"""Test with custom synchronous validator."""
|
||||
# Mock the parent class to return the API key
|
||||
with patch.object(
|
||||
api_key_auth_custom_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="custom-valid-key",
|
||||
):
|
||||
result = await api_key_auth_custom_validator(mock_request)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_sync_validator_invalid(
|
||||
api_key_auth_custom_validator, mock_request
|
||||
):
|
||||
"""Test custom synchronous validator with invalid key."""
|
||||
with patch.object(
|
||||
api_key_auth_custom_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="invalid-key",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth_custom_validator(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_async_validator(api_key_auth_async_validator, mock_request):
|
||||
"""Test with custom async validator."""
|
||||
with patch.object(
|
||||
api_key_auth_async_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="async-valid-key",
|
||||
):
|
||||
result = await api_key_auth_async_validator(mock_request)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_async_validator_invalid(
|
||||
api_key_auth_async_validator, mock_request
|
||||
):
|
||||
"""Test custom async validator with invalid key."""
|
||||
with patch.object(
|
||||
api_key_auth_async_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="invalid-key",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth_async_validator(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validator_returns_object(api_key_auth_object_validator, mock_request):
|
||||
"""Test validator that returns an object instead of boolean."""
|
||||
with patch.object(
|
||||
api_key_auth_object_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="user-key",
|
||||
):
|
||||
result = await api_key_auth_object_validator(mock_request)
|
||||
|
||||
expected_result = {"user_id": "123", "permissions": ["read", "write"]}
|
||||
assert result == expected_result
|
||||
# Verify the object is stored in request state
|
||||
assert mock_request.state.api_key == expected_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validator_returns_none(api_key_auth_object_validator, mock_request):
|
||||
"""Test validator that returns None (falsy)."""
|
||||
with patch.object(
|
||||
api_key_auth_object_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="invalid-key",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth_object_validator(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validator_database_lookup_simulation(mock_request):
|
||||
"""Test simulation of database lookup validator."""
|
||||
# Simulate database records
|
||||
valid_api_keys = {
|
||||
"key123": {"user_id": "user1", "active": True},
|
||||
"key456": {"user_id": "user2", "active": False},
|
||||
}
|
||||
|
||||
async def db_validator(api_key: str):
|
||||
record = valid_api_keys.get(api_key)
|
||||
return record if record and record["active"] else None
|
||||
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", validator=db_validator)
|
||||
|
||||
# Test valid active key
|
||||
with patch.object(auth.__class__.__bases__[0], "__call__", return_value="key123"):
|
||||
result = await auth(mock_request)
|
||||
assert result == {"user_id": "user1", "active": True}
|
||||
assert mock_request.state.api_key == {"user_id": "user1", "active": True}
|
||||
|
||||
# Test inactive key
|
||||
mock_request.state = Mock() # Reset state
|
||||
with patch.object(auth.__class__.__bases__[0], "__call__", return_value="key456"):
|
||||
with pytest.raises(HTTPException):
|
||||
await auth(mock_request)
|
||||
|
||||
# Test non-existent key
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value="nonexistent"
|
||||
):
|
||||
with pytest.raises(HTTPException):
|
||||
await auth(mock_request)
|
||||
|
||||
|
||||
# ========== Default Validator Tests ========== #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_validator_key_valid(api_key_auth):
|
||||
"""Test default validator with valid token."""
|
||||
result = await api_key_auth.default_validator("test-secret-token")
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_validator_key_invalid(api_key_auth):
|
||||
"""Test default validator with invalid token."""
|
||||
result = await api_key_auth.default_validator("wrong-token")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_validator_missing_expected_token():
|
||||
"""Test default validator when expected_token is not set."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key")
|
||||
|
||||
with pytest.raises(MissingConfigError) as exc_info:
|
||||
await auth.default_validator("any-token")
|
||||
|
||||
assert "expected_token is not set" in str(exc_info.value)
|
||||
assert "either specify it or provide a custom validator" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_validator_uses_constant_time_comparison(api_key_auth):
|
||||
"""
|
||||
Test that default validator uses secrets.compare_digest for timing attack protection
|
||||
"""
|
||||
with patch("secrets.compare_digest") as mock_compare:
|
||||
mock_compare.return_value = True
|
||||
|
||||
await api_key_auth.default_validator("test-token")
|
||||
|
||||
mock_compare.assert_called_once_with("test-token", "test-secret-token")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_empty(mock_request):
|
||||
"""Test behavior with empty string API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
with patch.object(auth.__class__.__bases__[0], "__call__", return_value=""):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_whitespace_only(mock_request):
|
||||
"""Test behavior with whitespace-only API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=" \t\n "
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_very_long(mock_request):
|
||||
"""Test behavior with extremely long API key (potential DoS protection)."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# Create a very long API key (10MB)
|
||||
long_api_key = "a" * (10 * 1024 * 1024)
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=long_api_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_null_bytes(mock_request):
|
||||
"""Test behavior with API key containing null bytes."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
api_key_with_null = "valid\x00token"
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=api_key_with_null
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_control_characters(mock_request):
|
||||
"""Test behavior with API key containing control characters."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# API key with various control characters
|
||||
api_key_with_control = "valid\r\n\t\x1b[31mtoken"
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=api_key_with_control
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_unicode_characters(mock_request):
|
||||
"""Test behavior with Unicode characters in API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# API key with Unicode characters
|
||||
unicode_api_key = "validтокен🔑"
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=unicode_api_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_unicode_characters_normalization_attack(mock_request):
|
||||
"""Test that Unicode normalization doesn't bypass validation."""
|
||||
# Create auth with composed Unicode character
|
||||
auth = APIKeyAuthenticator(
|
||||
header_name="X-API-Key", expected_token="café" # é is composed
|
||||
)
|
||||
|
||||
# Try with decomposed version (c + a + f + e + ´)
|
||||
decomposed_key = "cafe\u0301" # é as combining character
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=decomposed_key
|
||||
):
|
||||
# Should fail because secrets.compare_digest doesn't normalize
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_binary_data(mock_request):
|
||||
"""Test behavior with binary data in API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# Binary data that might cause encoding issues
|
||||
binary_api_key = bytes([0xFF, 0xFE, 0xFD, 0xFC, 0x80, 0x81]).decode(
|
||||
"latin1", errors="ignore"
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=binary_api_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_regex_dos_attack_pattern(mock_request):
|
||||
"""Test behavior with API key of repeated characters (pattern attack)."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# Pattern that might cause regex DoS in poorly implemented validators
|
||||
repeated_key = "a" * 1000 + "b" * 1000 + "c" * 1000
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=repeated_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_keys_with_newline_variations(mock_request):
|
||||
"""Test different newline characters in API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
newline_variations = [
|
||||
"valid\ntoken", # Unix newline
|
||||
"valid\r\ntoken", # Windows newline
|
||||
"valid\rtoken", # Mac newline
|
||||
"valid\x85token", # NEL (Next Line)
|
||||
"valid\x0Btoken", # Vertical Tab
|
||||
"valid\x0Ctoken", # Form Feed
|
||||
]
|
||||
|
||||
for api_key in newline_variations:
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=api_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
@@ -147,8 +147,10 @@ class AutoModManager:
|
||||
return None
|
||||
|
||||
# Get completed executions and collect outputs
|
||||
completed_executions = await db_client.get_node_executions(
|
||||
graph_exec_id, statuses=[ExecutionStatus.COMPLETED], include_exec_data=True
|
||||
completed_executions = await db_client.get_node_executions( # type: ignore
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
include_exec_data=True,
|
||||
)
|
||||
|
||||
if not completed_executions:
|
||||
@@ -218,7 +220,7 @@ class AutoModManager:
|
||||
):
|
||||
"""Update node execution statuses for frontend display when moderation fails"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.executor.manager import send_async_execution_update
|
||||
from backend.util.clients import get_async_execution_event_bus
|
||||
|
||||
if moderation_type == "input":
|
||||
# For input moderation, mark queued/running/incomplete nodes as failed
|
||||
@@ -232,8 +234,10 @@ class AutoModManager:
|
||||
target_statuses = [ExecutionStatus.COMPLETED]
|
||||
|
||||
# Get the executions that need to be updated
|
||||
executions_to_update = await db_client.get_node_executions(
|
||||
graph_exec_id, statuses=target_statuses, include_exec_data=True
|
||||
executions_to_update = await db_client.get_node_executions( # type: ignore
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=target_statuses,
|
||||
include_exec_data=True,
|
||||
)
|
||||
|
||||
if not executions_to_update:
|
||||
@@ -276,10 +280,12 @@ class AutoModManager:
|
||||
updated_execs = await asyncio.gather(*exec_updates)
|
||||
|
||||
# Send all websocket updates in parallel
|
||||
event_bus = get_async_execution_event_bus()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
send_async_execution_update(updated_exec)
|
||||
event_bus.publish(updated_exec)
|
||||
for updated_exec in updated_execs
|
||||
if updated_exec is not None
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import functools
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import prisma
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.data.block
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.block import Block, BlockCategory, BlockInfo, BlockSchema
|
||||
from backend.data.block import Block, BlockCategory, BlockSchema
|
||||
from backend.data.credit import get_block_costs
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.v2.builder.model import (
|
||||
BlockCategoryResponse,
|
||||
BlockData,
|
||||
BlockResponse,
|
||||
BlockType,
|
||||
CountResponse,
|
||||
@@ -23,7 +25,7 @@ from backend.util.models import Pagination
|
||||
logger = logging.getLogger(__name__)
|
||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||
_static_counts_cache: dict | None = None
|
||||
_suggested_blocks: list[BlockInfo] | None = None
|
||||
_suggested_blocks: list[BlockData] | None = None
|
||||
|
||||
|
||||
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
|
||||
@@ -51,7 +53,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
||||
|
||||
# Append if the category has less than the specified number of blocks
|
||||
if len(categories[category].blocks) < category_blocks:
|
||||
categories[category].blocks.append(block.get_info())
|
||||
categories[category].blocks.append(block.to_dict())
|
||||
|
||||
# Sort categories by name
|
||||
return sorted(categories.values(), key=lambda x: x.name)
|
||||
@@ -107,8 +109,10 @@ def get_blocks(
|
||||
take -= 1
|
||||
blocks.append(block)
|
||||
|
||||
costs = get_block_costs()
|
||||
|
||||
return BlockResponse(
|
||||
blocks=[b.get_info() for b in blocks],
|
||||
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
@@ -170,9 +174,11 @@ def search_blocks(
|
||||
take -= 1
|
||||
blocks.append(block)
|
||||
|
||||
costs = get_block_costs()
|
||||
|
||||
return SearchBlocksResponse(
|
||||
blocks=BlockResponse(
|
||||
blocks=[b.get_info() for b in blocks],
|
||||
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
@@ -296,7 +302,7 @@ def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@cached()
|
||||
@functools.cache
|
||||
def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
providers: dict[ProviderName, Provider] = {}
|
||||
|
||||
@@ -317,7 +323,7 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
return providers
|
||||
|
||||
|
||||
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
async def get_suggested_blocks(count: int = 5) -> list[BlockData]:
|
||||
global _suggested_blocks
|
||||
|
||||
if _suggested_blocks is not None and len(_suggested_blocks) >= count:
|
||||
@@ -345,7 +351,7 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
|
||||
# Get the top blocks based on execution count
|
||||
# But ignore Input and Output blocks
|
||||
blocks: list[tuple[BlockInfo, int]] = []
|
||||
blocks: list[tuple[BlockData, int]] = []
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
@@ -360,7 +366,7 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
(row["execution_count"] for row in results if row["block_id"] == block.id),
|
||||
0,
|
||||
)
|
||||
blocks.append((block.get_info(), execution_count))
|
||||
blocks.append((block.to_dict(), execution_count))
|
||||
# Sort blocks by execution count
|
||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.data.block import BlockInfo
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
|
||||
@@ -17,27 +16,29 @@ FilterType = Literal[
|
||||
|
||||
BlockType = Literal["all", "input", "action", "output"]
|
||||
|
||||
BlockData = dict[str, Any]
|
||||
|
||||
|
||||
# Suggestions
|
||||
class SuggestionsResponse(BaseModel):
|
||||
otto_suggestions: list[str]
|
||||
recent_searches: list[str]
|
||||
providers: list[ProviderName]
|
||||
top_blocks: list[BlockInfo]
|
||||
top_blocks: list[BlockData]
|
||||
|
||||
|
||||
# All blocks
|
||||
class BlockCategoryResponse(BaseModel):
|
||||
name: str
|
||||
total_blocks: int
|
||||
blocks: list[BlockInfo]
|
||||
blocks: list[BlockData]
|
||||
|
||||
model_config = {"use_enum_values": False} # <== use enum names like "AI"
|
||||
|
||||
|
||||
# Input/Action/Output and see all for block categories
|
||||
class BlockResponse(BaseModel):
|
||||
blocks: list[BlockInfo]
|
||||
blocks: list[BlockData]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@@ -70,7 +71,7 @@ class SearchBlocksResponse(BaseModel):
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
items: list[BlockInfo | library_model.LibraryAgent | store_model.StoreAgent]
|
||||
items: list[BlockData | library_model.LibraryAgent | store_model.StoreAgent]
|
||||
total_items: dict[FilterType, int]
|
||||
page: int
|
||||
more_pages: bool
|
||||
|
||||
@@ -16,7 +16,7 @@ import backend.server.v2.store.media as store_media
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.db import transaction
|
||||
from backend.data.execution import get_graph_execution
|
||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
@@ -144,92 +144,6 @@ async def list_library_agents(
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agents") from e
|
||||
|
||||
|
||||
async def list_favorite_library_agents(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Retrieves a paginated list of favorite LibraryAgent records for a given user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user whose favorite LibraryAgents we want to retrieve.
|
||||
page: Current page (1-indexed).
|
||||
page_size: Number of items per page.
|
||||
|
||||
Returns:
|
||||
A LibraryAgentResponse containing the list of favorite agents and pagination details.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there is an issue fetching from Prisma.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Fetching favorite library agents for user_id={user_id}, "
|
||||
f"page={page}, page_size={page_size}"
|
||||
)
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||
raise store_exceptions.DatabaseError("Invalid pagination input")
|
||||
|
||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"isFavorite": True, # Only fetch favorites
|
||||
}
|
||||
|
||||
# Sort favorites by updated date descending
|
||||
order_by: prisma.types.LibraryAgentOrderByInput = {"updatedAt": "desc"}
|
||||
|
||||
try:
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(user_id),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
agent_count = await prisma.models.LibraryAgent.prisma().count(
|
||||
where=where_clause
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
|
||||
)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
logger.error(
|
||||
f"Error parsing LibraryAgent #{agent.id} from DB item: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Return the response with only valid agents
|
||||
return library_model.LibraryAgentResponse(
|
||||
agents=valid_library_agents,
|
||||
pagination=Pagination(
|
||||
total_items=agent_count,
|
||||
total_pages=(agent_count + page_size - 1) // page_size,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching favorite library agents: {e}")
|
||||
raise store_exceptions.DatabaseError(
|
||||
"Failed to fetch favorite library agents"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Get a specific agent from the user's library.
|
||||
@@ -703,7 +617,7 @@ async def list_presets(
|
||||
where=query_filter,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
total_items = await prisma.models.AgentPreset.prisma().count(where=query_filter)
|
||||
total_pages = (total_items + page_size - 1) // page_size
|
||||
@@ -748,7 +662,7 @@ async def get_preset(
|
||||
try:
|
||||
preset = await prisma.models.AgentPreset.prisma().find_unique(
|
||||
where={"id": preset_id},
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not preset or preset.userId != user_id or preset.isDeleted:
|
||||
return None
|
||||
@@ -795,12 +709,15 @@ async def create_preset(
|
||||
)
|
||||
for name, data in {
|
||||
**preset.inputs,
|
||||
**preset.credentials,
|
||||
**{
|
||||
key: creds_meta.model_dump(exclude_none=True)
|
||||
for key, creds_meta in preset.credentials.items()
|
||||
},
|
||||
}.items()
|
||||
]
|
||||
},
|
||||
),
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
return library_model.LibraryAgentPreset.from_db(new_preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
@@ -830,25 +747,6 @@ async def create_preset_from_graph_execution(
|
||||
if not graph_execution:
|
||||
raise NotFoundError(f"Graph execution #{graph_exec_id} not found")
|
||||
|
||||
# Sanity check: credential inputs must be available if required for this preset
|
||||
if graph_execution.credential_inputs is None:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=graph_execution.graph_id,
|
||||
version=graph_execution.graph_version,
|
||||
user_id=graph_execution.user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph_execution.graph_id} not found or accessible"
|
||||
)
|
||||
elif len(graph.aggregate_credentials_inputs()) > 0:
|
||||
raise ValueError(
|
||||
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
||||
"because it was run before this feature existed "
|
||||
"and so the input credentials were not saved."
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Creating preset for user #{user_id} from graph execution #{graph_exec_id}",
|
||||
)
|
||||
@@ -856,7 +754,7 @@ async def create_preset_from_graph_execution(
|
||||
user_id=user_id,
|
||||
preset=library_model.LibraryAgentPresetCreatable(
|
||||
inputs=graph_execution.inputs,
|
||||
credentials=graph_execution.credential_inputs or {},
|
||||
credentials={}, # FIXME
|
||||
graph_id=graph_execution.graph_id,
|
||||
graph_version=graph_execution.graph_version,
|
||||
name=create_request.name,
|
||||
@@ -936,7 +834,7 @@ async def update_preset(
|
||||
updated = await prisma.models.AgentPreset.prisma(tx).update(
|
||||
where={"id": preset_id},
|
||||
data=update_data,
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
|
||||
@@ -951,7 +849,7 @@ async def set_preset_webhook(
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
current = await prisma.models.AgentPreset.prisma().find_unique(
|
||||
where={"id": preset_id},
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not current or current.userId != user_id:
|
||||
raise NotFoundError(f"Preset #{preset_id} not found")
|
||||
@@ -963,7 +861,7 @@ async def set_preset_webhook(
|
||||
if webhook_id
|
||||
else {"Webhook": {"disconnect": True}}
|
||||
),
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
@@ -9,11 +9,9 @@ import pydantic
|
||||
import backend.data.block as block_model
|
||||
import backend.data.graph as graph_model
|
||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.integrations import Webhook
|
||||
|
||||
|
||||
class LibraryAgentStatus(str, Enum):
|
||||
COMPLETED = "COMPLETED" # All runs completed
|
||||
@@ -22,6 +20,14 @@ class LibraryAgentStatus(str, Enum):
|
||||
ERROR = "ERROR" # Agent is in an error state
|
||||
|
||||
|
||||
class LibraryAgentTriggerInfo(pydantic.BaseModel):
|
||||
provider: ProviderName
|
||||
config_schema: dict[str, Any] = pydantic.Field(
|
||||
description="Input schema for the trigger block"
|
||||
)
|
||||
credentials_input_name: Optional[str]
|
||||
|
||||
|
||||
class LibraryAgent(pydantic.BaseModel):
|
||||
"""
|
||||
Represents an agent in the library, including metadata for display and
|
||||
@@ -43,7 +49,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
|
||||
name: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
|
||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
output_schema: dict[str, Any]
|
||||
@@ -54,7 +59,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
has_external_trigger: bool = pydantic.Field(
|
||||
description="Whether the agent has an external trigger (e.g. webhook) node"
|
||||
)
|
||||
trigger_setup_info: Optional[graph_model.GraphTriggerInfo] = None
|
||||
trigger_setup_info: Optional[LibraryAgentTriggerInfo] = None
|
||||
|
||||
# Indicates whether there's a new output (based on recent runs)
|
||||
new_output: bool
|
||||
@@ -65,12 +70,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
# Indicates if this agent is the latest version
|
||||
is_latest_version: bool
|
||||
|
||||
# Whether the agent is marked as favorite by the user
|
||||
is_favorite: bool
|
||||
|
||||
# Recommended schedule cron (from marketplace agents)
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
agent: prisma.models.LibraryAgent,
|
||||
@@ -127,19 +126,39 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
updated_at=updated_at,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
instructions=graph.instructions,
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
credentials_input_schema=(
|
||||
graph.credentials_input_schema if sub_graphs is not None else None
|
||||
),
|
||||
has_external_trigger=graph.has_external_trigger,
|
||||
trigger_setup_info=graph.trigger_setup_info,
|
||||
trigger_setup_info=(
|
||||
LibraryAgentTriggerInfo(
|
||||
provider=trigger_block.webhook_config.provider,
|
||||
config_schema={
|
||||
**(json_schema := trigger_block.input_schema.jsonschema()),
|
||||
"properties": {
|
||||
pn: sub_schema
|
||||
for pn, sub_schema in json_schema["properties"].items()
|
||||
if not is_credentials_field_name(pn)
|
||||
},
|
||||
"required": [
|
||||
pn
|
||||
for pn in json_schema.get("required", [])
|
||||
if not is_credentials_field_name(pn)
|
||||
],
|
||||
},
|
||||
credentials_input_name=next(
|
||||
iter(trigger_block.input_schema.get_credentials_fields()), None
|
||||
),
|
||||
)
|
||||
if graph.webhook_input_node
|
||||
and (trigger_block := graph.webhook_input_node.block).webhook_config
|
||||
else None
|
||||
),
|
||||
new_output=new_output,
|
||||
can_access_graph=can_access_graph,
|
||||
is_latest_version=is_latest_version,
|
||||
is_favorite=agent.isFavorite,
|
||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||
)
|
||||
|
||||
|
||||
@@ -263,21 +282,12 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
webhook: "Webhook | None"
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, preset: prisma.models.AgentPreset) -> "LibraryAgentPreset":
|
||||
from backend.data.integrations import Webhook
|
||||
|
||||
if preset.InputPresets is None:
|
||||
raise ValueError("InputPresets must be included in AgentPreset query")
|
||||
if preset.webhookId and preset.Webhook is None:
|
||||
raise ValueError(
|
||||
"Webhook must be included in AgentPreset query when webhookId is set"
|
||||
)
|
||||
|
||||
input_data: block_model.BlockInput = {}
|
||||
input_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
@@ -293,7 +303,6 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
return cls(
|
||||
id=preset.id,
|
||||
user_id=preset.userId,
|
||||
created_at=preset.createdAt,
|
||||
updated_at=preset.updatedAt,
|
||||
graph_id=preset.agentGraphId,
|
||||
graph_version=preset.agentGraphVersion,
|
||||
@@ -303,7 +312,6 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
inputs=input_data,
|
||||
credentials=input_credentials,
|
||||
webhook_id=preset.webhookId,
|
||||
webhook=Webhook.from_db(preset.Webhook) if preset.Webhook else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -79,54 +79,6 @@ async def list_library_agents(
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/favorites",
|
||||
summary="List Favorite Library Agents",
|
||||
responses={
|
||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||
},
|
||||
)
|
||||
async def list_favorite_library_agents(
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
page: int = Query(
|
||||
1,
|
||||
ge=1,
|
||||
description="Page number to retrieve (must be >= 1)",
|
||||
),
|
||||
page_size: int = Query(
|
||||
15,
|
||||
ge=1,
|
||||
description="Number of agents per page (must be >= 1)",
|
||||
),
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Get all favorite agents in the user's library.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user.
|
||||
page: Page number to retrieve.
|
||||
page_size: Number of agents per page.
|
||||
|
||||
Returns:
|
||||
A LibraryAgentResponse containing favorite agents and pagination metadata.
|
||||
|
||||
Raises:
|
||||
HTTPException: If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.list_favorite_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not list favorite library agents for user #{user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
||||
async def get_library_agent(
|
||||
library_agent_id: str,
|
||||
|
||||
@@ -6,10 +6,8 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.integrations import get_webhook
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -371,41 +369,48 @@ async def execute_preset(
|
||||
preset_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
inputs: dict[str, Any] = Body(..., embed=True, default_factory=dict),
|
||||
credential_inputs: dict[str, CredentialsMetaInput] = Body(
|
||||
..., embed=True, default_factory=dict
|
||||
),
|
||||
) -> GraphExecutionMeta:
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
"""
|
||||
Execute a preset given graph parameters, returning the execution ID on success.
|
||||
|
||||
Args:
|
||||
preset_id: ID of the preset to execute.
|
||||
user_id: ID of the authenticated user.
|
||||
inputs: Optionally, inputs to override the preset for execution.
|
||||
credential_inputs: Optionally, credentials to override the preset for execution.
|
||||
preset_id (str): ID of the preset to execute.
|
||||
user_id (str): ID of the authenticated user.
|
||||
inputs (dict[str, Any]): Optionally, additional input data for the graph execution.
|
||||
|
||||
Returns:
|
||||
GraphExecutionMeta: Object representing the created execution.
|
||||
{id: graph_exec_id}: A response containing the execution ID.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the preset is not found or an error occurs while executing the preset.
|
||||
"""
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset #{preset_id} not found",
|
||||
try:
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset #{preset_id} not found",
|
||||
)
|
||||
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | inputs
|
||||
|
||||
execution = await add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=preset.graph_id,
|
||||
graph_version=preset.graph_version,
|
||||
preset_id=preset_id,
|
||||
inputs=merged_node_input,
|
||||
)
|
||||
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | inputs
|
||||
merged_credential_inputs = preset.credentials | credential_inputs
|
||||
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
|
||||
|
||||
return await add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=preset.graph_id,
|
||||
graph_version=preset.graph_version,
|
||||
preset_id=preset_id,
|
||||
inputs=merged_node_input,
|
||||
graph_credentials_inputs=merged_credential_inputs,
|
||||
)
|
||||
return {"id": execution.id}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Preset execution failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
@@ -50,11 +50,9 @@ async def test_get_library_agents_success(
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
recommended_schedule_cron=None,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
is_favorite=False,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
library_model.LibraryAgent(
|
||||
@@ -71,11 +69,9 @@ async def test_get_library_agents_success(
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
recommended_schedule_cron=None,
|
||||
new_output=False,
|
||||
can_access_graph=False,
|
||||
is_latest_version=True,
|
||||
is_favorite=False,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
],
|
||||
@@ -123,76 +119,6 @@ def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_favorite_library_agents_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mocked_value = library_model.LibraryAgentResponse(
|
||||
agents=[
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
recommended_schedule_cron=None,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
is_favorite=True,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
],
|
||||
pagination=Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=15
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/agents/favorites")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = library_model.LibraryAgentResponse.model_validate(response.json())
|
||||
assert len(data.agents) == 1
|
||||
assert data.agents[0].is_favorite is True
|
||||
assert data.agents[0].name == "Favorite Agent 1"
|
||||
|
||||
mock_db_call.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
page=1,
|
||||
page_size=15,
|
||||
)
|
||||
|
||||
|
||||
def test_get_favorite_library_agents_error(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents/favorites")
|
||||
assert response.status_code == 500
|
||||
mock_db_call.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
page=1,
|
||||
page_size=15,
|
||||
)
|
||||
|
||||
|
||||
def test_add_agent_to_library_success(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
@@ -213,7 +139,6 @@ def test_add_agent_to_library_success(
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
is_favorite=False,
|
||||
updated_at=FIXED_NOW,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -10,7 +9,6 @@ import prisma.types
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
from backend.data.db import transaction
|
||||
from backend.data.graph import (
|
||||
GraphMeta,
|
||||
GraphModel,
|
||||
@@ -72,7 +70,7 @@ async def get_store_agents(
|
||||
)
|
||||
sanitized_query = sanitize_query(search_query)
|
||||
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
where_clause = {}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
@@ -96,13 +94,15 @@ async def get_store_agents(
|
||||
|
||||
try:
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total = await prisma.models.StoreAgent.prisma().count(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause)
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
@@ -183,36 +183,6 @@ async def get_store_agent_details(
|
||||
store_listing.hasApprovedVersion if store_listing else False
|
||||
)
|
||||
|
||||
if active_version_id:
|
||||
agent_by_active = await prisma.models.StoreAgent.prisma().find_first(
|
||||
where={"storeListingVersionId": active_version_id}
|
||||
)
|
||||
if agent_by_active:
|
||||
agent = agent_by_active
|
||||
elif store_listing:
|
||||
latest_approved = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"storeListingId": store_listing.id,
|
||||
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
|
||||
},
|
||||
order=[{"version": "desc"}],
|
||||
)
|
||||
)
|
||||
if latest_approved:
|
||||
agent_latest = await prisma.models.StoreAgent.prisma().find_first(
|
||||
where={"storeListingVersionId": latest_approved.id}
|
||||
)
|
||||
if agent_latest:
|
||||
agent = agent_latest
|
||||
|
||||
if store_listing and store_listing.ActiveVersion:
|
||||
recommended_schedule_cron = (
|
||||
store_listing.ActiveVersion.recommendedScheduleCron
|
||||
)
|
||||
else:
|
||||
recommended_schedule_cron = None
|
||||
|
||||
logger.debug(f"Found agent details for {username}/{agent_name}")
|
||||
return backend.server.v2.store.model.StoreAgentDetails(
|
||||
store_listing_version_id=agent.storeListingVersionId,
|
||||
@@ -220,8 +190,8 @@ async def get_store_agent_details(
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username or "",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
creator=agent.creator_username,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
@@ -231,7 +201,6 @@ async def get_store_agent_details(
|
||||
last_updated=agent.updated_at,
|
||||
active_version_id=active_version_id,
|
||||
has_approved_version=has_approved_version,
|
||||
recommended_schedule_cron=recommended_schedule_cron,
|
||||
)
|
||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||
raise
|
||||
@@ -294,8 +263,8 @@ async def get_store_agent_by_version_id(
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username or "",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
creator=agent.creator_username,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
@@ -499,7 +468,6 @@ async def get_store_submissions(
|
||||
sub_heading=sub.sub_heading,
|
||||
slug=sub.slug,
|
||||
description=sub.description,
|
||||
instructions=getattr(sub, "instructions", None),
|
||||
image_urls=sub.image_urls or [],
|
||||
date_submitted=sub.date_submitted or datetime.now(tz=timezone.utc),
|
||||
status=sub.status,
|
||||
@@ -591,11 +559,9 @@ async def create_store_submission(
|
||||
video_url: str | None = None,
|
||||
image_urls: list[str] = [],
|
||||
description: str = "",
|
||||
instructions: str | None = None,
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
changes_summary: str | None = "Initial Submission",
|
||||
recommended_schedule_cron: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create the first (and only) store listing and thus submission as a normal user
|
||||
@@ -663,7 +629,6 @@ async def create_store_submission(
|
||||
video_url=video_url,
|
||||
image_urls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
sub_heading=sub_heading,
|
||||
categories=categories,
|
||||
changes_summary=changes_summary,
|
||||
@@ -685,13 +650,11 @@ async def create_store_submission(
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(tz=timezone.utc),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
)
|
||||
]
|
||||
},
|
||||
@@ -716,7 +679,6 @@ async def create_store_submission(
|
||||
slug=slug,
|
||||
sub_heading=sub_heading,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=listing.createdAt,
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
@@ -748,8 +710,6 @@ async def edit_store_submission(
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
changes_summary: str | None = "Update submission",
|
||||
recommended_schedule_cron: str | None = None,
|
||||
instructions: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Edit an existing store listing submission.
|
||||
@@ -829,8 +789,6 @@ async def edit_store_submission(
|
||||
sub_heading=sub_heading,
|
||||
categories=categories,
|
||||
changes_summary=changes_summary,
|
||||
recommended_schedule_cron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
# For PENDING submissions, we can update the existing version
|
||||
@@ -846,8 +804,6 @@ async def edit_store_submission(
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -866,7 +822,6 @@ async def edit_store_submission(
|
||||
sub_heading=sub_heading,
|
||||
slug=current_version.StoreListing.slug,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
||||
status=updated_version.submissionStatus,
|
||||
@@ -908,11 +863,9 @@ async def create_store_version(
|
||||
video_url: str | None = None,
|
||||
image_urls: list[str] = [],
|
||||
description: str = "",
|
||||
instructions: str | None = None,
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
changes_summary: str | None = "Initial submission",
|
||||
recommended_schedule_cron: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create a new version for an existing store listing
|
||||
@@ -977,13 +930,11 @@ async def create_store_version(
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
storeListingId=store_listing_id,
|
||||
)
|
||||
)
|
||||
@@ -999,7 +950,6 @@ async def create_store_version(
|
||||
slug=listing.slug,
|
||||
sub_heading=sub_heading,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=datetime.now(),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
@@ -1176,20 +1126,7 @@ async def get_my_agents(
|
||||
try:
|
||||
search_filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"AgentGraph": {
|
||||
"is": {
|
||||
"StoreListings": {
|
||||
"none": {
|
||||
"isDeleted": False,
|
||||
"Versions": {
|
||||
"some": {
|
||||
"isAvailable": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}},
|
||||
"isArchived": False,
|
||||
"isDeleted": False,
|
||||
}
|
||||
@@ -1213,7 +1150,6 @@ async def get_my_agents(
|
||||
last_edited=graph.updatedAt or graph.createdAt,
|
||||
description=graph.description or "",
|
||||
agent_image=library_agent.imageUrl,
|
||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
||||
)
|
||||
for library_agent in library_agents
|
||||
if (graph := library_agent.AgentGraph)
|
||||
@@ -1264,103 +1200,40 @@ async def get_agent(store_listing_version_id: str) -> GraphModel:
|
||||
#####################################################
|
||||
|
||||
|
||||
async def _approve_sub_agent(
|
||||
tx,
|
||||
sub_graph: prisma.models.AgentGraph,
|
||||
main_agent_name: str,
|
||||
main_agent_version: int,
|
||||
main_agent_user_id: str,
|
||||
) -> None:
|
||||
"""Approve a single sub-agent by creating/updating store listings as needed"""
|
||||
heading = f"Sub-agent of {main_agent_name} v{main_agent_version}"
|
||||
async def _get_missing_sub_store_listing(
|
||||
graph: prisma.models.AgentGraph,
|
||||
) -> list[prisma.models.AgentGraph]:
|
||||
"""
|
||||
Agent graph can have sub-graphs, and those sub-graphs also need to be store listed.
|
||||
This method fetches the sub-graphs, and returns the ones not listed in the store.
|
||||
"""
|
||||
sub_graphs = await get_sub_graphs(graph)
|
||||
if not sub_graphs:
|
||||
return []
|
||||
|
||||
# Find existing listing for this sub-agent
|
||||
listing = await prisma.models.StoreListing.prisma(tx).find_first(
|
||||
where={"agentGraphId": sub_graph.id, "isDeleted": False},
|
||||
include={"Versions": True},
|
||||
)
|
||||
|
||||
# Early return: Create new listing if none exists
|
||||
if not listing:
|
||||
await prisma.models.StoreListing.prisma(tx).create(
|
||||
data=prisma.types.StoreListingCreateInput(
|
||||
slug=f"sub-agent-{sub_graph.id[:8]}",
|
||||
agentGraphId=sub_graph.id,
|
||||
agentGraphVersion=sub_graph.version,
|
||||
owningUserId=main_agent_user_id,
|
||||
hasApprovedVersion=True,
|
||||
Versions={
|
||||
"create": [
|
||||
_create_sub_agent_version_data(
|
||||
sub_graph, heading, main_agent_name
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Find version matching this sub-graph
|
||||
matching_version = next(
|
||||
(
|
||||
v
|
||||
for v in listing.Versions or []
|
||||
if v.agentGraphId == sub_graph.id
|
||||
and v.agentGraphVersion == sub_graph.version
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# Early return: Approve existing version if found and not already approved
|
||||
if matching_version:
|
||||
if matching_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
||||
return # Already approved, nothing to do
|
||||
|
||||
await prisma.models.StoreListingVersion.prisma(tx).update(
|
||||
where={"id": matching_version.id},
|
||||
data={
|
||||
# Fetch all the sub-graphs that are listed, and return the ones missing.
|
||||
store_listed_sub_graphs = {
|
||||
(listing.agentGraphId, listing.agentGraphVersion)
|
||||
for listing in await prisma.models.StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{
|
||||
"agentGraphId": sub_graph.id,
|
||||
"agentGraphVersion": sub_graph.version,
|
||||
}
|
||||
for sub_graph in sub_graphs
|
||||
],
|
||||
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
|
||||
"reviewedAt": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": listing.id}, data={"hasApprovedVersion": True}
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
# Create new version if no matching version found
|
||||
next_version = max((v.version for v in listing.Versions or []), default=0) + 1
|
||||
await prisma.models.StoreListingVersion.prisma(tx).create(
|
||||
data={
|
||||
**_create_sub_agent_version_data(sub_graph, heading, main_agent_name),
|
||||
"version": next_version,
|
||||
"storeListingId": listing.id,
|
||||
}
|
||||
)
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": listing.id}, data={"hasApprovedVersion": True}
|
||||
)
|
||||
|
||||
|
||||
def _create_sub_agent_version_data(
|
||||
sub_graph: prisma.models.AgentGraph, heading: str, main_agent_name: str
|
||||
) -> prisma.types.StoreListingVersionCreateInput:
|
||||
"""Create store listing version data for a sub-agent"""
|
||||
return prisma.types.StoreListingVersionCreateInput(
|
||||
agentGraphId=sub_graph.id,
|
||||
agentGraphVersion=sub_graph.version,
|
||||
name=sub_graph.name or heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
|
||||
subHeading=heading,
|
||||
description=(
|
||||
f"{heading}: {sub_graph.description}" if sub_graph.description else heading
|
||||
),
|
||||
changesSummary=f"Auto-approved as sub-agent of {main_agent_name}",
|
||||
isAvailable=False,
|
||||
submittedAt=datetime.now(tz=timezone.utc),
|
||||
imageUrls=[], # Sub-agents don't need images
|
||||
categories=[], # Sub-agents don't need categories
|
||||
)
|
||||
return [
|
||||
sub_graph
|
||||
for sub_graph in sub_graphs
|
||||
if (sub_graph.id, sub_graph.version) not in store_listed_sub_graphs
|
||||
]
|
||||
|
||||
|
||||
async def review_store_submission(
|
||||
@@ -1398,46 +1271,33 @@ async def review_store_submission(
|
||||
|
||||
# If approving, update the listing to indicate it has an approved version
|
||||
if is_approved and store_listing_version.AgentGraph:
|
||||
async with transaction() as tx:
|
||||
# Handle sub-agent approvals in transaction
|
||||
await asyncio.gather(
|
||||
*[
|
||||
_approve_sub_agent(
|
||||
tx,
|
||||
sub_graph,
|
||||
store_listing_version.name,
|
||||
store_listing_version.agentGraphVersion,
|
||||
store_listing_version.StoreListing.owningUserId,
|
||||
)
|
||||
for sub_graph in await get_sub_graphs(
|
||||
store_listing_version.AgentGraph
|
||||
)
|
||||
]
|
||||
)
|
||||
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentGraphVersion}"
|
||||
|
||||
# Update the AgentGraph with store listing data
|
||||
await prisma.models.AgentGraph.prisma().update(
|
||||
where={
|
||||
"graphVersionId": {
|
||||
"id": store_listing_version.agentGraphId,
|
||||
"version": store_listing_version.agentGraphVersion,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"name": store_listing_version.name,
|
||||
"description": store_listing_version.description,
|
||||
"recommendedScheduleCron": store_listing_version.recommendedScheduleCron,
|
||||
"instructions": store_listing_version.instructions,
|
||||
},
|
||||
sub_store_listing_versions = [
|
||||
prisma.types.StoreListingVersionCreateWithoutRelationsInput(
|
||||
agentGraphId=sub_graph.id,
|
||||
agentGraphVersion=sub_graph.version,
|
||||
name=sub_graph.name or heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
|
||||
subHeading=heading,
|
||||
description=f"{heading}: {sub_graph.description}",
|
||||
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentGraphId}.",
|
||||
isAvailable=False, # Hide sub-graphs from the store by default.
|
||||
submittedAt=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
for sub_graph in await _get_missing_sub_store_listing(
|
||||
store_listing_version.AgentGraph
|
||||
)
|
||||
]
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
data={
|
||||
"hasApprovedVersion": True,
|
||||
"ActiveVersion": {"connect": {"id": store_listing_version_id}},
|
||||
},
|
||||
)
|
||||
await prisma.models.StoreListing.prisma().update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
data={
|
||||
"hasApprovedVersion": True,
|
||||
"ActiveVersion": {"connect": {"id": store_listing_version_id}},
|
||||
"Versions": {"create": sub_store_listing_versions},
|
||||
},
|
||||
)
|
||||
|
||||
# If rejecting an approved agent, update the StoreListing accordingly
|
||||
if is_rejecting_approved:
|
||||
@@ -1593,7 +1453,6 @@ async def review_store_submission(
|
||||
else ""
|
||||
),
|
||||
description=submission.description,
|
||||
instructions=submission.instructions,
|
||||
image_urls=submission.imageUrls or [],
|
||||
date_submitted=submission.submittedAt or submission.createdAt,
|
||||
status=submission.submissionStatus,
|
||||
@@ -1729,7 +1588,6 @@ async def get_admin_listings_with_versions(
|
||||
sub_heading=version.subHeading,
|
||||
slug=listing.slug,
|
||||
description=version.description,
|
||||
instructions=version.instructions,
|
||||
image_urls=version.imageUrls or [],
|
||||
date_submitted=version.submittedAt or version.createdAt,
|
||||
status=version.submissionStatus,
|
||||
|
||||
@@ -41,7 +41,6 @@ async def test_get_store_agents(mocker):
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -83,53 +82,16 @@ async def test_get_store_agent_details(mocker):
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
)
|
||||
|
||||
# Mock active version agent (what we want to return for active version)
|
||||
mock_active_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="active-version-id",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent Active",
|
||||
agent_video="active_video.mp4",
|
||||
agent_image=["active_image.jpg"],
|
||||
featured=False,
|
||||
creator_username="creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test heading active",
|
||||
description="Test description active",
|
||||
categories=["test"],
|
||||
runs=15,
|
||||
rating=4.8,
|
||||
versions=["1.0", "2.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
)
|
||||
|
||||
# Create a mock StoreListing result
|
||||
mock_store_listing = mocker.MagicMock()
|
||||
mock_store_listing.activeVersionId = "active-version-id"
|
||||
mock_store_listing.hasApprovedVersion = True
|
||||
mock_store_listing.ActiveVersion = mocker.MagicMock()
|
||||
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
|
||||
|
||||
# Mock StoreAgent prisma call - need to handle multiple calls
|
||||
# Mock StoreAgent prisma call
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
|
||||
# Set up side_effect to return different results for different calls
|
||||
def mock_find_first_side_effect(*args, **kwargs):
|
||||
where_clause = kwargs.get("where", {})
|
||||
if "storeListingVersionId" in where_clause:
|
||||
# Second call for active version
|
||||
return mock_active_agent
|
||||
else:
|
||||
# First call for initial lookup
|
||||
return mock_agent
|
||||
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(
|
||||
side_effect=mock_find_first_side_effect
|
||||
)
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Mock Profile prisma call
|
||||
mock_profile = mocker.MagicMock()
|
||||
@@ -139,7 +101,7 @@ async def test_get_store_agent_details(mocker):
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Mock StoreListing prisma call
|
||||
# Mock StoreListing prisma call - this is what was missing
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_store_listing
|
||||
@@ -148,25 +110,16 @@ async def test_get_store_agent_details(mocker):
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
|
||||
# Verify results - should use active version data
|
||||
# Verify results
|
||||
assert result.slug == "test-agent"
|
||||
assert result.agent_name == "Test Agent Active" # From active version
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.active_version_id == "active-version-id"
|
||||
assert result.has_approved_version is True
|
||||
assert (
|
||||
result.store_listing_version_id == "active-version-id"
|
||||
) # Should be active version ID
|
||||
|
||||
# Verify mocks called correctly - now expecting 2 calls
|
||||
assert mock_store_agent.return_value.find_first.call_count == 2
|
||||
|
||||
# Check the specific calls
|
||||
calls = mock_store_agent.return_value.find_first.call_args_list
|
||||
assert calls[0] == mocker.call(
|
||||
# Verify mocks called correctly
|
||||
mock_store_agent.return_value.find_first.assert_called_once_with(
|
||||
where={"creator_username": "creator", "slug": "test-agent"}
|
||||
)
|
||||
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
|
||||
|
||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user