Compare commits

..

2 Commits

Author SHA1 Message Date
Nicholas Tindle
f20693d02b chore(classic): remove deprecated benchmark suite (agbenchmark)
Remove the entire classic/benchmark directory containing the agbenchmark
suite, test agent reports (mini-agi, babyagi, beebot, etc.), and challenge
fixtures. This benchmark harness is no longer used.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 01:41:44 -05:00
Nicholas Tindle
a4188c5657 chore(classic): remove deprecated Flutter frontend and unneeded files
- Remove deprecated Flutter frontend (replaced by autogpt_platform)
- Remove shell scripts (run, setup, autogpt.sh, etc.)
- Remove tutorials (outdated)
- Remove CLI-USAGE.md, FORGE-QUICKSTART.md, and cli.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 01:34:01 -05:00
880 changed files with 103264 additions and 110637 deletions

View File

@@ -1 +0,0 @@
../.claude/skills

View File

@@ -1,10 +0,0 @@
{
"permissions": {
"allowedTools": [
"Read", "Grep", "Glob",
"Bash(ls:*)", "Bash(cat:*)", "Bash(grep:*)", "Bash(find:*)",
"Bash(git status:*)", "Bash(git diff:*)", "Bash(git log:*)", "Bash(git worktree:*)",
"Bash(tmux:*)", "Bash(sleep:*)", "Bash(branchlet:*)"
]
}
}

View File

@@ -1,106 +0,0 @@
---
name: open-pr
description: Open a pull request with proper PR template, test coverage, and review workflow. Guides agents through creating a PR that follows repo conventions, ensures existing behaviors aren't broken, covers new behaviors with tests, and handles review via bot when local testing isn't possible. TRIGGER when user asks to "open a PR", "create a PR", "make a PR", "submit a PR", "open pull request", "push and create PR", or any variation of opening/submitting a pull request.
user-invocable: true
args: "[base-branch] — optional target branch (defaults to dev)."
metadata:
author: autogpt-team
version: "1.0.0"
---
# Open a Pull Request
## Step 1: Pre-flight checks
Before opening the PR:
1. Ensure all changes are committed
2. Ensure the branch is pushed to the remote (`git push -u origin <branch>`)
3. Run linters/formatters across the whole repo (not just changed files) and commit any fixes
## Step 2: Test coverage
**This is critical.** Before opening the PR, verify:
### Existing behavior is not broken
- Identify which modules/components your changes touch
- Run the existing test suites for those areas
- If tests fail, fix them before opening the PR — do not open a PR with known regressions
### New behavior has test coverage
- Every new feature, endpoint, or behavior change needs tests
- If you added a new block, add tests for that block
- If you changed API behavior, add or update API tests
- If you changed frontend behavior, verify it doesn't break existing flows
If you cannot run the full test suite locally, note which tests you ran and which you couldn't in the test plan.
## Step 3: Create the PR using the repo template
Read the canonical PR template at `.github/PULL_REQUEST_TEMPLATE.md` and use it **verbatim** as your PR body:
1. Read the template: `cat .github/PULL_REQUEST_TEMPLATE.md`
2. Preserve the exact section titles and formatting, including:
- `### Why / What / How`
- `### Changes 🏗️`
- `### Checklist 📋`
3. Replace HTML comment prompts (`<!-- ... -->`) with actual content; do not leave them in
4. **Do not pre-check boxes** — leave all checkboxes as `- [ ]` until each step is actually completed
5. Do not alter the template structure, rename sections, or remove any checklist items
**PR title must use conventional commit format** (e.g., `feat(backend): add new block`, `fix(frontend): resolve routing bug`, `dx(skills): update PR workflow`). See CLAUDE.md for the full list of scopes.
Use `gh pr create` with the base branch (defaults to `dev` if no `[base-branch]` was provided). Use `--body-file` to avoid shell interpretation of backticks and special characters:
```bash
BASE_BRANCH="${BASE_BRANCH:-dev}"
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
<filled-in template from .github/PULL_REQUEST_TEMPLATE.md>
PREOF
gh pr create --base "$BASE_BRANCH" --title "<type>(scope): short description" --body-file "$PR_BODY"
rm "$PR_BODY"
```
## Step 4: Review workflow
### If you have a workspace that allows testing (docker, running backend, etc.)
- Run `/pr-test` to do E2E manual testing of the PR using docker compose, agent-browser, and API calls. This is the most thorough way to validate your changes before review.
- After testing, run `/pr-review` to self-review the PR for correctness, security, code quality, and testing gaps before requesting human review.
### If you do NOT have a workspace that allows testing
This is common for agents running in worktrees without a full stack. In this case:
1. Run `/pr-review` locally to catch obvious issues before pushing
2. **Comment `/review` on the PR** after creating it to trigger the review bot
3. **Poll for the review** rather than blindly waiting — check for new review comments every 30 seconds using `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` and the GraphQL inline threads query. The bot typically responds within 30 minutes, but polling lets the agent react as soon as it arrives.
4. Do NOT proceed or merge until the bot review comes back
5. Address any issues the bot raises — use `/pr-address` which has a full polling loop with CI + comment tracking
```bash
# After creating the PR:
PR_NUMBER=$(gh pr view --json number -q .number)
gh pr comment "$PR_NUMBER" --body "/review"
# Then use /pr-address to poll for and address the review when it arrives
```
## Step 5: Address review feedback
Once the review bot or human reviewers leave comments:
- Run `/pr-address` to address review comments. It will loop until CI is green and all comments are resolved.
- Do not merge without human approval.
## Related skills
| Skill | When to use |
|---|---|
| `/pr-test` | E2E testing with docker compose, agent-browser, API calls — use when you have a running workspace |
| `/pr-review` | Review for correctness, security, code quality — use before requesting human review |
| `/pr-address` | Address reviewer comments and loop until CI green — use after reviews come in |
## Step 6: Post-creation
After the PR is created and review is triggered:
- Share the PR URL with the user
- If waiting on the review bot, let the user know the expected wait time (~30 min)
- Do not merge without human approval

View File

@@ -17,14 +17,6 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
gh pr view {N}
```
## Read the PR description
Understand the **Why / What / How** before addressing comments — you need context to make good fixes:
```bash
gh pr view {N} --json body --jq '.body'
```
## Fetch comments (all sources)
### 1. Inline review threads — GraphQL (primary source of actionable items)
@@ -95,28 +87,6 @@ Address comments **one at a time**: fix → commit → push → inline reply →
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
## Codecov coverage
Codecov patch target is **80%** on changed lines. Checks are **informational** (not blocking) but should be green.
### Running coverage locally
**Backend** (from `autogpt_platform/backend/`):
```bash
poetry run pytest -s -vv --cov=backend --cov-branch --cov-report term-missing
```
**Frontend** (from `autogpt_platform/frontend/`):
```bash
pnpm vitest run --coverage
```
### When codecov/patch fails
1. Find uncovered files: `git diff --name-only $(gh pr view --json baseRefName --jq '.baseRefName')...HEAD`
2. For each uncovered file — extract inline logic to `helpers.ts`/`helpers.py` and test those (highest ROI). Colocate tests as `*_test.py` (backend) or `__tests__/*.test.ts` (frontend).
3. Run coverage locally to verify, commit, push.
## Format and commit
After fixing, format the changed code:
@@ -135,9 +105,7 @@ kill $REST_PID 2>/dev/null; trap - EXIT
```
Never manually edit files in `src/app/api/__generated__/`.
Then commit and **push immediately** — never batch commits without pushing. Each fix should be visible on GitHub right away so CI can start and reviewers can see progress.
**Never push empty commits** (`git commit --allow-empty`) to re-trigger CI or bot checks. When a check fails, investigate the root cause (unchecked PR checklist, unaddressed review comments, code issues) and fix those directly. Empty commits add noise to git history.
Then commit and **push immediately** — never batch commits without pushing.
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).

View File

@@ -17,16 +17,6 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
gh pr view {N}
```
## Read the PR description
Before reading code, understand the **why**, **what**, and **how** from the PR description:
```bash
gh pr view {N} --json body --jq '.body'
```
Every PR should have a Why / What / How structure. If any of these are missing, note it as feedback.
## Read the diff
```bash
@@ -44,8 +34,6 @@ gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
## What to check
**Description quality:** Does the PR description cover Why (motivation/problem), What (summary of changes), and How (approach/implementation details)? If any are missing, request them — you can't judge the approach without understanding the problem and intent.
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).

View File

@@ -1,886 +0,0 @@
---
name: pr-test
description: "E2E manual testing of PRs/branches using docker compose, agent-browser, and API calls. TRIGGER when user asks to manually test a PR, test a feature end-to-end, or run integration tests against a running system."
user-invocable: true
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
metadata:
author: autogpt-team
version: "2.0.0"
---
# Manual E2E Test
Test a PR/branch end-to-end by building the full platform, interacting via browser and API, capturing screenshots, and reporting results.
## Critical Requirements
These are NON-NEGOTIABLE. Every test run MUST satisfy ALL the following:
### 1. Screenshots at Every Step
- Take a screenshot at EVERY significant test step — not just at the end
- Every test scenario MUST have at least one BEFORE and one AFTER screenshot
- Name screenshots sequentially: `{NN}-{action}-{state}.png` (e.g., `01-credits-before.png`, `02-credits-after.png`)
- If a screenshot is missing for a scenario, the test is INCOMPLETE — go back and take it
### 2. Screenshots MUST Be Posted to PR
- Push ALL screenshots to a temp branch `test-screenshots/pr-{N}`
- Post a PR comment with ALL screenshots embedded inline using GitHub raw URLs
- This is NOT optional — every test run MUST end with a PR comment containing screenshots
- If screenshot upload fails, retry. If it still fails, list failed files and require manual drag-and-drop/paste attachment in the PR comment
### 3. State Verification with Before/After Evidence
- For EVERY state-changing operation (API call, user action), capture the state BEFORE and AFTER
- Log the actual API response values (e.g., `credits_before=100, credits_after=95`)
- Screenshot MUST show the relevant UI state change
- Compare expected vs actual values explicitly — do not just eyeball it
### 4. Negative Test Cases Are Mandatory
- Test at least ONE negative case per feature (e.g., insufficient credits, invalid input, unauthorized access)
- Verify error messages are user-friendly and accurate
- Verify the system state did NOT change after a rejected operation
### 5. Test Report Must Include Full Evidence
Each test scenario in the report MUST have:
- **Steps**: What was done (exact commands or UI actions)
- **Expected**: What should happen
- **Actual**: What actually happened
- **API Evidence**: Before/after API response values for state-changing operations
- **Screenshot Evidence**: Before/after screenshots with explanations
## State Manipulation for Realistic Testing
When testing features that depend on specific states (rate limits, credits, quotas):
1. **Use Redis CLI to set counters directly:**
```bash
# Find the Redis container
REDIS_CONTAINER=$(docker ps --format '{{.Names}}' | grep redis | head -1)
# Set a key with expiry
docker exec $REDIS_CONTAINER redis-cli SET key value EX ttl
# Example: Set rate limit counter to near-limit
docker exec $REDIS_CONTAINER redis-cli SET "rate_limit:user:test@test.com" 99 EX 3600
# Example: Check current value
docker exec $REDIS_CONTAINER redis-cli GET "rate_limit:user:test@test.com"
```
2. **Use API calls to check before/after state:**
```bash
# BEFORE: Record current state
BEFORE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
echo "Credits BEFORE: $BEFORE"
# Perform the action...
# AFTER: Record new state and compare
AFTER=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
echo "Credits AFTER: $AFTER"
echo "Delta: $(( BEFORE - AFTER ))"
```
3. **Take screenshots BEFORE and AFTER state changes** — the UI must reflect the backend state change
4. **Never rely on mocked/injected browser state** — always use real backend state. Do NOT use `agent-browser eval` to fake UI state. The backend must be the source of truth.
5. **Use direct DB queries when needed:**
```bash
# Query via Supabase's PostgREST or docker exec into the DB
docker exec supabase-db psql -U supabase_admin -d postgres -c "SELECT credits FROM user_credits WHERE user_id = '...';"
```
6. **After every API test, verify the state change actually persisted:**
```bash
# Example: After a credits purchase, verify DB matches API
API_CREDITS=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
DB_CREDITS=$(docker exec supabase-db psql -U supabase_admin -d postgres -t -c "SELECT credits FROM user_credits WHERE user_id = '...';" | tr -d ' ')
[ "$API_CREDITS" = "$DB_CREDITS" ] && echo "CONSISTENT" || echo "MISMATCH: API=$API_CREDITS DB=$DB_CREDITS"
```
## Arguments
- `$ARGUMENTS` — worktree path (e.g. `$REPO_ROOT`) or PR number
- If `--fix` flag is present, auto-fix bugs found and push fixes (like pr-address loop)
## Step 0: Resolve the target
```bash
# If argument is a PR number, find its worktree
gh pr view {N} --json headRefName --jq '.headRefName'
# If argument is a path, use it directly
```
Determine:
- `REPO_ROOT` — the root repo directory: `git -C "$WORKTREE_PATH" worktree list | head -1 | awk '{print $1}'` (or `git rev-parse --show-toplevel` if not a worktree)
- `WORKTREE_PATH` — the worktree directory
- `PLATFORM_DIR` — `$WORKTREE_PATH/autogpt_platform`
- `BACKEND_DIR` — `$PLATFORM_DIR/backend`
- `FRONTEND_DIR` — `$PLATFORM_DIR/frontend`
- `PR_NUMBER` — the PR number (from `gh pr list --head $(git branch --show-current)`)
- `PR_TITLE` — the PR title, slugified (e.g. "Add copilot permissions" → "add-copilot-permissions")
- `RESULTS_DIR` — `$REPO_ROOT/test-results/PR-{PR_NUMBER}-{slugified-title}`
Create the results directory:
```bash
PR_NUMBER=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')
PR_TITLE=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json title --jq '.[0].title' | tr '[:upper:]' '[:lower:]' | sed 's/[^a-z0-9]/-/g' | sed 's/--*/-/g' | sed 's/^-//;s/-$//' | head -c 50)
RESULTS_DIR="$REPO_ROOT/test-results/PR-${PR_NUMBER}-${PR_TITLE}"
mkdir -p $RESULTS_DIR
```
**Test user credentials** (for logging into the UI or verifying results manually):
- Email: `test@test.com`
- Password: `testtest123`
## Step 1: Understand the PR
Before testing, understand what changed:
```bash
cd $WORKTREE_PATH
# Read PR description to understand the WHY
gh pr view {N} --json body --jq '.body'
git log --oneline dev..HEAD | head -20
git diff dev --stat
```
Read the PR description (Why / What / How) and changed files to understand:
0. **Why** does this PR exist? What problem does it solve?
1. **What** feature/fix does this PR implement?
2. **How** does it work? What's the approach?
3. What components are affected? (backend, frontend, copilot, executor, etc.)
4. What are the key user-facing behaviors to test?
## Step 2: Write test scenarios
Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
```markdown
# Test Plan: PR #{N} — {title}
## Scenarios
1. [Scenario name] — [what to verify]
2. ...
## API Tests (if applicable)
1. [Endpoint] — [expected behavior]
- Before state: [what to check before]
- After state: [what to verify changed]
## UI Tests (if applicable)
1. [Page/component] — [interaction to test]
- Screenshot before: [what to capture]
- Screenshot after: [what to capture]
## Negative Tests (REQUIRED — at least one per feature)
1. [What should NOT happen] — [how to trigger it]
- Expected error: [what error message/code]
- State unchanged: [what to verify did NOT change]
```
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
## Step 3: Environment setup
### 3a. Copy .env files from the root worktree
The root worktree (`$REPO_ROOT`) has the canonical `.env` files with all API keys. Copy them to the target worktree:
```bash
# CRITICAL: .env files are NOT checked into git. They must be copied manually.
cp $REPO_ROOT/autogpt_platform/.env $PLATFORM_DIR/.env
cp $REPO_ROOT/autogpt_platform/backend/.env $BACKEND_DIR/.env
cp $REPO_ROOT/autogpt_platform/frontend/.env $FRONTEND_DIR/.env
```
### 3b. Configure copilot authentication
The copilot needs an LLM API to function. Two approaches (try subscription first):
#### Option 1: Subscription mode (preferred — uses your Claude Max/Pro subscription)
The `claude_agent_sdk` Python package **bundles its own Claude CLI binary** — no need to install `@anthropic-ai/claude-code` via npm. The backend auto-provisions credentials from environment variables on startup.
Run the helper script to extract tokens from your host and auto-update `backend/.env` (works on macOS, Linux, and Windows/WSL):
```bash
# Extracts OAuth tokens and writes CLAUDE_CODE_OAUTH_TOKEN + CLAUDE_CODE_REFRESH_TOKEN into .env
bash $BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env
```
**How it works:** The script reads the OAuth token from:
- **macOS**: system keychain (`"Claude Code-credentials"`)
- **Linux/WSL**: `~/.claude/.credentials.json`
- **Windows**: `%APPDATA%/claude/.credentials.json`
It sets `CLAUDE_CODE_OAUTH_TOKEN`, `CLAUDE_CODE_REFRESH_TOKEN`, and `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` in the `.env` file. On container startup, the backend auto-provisions `~/.claude/.credentials.json` inside the container from these env vars. The SDK's bundled CLI then authenticates using that file. No `claude login`, no npm install needed.
**Note:** The OAuth token expires (~24h). If copilot returns auth errors, re-run the script and restart: `$BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env && docker compose up -d copilot_executor`
#### Option 2: OpenRouter API key mode (fallback)
If subscription mode doesn't work, switch to API key mode using OpenRouter:
```bash
# In $BACKEND_DIR/.env, ensure these are set:
CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false
CHAT_API_KEY=<value of OPEN_ROUTER_API_KEY from the same .env>
CHAT_BASE_URL=https://openrouter.ai/api/v1
CHAT_USE_CLAUDE_AGENT_SDK=true
```
Use `sed` to update these values:
```bash
ORKEY=$(grep "^OPEN_ROUTER_API_KEY=" $BACKEND_DIR/.env | cut -d= -f2)
[ -n "$ORKEY" ] || { echo "ERROR: OPEN_ROUTER_API_KEY is missing in $BACKEND_DIR/.env"; exit 1; }
perl -i -pe 's/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false/' $BACKEND_DIR/.env
# Add or update CHAT_API_KEY and CHAT_BASE_URL
grep -q "^CHAT_API_KEY=" $BACKEND_DIR/.env && perl -i -pe "s|^CHAT_API_KEY=.*|CHAT_API_KEY=$ORKEY|" $BACKEND_DIR/.env || echo "CHAT_API_KEY=$ORKEY" >> $BACKEND_DIR/.env
grep -q "^CHAT_BASE_URL=" $BACKEND_DIR/.env && perl -i -pe 's|^CHAT_BASE_URL=.*|CHAT_BASE_URL=https://openrouter.ai/api/v1|' $BACKEND_DIR/.env || echo "CHAT_BASE_URL=https://openrouter.ai/api/v1" >> $BACKEND_DIR/.env
```
### 3c. Stop conflicting containers
```bash
# Stop any running app containers (keep infra: supabase, redis, rabbitmq, clamav)
docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocket|database_manager|scheduler|notification|frontend|migrate" | while read name; do
docker stop "$name" 2>/dev/null
done
```
### 3e. Build and start
```bash
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker build failed"; exit 1; fi
cd $PLATFORM_DIR && docker compose up -d 2>&1 | tail -20
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker compose up failed"; exit 1; fi
```
**Note:** If the container appears to be running old code (e.g. missing PR changes), use `docker compose build --no-cache` to force a full rebuild. Docker BuildKit may sometimes reuse cached `COPY` layers from a previous build on a different branch.
**Expected time: 3-8 minutes** for build, 5-10 minutes with `--no-cache`.
### 3f. Wait for services to be ready
```bash
# Poll until backend and frontend respond
for i in $(seq 1 60); do
BACKEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8006/docs 2>/dev/null)
FRONTEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:3000 2>/dev/null)
if [ "$BACKEND" = "200" ] && [ "$FRONTEND" = "200" ]; then
echo "Services ready"
break
fi
sleep 5
done
```
### 3h. Create test user and get auth token
```bash
ANON_KEY=$(grep "NEXT_PUBLIC_SUPABASE_ANON_KEY=" $FRONTEND_DIR/.env | sed 's/.*NEXT_PUBLIC_SUPABASE_ANON_KEY=//' | tr -d '[:space:]')
# Signup (idempotent — returns "User already registered" if exists)
RESULT=$(curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
-H "apikey: $ANON_KEY" \
-H 'Content-Type: application/json' \
-d '{"email":"test@test.com","password":"testtest123"}')
# If "Database error finding user", restart supabase-auth and retry
if echo "$RESULT" | grep -q "Database error"; then
docker restart supabase-auth && sleep 5
curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
-H "apikey: $ANON_KEY" \
-H 'Content-Type: application/json' \
-d '{"email":"test@test.com","password":"testtest123"}'
fi
# Get auth token
TOKEN=$(curl -s -X POST 'http://localhost:8000/auth/v1/token?grant_type=password' \
-H "apikey: $ANON_KEY" \
-H 'Content-Type: application/json' \
-d '{"email":"test@test.com","password":"testtest123"}' | jq -r '.access_token // ""')
```
**Use this token for ALL API calls:**
```bash
curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/...
```
## Step 4: Run tests
### Service ports reference
| Service | Port | URL |
|---------|------|-----|
| Frontend | 3000 | http://localhost:3000 |
| Backend REST | 8006 | http://localhost:8006 |
| Supabase Auth (via Kong) | 8000 | http://localhost:8000 |
| Executor | 8002 | http://localhost:8002 |
| Copilot Executor | 8008 | http://localhost:8008 |
| WebSocket | 8001 | http://localhost:8001 |
| Database Manager | 8005 | http://localhost:8005 |
| Redis | 6379 | localhost:6379 |
| RabbitMQ | 5672 | localhost:5672 |
### API testing
Use `curl` with the auth token for backend API tests. **For EVERY API call that changes state, record before/after values:**
```bash
# Example: List agents
curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/graphs | jq . | head -20
# Example: Create an agent
curl -s -X POST http://localhost:8006/api/graphs \
-H "Authorization: Bearer $TOKEN" \
-H 'Content-Type: application/json' \
-d '{...}' | jq .
# Example: Run an agent
curl -s -X POST "http://localhost:8006/api/graphs/{graph_id}/execute" \
-H "Authorization: Bearer $TOKEN" \
-H 'Content-Type: application/json' \
-d '{"data": {...}}'
# Example: Get execution results
curl -s -H "Authorization: Bearer $TOKEN" \
"http://localhost:8006/api/graphs/{graph_id}/executions/{exec_id}" | jq .
```
**State verification pattern (use for EVERY state-changing API call):**
```bash
# 1. Record BEFORE state
BEFORE_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
echo "BEFORE: $BEFORE_STATE"
# 2. Perform the action
ACTION_RESULT=$(curl -s -X POST ... | jq .)
echo "ACTION RESULT: $ACTION_RESULT"
# 3. Record AFTER state
AFTER_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
echo "AFTER: $AFTER_STATE"
# 4. Log the comparison
echo "=== STATE CHANGE VERIFICATION ==="
echo "Before: $BEFORE_STATE"
echo "After: $AFTER_STATE"
echo "Expected change: {describe what should have changed}"
```
### Browser testing with agent-browser
```bash
# Close any existing session
agent-browser close 2>/dev/null || true
# Use --session-name to persist cookies across navigations
# This means login only needs to happen once per test session
agent-browser --session-name pr-test open 'http://localhost:3000/login' --timeout 15000
# Get interactive elements
agent-browser --session-name pr-test snapshot | grep "textbox\|button"
# Login
agent-browser --session-name pr-test fill {email_ref} "test@test.com"
agent-browser --session-name pr-test fill {password_ref} "testtest123"
agent-browser --session-name pr-test click {login_button_ref}
sleep 5
# Dismiss cookie banner if present
agent-browser --session-name pr-test click 'text=Accept All' 2>/dev/null || true
# Navigate — cookies are preserved so login persists
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
# Take screenshot
agent-browser --session-name pr-test screenshot $RESULTS_DIR/01-page.png
# Interact with elements
agent-browser --session-name pr-test fill {ref} "text"
agent-browser --session-name pr-test press "Enter"
agent-browser --session-name pr-test click {ref}
agent-browser --session-name pr-test click 'text=Button Text'
# Read page content
agent-browser --session-name pr-test snapshot | grep "text:"
```
**Key pages:**
- `/copilot` — CoPilot chat (for testing copilot features)
- `/build` — Agent builder (for testing block/node features)
- `/build?flowID={id}` — Specific agent in builder
- `/library` — Agent library (for testing listing/import features)
- `/library/agents/{id}` — Agent detail with run history
- `/marketplace` — Marketplace
### Checking logs
```bash
# Backend REST server
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
# Executor (runs agent graphs)
docker logs autogpt_platform-executor-1 2>&1 | tail -30
# Copilot executor (runs copilot chat sessions)
docker logs autogpt_platform-copilot_executor-1 2>&1 | tail -30
# Frontend
docker logs autogpt_platform-frontend-1 2>&1 | tail -30
# Filter for errors
docker logs autogpt_platform-executor-1 2>&1 | grep -i "error\|exception\|traceback" | tail -20
```
### Copilot chat testing
The copilot uses SSE streaming. To test via API:
```bash
# Create a session
SESSION_ID=$(curl -s -X POST 'http://localhost:8006/api/chat/sessions' \
-H "Authorization: Bearer $TOKEN" \
-H 'Content-Type: application/json' \
-d '{}' | jq -r '.id // .session_id // ""')
# Stream a message (SSE - will stream chunks)
curl -N -X POST "http://localhost:8006/api/chat/sessions/$SESSION_ID/stream" \
-H "Authorization: Bearer $TOKEN" \
-H 'Content-Type: application/json' \
-d '{"message": "Hello, what can you help me with?"}' \
--max-time 60 2>/dev/null | head -50
```
Or test via browser (preferred for UI verification):
```bash
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
# ... fill chat input and press Enter, wait 20-30s for response
```
## Step 5: Record results and take screenshots
**Take a screenshot at EVERY significant test step** — before and after interactions, on success, and on failure. This is NON-NEGOTIABLE.
**Required screenshot pattern for each test scenario:**
```bash
# BEFORE the action
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-before.png
# Perform the action...
# AFTER the action
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-after.png
```
**Naming convention:**
```bash
# Examples:
# $RESULTS_DIR/01-login-page-before.png
# $RESULTS_DIR/02-login-page-after.png
# $RESULTS_DIR/03-credits-page-before.png
# $RESULTS_DIR/04-credits-purchase-after.png
# $RESULTS_DIR/05-negative-insufficient-credits.png
# $RESULTS_DIR/06-error-state.png
```
**Minimum requirements:**
- At least TWO screenshots per test scenario (before + after)
- At least ONE screenshot for each negative test case showing the error state
- If a test fails, screenshot the failure state AND any error logs visible in the UI
## Step 6: Show results to user with screenshots
**CRITICAL: After all tests complete, you MUST show every screenshot to the user using the Read tool, with an explanation of what each screenshot shows.** This is the most important part of the test report — the user needs to visually verify the results.
For each screenshot:
1. Use the `Read` tool to display the PNG file (Claude can read images)
2. Write a 1-2 sentence explanation below it describing:
- What page/state is being shown
- What the screenshot proves (which test scenario it validates)
- Any notable details visible in the UI
Format the output like this:
```markdown
### Screenshot 1: {descriptive title}
[Read the PNG file here]
**What it shows:** {1-2 sentence explanation of what this screenshot proves}
---
```
After showing all screenshots, output a **detailed** summary table:
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|---|----------|--------|-------------|-------------------|
| 1 | {name} | PASS/FAIL | Before: X, After: Y | 01-before.png, 02-after.png |
| 2 | ... | ... | ... | ... |
**IMPORTANT:** As you show each screenshot and record test results, persist them in shell variables for Step 7:
```bash
# Build these variables during Step 6 — they are required by Step 7's script
# NOTE: declare -A requires Bash 4.0+. This is standard on modern systems (macOS ships zsh
# but Homebrew bash is 5.x; Linux typically has bash 5.x). If running on Bash <4, use a
# plain variable with a lookup function instead.
declare -A SCREENSHOT_EXPLANATIONS=(
# Each explanation MUST answer three things:
# 1. FLOW: Which test scenario / user journey is this part of?
# 2. STEPS: What exact actions were taken to reach this state?
# 3. EVIDENCE: What does this screenshot prove (pass/fail/data)?
#
# Good example:
# ["03-cost-log-after-run.png"]="Flow: LLM block cost tracking. Steps: Logged in as tester@gmail.com → ran 'Cost Test Agent' → waited for COMPLETED status. Evidence: PlatformCostLog table shows 1 new row with cost_microdollars=1234 and correct user_id."
#
# Bad example (too vague — never do this):
# ["03-cost-log.png"]="Shows the cost log table."
["01-login-page.png"]="Flow: Login flow. Steps: Opened /login. Evidence: Login page renders with email/password fields and SSO options visible."
["02-builder-with-block.png"]="Flow: Block execution. Steps: Logged in → /build → added LLM block. Evidence: Builder canvas shows block connected to trigger, ready to run."
# ... one entry per screenshot using the flow/steps/evidence format above
)
TEST_RESULTS_TABLE="| 1 | Login flow | PASS | N/A | 01-login-before.png, 02-login-after.png |
| 2 | Credits purchase | PASS | Before: 100, After: 95 | 03-credits-before.png, 04-credits-after.png |
| 3 | Insufficient credits (negative) | PASS | Credits: 0, rejected | 05-insufficient-credits-error.png |"
# ... one row per test scenario with actual results
```
## Step 7: Post test report as PR comment with screenshots
Upload screenshots to the PR using the GitHub Git API (no local git operations — safe for worktrees), then post a comment with inline images and per-screenshot explanations.
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
> **CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.**
> Every screenshot MUST appear as `![name](raw_url)` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
```bash
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
REPO="Significant-Gravitas/AutoGPT"
SCREENSHOTS_BRANCH="test-screenshots/pr-${PR_NUMBER}"
SCREENSHOTS_DIR="test-screenshots/PR-${PR_NUMBER}"
# Step 1: Create blobs for each screenshot and build tree JSON
# Retry each blob upload up to 3 times. If still failing, list them at end of report.
shopt -s nullglob
SCREENSHOT_FILES=("$RESULTS_DIR"/*.png)
if [ ${#SCREENSHOT_FILES[@]} -eq 0 ]; then
echo "ERROR: No screenshots found in $RESULTS_DIR. Test run is incomplete."
exit 1
fi
TREE_JSON='['
FIRST=true
FAILED_UPLOADS=()
for img in "${SCREENSHOT_FILES[@]}"; do
BASENAME=$(basename "$img")
B64=$(base64 < "$img")
BLOB_SHA=""
for attempt in 1 2 3; do
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha' 2>/dev/null || true)
[ -n "$BLOB_SHA" ] && break
sleep 1
done
if [ -z "$BLOB_SHA" ]; then
FAILED_UPLOADS+=("$img")
continue
fi
if [ "$FIRST" = true ]; then FIRST=false; else TREE_JSON+=','; fi
TREE_JSON+="{\"path\":\"${SCREENSHOTS_DIR}/${BASENAME}\",\"mode\":\"100644\",\"type\":\"blob\",\"sha\":\"${BLOB_SHA}\"}"
done
TREE_JSON+=']'
# Step 2: Create tree, commit (with parent), and branch ref
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
# Resolve existing branch tip as parent (avoids orphan commits on repeat runs)
PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || true)
if [ -n "$PARENT_SHA" ]; then
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
-f "parents[]=$PARENT_SHA" \
--jq '.sha')
else
# First commit on this branch — no parent
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
--jq '.sha')
fi
gh api "repos/${REPO}/git/refs" \
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
-f sha="$COMMIT_SHA" 2>/dev/null \
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
-X PATCH -f sha="$COMMIT_SHA" -f force=true
```
Then post the comment with **inline images AND explanations for each screenshot**:
```bash
REPO_URL="https://raw.githubusercontent.com/${REPO}/${SCREENSHOTS_BRANCH}"
# Build image markdown using uploaded image URLs; skip FAILED_UPLOADS (listed separately)
IMAGE_MARKDOWN=""
for img in "${SCREENSHOT_FILES[@]}"; do
BASENAME=$(basename "$img")
TITLE=$(echo "${BASENAME%.png}" | sed 's/^[0-9]*-//' | sed 's/-/ /g' | awk '{for(i=1;i<=NF;i++) $i=toupper(substr($i,1,1)) tolower(substr($i,2))}1')
# Skip images that failed to upload — they will be listed at the end
IS_FAILED=false
for failed in "${FAILED_UPLOADS[@]}"; do
[ "$(basename "$failed")" = "$BASENAME" ] && IS_FAILED=true && break
done
if [ "$IS_FAILED" = true ]; then
continue
fi
EXPLANATION="${SCREENSHOT_EXPLANATIONS[$BASENAME]}"
if [ -z "$EXPLANATION" ]; then
echo "ERROR: Missing screenshot explanation for $BASENAME. Add it to SCREENSHOT_EXPLANATIONS in Step 6."
exit 1
fi
IMAGE_MARKDOWN="${IMAGE_MARKDOWN}
### ${TITLE}
![${BASENAME}](${REPO_URL}/${SCREENSHOTS_DIR}/${BASENAME})
${EXPLANATION}
"
done
# Write comment body to file to avoid shell interpretation issues with special characters
COMMENT_FILE=$(mktemp)
# If any uploads failed, append a section listing them with instructions
FAILED_SECTION=""
if [ ${#FAILED_UPLOADS[@]} -gt 0 ]; then
FAILED_SECTION="
## ⚠️ Failed Screenshot Uploads
The following screenshots could not be uploaded via the GitHub API after 3 retries.
**To add them:** drag-and-drop or paste these files into a PR comment manually:
"
for failed in "${FAILED_UPLOADS[@]}"; do
FAILED_SECTION="${FAILED_SECTION}
- \`$(basename "$failed")\` (local path: \`$failed\`)"
done
FAILED_SECTION="${FAILED_SECTION}
**Run status:** INCOMPLETE until the files above are manually attached and visible inline in the PR."
fi
cat > "$COMMENT_FILE" <<INNEREOF
## E2E Test Report
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|---|----------|--------|-------------|-------------------|
${TEST_RESULTS_TABLE}
${IMAGE_MARKDOWN}
${FAILED_SECTION}
INNEREOF
POSTED_BODY=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE" --jq '.body')
rm -f "$COMMENT_FILE"
```
**The PR comment MUST include:**
1. A summary table of all scenarios with PASS/FAIL and before/after API evidence
2. Every successfully uploaded screenshot rendered inline; any failed uploads listed with manual attachment instructions
3. A structured explanation below each screenshot covering: **Flow** (which scenario), **Steps** (exact actions taken to reach this state), **Evidence** (what this proves — pass/fail/data values). A bare "shows the page" caption is not acceptable.
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
**Verify inline rendering after posting — this is required, not optional:**
```bash
# 1. Confirm the posted comment body contains inline image markdown syntax
if ! echo "$POSTED_BODY" | grep -q '!\['; then
echo "❌ FAIL: No inline image tags in posted comment body. Re-check IMAGE_MARKDOWN and re-post."
exit 1
fi
# 2. Verify at least one raw URL actually resolves (catches wrong branch name, wrong path, etc.)
FIRST_IMG_URL=$(echo "$POSTED_BODY" | grep -o 'https://raw.githubusercontent.com[^)]*' | head -1)
if [ -n "$FIRST_IMG_URL" ]; then
HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" --max-time 10 "$FIRST_IMG_URL")
if [ "$HTTP_STATUS" = "200" ]; then
echo "✅ Inline images confirmed and raw URL resolves (HTTP 200)"
else
echo "❌ FAIL: Raw image URL returned HTTP $HTTP_STATUS — images will not render inline."
echo " URL: $FIRST_IMG_URL"
echo " Check branch name, path, and that the push succeeded."
exit 1
fi
else
echo "⚠️ Could not extract a raw URL from the comment — verify manually."
fi
```
## Step 8: Evaluate test completeness and post a GitHub review
After posting the PR comment, evaluate whether the test run actually covered everything it needed to. This is NOT a rubber-stamp — be critical. Then post a formal GitHub review so the PR author and reviewers can see the verdict.
### 8a. Evaluate against the test plan
Re-read `$RESULTS_DIR/test-plan.md` (written in Step 2) and `$RESULTS_DIR/test-report.md` (written in Step 5). For each scenario in the plan, answer:
> **Note:** `test-report.md` is written in Step 5. If it doesn't exist, write it before proceeding here — see the Step 5 template. Do not skip evaluation because the file is missing; create it from your notes instead.
| Question | Pass criteria |
|----------|--------------|
| Was it tested? | Explicit steps were executed, not just described |
| Is there screenshot evidence? | At least one before/after screenshot per scenario |
| Did the core feature work correctly? | Expected state matches actual state |
| Were negative cases tested? | At least one failure/rejection case per feature |
| Was DB/API state verified (not just UI)? | Raw API response or DB query confirms state change |
Build a verdict:
- **APPROVE** — every scenario tested, evidence present, no bugs found or all bugs are minor/known
- **REQUEST_CHANGES** — one or more: untested scenarios, missing evidence, bugs found, data not verified
### 8b. Post the GitHub review
```bash
EVAL_FILE=$(mktemp)
# === STEP A: Write header ===
cat > "$EVAL_FILE" << 'ENDEVAL'
## 🧪 Test Evaluation
### Coverage checklist
ENDEVAL
# === STEP B: Append ONE line per scenario — do this BEFORE calculating verdict ===
# Format: "- ✅ **Scenario N name**: <what was done and verified>"
# or "- ❌ **Scenario N name**: <what is missing or broken>"
# Examples:
# echo "- ✅ **Scenario 1 Login flow**: tested, screenshot evidence present, auth token verified via API" >> "$EVAL_FILE"
# echo "- ❌ **Scenario 3 Cost logging**: NOT verified in DB — UI showed entry but raw SQL query was skipped" >> "$EVAL_FILE"
#
# !!! IMPORTANT: append ALL scenario lines here before proceeding to STEP C !!!
# === STEP C: Derive verdict from the checklist — runs AFTER all lines are appended ===
FAIL_COUNT=$(grep -c "^- ❌" "$EVAL_FILE" || true)
if [ "$FAIL_COUNT" -eq 0 ]; then
VERDICT="APPROVE"
else
VERDICT="REQUEST_CHANGES"
fi
# === STEP D: Append verdict section ===
cat >> "$EVAL_FILE" << ENDVERDICT
### Verdict
ENDVERDICT
if [ "$VERDICT" = "APPROVE" ]; then
echo "✅ All scenarios covered with evidence. No blocking issues found." >> "$EVAL_FILE"
else
echo "$FAIL_COUNT scenario(s) incomplete or have confirmed bugs. See ❌ items above." >> "$EVAL_FILE"
echo "" >> "$EVAL_FILE"
echo "**Required before merge:** address each ❌ item above." >> "$EVAL_FILE"
fi
# === STEP E: Post the review ===
gh api "repos/${REPO}/pulls/$PR_NUMBER/reviews" \
--method POST \
-f body="$(cat "$EVAL_FILE")" \
-f event="$VERDICT"
rm -f "$EVAL_FILE"
```
**Rules:**
- Never auto-approve without checking every scenario in the test plan
- `REQUEST_CHANGES` if ANY scenario is untested, lacks DB/API evidence, or has a confirmed bug
- The evaluation body must list every scenario explicitly (✅ or ❌) — not just the failures
- If you find new bugs during evaluation, add them to the request-changes body and (if `--fix` flag is set) fix them before posting
## Fix mode (--fix flag)
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.
### Fix protocol for EVERY issue found (including UX issues):
1. **Identify** the root cause in the code — read the relevant source files
2. **Write a failing test first** (TDD): For backend bugs, write a test marked with `pytest.mark.xfail(reason="...")`. For frontend/Playwright bugs, write a test with `.fixme` annotation. Run it to confirm it fails as expected.
3. **Screenshot** the broken state: `agent-browser screenshot $RESULTS_DIR/{NN}-broken-{description}.png`
4. **Fix** the code in the worktree
5. **Rebuild** ONLY the affected service (not the whole stack):
```bash
cd $PLATFORM_DIR && docker compose up --build -d {service_name}
# e.g., docker compose up --build -d rest_server
# e.g., docker compose up --build -d frontend
```
6. **Wait** for the service to be ready (poll health endpoint)
7. **Re-test** the same scenario
8. **Screenshot** the fixed state: `agent-browser screenshot $RESULTS_DIR/{NN}-fixed-{description}.png`
9. **Remove the xfail/fixme marker** from the test written in step 2, and verify it passes
10. **Verify** the fix did not break other scenarios (run a quick smoke test)
11. **Commit and push** immediately:
```bash
cd $WORKTREE_PATH
git add -A
git commit -m "fix: {description of fix}"
git push
```
12. **Continue** to the next test scenario
### Fix loop (like pr-address)
```text
test scenario → find issue (bug OR UX problem) → screenshot broken state
→ fix code → rebuild affected service only → re-test → screenshot fixed state
→ verify no regressions → commit + push
→ repeat for next scenario
→ after ALL scenarios pass, run full re-test to verify everything together
```
**Key differences from non-fix mode:**
- UX issues count as bugs — fix them (bad alignment, confusing labels, missing loading states)
- Every fix MUST have a before/after screenshot pair proving it works
- Commit after EACH fix, not in a batch at the end
- The final re-test must produce a clean set of all-passing screenshots
## Known issues and workarounds
### Problem: "Database error finding user" on signup
**Cause:** Supabase auth service schema cache is stale after migration.
**Fix:** `docker restart supabase-auth && sleep 5` then retry signup.
### Problem: Copilot returns auth errors in subscription mode
**Cause:** `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` but `CLAUDE_CODE_OAUTH_TOKEN` is not set or expired.
**Fix:** Re-extract the OAuth token from macOS keychain (see step 3b, Option 1) and recreate the container (`docker compose up -d copilot_executor`). The backend auto-provisions `~/.claude/.credentials.json` from the env var on startup. No `npm install` or `claude login` needed — the SDK bundles its own CLI binary.
### Problem: agent-browser can't find chromium
**Cause:** The Dockerfile auto-provisions system chromium on all architectures (including ARM64). If your branch is behind `dev`, this may not be present yet.
**Fix:** Check if chromium exists: `which chromium || which chromium-browser`. If missing, install it: `apt-get install -y chromium` and set `AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium` in the container environment.
### Problem: agent-browser selector matches multiple elements
**Cause:** `text=X` matches all elements containing that text.
**Fix:** Use `agent-browser snapshot` to get specific `ref=eNN` references, then use those: `agent-browser click eNN`.
### Problem: Frontend shows cookie banner blocking interaction
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
### Problem: Container loses npm packages after rebuild
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
### Problem: Services not starting after `docker compose up`
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.
### Problem: Docker uses cached layers with old code (PR changes not visible)
**Cause:** `docker compose up --build` reuses cached `COPY` layers from previous builds. If the PR branch changes Python files but the previous build already cached that layer from `dev`, the container runs `dev` code.
**Fix:** Always use `docker compose build --no-cache` for the first build of a PR branch. Subsequent rebuilds within the same branch can use `--build`.
### Problem: `agent-browser open` loses login session
**Cause:** Without session persistence, `agent-browser open` starts fresh.
**Fix:** Use `--session-name pr-test` on ALL agent-browser commands. This auto-saves/restores cookies and localStorage across navigations. Alternatively, use `agent-browser eval "window.location.href = '...'"` to navigate within the same context.
### Problem: Supabase auth returns "Database error querying schema"
**Cause:** The database schema changed (migration ran) but supabase-auth has a stale schema cache.
**Fix:** `docker restart supabase-db && sleep 10 && docker restart supabase-auth && sleep 8`. If user data was lost, re-signup.

View File

@@ -1,195 +0,0 @@
---
name: setup-repo
description: Initialize a worktree-based repo layout for parallel development. Creates a main worktree, a reviews worktree for PR reviews, and N numbered work branches. Handles .env creation, dependency installation, and branchlet config. TRIGGER when user asks to set up the repo from scratch, initialize worktrees, bootstrap their dev environment, "setup repo", "setup worktrees", "initialize dev environment", "set up branches", or when a freshly cloned repo has no sibling worktrees.
user-invocable: true
args: "No arguments — interactive setup via prompts."
metadata:
author: autogpt-team
version: "1.0.0"
---
# Repository Setup
This skill sets up a worktree-based development layout from a freshly cloned repo. It creates:
- A **main** worktree (the primary checkout)
- A **reviews** worktree (for PR reviews)
- **N work branches** (branch1..branchN) for parallel development
## Step 1: Identify the repo
Determine the repo root and parent directory:
```bash
ROOT=$(git rev-parse --show-toplevel)
REPO_NAME=$(basename "$ROOT")
PARENT=$(dirname "$ROOT")
```
Detect if the repo is already inside a worktree layout by counting sibling worktrees (not just checking the directory name, which could be anything):
```bash
# Count worktrees that are siblings (live under $PARENT but aren't $ROOT itself)
SIBLING_COUNT=$(git worktree list --porcelain 2>/dev/null | grep "^worktree " | grep -c "$PARENT/" || true)
if [ "$SIBLING_COUNT" -gt 1 ]; then
echo "INFO: Existing worktree layout detected at $PARENT ($SIBLING_COUNT worktrees)"
# Use $ROOT as-is; skip renaming/restructuring
else
echo "INFO: Fresh clone detected, proceeding with setup"
fi
```
## Step 2: Ask the user questions
Use AskUserQuestion to gather setup preferences:
1. **How many parallel work branches do you need?** (Options: 4, 8, 16, or custom)
- These become `branch1` through `branchN`
2. **Which branch should be the base?** (Options: origin/master, origin/dev, or custom)
- All work branches and reviews will start from this
## Step 3: Fetch and set up branches
```bash
cd "$ROOT"
git fetch origin
# Create the reviews branch from base (skip if already exists)
if git show-ref --verify --quiet refs/heads/reviews; then
echo "INFO: Branch 'reviews' already exists, skipping"
else
git branch reviews <base-branch>
fi
# Create numbered work branches from base (skip if already exists)
for i in $(seq 1 "$COUNT"); do
if git show-ref --verify --quiet "refs/heads/branch$i"; then
echo "INFO: Branch 'branch$i' already exists, skipping"
else
git branch "branch$i" <base-branch>
fi
done
```
## Step 4: Create worktrees
Create worktrees as siblings to the main checkout:
```bash
if [ -d "$PARENT/reviews" ]; then
echo "INFO: Worktree '$PARENT/reviews' already exists, skipping"
else
git worktree add "$PARENT/reviews" reviews
fi
for i in $(seq 1 "$COUNT"); do
if [ -d "$PARENT/branch$i" ]; then
echo "INFO: Worktree '$PARENT/branch$i' already exists, skipping"
else
git worktree add "$PARENT/branch$i" "branch$i"
fi
done
```
## Step 5: Set up environment files
**Do NOT assume .env files exist.** For each worktree (including main if needed):
1. Check if `.env` exists in the source worktree for each path
2. If `.env` exists, copy it
3. If only `.env.default` or `.env.example` exists, copy that as `.env`
4. If neither exists, warn the user and list which env files are missing
Env file locations to check (same as the `/worktree` skill — keep these in sync):
- `autogpt_platform/.env`
- `autogpt_platform/backend/.env`
- `autogpt_platform/frontend/.env`
> **Note:** This env copying logic intentionally mirrors the `/worktree` skill's approach. If you update the path list or fallback logic here, update `/worktree` as well.
```bash
SOURCE="$ROOT"
WORKTREES="reviews"
for i in $(seq 1 "$COUNT"); do WORKTREES="$WORKTREES branch$i"; done
FOUND_ANY_ENV=0
for wt in $WORKTREES; do
TARGET="$PARENT/$wt"
for envpath in autogpt_platform autogpt_platform/backend autogpt_platform/frontend; do
if [ -f "$SOURCE/$envpath/.env" ]; then
FOUND_ANY_ENV=1
cp "$SOURCE/$envpath/.env" "$TARGET/$envpath/.env"
elif [ -f "$SOURCE/$envpath/.env.default" ]; then
FOUND_ANY_ENV=1
cp "$SOURCE/$envpath/.env.default" "$TARGET/$envpath/.env"
echo "NOTE: $wt/$envpath/.env was created from .env.default — you may need to edit it"
elif [ -f "$SOURCE/$envpath/.env.example" ]; then
FOUND_ANY_ENV=1
cp "$SOURCE/$envpath/.env.example" "$TARGET/$envpath/.env"
echo "NOTE: $wt/$envpath/.env was created from .env.example — you may need to edit it"
else
echo "WARNING: No .env, .env.default, or .env.example found at $SOURCE/$envpath/"
fi
done
done
if [ "$FOUND_ANY_ENV" -eq 0 ]; then
echo "WARNING: No environment files or templates were found in the source worktree."
# Use AskUserQuestion to confirm: "Continue setup without env files?"
# If the user declines, stop here and let them set up .env files first.
fi
```
## Step 6: Copy branchlet config
Copy `.branchlet.json` from main to each worktree so branchlet can manage sub-worktrees:
```bash
if [ -f "$ROOT/.branchlet.json" ]; then
for wt in $WORKTREES; do
cp "$ROOT/.branchlet.json" "$PARENT/$wt/.branchlet.json"
done
fi
```
## Step 7: Install dependencies
Install deps in all worktrees. Run these sequentially per worktree:
```bash
for wt in $WORKTREES; do
TARGET="$PARENT/$wt"
echo "=== Installing deps for $wt ==="
(cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install) &&
(cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate) &&
(cd "$TARGET/autogpt_platform/frontend" && pnpm install) &&
echo "=== Done: $wt ===" ||
echo "=== FAILED: $wt ==="
done
```
This is slow. Run in background if possible and notify when complete.
## Step 8: Verify and report
After setup, verify and report to the user:
```bash
git worktree list
```
Summarize:
- Number of worktrees created
- Which env files were copied vs created from defaults vs missing
- Any warnings or errors encountered
## Final directory layout
```
parent/
main/ # Primary checkout (already exists)
reviews/ # PR review worktree
branch1/ # Work branch 1
branch2/ # Work branch 2
...
branchN/ # Work branch N
```

View File

@@ -1,224 +0,0 @@
---
name: write-frontend-tests
description: "Analyze the current branch diff against dev, plan integration tests for changed frontend pages/components, and write them. TRIGGER when user asks to write frontend tests, add test coverage, or 'write tests for my changes'."
user-invocable: true
args: "[base branch] — defaults to dev. Optionally pass a specific base branch to diff against."
metadata:
author: autogpt-team
version: "1.0.0"
---
# Write Frontend Tests
Analyze the current branch's frontend changes, plan integration tests, and write them.
## References
Before writing any tests, read the testing rules and conventions:
- `autogpt_platform/frontend/TESTING.md` — testing strategy, file locations, examples
- `autogpt_platform/frontend/src/tests/AGENTS.md` — detailed testing rules, MSW patterns, decision flowchart
- `autogpt_platform/frontend/src/tests/integrations/test-utils.tsx` — custom render with providers
- `autogpt_platform/frontend/src/tests/integrations/vitest.setup.tsx` — MSW server setup
## Step 1: Identify changed frontend files
```bash
BASE_BRANCH="${ARGUMENTS:-dev}"
cd autogpt_platform/frontend
# Get changed frontend files (excluding generated, config, and test files)
git diff "$BASE_BRANCH"...HEAD --name-only -- src/ \
| grep -v '__generated__' \
| grep -v '__tests__' \
| grep -v '\.test\.' \
| grep -v '\.stories\.' \
| grep -v '\.spec\.'
```
Also read the diff to understand what changed:
```bash
git diff "$BASE_BRANCH"...HEAD --stat -- src/
git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
```
## Step 2: Categorize changes and find test targets
For each changed file, determine:
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
2. **Is it a hook?** (`use*.ts`) — test via the page that uses it
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
**Priority order:**
1. Pages with new/changed data fetching or user interactions
2. Components with complex internal logic (modals, forms, wizards)
3. Hooks with non-trivial business logic
4. Pure helper functions
Skip: styling-only changes, type-only changes, config changes.
## Step 3: Check for existing tests
For each test target, check if tests already exist:
```bash
# For a page at src/app/(platform)/library/page.tsx
ls src/app/\(platform\)/library/__tests__/ 2>/dev/null
# For a component at src/app/(platform)/library/components/AgentCard/AgentCard.tsx
ls src/app/\(platform\)/library/components/AgentCard/__tests__/ 2>/dev/null
```
Note which targets have no tests (need new files) vs which have tests that need updating.
## Step 4: Identify API endpoints used
For each test target, find which API hooks are used:
```bash
# Find generated API hook imports in the changed files
grep -rn 'from.*__generated__/endpoints' src/app/\(platform\)/library/
grep -rn 'use[A-Z].*V[12]' src/app/\(platform\)/library/
```
For each API hook found, locate the corresponding MSW handler:
```bash
# If the page uses useGetV2ListLibraryAgents, find its MSW handlers
grep -rn 'getGetV2ListLibraryAgents.*Handler' src/app/api/__generated__/endpoints/library/library.msw.ts
```
List every MSW handler you will need (200 for happy path, 4xx for error paths).
## Step 5: Write the test plan
Before writing code, output a plan as a numbered list:
```
Test plan for [branch name]:
1. src/app/(platform)/library/__tests__/main.test.tsx (NEW)
- Renders page with agent list (MSW 200)
- Shows loading state
- Shows error state (MSW 422)
- Handles empty agent list
2. src/app/(platform)/library/__tests__/search.test.tsx (NEW)
- Filters agents by search query
- Shows no results message
- Clears search
3. src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx (UPDATE)
- Add test for new "duplicate" action
```
Present this plan to the user. Wait for confirmation before proceeding. If the user has feedback, adjust the plan.
## Step 6: Write the tests
For each test file in the plan, follow these conventions:
### File structure
```tsx
import { render, screen, waitFor } from "@/tests/integrations/test-utils";
import { server } from "@/mocks/mock-server";
// Import MSW handlers for endpoints the page uses
import {
getGetV2ListLibraryAgentsMockHandler200,
getGetV2ListLibraryAgentsMockHandler422,
} from "@/app/api/__generated__/endpoints/library/library.msw";
// Import the component under test
import LibraryPage from "../page";
describe("LibraryPage", () => {
test("renders agent list from API", async () => {
server.use(getGetV2ListLibraryAgentsMockHandler200());
render(<LibraryPage />);
expect(await screen.findByText(/my agents/i)).toBeDefined();
});
test("shows error state on API failure", async () => {
server.use(getGetV2ListLibraryAgentsMockHandler422());
render(<LibraryPage />);
expect(await screen.findByText(/error/i)).toBeDefined();
});
});
```
### Rules
- Use `render()` from `@/tests/integrations/test-utils` (NOT from `@testing-library/react` directly)
- Use `server.use()` to set up MSW handlers BEFORE rendering
- Use `findBy*` (async) for elements that appear after data fetching — NOT `getBy*`
- Use `getBy*` only for elements that are immediately present in the DOM
- Use `screen` queries — do NOT destructure from `render()`
- Use `waitFor` when asserting side effects or state changes after interactions
- Import `fireEvent` or `userEvent` from the test-utils for interactions
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
- Keep tests focused: one behavior per test
- Use descriptive test names that read like sentences
### Test location
```
# For pages: __tests__/ next to page.tsx
src/app/(platform)/library/__tests__/main.test.tsx
# For complex standalone components: __tests__/ inside component folder
src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx
# For pure helpers: co-located .test.ts
src/app/(platform)/library/helpers.test.ts
```
### Custom MSW overrides
When the auto-generated faker data is not enough, override with specific data:
```tsx
import { http, HttpResponse } from "msw";
server.use(
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
return HttpResponse.json({
agents: [
{ id: "1", name: "Test Agent", description: "A test agent" },
],
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
});
}),
);
```
Use the proxy URL pattern: `http://localhost:3000/api/proxy/api/v{version}/{path}` — this matches the MSW base URL configured in `orval.config.ts`.
## Step 7: Run and verify
After writing all tests:
```bash
cd autogpt_platform/frontend
pnpm test:unit --reporter=verbose
```
If tests fail:
1. Read the error output carefully
2. Fix the test (not the source code, unless there is a genuine bug)
3. Re-run until all pass
Then run the full checks:
```bash
pnpm format
pnpm lint
pnpm types
```

View File

@@ -1,12 +1,8 @@
### Why / What / How
<!-- Why: Why does this PR exist? What problem does it solve, or what's broken/missing without it? -->
<!-- What: What does this PR change? Summarize the changes at a high level. -->
<!-- How: How does it work? Describe the approach, key implementation details, or architecture decisions. -->
<!-- Clearly explain the need for these changes: -->
### Changes 🏗️
<!-- List the key changes. Keep it higher level than the diff but specific enough to highlight what's new/modified. -->
<!-- Concisely describe all of the changes made in this pull request: -->
### Checklist 📋

View File

@@ -6,19 +6,11 @@ on:
paths:
- '.github/workflows/classic-autogpt-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/direct_benchmark/**'
- 'classic/forge/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
pull_request:
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-autogpt-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/direct_benchmark/**'
- 'classic/forge/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
concurrency:
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -27,22 +19,47 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic
working-directory: classic/original_autogpt
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
steps:
- name: Start MinIO service
# Quite slow on macOS (2~4 minutes to set up Docker)
# - name: Set up Docker (macOS)
# if: runner.os == 'macOS'
# uses: crazy-max/ghaction-setup-docker@v3
- name: Start MinIO service (Linux)
if: runner.os == 'Linux'
working-directory: '.'
run: |
docker pull minio/minio:edge-cicd
docker run -d -p 9000:9000 minio/minio:edge-cicd
- name: Start MinIO service (macOS)
if: runner.os == 'macOS'
working-directory: ${{ runner.temp }}
run: |
brew install minio/stable/minio
mkdir data
minio server ./data &
# No MinIO on Windows:
# - Windows doesn't support running Linux Docker containers
# - It doesn't seem possible to start background processes on Windows. They are
# killed after the step returns.
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
- name: Checkout repository
uses: actions/checkout@v4
with:
@@ -54,23 +71,41 @@ jobs:
git config --global user.name "Auto-GPT-Bot"
git config --global user.email "github-bot@agpt.co"
- name: Set up Python 3.12
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: "3.12"
python-version: ${{ matrix.python-version }}
- id: get_date
name: Get date
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Python dependencies
run: poetry install
@@ -81,13 +116,12 @@ jobs:
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
--numprocesses=logical --durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
original_autogpt/tests/unit original_autogpt/tests/integration
tests/unit tests/integration
env:
CI: true
PLAIN_OUTPUT: True
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
S3_ENDPOINT_URL: http://127.0.0.1:9000
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
@@ -101,11 +135,11 @@ jobs:
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: autogpt-agent
flags: autogpt-agent,${{ runner.os }}
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: test-logs
path: classic/logs/
path: classic/original_autogpt/logs/

View File

@@ -148,7 +148,7 @@ jobs:
--entrypoint poetry ${{ env.IMAGE_NAME }} run \
pytest -v --cov=autogpt --cov-branch --cov-report term-missing \
--numprocesses=4 --durations=10 \
original_autogpt/tests/unit original_autogpt/tests/integration 2>&1 | tee test_output.txt
tests/unit tests/integration 2>&1 | tee test_output.txt
test_failure=${PIPESTATUS[0]}

View File

@@ -10,9 +10,10 @@ on:
- '.github/workflows/classic-autogpts-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- 'classic/benchmark/**'
- 'classic/run'
- 'classic/cli.py'
- 'classic/setup.py'
- '!**/*.md'
pull_request:
branches: [ master, dev, release-* ]
@@ -20,9 +21,10 @@ on:
- '.github/workflows/classic-autogpts-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- 'classic/benchmark/**'
- 'classic/run'
- 'classic/cli.py'
- 'classic/setup.py'
- '!**/*.md'
defaults:
@@ -33,9 +35,13 @@ defaults:
jobs:
serve-agent-protocol:
runs-on: ubuntu-latest
strategy:
matrix:
agent-name: [ original_autogpt ]
fail-fast: false
timeout-minutes: 20
env:
min-python-version: '3.12'
min-python-version: '3.10'
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -49,22 +55,22 @@ jobs:
python-version: ${{ env.min-python-version }}
- name: Install Poetry
working-directory: ./classic/${{ matrix.agent-name }}/
run: |
curl -sSL https://install.python-poetry.org | python -
- name: Install dependencies
run: poetry install
- name: Run smoke tests with direct-benchmark
- name: Run regression tests
run: |
poetry run direct-benchmark run \
--strategies one_shot \
--models claude \
--tests ReadFile,WriteFile \
--json
./run agent start ${{ matrix.agent-name }}
cd ${{ matrix.agent-name }}
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
poetry run agbenchmark --test=WriteFile
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
AGENT_NAME: ${{ matrix.agent-name }}
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
NONINTERACTIVE_MODE: "true"
CI: true
HELICONE_CACHE_ENABLED: false
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
TELEMETRY_ENVIRONMENT: autogpt-ci
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}

View File

@@ -1,24 +1,18 @@
name: Classic - Direct Benchmark CI
name: Classic - AGBenchmark CI
on:
push:
branches: [ master, dev, ci-test* ]
paths:
- 'classic/direct_benchmark/**'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- '!classic/benchmark/reports/**'
- .github/workflows/classic-benchmark-ci.yml
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
pull_request:
branches: [ master, dev, release-* ]
paths:
- 'classic/direct_benchmark/**'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- '!classic/benchmark/reports/**'
- .github/workflows/classic-benchmark-ci.yml
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
concurrency:
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -29,16 +23,23 @@ defaults:
shell: bash
env:
min-python-version: '3.12'
min-python-version: '3.10'
jobs:
benchmark-tests:
runs-on: ubuntu-latest
test:
permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
defaults:
run:
shell: bash
working-directory: classic
working-directory: classic/benchmark
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -46,88 +47,71 @@ jobs:
fetch-depth: 0
submodules: true
- name: Set up Python ${{ env.min-python-version }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ env.min-python-version }}
python-version: ${{ matrix.python-version }}
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
- name: Install Poetry
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
- name: Install dependencies
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Python dependencies
run: poetry install
- name: Run basic benchmark tests
- name: Run pytest with coverage
run: |
echo "Testing ReadFile challenge with one_shot strategy..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--tests ReadFile \
--json
echo "Testing WriteFile challenge..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--tests WriteFile \
--json
poetry run pytest -vv \
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
- name: Test category filtering
run: |
echo "Testing coding category..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--categories coding \
--tests ReadFile,WriteFile \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Test multiple strategies
run: |
echo "Testing multiple strategies..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot,plan_execute \
--models claude \
--tests ReadFile \
--parallel 2 \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: agbenchmark,${{ runner.os }}
# Run regression tests on maintain challenges
regression-tests:
self-test-with-agent:
runs-on: ubuntu-latest
timeout-minutes: 45
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
defaults:
run:
shell: bash
working-directory: classic
strategy:
matrix:
agent-name: [forge]
fail-fast: false
timeout-minutes: 20
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -140,31 +124,53 @@ jobs:
with:
python-version: ${{ env.min-python-version }}
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
- name: Install dependencies
run: poetry install
curl -sSL https://install.python-poetry.org | python -
- name: Run regression tests
working-directory: classic
run: |
echo "Running regression tests (previously beaten challenges)..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--maintain \
--parallel 4 \
--json
./run agent start ${{ matrix.agent-name }}
cd ${{ matrix.agent-name }}
set +e # Ignore non-zero exit codes and continue execution
echo "Running the following command: poetry run agbenchmark --maintain --mock"
poetry run agbenchmark --maintain --mock
EXIT_CODE=$?
set -e # Stop ignoring non-zero exit codes
# Check if the exit code was 5, and if so, exit with 0 instead
if [ $EXIT_CODE -eq 5 ]; then
echo "regression_tests.json is empty."
fi
echo "Running the following command: poetry run agbenchmark --mock"
poetry run agbenchmark --mock
echo "Running the following command: poetry run agbenchmark --mock --category=data"
poetry run agbenchmark --mock --category=data
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
poetry run agbenchmark --mock --category=coding
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
# poetry run agbenchmark --test=WriteFile
cd ../benchmark
poetry install
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
export BUILD_SKILL_TREE=true
# poetry run agbenchmark --mock
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
# if [ ! -z "$CHANGED" ]; then
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
# echo "$CHANGED"
# exit 1
# else
# echo "No unstaged changes."
# fi
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}

View File

@@ -6,15 +6,13 @@ on:
paths:
- '.github/workflows/classic-forge-ci.yml'
- 'classic/forge/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '!classic/forge/tests/vcr_cassettes'
pull_request:
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-forge-ci.yml'
- 'classic/forge/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '!classic/forge/tests/vcr_cassettes'
concurrency:
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -23,60 +21,131 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic
working-directory: classic/forge
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
steps:
- name: Start MinIO service
# Quite slow on macOS (2~4 minutes to set up Docker)
# - name: Set up Docker (macOS)
# if: runner.os == 'macOS'
# uses: crazy-max/ghaction-setup-docker@v3
- name: Start MinIO service (Linux)
if: runner.os == 'Linux'
working-directory: '.'
run: |
docker pull minio/minio:edge-cicd
docker run -d -p 9000:9000 minio/minio:edge-cicd
- name: Start MinIO service (macOS)
if: runner.os == 'macOS'
working-directory: ${{ runner.temp }}
run: |
brew install minio/stable/minio
mkdir data
minio server ./data &
# No MinIO on Windows:
# - Windows doesn't support running Linux Docker containers
# - It doesn't seem possible to start background processes on Windows. They are
# killed after the step returns.
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Set up Python 3.12
- name: Checkout cassettes
if: ${{ startsWith(github.event_name, 'pull_request') }}
env:
PR_BASE: ${{ github.event.pull_request.base.ref }}
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
run: |
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
cassette_base_branch="${PR_BASE}"
cd tests/vcr_cassettes
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
cassette_base_branch="master"
fi
if git ls-remote --exit-code --heads origin $cassette_branch ; then
git fetch origin $cassette_branch
git fetch origin $cassette_base_branch
git checkout $cassette_branch
# Pick non-conflicting cassette updates from the base branch
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
echo "Using cassettes from mirror branch '$cassette_branch'," \
"synced to upstream branch '$cassette_base_branch'."
else
git checkout -b $cassette_branch
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
"Using cassettes from '$cassette_base_branch'."
fi
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: "3.12"
python-version: ${{ matrix.python-version }}
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Python dependencies
run: poetry install
- name: Install Playwright browsers
run: poetry run playwright install chromium
- name: Run pytest with coverage
run: |
poetry run pytest -vv \
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
forge/forge forge/tests
forge
env:
CI: true
PLAIN_OUTPUT: True
# API keys - tests that need these will skip if not available
# Secrets are not available to fork PRs (GitHub security feature)
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
S3_ENDPOINT_URL: http://127.0.0.1:9000
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
@@ -90,11 +159,85 @@ jobs:
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: forge
flags: forge,${{ runner.os }}
- id: setup_git_auth
name: Set up git token authentication
# Cassettes may be pushed even when tests fail
if: success() || failure()
run: |
config_key="http.${{ github.server_url }}/.extraheader"
if [ "${{ runner.os }}" = 'macOS' ]; then
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
else
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
fi
git config "$config_key" \
"Authorization: Basic $base64_pat"
cd tests/vcr_cassettes
git config "$config_key" \
"Authorization: Basic $base64_pat"
echo "config_key=$config_key" >> $GITHUB_OUTPUT
- id: push_cassettes
name: Push updated cassettes
# For pull requests, push updated cassettes even when tests fail
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
env:
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
run: |
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
is_pull_request=true
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
else
cassette_branch="${{ github.ref_name }}"
fi
cd tests/vcr_cassettes
# Commit & push changes to cassettes if any
if ! git diff --quiet; then
git add .
git commit -m "Auto-update cassettes"
git push origin HEAD:$cassette_branch
if [ ! $is_pull_request ]; then
cd ../..
git add tests/vcr_cassettes
git commit -m "Update cassette submodule"
git push origin HEAD:$cassette_branch
fi
echo "updated=true" >> $GITHUB_OUTPUT
else
echo "updated=false" >> $GITHUB_OUTPUT
echo "No cassette changes to commit"
fi
- name: Post Set up git token auth
if: steps.setup_git_auth.outcome == 'success'
run: |
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
- name: Apply "behaviour change" label and comment on PR
if: ${{ startsWith(github.event_name, 'pull_request') }}
run: |
PR_NUMBER="${{ github.event.pull_request.number }}"
TOKEN="${{ secrets.PAT_REVIEW }}"
REPO="${{ github.repository }}"
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
echo "Adding label and comment..."
echo $TOKEN | gh auth login --with-token
gh issue edit $PR_NUMBER --add-label "behaviour change"
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
fi
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: test-logs
path: classic/logs/
path: classic/forge/logs/

View File

@@ -0,0 +1,60 @@
name: Classic - Frontend CI/CD
on:
push:
branches:
- master
- dev
- 'ci-test*' # This will match any branch that starts with "ci-test"
paths:
- 'classic/frontend/**'
- '.github/workflows/classic-frontend-ci.yml'
pull_request:
paths:
- 'classic/frontend/**'
- '.github/workflows/classic-frontend-ci.yml'
jobs:
build:
permissions:
contents: write
pull-requests: write
runs-on: ubuntu-latest
env:
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
steps:
- name: Checkout Repo
uses: actions/checkout@v4
- name: Setup Flutter
uses: subosito/flutter-action@v2
with:
flutter-version: '3.13.2'
- name: Build Flutter to Web
run: |
cd classic/frontend
flutter build web --base-href /app/
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
# if: github.event_name == 'push'
# run: |
# git config --local user.email "action@github.com"
# git config --local user.name "GitHub Action"
# git add classic/frontend/build/web
# git checkout -B ${{ env.BUILD_BRANCH }}
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
# git push -f origin ${{ env.BUILD_BRANCH }}
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
if: github.event_name == 'push'
uses: peter-evans/create-pull-request@v8
with:
add-paths: classic/frontend/build/web
base: ${{ github.ref_name }}
branch: ${{ env.BUILD_BRANCH }}
delete-branch: true
title: "Update frontend build in `${{ github.ref_name }}`"
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
commit-message: "Update frontend build based on commit ${{ github.sha }}"

View File

@@ -7,9 +7,7 @@ on:
- '.github/workflows/classic-python-checks-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- 'classic/benchmark/**'
- '**.py'
- '!classic/forge/tests/vcr_cassettes'
pull_request:
@@ -18,9 +16,7 @@ on:
- '.github/workflows/classic-python-checks-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- 'classic/benchmark/**'
- '**.py'
- '!classic/forge/tests/vcr_cassettes'
@@ -31,13 +27,44 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic
jobs:
get-changed-parts:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- id: changes-in
name: Determine affected subprojects
uses: dorny/paths-filter@v3
with:
filters: |
original_autogpt:
- classic/original_autogpt/autogpt/**
- classic/original_autogpt/tests/**
- classic/original_autogpt/poetry.lock
forge:
- classic/forge/forge/**
- classic/forge/tests/**
- classic/forge/poetry.lock
benchmark:
- classic/benchmark/agbenchmark/**
- classic/benchmark/tests/**
- classic/benchmark/poetry.lock
outputs:
changed-parts: ${{ steps.changes-in.outputs.changes }}
lint:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.12"
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
steps:
- name: Checkout repository
@@ -54,31 +81,42 @@ jobs:
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry install
run: poetry -C classic/${{ matrix.sub-package }} install
# Lint
- name: Lint (isort)
run: poetry run isort --check .
working-directory: classic/${{ matrix.sub-package }}
- name: Lint (Black)
if: success() || failure()
run: poetry run black --check .
working-directory: classic/${{ matrix.sub-package }}
- name: Lint (Flake8)
if: success() || failure()
run: poetry run flake8 .
working-directory: classic/${{ matrix.sub-package }}
types:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.12"
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
steps:
- name: Checkout repository
@@ -95,16 +133,19 @@ jobs:
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry install
run: poetry -C classic/${{ matrix.sub-package }} install
# Typecheck
- name: Typecheck
if: success() || failure()
run: poetry run pyright
working-directory: classic/${{ matrix.sub-package }}

View File

@@ -269,14 +269,12 @@ jobs:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
- name: Run pytest with coverage
- name: Run pytest
run: |
if [[ "${{ runner.debug }}" == "1" ]]; then
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG \
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
else
poetry run pytest -s -vv \
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
poetry run pytest -s -vv
fi
env:
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
@@ -289,13 +287,11 @@ jobs:
REDIS_PORT: "6379"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
- name: Upload coverage reports to Codecov
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: platform-backend
files: ./autogpt_platform/backend/coverage.xml
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# flags: backend,${{ runner.os }}
env:
CI: true

View File

@@ -148,11 +148,3 @@ jobs:
- name: Run Integration Tests
run: pnpm test:unit
- name: Upload coverage reports to Codecov
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: platform-frontend
files: ./autogpt_platform/frontend/coverage/cobertura-coverage.xml

View File

@@ -179,30 +179,21 @@ jobs:
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
export NEXT_PUBLIC_SOURCEMAPS NEXT_PUBLIC_PW_TEST
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Ensure NEXT_PUBLIC_SOURCEMAPS is in resolved compose
# (docker compose config on some versions drops this arg)
if ! grep -q "NEXT_PUBLIC_SOURCEMAPS" docker-compose.resolved.yml; then
echo "Injecting NEXT_PUBLIC_SOURCEMAPS into resolved compose (docker compose config dropped it)"
sed -i '/NEXT_PUBLIC_PW_TEST/a\ NEXT_PUBLIC_SOURCEMAPS: "true"' docker-compose.resolved.yml
fi
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}-sourcemaps" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
env:
NEXT_PUBLIC_PW_TEST: true
NEXT_PUBLIC_SOURCEMAPS: true
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
@@ -288,11 +279,6 @@ jobs:
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Copy source maps from Docker for E2E coverage
run: |
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
docker cp "$FRONTEND_CONTAINER":/app/.next/static .next-static-coverage
- name: Set up tests - Install dependencies
run: pnpm install --frozen-lockfile
@@ -303,15 +289,6 @@ jobs:
run: pnpm test:no-build
continue-on-error: false
- name: Upload E2E coverage to Codecov
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: platform-frontend-e2e
files: ./autogpt_platform/frontend/coverage/e2e/cobertura-coverage.xml
disable_search: true
- name: Upload Playwright report
if: always()
uses: actions/upload-artifact@v4

10
.gitignore vendored
View File

@@ -3,7 +3,6 @@
classic/original_autogpt/keys.py
classic/original_autogpt/*.json
auto_gpt_workspace/*
.autogpt/
*.mpeg
.env
# Root .env files
@@ -17,7 +16,6 @@ log-ingestion.txt
/logs
*.log
*.mp3
!autogpt_platform/frontend/public/notification.mp3
mem.sqlite3
venvAutoGPT
@@ -161,10 +159,6 @@ CURRENT_BULLETIN.md
# AgBenchmark
classic/benchmark/agbenchmark/reports/
classic/reports/
classic/direct_benchmark/reports/
classic/.benchmark_workspaces/
classic/direct_benchmark/.benchmark_workspaces/
# Nodejs
package-lock.json
@@ -183,13 +177,9 @@ autogpt_platform/backend/settings.py
*.ign.*
.test-contents
**/.claude/settings.local.json
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
# Test database
test.db
.next
# Implementation plans (generated by AI agents)
plans/

View File

@@ -1,36 +0,0 @@
title = "AutoGPT Gitleaks Config"
[extend]
useDefault = true
[allowlist]
description = "Global allowlist"
paths = [
# Template/example env files (no real secrets)
'''\.env\.(default|example|template)$''',
# Lock files
'''pnpm-lock\.yaml$''',
'''poetry\.lock$''',
# Secrets baseline
'''\.secrets\.baseline$''',
# Build artifacts and caches (should not be committed)
'''__pycache__/''',
'''classic/frontend/build/''',
# Docker dev setup (local dev JWTs/keys only)
'''autogpt_platform/db/docker/''',
# Load test configs (dev JWTs)
'''load-tests/configs/''',
# Test files with fake/fixture keys (_test.py, test_*.py, conftest.py)
'''(_test|test_.*|conftest)\.py$''',
# Documentation (only contains placeholder keys in curl/API examples)
'''docs/.*\.md$''',
# Firebase config (public API keys by design)
'''google-services\.json$''',
'''classic/frontend/(lib|web)/''',
]
# CI test-only encryption key (marked DO NOT USE IN PRODUCTION)
regexes = [
'''dvziYgz0KSK8FENhju0ZYi8''',
# LLM model name enum values falsely flagged as API keys
'''Llama-\d.*Instruct''',
]

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "classic/forge/tests/vcr_cassettes"]
path = classic/forge/tests/vcr_cassettes
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes

View File

@@ -23,15 +23,9 @@ repos:
- id: detect-secrets
name: Detect secrets
description: Detects high entropy strings that are likely to be passwords.
args: ["--baseline", ".secrets.baseline"]
files: ^autogpt_platform/
exclude: (pnpm-lock\.yaml|\.env\.(default|example|template))$
- repo: https://github.com/gitleaks/gitleaks
rev: v8.24.3
hooks:
- id: gitleaks
name: Detect secrets (gitleaks)
exclude: pnpm-lock\.yaml$
stages: [pre-push]
- repo: local
# For proper type checking, all dependencies need to be up-to-date.
@@ -90,16 +84,51 @@ repos:
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic
alias: poetry-install-classic
name: Check & Install dependencies - Classic - AutoGPT
alias: poetry-install-classic-autogpt
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/poetry\.lock$" || exit 0;
poetry -C classic install
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
poetry -C classic/original_autogpt install
'
# include forge source (since it's a path dependency)
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Forge
alias: poetry-install-classic-forge
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
poetry -C classic/forge install
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Benchmark
alias: poetry-install-classic-benchmark
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
poetry -C classic/benchmark install
'
always_run: true
language: system
@@ -194,10 +223,26 @@ repos:
language: system
- id: isort
name: Lint (isort) - Classic
alias: isort-classic
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
files: ^classic/(original_autogpt|forge|direct_benchmark)/
name: Lint (isort) - Classic - AutoGPT
alias: isort-classic-autogpt
entry: poetry -P classic/original_autogpt run isort -p autogpt
files: ^classic/original_autogpt/
types: [file, python]
language: system
- id: isort
name: Lint (isort) - Classic - Forge
alias: isort-classic-forge
entry: poetry -P classic/forge run isort -p forge
files: ^classic/forge/
types: [file, python]
language: system
- id: isort
name: Lint (isort) - Classic - Benchmark
alias: isort-classic-benchmark
entry: poetry -P classic/benchmark run isort -p agbenchmark
files: ^classic/benchmark/
types: [file, python]
language: system
@@ -211,13 +256,26 @@ repos:
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
# Use consolidated flake8 config at classic/.flake8
# To have flake8 load the config of the individual subprojects, we have to call
# them separately.
hooks:
- id: flake8
name: Lint (Flake8) - Classic
alias: flake8-classic
files: ^classic/(original_autogpt|forge|direct_benchmark)/
args: [--config=classic/.flake8]
name: Lint (Flake8) - Classic - AutoGPT
alias: flake8-classic-autogpt
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
args: [--config=classic/original_autogpt/.flake8]
- id: flake8
name: Lint (Flake8) - Classic - Forge
alias: flake8-classic-forge
files: ^classic/forge/(forge|tests)/
args: [--config=classic/forge/.flake8]
- id: flake8
name: Lint (Flake8) - Classic - Benchmark
alias: flake8-classic-benchmark
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
args: [--config=classic/benchmark/.flake8]
- repo: local
hooks:
@@ -253,10 +311,29 @@ repos:
pass_filenames: false
- id: pyright
name: Typecheck - Classic
alias: pyright-classic
entry: poetry -C classic run pyright
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
name: Typecheck - Classic - AutoGPT
alias: pyright-classic-autogpt
entry: poetry -C classic/original_autogpt run pyright
# include forge source (since it's a path dependency) but exclude *_test.py files:
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Classic - Forge
alias: pyright-classic-forge
entry: poetry -C classic/forge run pyright
files: ^classic/forge/(forge/|poetry\.lock$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Classic - Benchmark
alias: pyright-classic-benchmark
entry: poetry -C classic/benchmark run pyright
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
types: [file]
language: system
pass_filenames: false
@@ -283,9 +360,26 @@ repos:
# pass_filenames: false
# - id: pytest
# name: Run tests - Classic (excl. slow tests)
# alias: pytest-classic
# entry: bash -c 'cd classic && poetry run pytest -m "not slow"'
# files: ^classic/(original_autogpt|forge|direct_benchmark)/
# name: Run tests - Classic - AutoGPT (excl. slow tests)
# alias: pytest-classic-autogpt
# entry: bash -c 'cd classic/original_autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
# # include forge source (since it's a path dependency) but exclude *_test.py files:
# files: ^(classic/original_autogpt/((autogpt|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
# language: system
# pass_filenames: false
# - id: pytest
# name: Run tests - Classic - Forge (excl. slow tests)
# alias: pytest-classic-forge
# entry: bash -c 'cd classic/forge && poetry run pytest --cov=forge -m "not slow"'
# files: ^classic/forge/(forge/|tests/|poetry\.lock$)
# language: system
# pass_filenames: false
# - id: pytest
# name: Run tests - Classic - Benchmark
# alias: pytest-classic-benchmark
# entry: bash -c 'cd classic/benchmark && poetry run pytest --cov=benchmark'
# files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
# language: system
# pass_filenames: false

View File

@@ -1,467 +0,0 @@
{
"version": "1.5.0",
"plugins_used": [
{
"name": "ArtifactoryDetector"
},
{
"name": "AWSKeyDetector"
},
{
"name": "AzureStorageKeyDetector"
},
{
"name": "Base64HighEntropyString",
"limit": 4.5
},
{
"name": "BasicAuthDetector"
},
{
"name": "CloudantDetector"
},
{
"name": "DiscordBotTokenDetector"
},
{
"name": "GitHubTokenDetector"
},
{
"name": "GitLabTokenDetector"
},
{
"name": "HexHighEntropyString",
"limit": 3.0
},
{
"name": "IbmCloudIamDetector"
},
{
"name": "IbmCosHmacDetector"
},
{
"name": "IPPublicDetector"
},
{
"name": "JwtTokenDetector"
},
{
"name": "KeywordDetector",
"keyword_exclude": ""
},
{
"name": "MailchimpDetector"
},
{
"name": "NpmDetector"
},
{
"name": "OpenAIDetector"
},
{
"name": "PrivateKeyDetector"
},
{
"name": "PypiTokenDetector"
},
{
"name": "SendGridDetector"
},
{
"name": "SlackDetector"
},
{
"name": "SoftlayerDetector"
},
{
"name": "SquareOAuthDetector"
},
{
"name": "StripeDetector"
},
{
"name": "TelegramBotTokenDetector"
},
{
"name": "TwilioKeyDetector"
}
],
"filters_used": [
{
"path": "detect_secrets.filters.allowlist.is_line_allowlisted"
},
{
"path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
"min_level": 2
},
{
"path": "detect_secrets.filters.heuristic.is_indirect_reference"
},
{
"path": "detect_secrets.filters.heuristic.is_likely_id_string"
},
{
"path": "detect_secrets.filters.heuristic.is_lock_file"
},
{
"path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string"
},
{
"path": "detect_secrets.filters.heuristic.is_potential_uuid"
},
{
"path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign"
},
{
"path": "detect_secrets.filters.heuristic.is_sequential_string"
},
{
"path": "detect_secrets.filters.heuristic.is_swagger_file"
},
{
"path": "detect_secrets.filters.heuristic.is_templated_secret"
},
{
"path": "detect_secrets.filters.regex.should_exclude_file",
"pattern": [
"\\.env$",
"pnpm-lock\\.yaml$",
"\\.env\\.(default|example|template)$",
"__pycache__",
"_test\\.py$",
"test_.*\\.py$",
"conftest\\.py$",
"poetry\\.lock$",
"node_modules"
]
}
],
"results": {
"autogpt_platform/backend/backend/api/external/v1/integrations.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/api/external/v1/integrations.py",
"hashed_secret": "665b1e3851eefefa3fb878654292f16597d25155",
"is_verified": false,
"line_number": 289
}
],
"autogpt_platform/backend/backend/blocks/airtable/_config.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/blocks/airtable/_config.py",
"hashed_secret": "57e168b03afb7c1ee3cdc4ee3db2fe1cc6e0df26",
"is_verified": false,
"line_number": 29
}
],
"autogpt_platform/backend/backend/blocks/dataforseo/_config.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/blocks/dataforseo/_config.py",
"hashed_secret": "32ce93887331fa5d192f2876ea15ec000c7d58b8",
"is_verified": false,
"line_number": 12
}
],
"autogpt_platform/backend/backend/blocks/github/checks.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/checks.py",
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
"is_verified": false,
"line_number": 108
}
],
"autogpt_platform/backend/backend/blocks/github/ci.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/ci.py",
"hashed_secret": "90bd1b48e958257948487b90bee080ba5ed00caa",
"is_verified": false,
"line_number": 123
}
],
"autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
"hashed_secret": "f96896dafced7387dcd22343b8ea29d3d2c65663",
"is_verified": false,
"line_number": 42
},
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
"hashed_secret": "b80a94d5e70bedf4f5f89d2f5a5255cc9492d12e",
"is_verified": false,
"line_number": 193
},
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
"hashed_secret": "75b17e517fe1b3136394f6bec80c4f892da75e42",
"is_verified": false,
"line_number": 344
},
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
"hashed_secret": "b0bfb5e4e2394e7f8906e5ed1dffd88b2bc89dd5",
"is_verified": false,
"line_number": 534
}
],
"autogpt_platform/backend/backend/blocks/github/statuses.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/statuses.py",
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
"is_verified": false,
"line_number": 85
}
],
"autogpt_platform/backend/backend/blocks/google/docs.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/google/docs.py",
"hashed_secret": "c95da0c6696342c867ef0c8258d2f74d20fd94d4",
"is_verified": false,
"line_number": 203
}
],
"autogpt_platform/backend/backend/blocks/google/sheets.py": [
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/google/sheets.py",
"hashed_secret": "bd5a04fa3667e693edc13239b6d310c5c7a8564b",
"is_verified": false,
"line_number": 57
}
],
"autogpt_platform/backend/backend/blocks/linear/_config.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/blocks/linear/_config.py",
"hashed_secret": "b37f020f42d6d613b6ce30103e4d408c4499b3bb",
"is_verified": false,
"line_number": 53
}
],
"autogpt_platform/backend/backend/blocks/medium.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/medium.py",
"hashed_secret": "ff998abc1ce6d8f01a675fa197368e44c8916e9c",
"is_verified": false,
"line_number": 131
}
],
"autogpt_platform/backend/backend/blocks/replicate/replicate_block.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
"is_verified": false,
"line_number": 55
}
],
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/slant3d/webhook.py",
"hashed_secret": "36263c76947443b2f6e6b78153967ac4a7da99f9",
"is_verified": false,
"line_number": 100
}
],
"autogpt_platform/backend/backend/blocks/talking_head.py": [
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/talking_head.py",
"hashed_secret": "44ce2d66222529eea4a32932823466fc0601c799",
"is_verified": false,
"line_number": 113
}
],
"autogpt_platform/backend/backend/blocks/wordpress/_config.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/blocks/wordpress/_config.py",
"hashed_secret": "e62679512436161b78e8a8d68c8829c2a1031ccb",
"is_verified": false,
"line_number": 17
}
],
"autogpt_platform/backend/backend/util/cache.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/util/cache.py",
"hashed_secret": "37f0c918c3fa47ca4a70e42037f9f123fdfbc75b",
"is_verified": false,
"line_number": 449
}
],
"autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts",
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
"is_verified": false,
"line_number": 6
}
],
"autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json",
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
"is_verified": false,
"line_number": 5
}
],
"autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json",
"hashed_secret": "5a6d1c612954979ea99ee33dbb2d231b00f6ac0a",
"is_verified": false,
"line_number": 5
}
],
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
"is_verified": false,
"line_number": 6
},
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
"is_verified": false,
"line_number": 8
}
],
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
"is_verified": false,
"line_number": 5
},
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
"is_verified": false,
"line_number": 7
}
],
"autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
"is_verified": false,
"line_number": 192
},
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
"hashed_secret": "86275db852204937bbdbdebe5fabe8536e030ab6",
"is_verified": false,
"line_number": 193
}
],
"autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
"hashed_secret": "47acd2028cf81b5da88ddeedb2aea4eca4b71fbd",
"is_verified": false,
"line_number": 102
},
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
"is_verified": false,
"line_number": 103
}
],
"autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts": [
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "9c486c92f1a7420e1045c7ad963fbb7ba3621025",
"is_verified": false,
"line_number": 73
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "9277508c7a6effc8fb59163efbfada189e35425c",
"is_verified": false,
"line_number": 75
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "8dc7e2cb1d0935897d541bf5facab389b8a50340",
"is_verified": false,
"line_number": 77
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "79a26ad48775944299be6aaf9fb1d5302c1ed75b",
"is_verified": false,
"line_number": 79
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "a3b62b44500a1612e48d4cab8294df81561b3b1a",
"is_verified": false,
"line_number": 81
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "a58979bd0b21ef4f50417d001008e60dd7a85c64",
"is_verified": false,
"line_number": 83
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "6cb6e075f8e8c7c850f9d128d6608e5dbe209a79",
"is_verified": false,
"line_number": 85
}
],
"autogpt_platform/frontend/src/lib/constants.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/lib/constants.ts",
"hashed_secret": "27b924db06a28cc755fb07c54f0fddc30659fe4d",
"is_verified": false,
"line_number": 10
}
],
"autogpt_platform/frontend/src/tests/credentials/index.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/tests/credentials/index.ts",
"hashed_secret": "c18006fc138809314751cd1991f1e0b820fabd37",
"is_verified": false,
"line_number": 4
}
]
},
"generated_at": "2026-04-02T13:10:54Z"
}

View File

@@ -1,6 +1,6 @@
# AutoGPT Platform Contribution Guide
This guide provides context for coding agents when updating the **autogpt_platform** folder.
This guide provides context for Codex when updating the **autogpt_platform** folder.
## Directory overview
@@ -30,7 +30,7 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
- Regenerate with `pnpm generate:api`
- Pattern: `use{Method}{Version}{OperationName}`
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Integration tests (Vitest + RTL + MSW) are the default (~90%, page-level). Playwright for E2E critical flows. Storybook for design system components. See `autogpt_platform/frontend/TESTING.md`
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
@@ -47,9 +47,7 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
## Testing
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
- Frontend integration tests: `pnpm test:unit` (Vitest + RTL + MSW, primary testing approach).
- Frontend E2E tests: `pnpm test` or `pnpm test-ui` for Playwright tests.
- See `autogpt_platform/frontend/TESTING.md` for the full testing strategy.
- Frontend: `pnpm test` or `pnpm test-ui` for Playwright tests. See `docs/content/platform/contributing/tests.md` for tips.
Always run the relevant linters and tests before committing.
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).

View File

@@ -1 +0,0 @@
@AGENTS.md

View File

@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.

View File

@@ -1,120 +0,0 @@
# AutoGPT Platform
This file provides guidance to coding agents when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
## Component Documentation
- **Backend**: See @backend/AGENTS.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/AGENTS.md for frontend-specific commands, architecture, and development patterns
## Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Environment Configuration
#### Configuration Files
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
```bash
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
## Summary
- use `backticks` freely here
PREOF
gh pr create --title "..." --body-file "$PR_BODY" --base dev
rm "$PR_BODY"
```
- Run the github pre-commit hooks to ensure code quality.
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, follow a test-first approach:
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
2. **Implement the fix/feature** — write the minimal code to make the test pass.
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
This ensures every change is covered by a test and that the test actually validates the intended behavior.
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.

View File

@@ -1 +1,118 @@
@AGENTS.md
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
## Component Documentation
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
## Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Environment Configuration
#### Configuration Files
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
```bash
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
## Summary
- use `backticks` freely here
PREOF
gh pr create --title "..." --body-file "$PR_BODY" --base dev
rm "$PR_BODY"
```
- Run the github pre-commit hooks to ensure code quality.
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, follow a test-first approach:
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
2. **Implement the fix/feature** — write the minimal code to make the test pass.
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
This ensures every change is covered by a test and that the test actually validates the intended behavior.
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.

View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
[[package]]
name = "annotated-doc"
@@ -67,7 +67,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
optional = false
python-versions = "<3.11,>=3.8"
groups = ["dev"]
markers = "python_version == \"3.10\""
markers = "python_version < \"3.11\""
files = [
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
@@ -541,7 +541,7 @@ description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
groups = ["main", "dev"]
markers = "python_version == \"3.10\""
markers = "python_version < \"3.11\""
files = [
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
{file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"},
@@ -2181,14 +2181,14 @@ testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
name = "pytest-cov"
version = "7.1.0"
version = "7.0.0"
description = "Pytest plugin for measuring coverage."
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678"},
{file = "pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2"},
{file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"},
{file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"},
]
[package.dependencies]
@@ -2342,30 +2342,30 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.15.7"
version = "0.15.0"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
groups = ["dev"]
files = [
{file = "ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e"},
{file = "ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477"},
{file = "ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e"},
{file = "ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf"},
{file = "ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85"},
{file = "ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0"},
{file = "ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912"},
{file = "ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036"},
{file = "ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5"},
{file = "ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12"},
{file = "ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c"},
{file = "ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4"},
{file = "ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d"},
{file = "ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580"},
{file = "ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de"},
{file = "ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1"},
{file = "ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2"},
{file = "ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac"},
{file = "ruff-0.15.0-py3-none-linux_armv6l.whl", hash = "sha256:aac4ebaa612a82b23d45964586f24ae9bc23ca101919f5590bdb368d74ad5455"},
{file = "ruff-0.15.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:dcd4be7cc75cfbbca24a98d04d0b9b36a270d0833241f776b788d59f4142b14d"},
{file = "ruff-0.15.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d747e3319b2bce179c7c1eaad3d884dc0a199b5f4d5187620530adf9105268ce"},
{file = "ruff-0.15.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:650bd9c56ae03102c51a5e4b554d74d825ff3abe4db22b90fd32d816c2e90621"},
{file = "ruff-0.15.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a6664b7eac559e3048223a2da77769c2f92b43a6dfd4720cef42654299a599c9"},
{file = "ruff-0.15.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f811f97b0f092b35320d1556f3353bf238763420ade5d9e62ebd2b73f2ff179"},
{file = "ruff-0.15.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:761ec0a66680fab6454236635a39abaf14198818c8cdf691e036f4bc0f406b2d"},
{file = "ruff-0.15.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:940f11c2604d317e797b289f4f9f3fa5555ffe4fb574b55ed006c3d9b6f0eb78"},
{file = "ruff-0.15.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbca3d40558789126da91d7ef9a7c87772ee107033db7191edefa34e2c7f1b4"},
{file = "ruff-0.15.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9a121a96db1d75fa3eb39c4539e607f628920dd72ff1f7c5ee4f1b768ac62d6e"},
{file = "ruff-0.15.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5298d518e493061f2eabd4abd067c7e4fb89e2f63291c94332e35631c07c3662"},
{file = "ruff-0.15.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:afb6e603d6375ff0d6b0cee563fa21ab570fd15e65c852cb24922cef25050cf1"},
{file = "ruff-0.15.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:77e515f6b15f828b94dc17d2b4ace334c9ddb7d9468c54b2f9ed2b9c1593ef16"},
{file = "ruff-0.15.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6f6e80850a01eb13b3e42ee0ebdf6e4497151b48c35051aab51c101266d187a3"},
{file = "ruff-0.15.0-py3-none-win32.whl", hash = "sha256:238a717ef803e501b6d51e0bdd0d2c6e8513fe9eec14002445134d3907cd46c3"},
{file = "ruff-0.15.0-py3-none-win_amd64.whl", hash = "sha256:dd5e4d3301dc01de614da3cdffc33d4b1b96fb89e45721f1598e5532ccf78b18"},
{file = "ruff-0.15.0-py3-none-win_arm64.whl", hash = "sha256:c480d632cc0ca3f0727acac8b7d053542d9e114a462a145d0b00e7cd658c515a"},
{file = "ruff-0.15.0.tar.gz", hash = "sha256:6bdea47cdbea30d40f8f8d7d69c0854ba7c15420ec75a26f463290949d7f7e9a"},
]
[[package]]
@@ -2564,7 +2564,7 @@ description = "A lil' TOML parser"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
markers = "python_version == \"3.10\""
markers = "python_version < \"3.11\""
files = [
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
@@ -2912,4 +2912,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
content-hash = "e0936a065565550afed18f6298b7e04e814b44100def7049f1a0d68662624a39"
content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb"

View File

@@ -26,8 +26,8 @@ pyright = "^1.1.408"
pytest = "^8.4.1"
pytest-asyncio = "^1.3.0"
pytest-mock = "^3.15.1"
pytest-cov = "^7.1.0"
ruff = "^0.15.7"
pytest-cov = "^7.0.0"
ruff = "^0.15.0"
[build-system]
requires = ["poetry-core"]

View File

@@ -178,7 +178,6 @@ SMTP_USERNAME=
SMTP_PASSWORD=
# Business & Marketing Tools
AGENTMAIL_API_KEY=
APOLLO_API_KEY=
ENRICHLAYER_API_KEY=
AYRSHARE_API_KEY=

View File

@@ -1,227 +0,0 @@
# Backend
This file provides guidance to coding agents when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, write the test **before** the implementation:
```python
# 1. Write a failing test marked xfail
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
# 2. Run it — confirm it fails (XFAIL)
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
# 3. Implement the fix
# 4. Remove xfail, run again — confirm it passes
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
```
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications

View File

@@ -1 +1,226 @@
@AGENTS.md
# CLAUDE.md - Backend
This file provides guidance to Claude Code when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, write the test **before** the implementation:
```python
# 1. Write a failing test marked xfail
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
# 2. Run it — confirm it fails (XFAIL)
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
# 3. Implement the fix
# 4. Remove xfail, run again — confirm it passes
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
```
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications

View File

@@ -121,20 +121,36 @@ RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
# Install agent-browser (Copilot browser tool) using the system chromium package.
# Chrome for Testing (the binary agent-browser downloads via `agent-browser install`)
# has no ARM64 builds, so we use the distro-packaged chromium instead — verified to
# work with agent-browser via Docker tests on arm64; amd64 is validated in CI.
# Note: system chromium tracks the Debian package schedule rather than a pinned
# Chrome for Testing release. If agent-browser requires a specific Chrome version,
# verify compatibility against the chromium package version in the base image.
# Install agent-browser (Copilot browser tool) + Chromium.
# On amd64: install runtime libs + run `agent-browser install` to download
# Chrome for Testing (pinned version, tested with Playwright).
# On arm64: install system chromium package — Chrome for Testing has no ARM64
# binary. AGENT_BROWSER_EXECUTABLE_PATH is set at runtime by the entrypoint
# script (below) to redirect agent-browser to the system binary.
ARG TARGETARCH
RUN apt-get update \
&& apt-get install -y --no-install-recommends chromium fonts-liberation \
&& if [ "$TARGETARCH" = "arm64" ]; then \
apt-get install -y --no-install-recommends chromium fonts-liberation; \
else \
apt-get install -y --no-install-recommends \
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
fonts-liberation libfontconfig1; \
fi \
&& rm -rf /var/lib/apt/lists/* \
&& npm install -g agent-browser \
&& ([ "$TARGETARCH" = "arm64" ] || agent-browser install) \
&& rm -rf /tmp/* /root/.npm
ENV AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium
# On arm64 the system chromium is at /usr/bin/chromium; set
# AGENT_BROWSER_EXECUTABLE_PATH so agent-browser's daemon uses it instead of
# Chrome for Testing (which has no ARM64 binary). On amd64 the variable is left
# unset so agent-browser uses the Chrome for Testing binary it downloaded above.
RUN printf '#!/bin/sh\n[ -x /usr/bin/chromium ] && export AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium\nexec "$@"\n' \
> /usr/local/bin/entrypoint.sh \
&& chmod +x /usr/local/bin/entrypoint.sh
WORKDIR /app/autogpt_platform/backend
@@ -157,4 +173,5 @@ RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
ENV PORT=8000
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["rest"]

View File

@@ -18,22 +18,14 @@ from pydantic import BaseModel, Field, SecretStr
from backend.api.external.middleware import require_permission
from backend.api.features.integrations.models import get_all_provider_names
from backend.api.features.integrations.router import (
CredentialsMetaResponse,
to_meta_response,
)
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.model import (
APIKeyCredentials,
Credentials,
CredentialsType,
HostScopedCredentials,
OAuth2Credentials,
UserPasswordCredentials,
is_sdk_default,
)
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
@@ -99,6 +91,18 @@ class OAuthCompleteResponse(BaseModel):
)
class CredentialSummary(BaseModel):
"""Summary of a credential without sensitive data."""
id: str
provider: str
type: CredentialsType
title: Optional[str] = None
scopes: Optional[list[str]] = None
username: Optional[str] = None
host: Optional[str] = None
class ProviderInfo(BaseModel):
"""Information about an integration provider."""
@@ -469,12 +473,12 @@ async def complete_oauth(
)
@integrations_router.get("/credentials", response_model=list[CredentialsMetaResponse])
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
async def list_credentials(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[CredentialsMetaResponse]:
) -> list[CredentialSummary]:
"""
List all credentials for the authenticated user.
@@ -482,19 +486,28 @@ async def list_credentials(
"""
credentials = await creds_manager.store.get_all_creds(auth.user_id)
return [
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
CredentialSummary(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@integrations_router.get(
"/{provider}/credentials", response_model=list[CredentialsMetaResponse]
"/{provider}/credentials", response_model=list[CredentialSummary]
)
async def list_credentials_by_provider(
provider: Annotated[str, Path(title="The provider to list credentials for")],
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[CredentialsMetaResponse]:
) -> list[CredentialSummary]:
"""
List credentials for a specific provider.
"""
@@ -502,7 +515,16 @@ async def list_credentials_by_provider(
auth.user_id, provider
)
return [
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
CredentialSummary(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -575,11 +597,11 @@ async def create_credential(
# Store credentials
try:
await creds_manager.create(auth.user_id, credentials)
except Exception:
logger.exception("Failed to store credentials")
except Exception as e:
logger.error(f"Failed to store credentials: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to store credentials",
detail=f"Failed to store credentials: {str(e)}",
)
logger.info(f"Created {request.type} credentials for provider {provider}")
@@ -617,23 +639,15 @@ async def delete_credential(
use the main API's delete endpoint which handles webhook cleanup and
token revocation.
"""
if is_sdk_default(cred_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
if not creds:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if not provider_matches(creds.provider, provider):
if creds.provider != provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials do not match the specified provider",
)
await creds_manager.delete(auth.user_id, cred_id)

View File

@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
def _create_ephemeral_session(user_id: str) -> ChatSession:
"""Create an ephemeral session for stateless API requests."""
return ChatSession.new(user_id, dry_run=False)
return ChatSession.new(user_id)
@tools_router.post(

View File

@@ -1,98 +0,0 @@
import logging
from datetime import datetime
from autogpt_libs.auth import get_user_id, requires_admin_user
from cachetools import TTLCache
from fastapi import APIRouter, Query, Security
from pydantic import BaseModel
from backend.data.platform_cost import (
CostLogRow,
PlatformCostDashboard,
get_platform_cost_dashboard,
get_platform_cost_logs,
)
from backend.util.models import Pagination
logger = logging.getLogger(__name__)
# Cache dashboard results for 30 seconds per unique filter combination.
# The table is append-only so stale reads are acceptable for analytics.
_DASHBOARD_CACHE_TTL = 30
_dashboard_cache: TTLCache[tuple, PlatformCostDashboard] = TTLCache(
maxsize=256, ttl=_DASHBOARD_CACHE_TTL
)
router = APIRouter(
prefix="/platform-costs",
tags=["platform-cost", "admin"],
dependencies=[Security(requires_admin_user)],
)
class PlatformCostLogsResponse(BaseModel):
logs: list[CostLogRow]
pagination: Pagination
@router.get(
"/dashboard",
response_model=PlatformCostDashboard,
summary="Get Platform Cost Dashboard",
)
async def get_cost_dashboard(
admin_user_id: str = Security(get_user_id),
start: datetime | None = Query(None),
end: datetime | None = Query(None),
provider: str | None = Query(None),
user_id: str | None = Query(None),
):
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
cache_key = (start, end, provider, user_id)
cached = _dashboard_cache.get(cache_key)
if cached is not None:
return cached
result = await get_platform_cost_dashboard(
start=start,
end=end,
provider=provider,
user_id=user_id,
)
_dashboard_cache[cache_key] = result
return result
@router.get(
"/logs",
response_model=PlatformCostLogsResponse,
summary="Get Platform Cost Logs",
)
async def get_cost_logs(
admin_user_id: str = Security(get_user_id),
start: datetime | None = Query(None),
end: datetime | None = Query(None),
provider: str | None = Query(None),
user_id: str | None = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
):
logger.info("Admin %s fetching platform cost logs", admin_user_id)
logs, total = await get_platform_cost_logs(
start=start,
end=end,
provider=provider,
user_id=user_id,
page=page,
page_size=page_size,
)
total_pages = (total + page_size - 1) // page_size
return PlatformCostLogsResponse(
logs=logs,
pagination=Pagination(
total_items=total,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)

View File

@@ -1,192 +0,0 @@
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.platform_cost import PlatformCostDashboard
from . import platform_cost_routes
from .platform_cost_routes import router as platform_cost_router
app = fastapi.FastAPI()
app.include_router(platform_cost_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
# Clear TTL cache so each test starts cold.
platform_cost_routes._dashboard_cache.clear()
yield
app.dependency_overrides.clear()
def test_get_dashboard_success(
mocker: pytest_mock.MockerFixture,
) -> None:
real_dashboard = PlatformCostDashboard(
by_provider=[],
by_user=[],
total_cost_microdollars=0,
total_requests=0,
total_users=0,
)
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
AsyncMock(return_value=real_dashboard),
)
response = client.get("/platform-costs/dashboard")
assert response.status_code == 200
data = response.json()
assert "by_provider" in data
assert "by_user" in data
assert data["total_cost_microdollars"] == 0
def test_get_logs_success(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
AsyncMock(return_value=([], 0)),
)
response = client.get("/platform-costs/logs")
assert response.status_code == 200
data = response.json()
assert data["logs"] == []
assert data["pagination"]["total_items"] == 0
def test_get_dashboard_with_filters(
mocker: pytest_mock.MockerFixture,
) -> None:
real_dashboard = PlatformCostDashboard(
by_provider=[],
by_user=[],
total_cost_microdollars=0,
total_requests=0,
total_users=0,
)
mock_dashboard = AsyncMock(return_value=real_dashboard)
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
mock_dashboard,
)
response = client.get(
"/platform-costs/dashboard",
params={
"start": "2026-01-01T00:00:00",
"end": "2026-04-01T00:00:00",
"provider": "openai",
"user_id": "test-user-123",
},
)
assert response.status_code == 200
mock_dashboard.assert_called_once()
call_kwargs = mock_dashboard.call_args.kwargs
assert call_kwargs["provider"] == "openai"
assert call_kwargs["user_id"] == "test-user-123"
assert call_kwargs["start"] is not None
assert call_kwargs["end"] is not None
def test_get_logs_with_pagination(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
AsyncMock(return_value=([], 0)),
)
response = client.get(
"/platform-costs/logs",
params={"page": 2, "page_size": 25, "provider": "anthropic"},
)
assert response.status_code == 200
data = response.json()
assert data["pagination"]["current_page"] == 2
assert data["pagination"]["page_size"] == 25
def test_get_dashboard_requires_admin() -> None:
import fastapi
from fastapi import HTTPException
def reject_jwt(request: fastapi.Request):
raise HTTPException(status_code=401, detail="Not authenticated")
app.dependency_overrides[get_jwt_payload] = reject_jwt
try:
response = client.get("/platform-costs/dashboard")
assert response.status_code == 401
response = client.get("/platform-costs/logs")
assert response.status_code == 401
finally:
app.dependency_overrides.clear()
def test_get_dashboard_rejects_non_admin(mock_jwt_user, mock_jwt_admin) -> None:
"""Non-admin JWT must be rejected with 403 by requires_admin_user."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
try:
response = client.get("/platform-costs/dashboard")
assert response.status_code == 403
response = client.get("/platform-costs/logs")
assert response.status_code == 403
finally:
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
def test_get_logs_invalid_page_size_too_large() -> None:
"""page_size > 200 must be rejected with 422."""
response = client.get("/platform-costs/logs", params={"page_size": 201})
assert response.status_code == 422
def test_get_logs_invalid_page_size_zero() -> None:
"""page_size = 0 (below ge=1) must be rejected with 422."""
response = client.get("/platform-costs/logs", params={"page_size": 0})
assert response.status_code == 422
def test_get_logs_invalid_page_negative() -> None:
"""page < 1 must be rejected with 422."""
response = client.get("/platform-costs/logs", params={"page": 0})
assert response.status_code == 422
def test_get_dashboard_invalid_date_format() -> None:
"""Malformed start date must be rejected with 422."""
response = client.get("/platform-costs/dashboard", params={"start": "not-a-date"})
assert response.status_code == 422
def test_get_dashboard_cache_hit(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Second identical request returns cached result without calling the DB again."""
real_dashboard = PlatformCostDashboard(
by_provider=[],
by_user=[],
total_cost_microdollars=42,
total_requests=1,
total_users=1,
)
mock_fn = mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
AsyncMock(return_value=real_dashboard),
)
client.get("/platform-costs/dashboard")
client.get("/platform-costs/dashboard")
mock_fn.assert_awaited_once() # second request hit the cache

View File

@@ -1,259 +0,0 @@
"""Admin endpoints for checking and resetting user CoPilot rate limit usage."""
import logging
from typing import Optional
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, Body, HTTPException, Security
from pydantic import BaseModel
from backend.copilot.config import ChatConfig
from backend.copilot.rate_limit import (
SubscriptionTier,
get_global_rate_limits,
get_usage_status,
get_user_tier,
reset_user_usage,
set_user_tier,
)
from backend.data.user import get_user_by_email, get_user_email_by_id, search_users
logger = logging.getLogger(__name__)
config = ChatConfig()
router = APIRouter(
prefix="/admin",
tags=["copilot", "admin"],
dependencies=[Security(requires_admin_user)],
)
class UserRateLimitResponse(BaseModel):
user_id: str
user_email: Optional[str] = None
daily_token_limit: int
weekly_token_limit: int
daily_tokens_used: int
weekly_tokens_used: int
tier: SubscriptionTier
class UserTierResponse(BaseModel):
user_id: str
tier: SubscriptionTier
class SetUserTierRequest(BaseModel):
user_id: str
tier: SubscriptionTier
async def _resolve_user_id(
user_id: Optional[str], email: Optional[str]
) -> tuple[str, Optional[str]]:
"""Resolve a user_id and email from the provided parameters.
Returns (user_id, email). Accepts either user_id or email; at least one
must be provided. When both are provided, ``email`` takes precedence.
"""
if email:
user = await get_user_by_email(email)
if not user:
raise HTTPException(
status_code=404, detail="No user found with the provided email."
)
return user.id, email
if not user_id:
raise HTTPException(
status_code=400,
detail="Either user_id or email query parameter is required.",
)
# We have a user_id; try to look up their email for display purposes.
# This is non-critical -- a failure should not block the response.
try:
resolved_email = await get_user_email_by_id(user_id)
except Exception:
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
resolved_email = None
return user_id, resolved_email
@router.get(
"/rate_limit",
response_model=UserRateLimitResponse,
summary="Get User Rate Limit",
)
async def get_user_rate_limit(
user_id: Optional[str] = None,
email: Optional[str] = None,
admin_user_id: str = Security(get_user_id),
) -> UserRateLimitResponse:
"""Get a user's current usage and effective rate limits. Admin-only.
Accepts either ``user_id`` or ``email`` as a query parameter.
When ``email`` is provided the user is looked up by email first.
"""
resolved_id, resolved_email = await _resolve_user_id(user_id, email)
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
daily_limit, weekly_limit, tier = await get_global_rate_limits(
resolved_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
return UserRateLimitResponse(
user_id=resolved_id,
user_email=resolved_email,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
tier=tier,
)
@router.post(
"/rate_limit/reset",
response_model=UserRateLimitResponse,
summary="Reset User Rate Limit Usage",
)
async def reset_user_rate_limit(
user_id: str = Body(embed=True),
reset_weekly: bool = Body(False, embed=True),
admin_user_id: str = Security(get_user_id),
) -> UserRateLimitResponse:
"""Reset a user's daily usage counter (and optionally weekly). Admin-only."""
logger.info(
"Admin %s resetting rate limit for user %s (reset_weekly=%s)",
admin_user_id,
user_id,
reset_weekly,
)
try:
await reset_user_usage(user_id, reset_weekly=reset_weekly)
except Exception as e:
logger.exception("Failed to reset user usage")
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
try:
resolved_email = await get_user_email_by_id(user_id)
except Exception:
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
resolved_email = None
return UserRateLimitResponse(
user_id=user_id,
user_email=resolved_email,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
tier=tier,
)
@router.get(
"/rate_limit/tier",
response_model=UserTierResponse,
summary="Get User Rate Limit Tier",
)
async def get_user_rate_limit_tier(
user_id: str,
admin_user_id: str = Security(get_user_id),
) -> UserTierResponse:
"""Get a user's current rate-limit tier. Admin-only.
Returns 404 if the user does not exist in the database.
"""
logger.info("Admin %s checking tier for user %s", admin_user_id, user_id)
resolved_email = await get_user_email_by_id(user_id)
if resolved_email is None:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
tier = await get_user_tier(user_id)
return UserTierResponse(user_id=user_id, tier=tier)
@router.post(
"/rate_limit/tier",
response_model=UserTierResponse,
summary="Set User Rate Limit Tier",
)
async def set_user_rate_limit_tier(
request: SetUserTierRequest,
admin_user_id: str = Security(get_user_id),
) -> UserTierResponse:
"""Set a user's rate-limit tier. Admin-only.
Returns 404 if the user does not exist in the database.
"""
try:
resolved_email = await get_user_email_by_id(request.user_id)
except Exception:
logger.warning(
"Failed to resolve email for user %s",
request.user_id,
exc_info=True,
)
resolved_email = None
if resolved_email is None:
raise HTTPException(status_code=404, detail=f"User {request.user_id} not found")
old_tier = await get_user_tier(request.user_id)
logger.info(
"Admin %s changing tier for user %s (%s): %s -> %s",
admin_user_id,
request.user_id,
resolved_email,
old_tier.value,
request.tier.value,
)
try:
await set_user_tier(request.user_id, request.tier)
except Exception as e:
logger.exception("Failed to set user tier")
raise HTTPException(status_code=500, detail="Failed to set tier") from e
return UserTierResponse(user_id=request.user_id, tier=request.tier)
class UserSearchResult(BaseModel):
user_id: str
user_email: Optional[str] = None
@router.get(
"/rate_limit/search_users",
response_model=list[UserSearchResult],
summary="Search Users by Name or Email",
)
async def admin_search_users(
query: str,
limit: int = 20,
admin_user_id: str = Security(get_user_id),
) -> list[UserSearchResult]:
"""Search users by partial email or name. Admin-only.
Queries the User table directly — returns results even for users
without credit transaction history.
"""
if len(query.strip()) < 3:
raise HTTPException(
status_code=400,
detail="Search query must be at least 3 characters.",
)
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
results = await search_users(query, limit=max(1, min(limit, 50)))
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]

View File

@@ -1,566 +0,0 @@
import json
from types import SimpleNamespace
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from pytest_snapshot.plugin import Snapshot
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
from .rate_limit_admin_routes import router as rate_limit_admin_router
app = fastapi.FastAPI()
app.include_router(rate_limit_admin_router)
client = fastapi.testclient.TestClient(app)
_MOCK_MODULE = "backend.api.features.admin.rate_limit_admin_routes"
_TARGET_EMAIL = "target@example.com"
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _mock_usage_status(
daily_used: int = 500_000, weekly_used: int = 3_000_000
) -> CoPilotUsageStatus:
from datetime import UTC, datetime, timedelta
now = datetime.now(UTC)
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used, limit=2_500_000, resets_at=now + timedelta(hours=6)
),
weekly=UsageWindow(
used=weekly_used, limit=12_500_000, resets_at=now + timedelta(days=3)
),
)
def _patch_rate_limit_deps(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
daily_used: int = 500_000,
weekly_used: int = 3_000_000,
):
"""Patch the common rate-limit + user-lookup dependencies."""
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(daily_used=daily_used, weekly_used=weekly_used),
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
def test_get_rate_limit(
mocker: pytest_mock.MockerFixture,
configured_snapshot: Snapshot,
target_user_id: str,
) -> None:
"""Test getting rate limit and usage for a user."""
_patch_rate_limit_deps(mocker, target_user_id)
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_token_limit"] == 2_500_000
assert data["weekly_token_limit"] == 12_500_000
assert data["daily_tokens_used"] == 500_000
assert data["weekly_tokens_used"] == 3_000_000
assert data["tier"] == "FREE"
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
"get_rate_limit",
)
def test_get_rate_limit_by_email(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test looking up rate limits via email instead of user_id."""
_patch_rate_limit_deps(mocker, target_user_id)
mock_user = SimpleNamespace(id=target_user_id, email=_TARGET_EMAIL)
mocker.patch(
f"{_MOCK_MODULE}.get_user_by_email",
new_callable=AsyncMock,
return_value=mock_user,
)
response = client.get("/admin/rate_limit", params={"email": _TARGET_EMAIL})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_token_limit"] == 2_500_000
def test_get_rate_limit_by_email_not_found(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Test that looking up a non-existent email returns 404."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_by_email",
new_callable=AsyncMock,
return_value=None,
)
response = client.get("/admin/rate_limit", params={"email": "nobody@example.com"})
assert response.status_code == 404
def test_get_rate_limit_no_params() -> None:
"""Test that omitting both user_id and email returns 400."""
response = client.get("/admin/rate_limit")
assert response.status_code == 400
def test_reset_user_usage_daily_only(
mocker: pytest_mock.MockerFixture,
configured_snapshot: Snapshot,
target_user_id: str,
) -> None:
"""Test resetting only daily usage (default behaviour)."""
mock_reset = mocker.patch(
f"{_MOCK_MODULE}.reset_user_usage",
new_callable=AsyncMock,
)
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=3_000_000)
response = client.post(
"/admin/rate_limit/reset",
json={"user_id": target_user_id},
)
assert response.status_code == 200
data = response.json()
assert data["daily_tokens_used"] == 0
# Weekly is untouched
assert data["weekly_tokens_used"] == 3_000_000
assert data["tier"] == "FREE"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
"reset_user_usage_daily_only",
)
def test_reset_user_usage_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
configured_snapshot: Snapshot,
target_user_id: str,
) -> None:
"""Test resetting both daily and weekly usage."""
mock_reset = mocker.patch(
f"{_MOCK_MODULE}.reset_user_usage",
new_callable=AsyncMock,
)
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=0)
response = client.post(
"/admin/rate_limit/reset",
json={"user_id": target_user_id, "reset_weekly": True},
)
assert response.status_code == 200
data = response.json()
assert data["daily_tokens_used"] == 0
assert data["weekly_tokens_used"] == 0
assert data["tier"] == "FREE"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
"reset_user_usage_daily_and_weekly",
)
def test_reset_user_usage_redis_failure(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that Redis failure on reset returns 500."""
mocker.patch(
f"{_MOCK_MODULE}.reset_user_usage",
new_callable=AsyncMock,
side_effect=Exception("Redis connection refused"),
)
response = client.post(
"/admin/rate_limit/reset",
json={"user_id": target_user_id},
)
assert response.status_code == 500
def test_get_rate_limit_email_lookup_failure(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that failing to resolve a user email degrades gracefully."""
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(),
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
side_effect=Exception("DB connection lost"),
)
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] is None
def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
"""Test that rate limit admin endpoints require admin role."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get("/admin/rate_limit", params={"user_id": "test"})
assert response.status_code == 403
response = client.post(
"/admin/rate_limit/reset",
json={"user_id": "test"},
)
assert response.status_code == 403
# ---------------------------------------------------------------------------
# Tier management endpoints
# ---------------------------------------------------------------------------
def test_get_user_tier(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test getting a user's rate-limit tier."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
)
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "PRO"
def test_get_user_tier_user_not_found(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that getting tier for a non-existent user returns 404."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=None,
)
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
assert response.status_code == 404
def test_set_user_tier(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test setting a user's rate-limit tier (upgrade)."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
)
mock_set = mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "ENTERPRISE"},
)
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "ENTERPRISE"
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.ENTERPRISE)
def test_set_user_tier_downgrade(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test downgrading a user's tier from PRO to FREE."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
)
mock_set = mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "FREE"},
)
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "FREE"
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
def test_set_user_tier_invalid_tier(
target_user_id: str,
) -> None:
"""Test that setting an invalid tier returns 422."""
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "invalid"},
)
assert response.status_code == 422
def test_set_user_tier_invalid_tier_uppercase(
target_user_id: str,
) -> None:
"""Test that setting an unrecognised uppercase tier (e.g. 'INVALID') returns 422.
Regression: ensures Pydantic enum validation rejects values that are not
members of SubscriptionTier, even when they look like valid enum names.
"""
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "INVALID"},
)
assert response.status_code == 422
body = response.json()
assert "detail" in body
def test_set_user_tier_email_lookup_failure_returns_404(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that email lookup failure returns 404 (user unverifiable)."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
side_effect=Exception("DB connection failed"),
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "PRO"},
)
assert response.status_code == 404
def test_set_user_tier_user_not_found(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that setting tier for a non-existent user returns 404."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=None,
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "PRO"},
)
assert response.status_code == 404
def test_set_user_tier_db_failure(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that DB failure on set tier returns 500."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
)
mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
side_effect=Exception("DB connection refused"),
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "PRO"},
)
assert response.status_code == 500
def test_tier_endpoints_require_admin_role(mock_jwt_user) -> None:
"""Test that tier admin endpoints require admin role."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get("/admin/rate_limit/tier", params={"user_id": "test"})
assert response.status_code == 403
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": "test", "tier": "PRO"},
)
assert response.status_code == 403
# ─── search_users endpoint ──────────────────────────────────────────
def test_search_users_returns_matching_users(
mocker: pytest_mock.MockerFixture,
admin_user_id: str,
) -> None:
"""Partial search should return all matching users from the User table."""
mocker.patch(
_MOCK_MODULE + ".search_users",
new_callable=AsyncMock,
return_value=[
("user-1", "zamil.majdy@gmail.com"),
("user-2", "zamil.majdy@agpt.co"),
],
)
response = client.get("/admin/rate_limit/search_users", params={"query": "zamil"})
assert response.status_code == 200
results = response.json()
assert len(results) == 2
assert results[0]["user_email"] == "zamil.majdy@gmail.com"
assert results[1]["user_email"] == "zamil.majdy@agpt.co"
def test_search_users_empty_results(
mocker: pytest_mock.MockerFixture,
admin_user_id: str,
) -> None:
"""Search with no matches returns empty list."""
mocker.patch(
_MOCK_MODULE + ".search_users",
new_callable=AsyncMock,
return_value=[],
)
response = client.get(
"/admin/rate_limit/search_users", params={"query": "nonexistent"}
)
assert response.status_code == 200
assert response.json() == []
def test_search_users_short_query_rejected(
admin_user_id: str,
) -> None:
"""Query shorter than 3 characters should return 400."""
response = client.get("/admin/rate_limit/search_users", params={"query": "ab"})
assert response.status_code == 400
def test_search_users_negative_limit_clamped(
mocker: pytest_mock.MockerFixture,
admin_user_id: str,
) -> None:
"""Negative limit should be clamped to 1, not passed through."""
mock_search = mocker.patch(
_MOCK_MODULE + ".search_users",
new_callable=AsyncMock,
return_value=[],
)
response = client.get(
"/admin/rate_limit/search_users", params={"query": "test", "limit": -1}
)
assert response.status_code == 200
mock_search.assert_awaited_once_with("test", limit=1)
def test_search_users_requires_admin_role(mock_jwt_user) -> None:
"""Test that the search_users endpoint requires admin role."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get("/admin/rate_limit/search_users", params={"query": "test"})
assert response.status_code == 403

View File

@@ -7,8 +7,6 @@ import fastapi
import fastapi.responses
import prisma.enums
import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model
import backend.api.features.store.cache as store_cache
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
@@ -134,40 +132,3 @@ async def admin_download_agent_file(
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
@router.get(
"/submissions/{store_listing_version_id}/preview",
summary="Admin Preview Submission Listing",
)
async def admin_preview_submission(
store_listing_version_id: str,
) -> store_model.StoreAgentDetails:
"""
Preview a marketplace submission as it would appear on the listing page.
Bypasses the APPROVED-only StoreAgent view so admins can preview pending
submissions before approving.
"""
return await store_db.get_store_agent_details_as_admin(store_listing_version_id)
@router.post(
"/submissions/{store_listing_version_id}/add-to-library",
summary="Admin Add Pending Agent to Library",
status_code=201,
)
async def admin_add_agent_to_library(
store_listing_version_id: str,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> library_model.LibraryAgent:
"""
Add a pending marketplace agent to the admin's library for review.
Uses admin-level access to bypass marketplace APPROVED-only checks.
The builder can load the graph because get_graph() checks library
membership as a fallback: "you added it, you keep it."
"""
return await library_db.add_store_agent_to_library_as_admin(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)

View File

@@ -1,335 +0,0 @@
"""Tests for admin store routes and the bypass logic they depend on.
Tests are organized by what they protect:
- SECRT-2162: get_graph_as_admin bypasses ownership/marketplace checks
- SECRT-2167 security: admin endpoints reject non-admin users
- SECRT-2167 bypass: preview queries StoreListingVersion (not StoreAgent view),
and add-to-library uses get_graph_as_admin (not get_graph)
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import fastapi
import fastapi.responses
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.graph import get_graph_as_admin
from backend.util.exceptions import NotFoundError
from .store_admin_routes import router as store_admin_router
# Shared constants
ADMIN_USER_ID = "admin-user-id"
CREATOR_USER_ID = "other-creator-id"
GRAPH_ID = "test-graph-id"
GRAPH_VERSION = 3
SLV_ID = "test-store-listing-version-id"
def _make_mock_graph(user_id: str = CREATOR_USER_ID) -> MagicMock:
graph = MagicMock()
graph.userId = user_id
graph.id = GRAPH_ID
graph.version = GRAPH_VERSION
graph.Nodes = []
return graph
# ---- SECRT-2162: get_graph_as_admin bypasses ownership checks ---- #
@pytest.mark.asyncio
async def test_admin_can_access_pending_agent_not_owned() -> None:
"""get_graph_as_admin must return a graph even when the admin doesn't own
it and it's not APPROVED in the marketplace."""
mock_graph = _make_mock_graph()
mock_graph_model = MagicMock(name="GraphModel")
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
patch(
"backend.data.graph.GraphModel.from_db",
return_value=mock_graph_model,
),
):
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
result = await get_graph_as_admin(
graph_id=GRAPH_ID,
version=GRAPH_VERSION,
user_id=ADMIN_USER_ID,
for_export=False,
)
assert result is mock_graph_model
@pytest.mark.asyncio
async def test_admin_download_pending_agent_with_subagents() -> None:
"""get_graph_as_admin with for_export=True must call get_sub_graphs
and pass sub_graphs to GraphModel.from_db."""
mock_graph = _make_mock_graph()
mock_sub_graph = MagicMock(name="SubGraph")
mock_graph_model = MagicMock(name="GraphModel")
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
patch(
"backend.data.graph.get_sub_graphs",
new_callable=AsyncMock,
return_value=[mock_sub_graph],
) as mock_get_sub,
patch(
"backend.data.graph.GraphModel.from_db",
return_value=mock_graph_model,
) as mock_from_db,
):
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
result = await get_graph_as_admin(
graph_id=GRAPH_ID,
version=GRAPH_VERSION,
user_id=ADMIN_USER_ID,
for_export=True,
)
assert result is mock_graph_model
mock_get_sub.assert_awaited_once_with(mock_graph)
mock_from_db.assert_called_once_with(
graph=mock_graph,
sub_graphs=[mock_sub_graph],
for_export=True,
)
# ---- SECRT-2167 security: admin endpoints reject non-admin users ---- #
app = fastapi.FastAPI()
app.include_router(store_admin_router)
@app.exception_handler(NotFoundError)
async def _not_found_handler(
request: fastapi.Request, exc: NotFoundError
) -> fastapi.responses.JSONResponse:
return fastapi.responses.JSONResponse(status_code=404, content={"detail": str(exc)})
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all route tests in this module."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_preview_requires_admin(mock_jwt_user) -> None:
"""Non-admin users must get 403 on the preview endpoint."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
assert response.status_code == 403
def test_add_to_library_requires_admin(mock_jwt_user) -> None:
"""Non-admin users must get 403 on the add-to-library endpoint."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.post(f"/admin/submissions/{SLV_ID}/add-to-library")
assert response.status_code == 403
def test_preview_nonexistent_submission(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Preview of a nonexistent submission returns 404."""
mocker.patch(
"backend.api.features.admin.store_admin_routes.store_db"
".get_store_agent_details_as_admin",
side_effect=NotFoundError("not found"),
)
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
assert response.status_code == 404
# ---- SECRT-2167 bypass: verify the right data sources are used ---- #
@pytest.mark.asyncio
async def test_preview_queries_store_listing_version_not_store_agent() -> None:
"""get_store_agent_details_as_admin must query StoreListingVersion
directly (not the APPROVED-only StoreAgent view). This is THE test that
prevents the bypass from being accidentally reverted."""
from backend.api.features.store.db import get_store_agent_details_as_admin
mock_slv = MagicMock()
mock_slv.id = SLV_ID
mock_slv.name = "Test Agent"
mock_slv.subHeading = "Short desc"
mock_slv.description = "Long desc"
mock_slv.videoUrl = None
mock_slv.agentOutputDemoUrl = None
mock_slv.imageUrls = ["https://example.com/img.png"]
mock_slv.instructions = None
mock_slv.categories = ["productivity"]
mock_slv.version = 1
mock_slv.agentGraphId = GRAPH_ID
mock_slv.agentGraphVersion = GRAPH_VERSION
mock_slv.updatedAt = datetime(2026, 3, 24, tzinfo=timezone.utc)
mock_slv.recommendedScheduleCron = "0 9 * * *"
mock_listing = MagicMock()
mock_listing.id = "listing-id"
mock_listing.slug = "test-agent"
mock_listing.activeVersionId = SLV_ID
mock_listing.hasApprovedVersion = False
mock_listing.CreatorProfile = MagicMock(username="creator", avatarUrl="")
mock_slv.StoreListing = mock_listing
with (
patch(
"backend.api.features.store.db.prisma.models" ".StoreListingVersion.prisma",
) as mock_slv_prisma,
patch(
"backend.api.features.store.db.prisma.models.StoreAgent.prisma",
) as mock_store_agent_prisma,
):
mock_slv_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
result = await get_store_agent_details_as_admin(SLV_ID)
# Verify it queried StoreListingVersion (not the APPROVED-only StoreAgent)
mock_slv_prisma.return_value.find_unique.assert_awaited_once()
await_args = mock_slv_prisma.return_value.find_unique.await_args
assert await_args is not None
assert await_args.kwargs["where"] == {"id": SLV_ID}
# Verify the APPROVED-only StoreAgent view was NOT touched
mock_store_agent_prisma.assert_not_called()
# Verify the result has the right data
assert result.agent_name == "Test Agent"
assert result.agent_image == ["https://example.com/img.png"]
assert result.has_approved_version is False
assert result.runs == 0
assert result.rating == 0.0
@pytest.mark.asyncio
async def test_resolve_graph_admin_uses_get_graph_as_admin() -> None:
"""resolve_graph_for_library(admin=True) must call get_graph_as_admin,
not get_graph. This is THE test that prevents the add-to-library bypass
from being accidentally reverted."""
from backend.api.features.library._add_to_library import resolve_graph_for_library
mock_slv = MagicMock()
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
mock_graph_model = MagicMock(name="GraphModel")
with (
patch(
"backend.api.features.library._add_to_library.prisma.models"
".StoreListingVersion.prisma",
) as mock_prisma,
patch(
"backend.api.features.library._add_to_library.graph_db"
".get_graph_as_admin",
new_callable=AsyncMock,
return_value=mock_graph_model,
) as mock_admin,
patch(
"backend.api.features.library._add_to_library.graph_db.get_graph",
new_callable=AsyncMock,
) as mock_regular,
):
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
result = await resolve_graph_for_library(SLV_ID, ADMIN_USER_ID, admin=True)
assert result is mock_graph_model
mock_admin.assert_awaited_once_with(
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id=ADMIN_USER_ID
)
mock_regular.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_graph_regular_uses_get_graph() -> None:
"""resolve_graph_for_library(admin=False) must call get_graph,
not get_graph_as_admin. Ensures the non-admin path is preserved."""
from backend.api.features.library._add_to_library import resolve_graph_for_library
mock_slv = MagicMock()
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
mock_graph_model = MagicMock(name="GraphModel")
with (
patch(
"backend.api.features.library._add_to_library.prisma.models"
".StoreListingVersion.prisma",
) as mock_prisma,
patch(
"backend.api.features.library._add_to_library.graph_db"
".get_graph_as_admin",
new_callable=AsyncMock,
) as mock_admin,
patch(
"backend.api.features.library._add_to_library.graph_db.get_graph",
new_callable=AsyncMock,
return_value=mock_graph_model,
) as mock_regular,
):
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
result = await resolve_graph_for_library(SLV_ID, "regular-user-id", admin=False)
assert result is mock_graph_model
mock_regular.assert_awaited_once_with(
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id="regular-user-id"
)
mock_admin.assert_not_awaited()
# ---- Library membership grants graph access (product decision) ---- #
@pytest.mark.asyncio
async def test_library_member_can_view_pending_agent_in_builder() -> None:
"""After adding a pending agent to their library, the user should be
able to load the graph in the builder via get_graph()."""
mock_graph = _make_mock_graph()
mock_graph_model = MagicMock(name="GraphModel")
mock_library_agent = MagicMock()
mock_library_agent.AgentGraph = mock_graph
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.GraphModel.from_db",
return_value=mock_graph_model,
),
):
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_lib_prisma.return_value.find_first = AsyncMock(
return_value=mock_library_agent
)
from backend.data.graph import get_graph
result = await get_graph(
graph_id=GRAPH_ID,
version=GRAPH_VERSION,
user_id=ADMIN_USER_ID,
)
assert result is mock_graph_model, "Library membership should grant graph access"

View File

@@ -11,17 +11,15 @@ from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig, CopilotMode
from backend.copilot.db import get_chat_messages_paginated
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.model import (
ChatMessage,
ChatSession,
ChatSessionMetadata,
append_and_save_message,
create_chat_session,
delete_chat_session,
@@ -32,14 +30,8 @@ from backend.copilot.model import (
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
acquire_reset_lock,
check_rate_limit,
get_daily_reset_count,
get_global_rate_limits,
get_usage_status,
increment_daily_reset_count,
release_reset_lock,
reset_daily_usage,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
@@ -67,16 +59,9 @@ from backend.copilot.tools.models import (
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.redis_client import get_redis_async
from backend.data.understanding import get_business_understanding
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
from backend.util.settings import Settings
settings = Settings()
logger = logging.getLogger(__name__)
from backend.util.exceptions import NotFoundError
config = ChatConfig()
@@ -84,6 +69,8 @@ _UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
)
logger = logging.getLogger(__name__)
async def _validate_and_get_session(
session_id: str,
@@ -112,23 +99,6 @@ class StreamChatRequest(BaseModel):
file_ids: list[str] | None = Field(
default=None, max_length=20
) # Workspace file IDs attached to this message
mode: CopilotMode | None = Field(
default=None,
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
"If None, uses the server default (extended_thinking).",
)
class CreateSessionRequest(BaseModel):
"""Request model for creating a new chat session.
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
Extra/unknown fields are rejected (422) to prevent silent mis-use.
"""
model_config = ConfigDict(extra="forbid")
dry_run: bool = False
class CreateSessionResponse(BaseModel):
@@ -137,7 +107,6 @@ class CreateSessionResponse(BaseModel):
id: str
created_at: str
user_id: str | None
metadata: ChatSessionMetadata = ChatSessionMetadata()
class ActiveStreamInfo(BaseModel):
@@ -156,11 +125,8 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
has_more_messages: bool = False
oldest_sequence: int | None = None
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
metadata: ChatSessionMetadata = ChatSessionMetadata()
class SessionSummaryResponse(BaseModel):
@@ -271,7 +237,6 @@ async def list_sessions(
)
async def create_session(
user_id: Annotated[str, Security(auth.get_user_id)],
request: CreateSessionRequest | None = None,
) -> CreateSessionResponse:
"""
Create a new chat session.
@@ -280,28 +245,22 @@ async def create_session(
Args:
user_id: The authenticated user ID parsed from the JWT (required).
request: Optional request body. When provided, ``dry_run=True``
forces run_block and run_agent calls to use dry-run simulation.
Returns:
CreateSessionResponse: Details of the created session.
"""
dry_run = request.dry_run if request else False
logger.info(
f"Creating session with user_id: "
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
f"{', dry_run=True' if dry_run else ''}"
)
session = await create_chat_session(user_id, dry_run=dry_run)
session = await create_chat_session(user_id)
return CreateSessionResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
user_id=session.user_id,
metadata=session.metadata,
)
@@ -397,78 +356,59 @@ async def update_session_title_route(
async def get_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
limit: int = Query(default=50, ge=1, le=200),
before_sequence: int | None = Query(default=None, ge=0),
) -> SessionDetailResponse:
"""
Retrieve the details of a specific chat session.
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
When no pagination params are provided, returns the most recent messages.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns active_stream info for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The authenticated user's ID.
limit: Maximum number of messages to return (1-200, default 50).
before_sequence: Return messages with sequence < this value (cursor).
user_id: The optional authenticated user ID, or None for anonymous access.
Returns:
SessionDetailResponse: Details for the requested session, including
active_stream info and pagination metadata.
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
"""
page = await get_chat_messages_paginated(
session_id, limit, before_sequence, user_id=user_id
)
if page is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in page.messages]
# Only check active stream on initial load (not on "load more" requests)
messages = [message.model_dump() for message in session.messages]
# Check if there's an active stream for this session
active_stream_info = None
if before_sequence is None:
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
# Skip session metadata on "load more" — frontend only needs messages
if before_sequence is not None:
return SessionDetailResponse(
id=page.session.session_id,
created_at=page.session.started_at.isoformat(),
updated_at=page.session.updated_at.isoformat(),
user_id=page.session.user_id or None,
messages=messages,
active_stream=None,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
total_prompt_tokens=0,
total_completion_tokens=0,
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
# Keep the assistant message (including tool_calls) so the frontend can
# render the correct tool UI (e.g. CreateAgent with mini game).
# convertChatSessionToUiMessages handles isComplete=false by setting
# tool parts without output to state "input-available".
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
total_prompt = sum(u.prompt_tokens for u in page.session.usage)
total_completion = sum(u.completion_tokens for u in page.session.usage)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
return SessionDetailResponse(
id=page.session.session_id,
created_at=page.session.started_at.isoformat(),
updated_at=page.session.updated_at.isoformat(),
user_id=page.session.user_id or None,
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
metadata=page.session.metadata,
)
@@ -481,193 +421,11 @@ async def get_copilot_usage(
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
Global defaults sourced from LaunchDarkly (falling back to config).
Includes the user's rate-limit tier.
"""
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
return await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=tier,
)
class RateLimitResetResponse(BaseModel):
"""Response from resetting the daily rate limit."""
success: bool
credits_charged: int = Field(description="Credits charged (in cents)")
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
@router.post(
"/usage/reset",
status_code=200,
responses={
400: {
"description": "Bad Request (feature disabled or daily limit not reached)"
},
402: {"description": "Payment Required (insufficient credits)"},
429: {
"description": "Too Many Requests (max daily resets exceeded or reset in progress)"
},
503: {
"description": "Service Unavailable (Redis reset failed; credits refunded or support needed)"
},
},
)
async def reset_copilot_usage(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> RateLimitResetResponse:
"""Reset the daily CoPilot rate limit by spending credits.
Allows users who have hit their daily token limit to spend credits
to reset their daily usage counter and continue working.
Returns 400 if the feature is disabled or the user is not over the limit.
Returns 402 if the user has insufficient credits.
"""
cost = config.rate_limit_reset_cost
if cost <= 0:
raise HTTPException(
status_code=400,
detail="Rate limit reset is not available.",
)
if not settings.config.enable_credit:
raise HTTPException(
status_code=400,
detail="Rate limit reset is not available (credit system is disabled).",
)
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
if daily_limit <= 0:
raise HTTPException(
status_code=400,
detail="No daily limit is configured — nothing to reset.",
)
# Check max daily resets. get_daily_reset_count returns None when Redis
# is unavailable; reject the reset in that case to prevent unlimited
# free resets when the counter store is down.
reset_count = await get_daily_reset_count(user_id)
if reset_count is None:
raise HTTPException(
status_code=503,
detail="Unable to verify reset eligibility — please try again later.",
)
if config.max_daily_resets > 0 and reset_count >= config.max_daily_resets:
raise HTTPException(
status_code=429,
detail=f"You've used all {config.max_daily_resets} resets for today.",
)
# Acquire a per-user lock to prevent TOCTOU races (concurrent resets).
if not await acquire_reset_lock(user_id):
raise HTTPException(
status_code=429,
detail="A reset is already in progress. Please try again.",
)
try:
# Verify the user is actually at or over their daily limit.
# (rate_limit_reset_cost intentionally omitted — this object is only
# used for limit checks, not returned to the client.)
usage_status = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
tier=tier,
)
if daily_limit > 0 and usage_status.daily.used < daily_limit:
raise HTTPException(
status_code=400,
detail="You have not reached your daily limit yet.",
)
# If the weekly limit is also exhausted, resetting the daily counter
# won't help — the user would still be blocked by the weekly limit.
if weekly_limit > 0 and usage_status.weekly.used >= weekly_limit:
raise HTTPException(
status_code=400,
detail="Your weekly limit is also reached. Resetting the daily limit won't help.",
)
# Charge credits.
credit_model = await get_user_credit_model(user_id)
try:
remaining = await credit_model.spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
reason="CoPilot daily rate limit reset",
),
)
except InsufficientBalanceError as e:
raise HTTPException(
status_code=402,
detail="Insufficient credits to reset your rate limit.",
) from e
# Reset daily usage in Redis. If this fails, refund the credits
# so the user is not charged for a service they did not receive.
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
# Compensate: refund the charged credits.
refunded = False
try:
await credit_model.top_up_credits(user_id, cost)
refunded = True
logger.warning(
"Refunded %d credits to user %s after Redis reset failure",
cost,
user_id[:8],
)
except Exception:
logger.error(
"CRITICAL: Failed to refund %d credits to user %s "
"after Redis reset failure — manual intervention required",
cost,
user_id[:8],
exc_info=True,
)
if refunded:
raise HTTPException(
status_code=503,
detail="Rate limit reset failed — please try again later. "
"Your credits have not been charged.",
)
raise HTTPException(
status_code=503,
detail="Rate limit reset failed and the automatic refund "
"also failed. Please contact support for assistance.",
)
# Track the reset count for daily cap enforcement.
await increment_daily_reset_count(user_id)
finally:
await release_reset_lock(user_id)
# Return updated usage status.
updated_usage = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=tier,
)
return RateLimitResetResponse(
success=True,
credits_charged=cost,
remaining_balance=remaining,
usage=updated_usage,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
@@ -768,16 +526,12 @@ async def stream_chat_post(
# Pre-turn rate limit check (token-based).
# check_rate_limit short-circuits internally when both limits are 0.
# Global defaults sourced from LaunchDarkly, falling back to config.
if user_id:
try:
daily_limit, weekly_limit, _ = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
await check_rate_limit(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
@@ -866,7 +620,6 @@ async def stream_chat_post(
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
mode=request.mode,
)
setup_time = (time.perf_counter() - stream_start_time) * 1000
@@ -1141,47 +894,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Suggested Prompts ==========
class SuggestedTheme(BaseModel):
"""A themed group of suggested prompts."""
name: str
prompts: list[str]
class SuggestedPromptsResponse(BaseModel):
"""Response model for user-specific suggested prompts grouped by theme."""
themes: list[SuggestedTheme]
@router.get(
"/suggested-prompts",
dependencies=[Security(auth.requires_user)],
)
async def get_suggested_prompts(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> SuggestedPromptsResponse:
"""
Get LLM-generated suggested prompts grouped by theme.
Returns personalized quick-action prompts based on the user's
business understanding. Returns empty themes list if no custom
prompts are available.
"""
understanding = await get_business_understanding(user_id)
if understanding is None or not understanding.suggested_prompts:
return SuggestedPromptsResponse(themes=[])
themes = [
SuggestedTheme(name=name, prompts=prompts)
for name, prompts in understanding.suggested_prompts.items()
]
return SuggestedPromptsResponse(themes=themes)
# ========== Configuration ==========
@@ -1230,7 +942,7 @@ async def health_check() -> dict:
)
# Create and retrieve session to verify full data layer
session = await create_chat_session(health_check_user_id, dry_run=False)
session = await create_chat_session(health_check_user_id)
await get_chat_session(session.session_id, health_check_user_id)
return {

View File

@@ -1,7 +1,7 @@
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
@@ -9,7 +9,6 @@ import pytest
import pytest_mock
from backend.api.features.chat import routes as chat_routes
from backend.copilot.rate_limit import SubscriptionTier
app = fastapi.FastAPI()
app.include_router(chat_routes.router)
@@ -332,28 +331,14 @@ def _mock_usage(
*,
daily_used: int = 500,
weekly_used: int = 2000,
daily_limit: int = 10000,
weekly_limit: int = 50000,
tier: "SubscriptionTier" = SubscriptionTier.FREE,
) -> AsyncMock:
"""Mock get_usage_status and get_global_rate_limits for usage endpoint tests.
Mocks both ``get_global_rate_limits`` (returns the given limits + tier) and
``get_usage_status`` so that tests exercise the endpoint without hitting
LaunchDarkly or Prisma.
"""
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
mocker.patch(
"backend.api.features.chat.routes.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(daily_limit, weekly_limit, tier),
)
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at),
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
@@ -383,8 +368,6 @@ def test_usage_returns_daily_and_weekly(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
tier=SubscriptionTier.FREE,
)
@@ -392,10 +375,11 @@ def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards resolved limits from get_global_rate_limits to get_usage_status."""
mock_get = _mock_usage(mocker, daily_limit=99999, weekly_limit=77777)
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
response = client.get("/usage")
@@ -404,8 +388,6 @@ def test_usage_uses_config_limits(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
rate_limit_reset_cost=500,
tier=SubscriptionTier.FREE,
)
@@ -418,164 +400,3 @@ def test_usage_rejects_unauthenticated_request() -> None:
response = unauthenticated_client.get("/usage")
assert response.status_code == 401
# ─── Suggested prompts endpoint ──────────────────────────────────────
def _mock_get_business_understanding(
mocker: pytest_mock.MockerFixture,
*,
return_value=None,
):
"""Mock get_business_understanding."""
return mocker.patch(
"backend.api.features.chat.routes.get_business_understanding",
new_callable=AsyncMock,
return_value=return_value,
)
def test_suggested_prompts_returns_themes(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with themed prompts gets them back as themes list."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = {
"Learn": ["L1", "L2"],
"Create": ["C1"],
}
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
data = response.json()
assert "themes" in data
themes_by_name = {t["name"]: t["prompts"] for t in data["themes"]}
assert themes_by_name["Learn"] == ["L1", "L2"]
assert themes_by_name["Create"] == ["C1"]
def test_suggested_prompts_no_understanding(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with no understanding gets empty themes list."""
_mock_get_business_understanding(mocker, return_value=None)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"themes": []}
def test_suggested_prompts_empty_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding but empty prompts gets empty themes list."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = {}
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"themes": []}
# ─── Create session: dry_run contract ─────────────────────────────────
def _mock_create_chat_session(mocker: pytest_mock.MockerFixture):
"""Mock create_chat_session to return a fake session."""
from backend.copilot.model import ChatSession
async def _fake_create(user_id: str, *, dry_run: bool):
return ChatSession.new(user_id, dry_run=dry_run)
return mocker.patch(
"backend.api.features.chat.routes.create_chat_session",
new_callable=AsyncMock,
side_effect=_fake_create,
)
def test_create_session_dry_run_true(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Sending ``{"dry_run": true}`` sets metadata.dry_run to True."""
_mock_create_chat_session(mocker)
response = client.post("/sessions", json={"dry_run": True})
assert response.status_code == 200
assert response.json()["metadata"]["dry_run"] is True
def test_create_session_dry_run_default_false(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Empty body defaults dry_run to False."""
_mock_create_chat_session(mocker)
response = client.post("/sessions")
assert response.status_code == 200
assert response.json()["metadata"]["dry_run"] is False
def test_create_session_rejects_nested_metadata(
test_user_id: str,
) -> None:
"""Sending ``{"metadata": {"dry_run": true}}`` must return 422, not silently
default to ``dry_run=False``. This guards against the common mistake of
nesting dry_run inside metadata instead of providing it at the top level."""
response = client.post(
"/sessions",
json={"metadata": {"dry_run": True}},
)
assert response.status_code == 422
class TestStreamChatRequestModeValidation:
"""Pydantic-level validation of the ``mode`` field on StreamChatRequest."""
def test_rejects_invalid_mode_value(self) -> None:
"""Any string outside the Literal set must raise ValidationError."""
from pydantic import ValidationError
from backend.api.features.chat.routes import StreamChatRequest
with pytest.raises(ValidationError):
StreamChatRequest(message="hi", mode="turbo") # type: ignore[arg-type]
def test_accepts_fast_mode(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode="fast")
assert req.mode == "fast"
def test_accepts_extended_thinking_mode(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode="extended_thinking")
assert req.mode == "extended_thinking"
def test_accepts_none_mode(self) -> None:
"""``mode=None`` is valid (server decides via feature flags)."""
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode=None)
assert req.mode is None
def test_mode_defaults_to_none_when_omitted(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi")
assert req.mode is None

View File

@@ -1,13 +0,0 @@
"""Override session-scoped fixtures so unit tests run without the server."""
import pytest
@pytest.fixture(scope="session")
def server():
yield None
@pytest.fixture(scope="session", autouse=True)
def graph_cleanup():
yield

View File

@@ -34,21 +34,16 @@ from backend.data.model import (
HostScopedCredentials,
OAuth2Credentials,
UserIntegrations,
is_sdk_default,
)
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.managed_credentials import ensure_managed_credentials
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
@@ -114,7 +109,6 @@ class CredentialsMetaResponse(BaseModel):
default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
)
is_managed: bool = False
@model_validator(mode="before")
@classmethod
@@ -144,19 +138,6 @@ class CredentialsMetaResponse(BaseModel):
return None
def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
return CredentialsMetaResponse(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
is_managed=cred.is_managed,
)
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
async def callback(
provider: Annotated[
@@ -223,20 +204,34 @@ async def callback(
f"and provider {provider.value}"
)
return to_meta_response(credentials)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=(CredentialsMetaResponse.get_host(credentials)),
)
@router.get("/credentials", summary="List Credentials")
async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
# Fire-and-forget: provision missing managed credentials in the background.
# The credential appears on the next page load; listing is never blocked.
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_all_creds(user_id)
return [
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
CredentialsMetaResponse(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
)
for cred in credentials
]
@@ -247,11 +242,19 @@ async def list_credentials_by_provider(
],
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
return [
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
CredentialsMetaResponse(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
)
for cred in credentials
]
@@ -264,21 +267,18 @@ async def get_credential(
],
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
if is_sdk_default(cred_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
) -> Credentials:
credential = await creds_manager.get(user_id, cred_id)
if not credential:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if not provider_matches(credential.provider, provider):
if credential.provider != provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials do not match the specified provider",
)
return to_meta_response(credential)
return credential
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
@@ -288,22 +288,16 @@ async def create_credentials(
ProviderName, Path(title="The provider to create credentials for")
],
credentials: Credentials,
) -> CredentialsMetaResponse:
if is_sdk_default(credentials.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Cannot create credentials with a reserved ID",
)
) -> Credentials:
credentials.provider = provider
try:
await creds_manager.create(user_id, credentials)
except Exception:
logger.exception("Failed to store credentials")
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to store credentials",
detail=f"Failed to store credentials: {str(e)}",
)
return to_meta_response(credentials)
return credentials
class CredentialsDeletionResponse(BaseModel):
@@ -338,29 +332,15 @@ async def delete_credentials(
bool, Query(title="Whether to proceed if any linked webhooks are still in use")
] = False,
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
if is_sdk_default(cred_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
if not creds:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if not provider_matches(creds.provider, provider):
if creds.provider != provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials not found",
)
if creds.is_managed:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="AutoGPT-managed credentials cannot be deleted",
detail="Credentials do not match the specified provider",
)
try:

View File

@@ -1,570 +0,0 @@
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import fastapi
import fastapi.testclient
import pytest
from pydantic import SecretStr
from backend.api.features.integrations.router import router
from backend.data.model import (
APIKeyCredentials,
HostScopedCredentials,
OAuth2Credentials,
UserPasswordCredentials,
)
app = fastapi.FastAPI()
app.include_router(router)
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "test-user-id"
def _make_api_key_cred(cred_id: str = "cred-123", provider: str = "openai"):
return APIKeyCredentials(
id=cred_id,
provider=provider,
title="My API Key",
api_key=SecretStr("sk-secret-key-value"),
)
def _make_oauth2_cred(cred_id: str = "cred-456", provider: str = "github"):
return OAuth2Credentials(
id=cred_id,
provider=provider,
title="My OAuth",
access_token=SecretStr("ghp_secret_token"),
refresh_token=SecretStr("ghp_refresh_secret"),
scopes=["repo", "user"],
username="testuser",
)
def _make_user_password_cred(cred_id: str = "cred-789", provider: str = "openai"):
return UserPasswordCredentials(
id=cred_id,
provider=provider,
title="My Login",
username=SecretStr("admin"),
password=SecretStr("s3cret-pass"),
)
def _make_host_scoped_cred(cred_id: str = "cred-host", provider: str = "openai"):
return HostScopedCredentials(
id=cred_id,
provider=provider,
title="Host Cred",
host="https://api.example.com",
headers={"Authorization": SecretStr("Bearer top-secret")},
)
def _make_sdk_default_cred(provider: str = "openai"):
return APIKeyCredentials(
id=f"{provider}-default",
provider=provider,
title=f"{provider} (default)",
api_key=SecretStr("sk-platform-secret-key"),
)
@pytest.fixture(autouse=True)
def setup_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
class TestGetCredentialReturnsMetaOnly:
"""GET /{provider}/credentials/{cred_id} must not return secrets."""
def test_api_key_credential_no_secret(self):
cred = _make_api_key_cred()
with (
patch.object(router, "dependencies", []),
patch("backend.api.features.integrations.router.creds_manager") as mock_mgr,
):
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.get("/openai/credentials/cred-123")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "cred-123"
assert data["provider"] == "openai"
assert data["type"] == "api_key"
assert "api_key" not in data
assert "sk-secret-key-value" not in str(data)
def test_oauth2_credential_no_secret(self):
cred = _make_oauth2_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.get("/github/credentials/cred-456")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "cred-456"
assert data["scopes"] == ["repo", "user"]
assert data["username"] == "testuser"
assert "access_token" not in data
assert "refresh_token" not in data
assert "ghp_" not in str(data)
def test_user_password_credential_no_secret(self):
cred = _make_user_password_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.get("/openai/credentials/cred-789")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "cred-789"
assert "password" not in data
assert "username" not in data or data["username"] is None
assert "s3cret-pass" not in str(data)
assert "admin" not in str(data)
def test_host_scoped_credential_no_secret(self):
cred = _make_host_scoped_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.get("/openai/credentials/cred-host")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "cred-host"
assert data["host"] == "https://api.example.com"
assert "headers" not in data
assert "top-secret" not in str(data)
def test_get_credential_wrong_provider_returns_404(self):
"""Provider mismatch should return generic 404, not leak credential existence."""
cred = _make_api_key_cred(provider="openai")
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.get("/github/credentials/cred-123")
assert resp.status_code == 404
assert resp.json()["detail"] == "Credentials not found"
def test_list_credentials_no_secrets(self):
"""List endpoint must not leak secrets in any credential."""
creds = [_make_api_key_cred(), _make_oauth2_cred()]
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_all_creds = AsyncMock(return_value=creds)
resp = client.get("/credentials")
assert resp.status_code == 200
raw = str(resp.json())
assert "sk-secret-key-value" not in raw
assert "ghp_secret_token" not in raw
assert "ghp_refresh_secret" not in raw
class TestSdkDefaultCredentialsNotAccessible:
"""SDK default credentials (ID ending in '-default') must be hidden."""
def test_get_sdk_default_returns_404(self):
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock()
resp = client.get("/openai/credentials/openai-default")
assert resp.status_code == 404
mock_mgr.get.assert_not_called()
def test_list_credentials_excludes_sdk_defaults(self):
user_cred = _make_api_key_cred()
sdk_cred = _make_sdk_default_cred("openai")
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_all_creds = AsyncMock(return_value=[user_cred, sdk_cred])
resp = client.get("/credentials")
assert resp.status_code == 200
data = resp.json()
ids = [c["id"] for c in data]
assert "cred-123" in ids
assert "openai-default" not in ids
def test_list_by_provider_excludes_sdk_defaults(self):
user_cred = _make_api_key_cred()
sdk_cred = _make_sdk_default_cred("openai")
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_creds_by_provider = AsyncMock(
return_value=[user_cred, sdk_cred]
)
resp = client.get("/openai/credentials")
assert resp.status_code == 200
data = resp.json()
ids = [c["id"] for c in data]
assert "cred-123" in ids
assert "openai-default" not in ids
def test_delete_sdk_default_returns_404(self):
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_creds_by_id = AsyncMock()
resp = client.request("DELETE", "/openai/credentials/openai-default")
assert resp.status_code == 404
mock_mgr.store.get_creds_by_id.assert_not_called()
class TestCreateCredentialNoSecretInResponse:
"""POST /{provider}/credentials must not return secrets."""
def test_create_api_key_no_secret_in_response(self):
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.create = AsyncMock()
resp = client.post(
"/openai/credentials",
json={
"id": "new-cred",
"provider": "openai",
"type": "api_key",
"title": "New Key",
"api_key": "sk-newsecret",
},
)
assert resp.status_code == 201
data = resp.json()
assert data["id"] == "new-cred"
assert "api_key" not in data
assert "sk-newsecret" not in str(data)
def test_create_with_sdk_default_id_rejected(self):
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.create = AsyncMock()
resp = client.post(
"/openai/credentials",
json={
"id": "openai-default",
"provider": "openai",
"type": "api_key",
"title": "Sneaky",
"api_key": "sk-evil",
},
)
assert resp.status_code == 403
mock_mgr.create.assert_not_called()
class TestManagedCredentials:
"""AutoGPT-managed credentials cannot be deleted by users."""
def test_delete_is_managed_returns_403(self):
cred = APIKeyCredentials(
id="managed-cred-1",
provider="agent_mail",
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-managed-key"),
is_managed=True,
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
assert resp.status_code == 403
assert "AutoGPT-managed" in resp.json()["detail"]
def test_list_credentials_includes_is_managed_field(self):
managed = APIKeyCredentials(
id="managed-1",
provider="agent_mail",
title="AgentMail (managed)",
api_key=SecretStr("sk-key"),
is_managed=True,
)
regular = APIKeyCredentials(
id="regular-1",
provider="openai",
title="My Key",
api_key=SecretStr("sk-key"),
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
resp = client.get("/credentials")
assert resp.status_code == 200
data = resp.json()
managed_cred = next(c for c in data if c["id"] == "managed-1")
regular_cred = next(c for c in data if c["id"] == "regular-1")
assert managed_cred["is_managed"] is True
assert regular_cred["is_managed"] is False
# ---------------------------------------------------------------------------
# Managed credential provisioning infrastructure
# ---------------------------------------------------------------------------
def _make_managed_cred(
provider: str = "agent_mail", pod_id: str = "pod-abc"
) -> APIKeyCredentials:
return APIKeyCredentials(
id="managed-auto",
provider=provider,
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-pod-key"),
is_managed=True,
metadata={"pod_id": pod_id},
)
def _make_store_mock(**kwargs) -> MagicMock:
"""Create a store mock with a working async ``locks()`` context manager."""
@asynccontextmanager
async def _noop_locked(key):
yield
locks_obj = MagicMock()
locks_obj.locked = _noop_locked
store = MagicMock(**kwargs)
store.locks = AsyncMock(return_value=locks_obj)
return store
class TestEnsureManagedCredentials:
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
@pytest.mark.asyncio
async def test_provisions_when_missing(self):
"""Provider.provision() is called when no managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(return_value=cred)
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
store.add_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_awaited_once_with("user-1")
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_when_already_exists(self):
"""Provider.provision() is NOT called when managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=True)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_unavailable(self):
"""Provider.provision() is NOT called when provider is not available."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=False)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
store.has_managed_credential.assert_not_awaited()
@pytest.mark.asyncio
async def test_provision_failure_does_not_propagate(self):
"""A failed provision is logged but does not raise."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
# No exception raised — provisioning failure is swallowed.
class TestCleanupManagedCredentials:
"""Unit tests for cleanup_managed_credentials."""
@pytest.mark.asyncio
async def test_calls_deprovision_for_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_non_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
regular = _make_api_key_cred()
provider = MagicMock()
provider.provider_name = "openai"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[regular])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["openai"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_not_awaited()
@pytest.mark.asyncio
async def test_deprovision_failure_does_not_propagate(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
# No exception raised — cleanup failure is swallowed.

View File

@@ -1,120 +0,0 @@
"""Shared logic for adding store agents to a user's library.
Both `add_store_agent_to_library` and `add_store_agent_to_library_as_admin`
delegate to these helpers so the duplication-prone create/restore/dedup
logic lives in exactly one place.
"""
import logging
import prisma.errors
import prisma.models
import backend.api.features.library.model as library_model
import backend.data.graph as graph_db
from backend.data.graph import GraphModel, GraphSettings
from backend.data.includes import library_agent_include
from backend.util.exceptions import NotFoundError
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
async def resolve_graph_for_library(
store_listing_version_id: str,
user_id: str,
*,
admin: bool,
) -> GraphModel:
"""Look up a StoreListingVersion and resolve its graph.
When ``admin=True``, uses ``get_graph_as_admin`` to bypass the marketplace
APPROVED-only check. Otherwise uses the regular ``get_graph``.
"""
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"AgentGraph": True}
)
if not slv or not slv.AgentGraph:
raise NotFoundError(
f"Store listing version {store_listing_version_id} not found or invalid"
)
ag = slv.AgentGraph
if admin:
graph_model = await graph_db.get_graph_as_admin(
graph_id=ag.id, version=ag.version, user_id=user_id
)
else:
graph_model = await graph_db.get_graph(
graph_id=ag.id, version=ag.version, user_id=user_id
)
if not graph_model:
raise NotFoundError(f"Graph #{ag.id} v{ag.version} not found or accessible")
return graph_model
async def add_graph_to_library(
store_listing_version_id: str,
graph_model: GraphModel,
user_id: str,
) -> library_model.LibraryAgent:
"""Check existing / restore soft-deleted / create new LibraryAgent.
Uses a create-then-catch-UniqueViolationError-then-update pattern on
the (userId, agentGraphId, agentGraphVersion) composite unique constraint.
This is more robust than ``upsert`` because Prisma's upsert atomicity
guarantees are not well-documented for all versions.
"""
settings_json = SafeJson(GraphSettings.from_graph(graph_model).model_dump())
_include = library_agent_include(
user_id, include_nodes=False, include_executions=False
)
try:
added_agent = await prisma.models.LibraryAgent.prisma().create(
data={
"User": {"connect": {"id": user_id}},
"AgentGraph": {
"connect": {
"graphVersionId": {
"id": graph_model.id,
"version": graph_model.version,
}
}
},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
"settings": settings_json,
},
include=_include,
)
except prisma.errors.UniqueViolationError:
# Already exists — update to restore if previously soft-deleted/archived
added_agent = await prisma.models.LibraryAgent.prisma().update(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph_model.id,
"agentGraphVersion": graph_model.version,
}
},
data={
"isDeleted": False,
"isArchived": False,
"settings": settings_json,
},
include=_include,
)
if added_agent is None:
raise NotFoundError(
f"LibraryAgent for graph #{graph_model.id} "
f"v{graph_model.version} not found after UniqueViolationError"
)
logger.debug(
f"Added graph #{graph_model.id} v{graph_model.version} "
f"for store listing version #{store_listing_version_id} "
f"to library for user #{user_id}"
)
return library_model.LibraryAgent.from_db(added_agent)

View File

@@ -1,80 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import prisma.errors
import pytest
from ._add_to_library import add_graph_to_library
@pytest.mark.asyncio
async def test_add_graph_to_library_create_new_agent() -> None:
"""When no matching LibraryAgent exists, create inserts a new one."""
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
created_agent = MagicMock(name="CreatedLibraryAgent")
converted_agent = MagicMock(name="ConvertedLibraryAgent")
with (
patch(
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
) as mock_prisma,
patch(
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
):
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(created_agent)
# Verify create was called with correct data
create_call = mock_prisma.return_value.create.call_args
create_data = create_call.kwargs["data"]
assert create_data["User"] == {"connect": {"id": "user-id"}}
assert create_data["AgentGraph"] == {
"connect": {"graphVersionId": {"id": "graph-id", "version": 2}}
}
assert create_data["isCreatedByUser"] is False
assert create_data["useGraphIsActiveVersion"] is False
@pytest.mark.asyncio
async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
"""UniqueViolationError on create falls back to update."""
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
updated_agent = MagicMock(name="UpdatedLibraryAgent")
converted_agent = MagicMock(name="ConvertedLibraryAgent")
with (
patch(
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
) as mock_prisma,
patch(
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
):
mock_prisma.return_value.create = AsyncMock(
side_effect=prisma.errors.UniqueViolationError(
MagicMock(), message="unique constraint"
)
)
mock_prisma.return_value.update = AsyncMock(return_value=updated_agent)
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(updated_agent)
# Verify update was called with correct where and data
update_call = mock_prisma.return_value.update.call_args
assert update_call.kwargs["where"] == {
"userId_agentGraphId_agentGraphVersion": {
"userId": "user-id",
"agentGraphId": "graph-id",
"agentGraphVersion": 2,
}
}
update_data = update_call.kwargs["data"]
assert update_data["isDeleted"] is False
assert update_data["isArchived"] is False

View File

@@ -336,15 +336,12 @@ async def get_library_agent_by_graph_id(
user_id: str,
graph_id: str,
graph_version: Optional[int] = None,
include_archived: bool = False,
) -> library_model.LibraryAgent | None:
filter: prisma.types.LibraryAgentWhereInput = {
"agentGraphId": graph_id,
"userId": user_id,
"isDeleted": False,
}
if not include_archived:
filter["isArchived"] = False
if graph_version is not None:
filter["agentGraphVersion"] = graph_version
@@ -436,58 +433,32 @@ async def create_library_agent(
async with transaction() as tx:
library_agents = await asyncio.gather(
*(
prisma.models.LibraryAgent.prisma(tx).upsert(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph_entry.id,
"agentGraphVersion": graph_entry.version,
}
},
data={
"create": prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == graph.user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
"id": graph_entry.id,
"version": graph_entry.version,
}
prisma.models.LibraryAgent.prisma(tx).create(
data=prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
"id": graph_entry.id,
"version": graph_entry.version,
}
},
settings=SafeJson(
GraphSettings.from_graph(
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
"update": {
"isDeleted": False,
"isArchived": False,
"useGraphIsActiveVersion": True,
"settings": SafeJson(
GraphSettings.from_graph(
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
}
},
},
settings=SafeJson(
GraphSettings.from_graph(
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
),
@@ -611,9 +582,7 @@ async def update_graph_in_library(
created_graph = await graph_db.create_graph(graph_model, user_id)
library_agent = await get_library_agent_by_graph_id(
user_id, created_graph.id, include_archived=True
)
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
if not library_agent:
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
@@ -849,38 +818,92 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
async def add_store_agent_to_library(
store_listing_version_id: str, user_id: str
) -> library_model.LibraryAgent:
"""Adds a marketplace agent to the users library.
See also: `add_store_agent_to_library_as_admin()` which uses
`get_graph_as_admin` to bypass marketplace status checks for admin review.
"""
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
Adds an agent from a store listing version to the user's library if they don't already have it.
Args:
store_listing_version_id: The ID of the store listing version containing the agent.
user_id: The users library to which the agent is being added.
Returns:
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
Raises:
NotFoundError: If the store listing or associated agent is not found.
DatabaseError: If there's an issue creating the LibraryAgent record.
"""
logger.debug(
f"Adding agent from store listing version #{store_listing_version_id} "
f"to library for user #{user_id}"
)
graph_model = await resolve_graph_for_library(
store_listing_version_id, user_id, admin=False
)
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
async def add_store_agent_to_library_as_admin(
store_listing_version_id: str, user_id: str
) -> library_model.LibraryAgent:
"""Admin variant that uses `get_graph_as_admin` to bypass marketplace
APPROVED-only checks, allowing admins to add pending agents for review."""
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
logger.warning(
f"ADMIN adding agent from store listing version "
f"#{store_listing_version_id} to library for user #{user_id}"
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"AgentGraph": True}
)
)
graph_model = await resolve_graph_for_library(
store_listing_version_id, user_id, admin=True
if not store_listing_version or not store_listing_version.AgentGraph:
logger.warning(f"Store listing version not found: {store_listing_version_id}")
raise NotFoundError(
f"Store listing version {store_listing_version_id} not found or invalid"
)
graph = store_listing_version.AgentGraph
# Convert to GraphModel to check for HITL blocks
graph_model = await graph_db.get_graph(
graph_id=graph.id,
version=graph.version,
user_id=user_id,
include_subgraphs=False,
)
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
if not graph_model:
raise NotFoundError(
f"Graph #{graph.id} v{graph.version} not found or accessible"
)
# Check if user already has this agent (non-deleted)
if existing := await get_library_agent_by_graph_id(
user_id, graph.id, graph.version
):
return existing
# Check for soft-deleted version and restore it
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
}
},
)
if deleted_agent and deleted_agent.isDeleted:
return await update_library_agent(deleted_agent.id, user_id, is_deleted=False)
# Create LibraryAgent entry
added_agent = await prisma.models.LibraryAgent.prisma().create(
data={
"User": {"connect": {"id": user_id}},
"AgentGraph": {
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
"settings": SafeJson(GraphSettings.from_graph(graph_model).model_dump()),
},
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
),
)
logger.debug(
f"Added graph #{graph.id} v{graph.version}"
f"for store listing version #{store_listing_version.id} "
f"to library for user #{user_id}"
)
return library_model.LibraryAgent.from_db(added_agent)
##############################################

View File

@@ -1,6 +1,4 @@
from contextlib import asynccontextmanager
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import prisma.enums
import prisma.models
@@ -87,6 +85,10 @@ async def test_get_library_agents(mocker):
async def test_add_agent_to_library(mocker):
await connect()
# Mock the transaction context
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
# Mock data
mock_store_listing_data = prisma.models.StoreListingVersion(
id="version123",
@@ -141,18 +143,15 @@ async def test_add_agent_to_library(mocker):
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
)
# Mock graph_db.get_graph function that's called in resolve_graph_for_library
# (lives in _add_to_library.py after refactor, not db.py)
mock_graph_db = mocker.patch(
"backend.api.features.library._add_to_library.graph_db"
)
# Mock graph_db.get_graph function that's called to check for HITL blocks
mock_graph_db = mocker.patch("backend.api.features.library.db.graph_db")
mock_graph_model = mocker.Mock()
mock_graph_model.id = "agent1"
mock_graph_model.version = 1
mock_graph_model.nodes = (
[]
) # Empty list so _has_human_in_the_loop_blocks returns False
@@ -171,27 +170,37 @@ async def test_add_agent_to_library(mocker):
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"AgentGraph": True}
)
mock_library_agent.return_value.find_unique.assert_called_once_with(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": "test-user",
"agentGraphId": "agent1",
"agentGraphVersion": 1,
}
},
)
# Check that create was called with the expected data including settings
create_call_args = mock_library_agent.return_value.create.call_args
assert create_call_args is not None
# Verify the create data structure
create_data = create_call_args.kwargs["data"]
expected_create = {
# Verify the main structure
expected_data = {
"User": {"connect": {"id": "test-user"}},
"AgentGraph": {"connect": {"graphVersionId": {"id": "agent1", "version": 1}}},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
}
for key, value in expected_create.items():
assert create_data[key] == value
actual_data = create_call_args[1]["data"]
# Check that all expected fields are present
for key, value in expected_data.items():
assert actual_data[key] == value
# Check that settings field is present and is a SafeJson object
assert "settings" in create_data
assert hasattr(create_data["settings"], "__class__") # Should be a SafeJson object
assert "settings" in actual_data
assert hasattr(actual_data["settings"], "__class__") # Should be a SafeJson object
# Check include parameter
assert create_call_args.kwargs["include"] == library_agent_include(
assert create_call_args[1]["include"] == library_agent_include(
"test-user", include_nodes=False, include_executions=False
)
@@ -215,141 +224,3 @@ async def test_add_agent_to_library_not_found(mocker):
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"AgentGraph": True}
)
@pytest.mark.asyncio
async def test_get_library_agent_by_graph_id_excludes_archived(mocker):
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
result = await db.get_library_agent_by_graph_id("test-user", "agent1", 7)
assert result is None
mock_library_agent.return_value.find_first.assert_called_once()
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
assert where == {
"agentGraphId": "agent1",
"userId": "test-user",
"isDeleted": False,
"isArchived": False,
"agentGraphVersion": 7,
}
@pytest.mark.asyncio
async def test_get_library_agent_by_graph_id_can_include_archived(mocker):
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
result = await db.get_library_agent_by_graph_id(
"test-user",
"agent1",
7,
include_archived=True,
)
assert result is None
mock_library_agent.return_value.find_first.assert_called_once()
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
assert where == {
"agentGraphId": "agent1",
"userId": "test-user",
"isDeleted": False,
"agentGraphVersion": 7,
}
@pytest.mark.asyncio
async def test_update_graph_in_library_allows_archived_library_agent(mocker):
graph = mocker.Mock(id="graph-id")
existing_version = mocker.Mock(version=1, is_active=True)
graph_model = mocker.Mock()
created_graph = mocker.Mock(id="graph-id", version=2, is_active=False)
current_library_agent = mocker.Mock()
updated_library_agent = mocker.Mock()
mocker.patch(
"backend.api.features.library.db.graph_db.get_graph_all_versions",
new=mocker.AsyncMock(return_value=[existing_version]),
)
mocker.patch(
"backend.api.features.library.db.graph_db.make_graph_model",
return_value=graph_model,
)
mocker.patch(
"backend.api.features.library.db.graph_db.create_graph",
new=mocker.AsyncMock(return_value=created_graph),
)
mock_get_library_agent = mocker.patch(
"backend.api.features.library.db.get_library_agent_by_graph_id",
new=mocker.AsyncMock(return_value=current_library_agent),
)
mock_update_library_agent = mocker.patch(
"backend.api.features.library.db.update_library_agent_version_and_settings",
new=mocker.AsyncMock(return_value=updated_library_agent),
)
result_graph, result_library_agent = await db.update_graph_in_library(
graph,
"test-user",
)
assert result_graph is created_graph
assert result_library_agent is updated_library_agent
assert graph.version == 2
graph_model.reassign_ids.assert_called_once_with(
user_id="test-user", reassign_graph_id=False
)
mock_get_library_agent.assert_awaited_once_with(
"test-user",
"graph-id",
include_archived=True,
)
mock_update_library_agent.assert_awaited_once_with("test-user", created_graph)
@pytest.mark.asyncio
async def test_create_library_agent_uses_upsert():
"""create_library_agent should use upsert (not create) to handle duplicates."""
mock_graph = MagicMock()
mock_graph.id = "graph-1"
mock_graph.version = 1
mock_graph.user_id = "user-1"
mock_graph.nodes = []
mock_graph.sub_graphs = []
mock_upserted = MagicMock(name="UpsertedLibraryAgent")
@asynccontextmanager
async def fake_tx():
yield None
with (
patch("backend.api.features.library.db.transaction", fake_tx),
patch("prisma.models.LibraryAgent.prisma") as mock_prisma,
patch(
"backend.api.features.library.db.add_generated_agent_image",
new=AsyncMock(),
),
patch(
"backend.api.features.library.model.LibraryAgent.from_db",
return_value=MagicMock(),
),
):
mock_prisma.return_value.upsert = AsyncMock(return_value=mock_upserted)
result = await db.create_library_agent(mock_graph, "user-1")
assert len(result) == 1
upsert_call = mock_prisma.return_value.upsert.call_args
assert upsert_call is not None
# Verify the upsert where clause uses the composite unique key
where = upsert_call.kwargs["where"]
assert "userId_agentGraphId_agentGraphVersion" in where
# Verify the upsert data has both create and update branches
data = upsert_call.kwargs["data"]
assert "create" in data
assert "update" in data
# Verify update branch restores soft-deleted/archived agents
assert data["update"]["isDeleted"] is False
assert data["update"]["isArchived"] is False

View File

@@ -12,7 +12,6 @@ Tests cover:
5. Complete OAuth flow end-to-end
"""
import asyncio
import base64
import hashlib
import secrets
@@ -59,27 +58,14 @@ async def test_user(server, test_user_id: str):
yield test_user_id
# Cleanup - delete in correct order due to foreign key constraints.
# Wrap in try/except because the event loop or Prisma engine may already
# be closed during session teardown on Python 3.12+.
try:
await asyncio.gather(
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
PrismaOAuthRefreshToken.prisma().delete_many(
where={"userId": test_user_id}
),
PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
),
)
await asyncio.gather(
PrismaOAuthApplication.prisma().delete_many(
where={"ownerId": test_user_id}
),
PrismaUser.prisma().delete(where={"id": test_user_id}),
)
except RuntimeError:
pass
# Cleanup - delete in correct order due to foreign key constraints
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
)
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
await PrismaUser.prisma().delete(where={"id": test_user_id})
@pytest_asyncio.fixture

View File

@@ -1,61 +0,0 @@
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
from backend.api.features.v1 import v1_router
app = fastapi.FastAPI()
app.include_router(v1_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_onboarding_profile_success(mocker):
mock_extract = mocker.patch(
"backend.api.features.v1.extract_business_understanding",
new_callable=AsyncMock,
)
mock_upsert = mocker.patch(
"backend.api.features.v1.upsert_business_understanding",
new_callable=AsyncMock,
)
from backend.data.understanding import BusinessUnderstandingInput
mock_extract.return_value = BusinessUnderstandingInput.model_construct(
user_name="John",
user_role="Founder/CEO",
pain_points=["Finding leads"],
suggested_prompts={"Learn": ["How do I automate lead gen?"]},
)
mock_upsert.return_value = AsyncMock()
response = client.post(
"/onboarding/profile",
json={
"user_name": "John",
"user_role": "Founder/CEO",
"pain_points": ["Finding leads", "Email & outreach"],
},
)
assert response.status_code == 200
mock_extract.assert_awaited_once()
mock_upsert.assert_awaited_once()
def test_onboarding_profile_missing_fields():
response = client.post(
"/onboarding/profile",
json={"user_name": "John"},
)
assert response.status_code == 422

View File

@@ -391,11 +391,6 @@ async def get_available_graph(
async def get_store_agent_by_version_id(
store_listing_version_id: str,
) -> store_model.StoreAgentDetails:
"""Get agent details from the StoreAgent view (APPROVED agents only).
See also: `get_store_agent_details_as_admin()` which bypasses the
APPROVED-only StoreAgent view for admin preview of pending submissions.
"""
logger.debug(f"Getting store agent details for {store_listing_version_id}")
try:
@@ -416,57 +411,6 @@ async def get_store_agent_by_version_id(
raise DatabaseError("Failed to fetch agent details") from e
async def get_store_agent_details_as_admin(
store_listing_version_id: str,
) -> store_model.StoreAgentDetails:
"""Get agent details for admin preview, bypassing the APPROVED-only
StoreAgent view. Queries StoreListingVersion directly so pending
submissions are visible."""
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id},
include={
"StoreListing": {"include": {"CreatorProfile": True}},
},
)
if not slv or not slv.StoreListing:
raise NotFoundError(
f"Store listing version {store_listing_version_id} not found"
)
listing = slv.StoreListing
# CreatorProfile is a required FK relation — should always exist.
# If it's None, the DB is in a bad state.
profile = listing.CreatorProfile
if not profile:
raise DatabaseError(
f"StoreListing {listing.id} has no CreatorProfile — FK violated"
)
return store_model.StoreAgentDetails(
store_listing_version_id=slv.id,
slug=listing.slug,
agent_name=slv.name,
agent_video=slv.videoUrl or "",
agent_output_demo=slv.agentOutputDemoUrl or "",
agent_image=slv.imageUrls,
creator=profile.username,
creator_avatar=profile.avatarUrl or "",
sub_heading=slv.subHeading,
description=slv.description,
instructions=slv.instructions,
categories=slv.categories,
runs=0,
rating=0.0,
versions=[str(slv.version)],
graph_id=slv.agentGraphId,
graph_versions=[str(slv.agentGraphVersion)],
last_updated=slv.updatedAt,
recommended_schedule_cron=slv.recommendedScheduleCron,
active_version_id=listing.activeVersionId or slv.id,
has_approved_version=listing.hasApprovedVersion,
)
class StoreCreatorsSortOptions(Enum):
# NOTE: values correspond 1:1 to columns of the Creator view
AGENT_RATING = "agent_rating"

View File

@@ -189,7 +189,6 @@ async def test_create_store_submission(mocker):
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="Europe/Delft",
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
)
mock_agent = prisma.models.AgentGraph(
id="agent-id",

View File

@@ -63,17 +63,12 @@ from backend.data.onboarding import (
UserOnboardingUpdate,
complete_onboarding_step,
complete_re_run_agent,
format_onboarding_for_extraction,
get_recommended_agents,
get_user_onboarding,
onboarding_enabled,
reset_user_onboarding,
update_user_onboarding,
)
from backend.data.tally import extract_business_understanding
from backend.data.understanding import (
BusinessUnderstandingInput,
upsert_business_understanding,
)
from backend.data.user import (
get_or_create_user,
get_user_by_id,
@@ -287,33 +282,35 @@ async def get_onboarding_agents(
return await get_recommended_agents(user_id)
class OnboardingProfileRequest(pydantic.BaseModel):
"""Request body for onboarding profile submission."""
user_name: str = pydantic.Field(min_length=1, max_length=100)
user_role: str = pydantic.Field(min_length=1, max_length=100)
pain_points: list[str] = pydantic.Field(default_factory=list, max_length=20)
class OnboardingStatusResponse(pydantic.BaseModel):
"""Response for onboarding completion check."""
"""Response for onboarding status check."""
is_completed: bool
is_onboarding_enabled: bool
is_chat_enabled: bool
@v1_router.get(
"/onboarding/completed",
summary="Check if onboarding is completed",
"/onboarding/enabled",
summary="Is onboarding enabled",
tags=["onboarding", "public"],
response_model=OnboardingStatusResponse,
dependencies=[Security(requires_user)],
)
async def is_onboarding_completed(
async def is_onboarding_enabled(
user_id: Annotated[str, Security(get_user_id)],
) -> OnboardingStatusResponse:
user_onboarding = await get_user_onboarding(user_id)
# Check if chat is enabled for user
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
# If chat is enabled, skip legacy onboarding
if is_chat_enabled:
return OnboardingStatusResponse(
is_onboarding_enabled=False,
is_chat_enabled=True,
)
return OnboardingStatusResponse(
is_completed=OnboardingStep.VISIT_COPILOT in user_onboarding.completedSteps,
is_onboarding_enabled=await onboarding_enabled(),
is_chat_enabled=False,
)
@@ -328,38 +325,6 @@ async def reset_onboarding(user_id: Annotated[str, Security(get_user_id)]):
return await reset_user_onboarding(user_id)
@v1_router.post(
"/onboarding/profile",
summary="Submit onboarding profile",
tags=["onboarding"],
dependencies=[Security(requires_user)],
)
async def submit_onboarding_profile(
data: OnboardingProfileRequest,
user_id: Annotated[str, Security(get_user_id)],
):
formatted = format_onboarding_for_extraction(
user_name=data.user_name,
user_role=data.user_role,
pain_points=data.pain_points,
)
try:
understanding_input = await extract_business_understanding(formatted)
except Exception:
understanding_input = BusinessUnderstandingInput.model_construct()
# Ensure the direct fields are set even if LLM missed them
understanding_input.user_name = data.user_name
understanding_input.user_role = data.user_role
if not understanding_input.pain_points:
understanding_input.pain_points = data.pain_points
await upsert_business_understanding(user_id, understanding_input)
return {"status": "ok"}
########################################################
##################### Blocks ###########################
########################################################
@@ -627,11 +592,6 @@ async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
async def configure_user_auto_top_up(
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
) -> str:
"""Configure auto top-up settings and perform an immediate top-up if needed.
Raises HTTPException(422) if the request parameters are invalid or if
the credit top-up fails.
"""
if request.threshold < 0:
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
if request.amount < 500 and request.amount != 0:
@@ -646,20 +606,10 @@ async def configure_user_auto_top_up(
user_credit_model = await get_user_credit_model(user_id)
current_balance = await user_credit_model.get_credits(user_id)
try:
if current_balance < request.threshold:
await user_credit_model.top_up_credits(user_id, request.amount)
else:
await user_credit_model.top_up_credits(user_id, 0)
except ValueError as e:
known_messages = (
"must not be negative",
"already exists for user",
"No payment method found",
)
if any(msg in str(e) for msg in known_messages):
raise HTTPException(status_code=422, detail=str(e))
raise
if current_balance < request.threshold:
await user_credit_model.top_up_credits(user_id, request.amount)
else:
await user_credit_model.top_up_credits(user_id, 0)
await set_auto_top_up(
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
@@ -1015,16 +965,14 @@ async def execute_graph(
source: Annotated[GraphExecutionSource | None, Body(embed=True)] = None,
graph_version: Optional[int] = None,
preset_id: Optional[str] = None,
dry_run: Annotated[bool, Body(embed=True)] = False,
) -> execution_db.GraphExecutionMeta:
if not dry_run:
user_credit_model = await get_user_credit_model(user_id)
current_balance = await user_credit_model.get_credits(user_id)
if current_balance <= 0:
raise HTTPException(
status_code=402,
detail="Insufficient balance to execute the agent. Please top up your account.",
)
user_credit_model = await get_user_credit_model(user_id)
current_balance = await user_credit_model.get_credits(user_id)
if current_balance <= 0:
raise HTTPException(
status_code=402,
detail="Insufficient balance to execute the agent. Please top up your account.",
)
try:
result = await execution_utils.add_graph_execution(
@@ -1034,7 +982,6 @@ async def execute_graph(
preset_id=preset_id,
graph_version=graph_version,
graph_credentials_inputs=credentials_inputs,
dry_run=dry_run,
)
# Record successful graph execution
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)

View File

@@ -12,7 +12,7 @@ import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi import Query, UploadFile
from fastapi.responses import Response
from pydantic import BaseModel, Field
from pydantic import BaseModel
from backend.data.workspace import (
WorkspaceFile,
@@ -131,26 +131,9 @@ class StorageUsageResponse(BaseModel):
file_count: int
class WorkspaceFileItem(BaseModel):
id: str
name: str
path: str
mime_type: str
size_bytes: int
metadata: dict = Field(default_factory=dict)
created_at: str
class ListFilesResponse(BaseModel):
files: list[WorkspaceFileItem]
offset: int = 0
has_more: bool = False
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
operation_id="getWorkspaceDownloadFileById",
)
async def download_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -175,7 +158,6 @@ async def download_file(
@router.delete(
"/files/{file_id}",
summary="Delete a workspace file",
operation_id="deleteWorkspaceFile",
)
async def delete_workspace_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -201,13 +183,11 @@ async def delete_workspace_file(
@router.post(
"/files/upload",
summary="Upload file to workspace",
operation_id="uploadWorkspaceFile",
)
async def upload_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file: UploadFile,
session_id: str | None = Query(default=None),
overwrite: bool = Query(default=False),
) -> UploadFileResponse:
"""
Upload a file to the user's workspace.
@@ -215,9 +195,6 @@ async def upload_file(
Files are stored in session-scoped paths when session_id is provided,
so the agent's session-scoped tools can discover them automatically.
"""
# Empty-string session_id drops session scoping; normalize to None.
session_id = session_id or None
config = Config()
# Sanitize filename — strip any directory components
@@ -271,28 +248,15 @@ async def upload_file(
# Write file via WorkspaceManager
manager = WorkspaceManager(user_id, workspace.id, session_id)
try:
workspace_file = await manager.write_file(
content, filename, overwrite=overwrite, metadata={"origin": "user-upload"}
)
workspace_file = await manager.write_file(content, filename)
except ValueError as e:
# write_file raises ValueError for both path-conflict and size-limit
# cases; map each to its correct HTTP status.
message = str(e)
if message.startswith("File too large"):
raise fastapi.HTTPException(status_code=413, detail=message) from e
raise fastapi.HTTPException(status_code=409, detail=message) from e
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
# Post-write storage check — eliminates TOCTOU race on the quota.
# If a concurrent upload pushed us over the limit, undo this write.
new_total = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and new_total > storage_limit_bytes:
try:
await soft_delete_workspace_file(workspace_file.id, workspace.id)
except Exception as e:
logger.warning(
f"Failed to soft-delete over-quota file {workspace_file.id} "
f"in workspace {workspace.id}: {e}"
)
await soft_delete_workspace_file(workspace_file.id, workspace.id)
raise fastapi.HTTPException(
status_code=413,
detail={
@@ -314,7 +278,6 @@ async def upload_file(
@router.get(
"/storage/usage",
summary="Get workspace storage usage",
operation_id="getWorkspaceStorageUsage",
)
async def get_storage_usage(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -335,57 +298,3 @@ async def get_storage_usage(
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
file_count=file_count,
)
@router.get(
"/files",
summary="List workspace files",
operation_id="listWorkspaceFiles",
)
async def list_workspace_files(
user_id: Annotated[str, fastapi.Security(get_user_id)],
session_id: str | None = Query(default=None),
limit: int = Query(default=200, ge=1, le=1000),
offset: int = Query(default=0, ge=0),
) -> ListFilesResponse:
"""
List files in the user's workspace.
When session_id is provided, only files for that session are returned.
Otherwise, all files across sessions are listed. Results are paginated
via `limit`/`offset`; `has_more` indicates whether additional pages exist.
"""
workspace = await get_or_create_workspace(user_id)
# Treat empty-string session_id the same as omitted — an empty value
# would otherwise silently list files across every session instead of
# scoping to one.
session_id = session_id or None
manager = WorkspaceManager(user_id, workspace.id, session_id)
include_all = session_id is None
# Fetch one extra to compute has_more without a separate count query.
files = await manager.list_files(
limit=limit + 1,
offset=offset,
include_all_sessions=include_all,
)
has_more = len(files) > limit
page = files[:limit]
return ListFilesResponse(
files=[
WorkspaceFileItem(
id=f.id,
name=f.name,
path=f.path,
mime_type=f.mime_type,
size_bytes=f.size_bytes,
metadata=f.metadata or {},
created_at=f.created_at.isoformat(),
)
for f in page
],
offset=offset,
has_more=has_more,
)

View File

@@ -1,28 +1,48 @@
"""Tests for workspace file upload and download routes."""
import io
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.workspace.routes import router
from backend.data.workspace import Workspace, WorkspaceFile
from backend.api.features.workspace import routes as workspace_routes
from backend.data.workspace import WorkspaceFile
app = fastapi.FastAPI()
app.include_router(router)
app.include_router(workspace_routes.router)
@app.exception_handler(ValueError)
async def _value_error_handler(
request: fastapi.Request, exc: ValueError
) -> fastapi.responses.JSONResponse:
"""Mirror the production ValueError → 400 mapping from the REST app."""
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
MOCK_FILE = WorkspaceFile(
id="file-aaa-bbb",
workspace_id="ws-1",
created_at=_NOW,
updated_at=_NOW,
name="hello.txt",
path="/session/hello.txt",
mime_type="text/plain",
size_bytes=13,
storage_path="local://hello.txt",
)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
@@ -33,201 +53,25 @@ def setup_app_auth(mock_jwt_user):
app.dependency_overrides.clear()
def _make_workspace(user_id: str = "test-user-id") -> Workspace:
return Workspace(
id="ws-001",
user_id=user_id,
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
)
def _make_file(**overrides) -> WorkspaceFile:
defaults = {
"id": "file-001",
"workspace_id": "ws-001",
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"name": "test.txt",
"path": "/test.txt",
"storage_path": "local://test.txt",
"mime_type": "text/plain",
"size_bytes": 100,
"checksum": None,
"is_deleted": False,
"deleted_at": None,
"metadata": {},
}
defaults.update(overrides)
return WorkspaceFile(**defaults)
def _make_file_mock(**overrides) -> MagicMock:
"""Create a mock WorkspaceFile to simulate DB records with null fields."""
defaults = {
"id": "file-001",
"name": "test.txt",
"path": "/test.txt",
"mime_type": "text/plain",
"size_bytes": 100,
"metadata": {},
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
}
defaults.update(overrides)
mock = MagicMock(spec=WorkspaceFile)
for k, v in defaults.items():
setattr(mock, k, v)
return mock
# -- list_workspace_files tests --
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_returns_all_when_no_session(mock_manager_cls, mock_get_workspace):
mock_get_workspace.return_value = _make_workspace()
files = [
_make_file(id="f1", name="a.txt", metadata={"origin": "user-upload"}),
_make_file(id="f2", name="b.csv", metadata={"origin": "agent-created"}),
]
mock_instance = AsyncMock()
mock_instance.list_files.return_value = files
mock_manager_cls.return_value = mock_instance
response = client.get("/files")
assert response.status_code == 200
data = response.json()
assert len(data["files"]) == 2
assert data["has_more"] is False
assert data["offset"] == 0
assert data["files"][0]["id"] == "f1"
assert data["files"][0]["metadata"] == {"origin": "user-upload"}
assert data["files"][1]["id"] == "f2"
mock_instance.list_files.assert_called_once_with(
limit=201, offset=0, include_all_sessions=True
)
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_scopes_to_session_when_provided(
mock_manager_cls, mock_get_workspace, test_user_id
):
mock_get_workspace.return_value = _make_workspace(user_id=test_user_id)
mock_instance = AsyncMock()
mock_instance.list_files.return_value = []
mock_manager_cls.return_value = mock_instance
response = client.get("/files?session_id=sess-123")
assert response.status_code == 200
data = response.json()
assert data["files"] == []
assert data["has_more"] is False
mock_manager_cls.assert_called_once_with(test_user_id, "ws-001", "sess-123")
mock_instance.list_files.assert_called_once_with(
limit=201, offset=0, include_all_sessions=False
)
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_null_metadata_coerced_to_empty_dict(
mock_manager_cls, mock_get_workspace
):
"""Route uses `f.metadata or {}` for pre-existing files with null metadata."""
mock_get_workspace.return_value = _make_workspace()
mock_instance = AsyncMock()
mock_instance.list_files.return_value = [_make_file_mock(metadata=None)]
mock_manager_cls.return_value = mock_instance
response = client.get("/files")
assert response.status_code == 200
assert response.json()["files"][0]["metadata"] == {}
# -- upload_file metadata tests --
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
@patch("backend.api.features.workspace.routes.scan_content_safe")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_upload_passes_user_upload_origin_metadata(
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
):
mock_get_workspace.return_value = _make_workspace()
mock_total_size.return_value = 100
written = _make_file(id="new-file", name="doc.pdf")
mock_instance = AsyncMock()
mock_instance.write_file.return_value = written
mock_manager_cls.return_value = mock_instance
response = client.post(
"/files/upload",
files={"file": ("doc.pdf", b"fake-pdf-content", "application/pdf")},
)
assert response.status_code == 200
mock_instance.write_file.assert_called_once()
call_kwargs = mock_instance.write_file.call_args
assert call_kwargs.kwargs.get("metadata") == {"origin": "user-upload"}
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
@patch("backend.api.features.workspace.routes.scan_content_safe")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_upload_returns_409_on_file_conflict(
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
):
mock_get_workspace.return_value = _make_workspace()
mock_total_size.return_value = 100
mock_instance = AsyncMock()
mock_instance.write_file.side_effect = ValueError("File already exists at path")
mock_manager_cls.return_value = mock_instance
response = client.post(
"/files/upload",
files={"file": ("dup.txt", b"content", "text/plain")},
)
assert response.status_code == 409
assert "already exists" in response.json()["detail"]
# -- Restored upload/download/delete security + invariant tests --
def _upload(
filename: str = "hello.txt",
content: bytes = b"Hello, world!",
content_type: str = "text/plain",
):
"""Helper to POST a file upload."""
return client.post(
"/files/upload?session_id=sess-1",
files={"file": (filename, io.BytesIO(content), content_type)},
)
_MOCK_FILE = WorkspaceFile(
id="file-aaa-bbb",
workspace_id="ws-001",
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
name="hello.txt",
path="/sessions/sess-1/hello.txt",
mime_type="text/plain",
size_bytes=13,
storage_path="local://hello.txt",
)
# ---- Happy path ----
def test_upload_happy_path(mocker):
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -238,7 +82,7 @@ def test_upload_happy_path(mocker):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -252,7 +96,10 @@ def test_upload_happy_path(mocker):
assert data["size_bytes"] == 13
def test_upload_exceeds_max_file_size(mocker):
# ---- Per-file size limit ----
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
"""Files larger than max_file_size_mb should be rejected with 413."""
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
@@ -262,11 +109,15 @@ def test_upload_exceeds_max_file_size(mocker):
assert response.status_code == 413
def test_upload_storage_quota_exceeded(mocker):
# ---- Storage quota exceeded ----
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
# Current usage already at limit
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=500 * 1024 * 1024,
@@ -277,22 +128,27 @@ def test_upload_storage_quota_exceeded(mocker):
assert "Storage limit exceeded" in response.text
def test_upload_post_write_quota_race(mocker):
"""Concurrent upload tipping over limit after write should soft-delete + 413."""
# ---- Post-write quota race (B2) ----
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
"""If a concurrent upload tips the total over the limit after write,
the file should be soft-deleted and 413 returned."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
# Pre-write check passes (under limit), but post-write check fails
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
side_effect=[0, 600 * 1024 * 1024],
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -304,14 +160,17 @@ def test_upload_post_write_quota_race(mocker):
response = _upload()
assert response.status_code == 413
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-001")
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
def test_upload_any_extension(mocker):
# ---- Any extension accepted (no allowlist) ----
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
"""Any file extension should be accepted — ClamAV is the security layer."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -322,7 +181,7 @@ def test_upload_any_extension(mocker):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -332,13 +191,16 @@ def test_upload_any_extension(mocker):
assert response.status_code == 200
def test_upload_blocked_by_virus_scan(mocker):
# ---- Virus scan rejection ----
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
"""Files flagged by ClamAV should be rejected and never written to storage."""
from backend.api.features.store.exceptions import VirusDetectedError
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -349,7 +211,7 @@ def test_upload_blocked_by_virus_scan(mocker):
side_effect=VirusDetectedError("Eicar-Test-Signature"),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -357,14 +219,18 @@ def test_upload_blocked_by_virus_scan(mocker):
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
assert response.status_code == 400
assert "Virus detected" in response.text
mock_manager.write_file.assert_not_called()
def test_upload_file_without_extension(mocker):
# ---- No file extension ----
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
"""Files without an extension should be accepted and stored as-is."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -375,7 +241,7 @@ def test_upload_file_without_extension(mocker):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -391,11 +257,14 @@ def test_upload_file_without_extension(mocker):
assert mock_manager.write_file.call_args[0][1] == "Makefile"
def test_upload_strips_path_components(mocker):
# ---- Filename sanitization (SF5) ----
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
"""Path-traversal filenames should be reduced to their basename."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -406,23 +275,28 @@ def test_upload_strips_path_components(mocker):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
# Filename with traversal
_upload(filename="../../etc/passwd.txt")
# write_file should have been called with just the basename
mock_manager.write_file.assert_called_once()
call_args = mock_manager.write_file.call_args
assert call_args[0][1] == "passwd.txt"
def test_download_file_not_found(mocker):
# ---- Download ----
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_file",
@@ -433,11 +307,14 @@ def test_download_file_not_found(mocker):
assert response.status_code == 404
def test_delete_file_success(mocker):
# ---- Delete ----
def test_delete_file_success(mocker: pytest_mock.MockFixture):
"""Deleting an existing file should return {"deleted": true}."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
@@ -452,11 +329,11 @@ def test_delete_file_success(mocker):
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
def test_delete_file_not_found(mocker):
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
"""Deleting a non-existent file should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
@@ -470,7 +347,7 @@ def test_delete_file_not_found(mocker):
assert "File not found" in response.text
def test_delete_file_no_workspace(mocker):
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
"""Deleting when user has no workspace should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
@@ -480,123 +357,3 @@ def test_delete_file_no_workspace(mocker):
response = client.delete("/files/file-aaa-bbb")
assert response.status_code == 404
assert "Workspace not found" in response.text
def test_upload_write_file_too_large_returns_413(mocker):
"""write_file raises ValueError("File too large: …") → must map to 413."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=ValueError("File too large: 900 bytes exceeds 1MB limit")
)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 413
assert "File too large" in response.text
def test_upload_write_file_conflict_returns_409(mocker):
"""Non-'File too large' ValueErrors from write_file stay as 409."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=ValueError("File already exists at path: /sessions/x/a.txt")
)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 409
assert "already exists" in response.text
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_has_more_true_when_limit_exceeded(
mock_manager_cls, mock_get_workspace
):
"""The limit+1 fetch trick must flip has_more=True and trim the page."""
mock_get_workspace.return_value = _make_workspace()
# Backend was asked for limit+1=3, and returned exactly 3 items.
files = [
_make_file(id="f1", name="a.txt"),
_make_file(id="f2", name="b.txt"),
_make_file(id="f3", name="c.txt"),
]
mock_instance = AsyncMock()
mock_instance.list_files.return_value = files
mock_manager_cls.return_value = mock_instance
response = client.get("/files?limit=2")
assert response.status_code == 200
data = response.json()
assert data["has_more"] is True
assert len(data["files"]) == 2
assert data["files"][0]["id"] == "f1"
assert data["files"][1]["id"] == "f2"
mock_instance.list_files.assert_called_once_with(
limit=3, offset=0, include_all_sessions=True
)
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_has_more_false_when_exactly_page_size(
mock_manager_cls, mock_get_workspace
):
"""Exactly `limit` rows means we're on the last page — has_more=False."""
mock_get_workspace.return_value = _make_workspace()
files = [_make_file(id="f1", name="a.txt"), _make_file(id="f2", name="b.txt")]
mock_instance = AsyncMock()
mock_instance.list_files.return_value = files
mock_manager_cls.return_value = mock_instance
response = client.get("/files?limit=2")
assert response.status_code == 200
data = response.json()
assert data["has_more"] is False
assert len(data["files"]) == 2
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
mock_get_workspace.return_value = _make_workspace()
mock_instance = AsyncMock()
mock_instance.list_files.return_value = []
mock_manager_cls.return_value = mock_instance
response = client.get("/files?offset=50&limit=10")
assert response.status_code == 200
assert response.json()["offset"] == 50
mock_instance.list_files.assert_called_once_with(
limit=11, offset=50, include_all_sessions=True
)

View File

@@ -18,8 +18,6 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.platform_cost_routes
import backend.api.features.admin.rate_limit_admin_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
@@ -119,11 +117,6 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Register managed credential providers (e.g. AgentMail)
from backend.integrations.managed_providers import register_all
register_all()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
@@ -217,22 +210,13 @@ instrument_fastapi(
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
def handler(request: fastapi.Request, exc: Exception):
if log_error:
if status_code >= 500:
logger.exception(
"%s %s failed. Investigate and resolve the underlying issue: %s",
request.method,
request.url.path,
exc,
exc_info=exc,
)
else:
logger.warning(
"%s %s failed with %d: %s",
request.method,
request.url.path,
status_code,
exc,
)
logger.exception(
"%s %s failed. Investigate and resolve the underlying issue: %s",
request.method,
request.url.path,
exc,
exc_info=exc,
)
hint = (
"Adjust the request and retry."
@@ -282,10 +266,12 @@ async def validation_error_handler(
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
app.add_exception_handler(FolderAlreadyExistsError, handle_internal_http_error(409))
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400))
app.add_exception_handler(NotFoundError, handle_internal_http_error(404))
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403))
app.add_exception_handler(
FolderAlreadyExistsError, handle_internal_http_error(409, False)
)
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
app.add_exception_handler(RequestValidationError, validation_error_handler)
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
@@ -325,16 +311,6 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/executions",
)
app.include_router(
backend.api.features.admin.rate_limit_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/copilot",
)
app.include_router(
backend.api.features.admin.platform_cost_routes.router,
tags=["v2", "admin"],
prefix="/api/admin",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],
@@ -545,11 +521,8 @@ class AgentServer(backend.util.service.AppProcess):
user_id: str,
provider: ProviderName,
credentials: Credentials,
):
from backend.api.features.integrations.router import (
create_credentials,
get_credential,
)
) -> Credentials:
from .features.integrations.router import create_credentials, get_credential
try:
return await create_credentials(

View File

@@ -698,30 +698,13 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
if should_pause:
return
# Validate the input data (original or reviewer-modified) once.
# In dry-run mode, credential fields may contain sentinel None values
# that would fail JSON schema required checks. We still validate the
# non-credential fields so blocks that execute for real during dry-run
# (e.g. AgentExecutorBlock) get proper input validation.
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
if is_dry_run:
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
non_cred_data = {
k: v for k, v in input_data.items() if k not in cred_field_names
}
if error := self.input_schema.validate_data(non_cred_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
else:
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(

View File

@@ -49,17 +49,11 @@ class AgentExecutorBlock(Block):
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
required_fields = cls.get_input_schema(data).get("required", [])
# Check against the nested `inputs` dict, not the top-level node
# data — required fields like "topic" live inside data["inputs"],
# not at data["topic"].
provided = data.get("inputs", {})
return set(required_fields) - set(provided)
return set(required_fields) - set(data)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return validate_with_jsonschema(
cls.get_input_schema(data), data.get("inputs", {})
)
return validate_with_jsonschema(cls.get_input_schema(data), data)
class Output(BlockSchema):
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
@@ -94,7 +88,6 @@ class AgentExecutorBlock(Block):
execution_context=execution_context.model_copy(
update={"parent_execution_id": graph_exec_id},
),
dry_run=execution_context.dry_run,
)
logger = execution_utils.LogMetadata(
@@ -156,19 +149,14 @@ class AgentExecutorBlock(Block):
ExecutionStatus.TERMINATED,
ExecutionStatus.FAILED,
]:
logger.info(
f"Execution {log_id} skipping event {event.event_type} status={event.status} "
f"node={getattr(event, 'node_exec_id', '?')}"
logger.debug(
f"Execution {log_id} received event {event.event_type} with status {event.status}"
)
continue
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
# we can stop listening for further events.
logger.info(
f"Execution {log_id} graph completed with status {event.status}, "
f"yielded {len(yielded_node_exec_ids)} outputs"
)
self.merge_stats(
NodeExecutionStats(
extra_cost=event.stats.cost if event.stats else 0,

View File

@@ -1,4 +1,3 @@
import re
from typing import Any
from backend.blocks._base import (
@@ -20,33 +19,6 @@ from backend.blocks.llm import (
)
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
# Minimum max_output_tokens accepted by OpenAI-compatible APIs.
# A true/false answer fits comfortably within this budget.
MIN_LLM_OUTPUT_TOKENS = 16
def _parse_boolean_response(response_text: str) -> tuple[bool, str | None]:
"""Parse an LLM response into a boolean result.
Returns a ``(result, error)`` tuple. *error* is ``None`` when the
response is unambiguous; otherwise it contains a diagnostic message
and *result* defaults to ``False``.
"""
text = response_text.strip().lower()
if text == "true":
return True, None
if text == "false":
return False, None
# Fuzzy match use word boundaries to avoid false positives like "untrue".
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", text))
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
return True, None
if tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
return False, None
return False, f"Unclear AI response: '{response_text}'"
class AIConditionBlock(AIBlockBase):
"""
@@ -190,26 +162,54 @@ class AIConditionBlock(AIBlockBase):
]
# Call the LLM
response = await self.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
max_tokens=MIN_LLM_OUTPUT_TOKENS,
)
# Extract the boolean result from the response
result, error = _parse_boolean_response(response.response)
if error:
yield "error", error
# Update internal stats
self.merge_stats(
NodeExecutionStats(
input_token_count=response.prompt_tokens,
output_token_count=response.completion_tokens,
try:
response = await self.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
max_tokens=10, # We only expect a true/false response
)
)
self.prompt = response.prompt
# Extract the boolean result from the response
response_text = response.response.strip().lower()
if response_text == "true":
result = True
elif response_text == "false":
result = False
else:
# If the response is not clear, try to interpret it using word boundaries
import re
# Use word boundaries to avoid false positives like 'untrue' or '10'
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
result = True
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
result = False
else:
# Unclear or conflicting response - default to False and yield error
result = False
yield "error", f"Unclear AI response: '{response.response}'"
# Update internal stats
self.merge_stats(
NodeExecutionStats(
input_token_count=response.prompt_tokens,
output_token_count=response.completion_tokens,
)
)
self.prompt = response.prompt
except Exception as e:
# In case of any error, default to False to be safe
result = False
# Log the error but don't fail the block execution
import logging
logger = logging.getLogger(__name__)
logger.error(f"AI condition evaluation failed: {str(e)}")
yield "error", f"AI evaluation failed: {str(e)}"
# Yield results
yield "result", result

View File

@@ -1,147 +0,0 @@
"""Tests for AIConditionBlock regression coverage for max_tokens and error propagation."""
from __future__ import annotations
from typing import cast
import pytest
from backend.blocks.ai_condition import (
MIN_LLM_OUTPUT_TOKENS,
AIConditionBlock,
_parse_boolean_response,
)
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AICredentials,
LLMResponse,
)
_TEST_AI_CREDENTIALS = cast(AICredentials, TEST_CREDENTIALS_INPUT)
# ---------------------------------------------------------------------------
# Helper to collect all yields from the async generator
# ---------------------------------------------------------------------------
async def _collect_outputs(block: AIConditionBlock, input_data, credentials):
outputs: dict[str, object] = {}
async for name, value in block.run(input_data, credentials=credentials):
outputs[name] = value
return outputs
def _make_input(**overrides) -> AIConditionBlock.Input:
defaults: dict = {
"input_value": "hello@example.com",
"condition": "the input is an email address",
"yes_value": "yes!",
"no_value": "no!",
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
}
defaults.update(overrides)
return AIConditionBlock.Input(**defaults)
def _mock_llm_response(response_text: str) -> LLMResponse:
return LLMResponse(
raw_response="",
prompt=[],
response=response_text,
tool_calls=None,
prompt_tokens=10,
completion_tokens=5,
reasoning=None,
)
# ---------------------------------------------------------------------------
# _parse_boolean_response unit tests
# ---------------------------------------------------------------------------
class TestParseBooleanResponse:
def test_true_exact(self):
assert _parse_boolean_response("true") == (True, None)
def test_false_exact(self):
assert _parse_boolean_response("false") == (False, None)
def test_true_with_whitespace(self):
assert _parse_boolean_response(" True ") == (True, None)
def test_yes_fuzzy(self):
assert _parse_boolean_response("Yes") == (True, None)
def test_no_fuzzy(self):
assert _parse_boolean_response("no") == (False, None)
def test_one_fuzzy(self):
assert _parse_boolean_response("1") == (True, None)
def test_zero_fuzzy(self):
assert _parse_boolean_response("0") == (False, None)
def test_unclear_response(self):
result, error = _parse_boolean_response("I'm not sure")
assert result is False
assert error is not None
assert "Unclear" in error
def test_conflicting_tokens(self):
result, error = _parse_boolean_response("true and false")
assert result is False
assert error is not None
# ---------------------------------------------------------------------------
# Regression: max_tokens is set to MIN_LLM_OUTPUT_TOKENS
# ---------------------------------------------------------------------------
class TestMaxTokensRegression:
@pytest.mark.asyncio
async def test_llm_call_receives_min_output_tokens(self):
"""max_tokens must be MIN_LLM_OUTPUT_TOKENS (16) the previous value
of 1 was too low and caused OpenAI to reject the request."""
block = AIConditionBlock()
captured_kwargs: dict = {}
async def spy_llm_call(**kwargs):
captured_kwargs.update(kwargs)
return _mock_llm_response("true")
block.llm_call = spy_llm_call # type: ignore[assignment]
input_data = _make_input()
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
assert captured_kwargs["max_tokens"] == MIN_LLM_OUTPUT_TOKENS
assert captured_kwargs["max_tokens"] == 16
# ---------------------------------------------------------------------------
# Regression: exceptions from llm_call must propagate
# ---------------------------------------------------------------------------
class TestExceptionPropagation:
@pytest.mark.asyncio
async def test_llm_call_exception_propagates(self):
"""If llm_call raises, the exception must NOT be swallowed.
Previously the block caught all exceptions and silently returned
result=False."""
block = AIConditionBlock()
async def boom(**kwargs):
raise RuntimeError("LLM provider error")
block.llm_call = boom # type: ignore[assignment]
input_data = _make_input()
with pytest.raises(RuntimeError, match="LLM provider error"):
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)

View File

@@ -17,7 +17,7 @@ from backend.blocks.apollo.models import (
PrimaryPhone,
SearchOrganizationsRequest,
)
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
from backend.data.model import CredentialsField, SchemaField
class SearchOrganizationsBlock(Block):
@@ -218,11 +218,6 @@ To find IDs, identify the values for organization_id when you call this endpoint
) -> BlockOutput:
query = SearchOrganizationsRequest(**input_data.model_dump())
organizations = await self.search_organizations(query, credentials)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(organizations)), provider_cost_type="items"
)
)
for organization in organizations:
yield "organization", organization
yield "organizations", organizations

View File

@@ -21,7 +21,7 @@ from backend.blocks.apollo.models import (
SearchPeopleRequest,
SenorityLevels,
)
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
from backend.data.model import CredentialsField, SchemaField
class SearchPeopleBlock(Block):
@@ -366,9 +366,4 @@ class SearchPeopleBlock(Block):
*(enrich_or_fallback(person) for person in people)
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(people)), provider_cost_type="items"
)
)
yield "people", people

View File

@@ -15,12 +15,6 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.copilot.permissions import (
CopilotPermissions,
ToolName,
all_known_tool_names,
validate_block_identifiers,
)
from backend.data.model import SchemaField
if TYPE_CHECKING:
@@ -102,65 +96,6 @@ class AutoPilotBlock(Block):
advanced=True,
)
tools: list[ToolName] = SchemaField(
description=(
"Tool names to filter. Works with tools_exclude to form an "
"allow-list or deny-list. "
"Leave empty to apply no tool filter."
),
default=[],
advanced=True,
)
tools_exclude: bool = SchemaField(
description=(
"Controls how the 'tools' list is interpreted. "
"True (default): 'tools' is a deny-list — listed tools are blocked, "
"all others are allowed. An empty 'tools' list means allow everything. "
"False: 'tools' is an allow-list — only listed tools are permitted."
),
default=True,
advanced=True,
)
blocks: list[str] = SchemaField(
description=(
"Block identifiers to filter when the copilot uses run_block. "
"Each entry can be: a block name (e.g. 'HTTP Request'), "
"a full block UUID, or the first 8 hex characters of the UUID "
"(e.g. 'c069dc6b'). Works with blocks_exclude. "
"Leave empty to apply no block filter."
),
default=[],
advanced=True,
)
blocks_exclude: bool = SchemaField(
description=(
"Controls how the 'blocks' list is interpreted. "
"True (default): 'blocks' is a deny-list — listed blocks are blocked, "
"all others are allowed. An empty 'blocks' list means allow everything. "
"False: 'blocks' is an allow-list — only listed blocks are permitted."
),
default=True,
advanced=True,
)
dry_run: bool = SchemaField(
description=(
"When enabled, run_block and run_agent tool calls in this "
"autopilot session are forced to use dry-run simulation mode. "
"No real API calls, side effects, or credits are consumed "
"by those tools. Useful for testing agent wiring and "
"previewing outputs. "
"Only applies when creating a new session (session_id is empty). "
"When reusing an existing session_id, the session's original "
"dry_run setting is preserved."
),
default=False,
advanced=True,
)
# timeout_seconds removed: the SDK manages its own heartbeat-based
# timeouts internally; wrapping with asyncio.timeout corrupts the
# SDK's internal stream (see service.py CRITICAL comment).
@@ -247,11 +182,11 @@ class AutoPilotBlock(Block):
},
)
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
async def create_session(self, user_id: str) -> str:
"""Create a new chat session and return its ID (mockable for tests)."""
from backend.copilot.model import create_chat_session # avoid circular import
from backend.copilot.model import create_chat_session
session = await create_chat_session(user_id, dry_run=dry_run)
session = await create_chat_session(user_id)
return session.session_id
async def execute_copilot(
@@ -261,7 +196,6 @@ class AutoPilotBlock(Block):
session_id: str,
max_recursion_depth: int,
user_id: str,
permissions: "CopilotPermissions | None" = None,
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
"""Invoke the copilot and collect all stream results.
@@ -275,21 +209,14 @@ class AutoPilotBlock(Block):
session_id: Chat session to use.
max_recursion_depth: Maximum allowed recursion nesting.
user_id: Authenticated user ID.
permissions: Optional capability filter restricting tools/blocks.
Returns:
A tuple of (response_text, tool_calls, history_json, session_id, usage).
"""
from backend.copilot.sdk.collect import (
collect_copilot_response, # avoid circular import
)
from backend.copilot.sdk.collect import collect_copilot_response
tokens = _check_recursion(max_recursion_depth)
perm_token = None
try:
effective_permissions, perm_token = _merge_inherited_permissions(
permissions
)
effective_prompt = prompt
if system_context:
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
@@ -298,7 +225,6 @@ class AutoPilotBlock(Block):
session_id=session_id,
message=effective_prompt,
user_id=user_id,
permissions=effective_permissions,
)
# Build a lightweight conversation summary from streamed data.
@@ -345,8 +271,6 @@ class AutoPilotBlock(Block):
)
finally:
_reset_recursion(tokens)
if perm_token is not None:
_inherited_permissions.reset(perm_token)
async def run(
self,
@@ -371,20 +295,11 @@ class AutoPilotBlock(Block):
yield "error", "max_recursion_depth must be at least 1."
return
# Validate and build permissions eagerly — fail before creating a session.
permissions = await _build_and_validate_permissions(input_data)
if isinstance(permissions, str):
# Validation error returned as a string message.
yield "error", permissions
return
# Create session eagerly so the user always gets the session_id,
# even if the downstream stream fails (avoids orphaned sessions).
sid = input_data.session_id
if not sid:
sid = await self.create_session(
execution_context.user_id, dry_run=input_data.dry_run
)
sid = await self.create_session(execution_context.user_id)
# NOTE: No asyncio.timeout() here — the SDK manages its own
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
@@ -397,7 +312,6 @@ class AutoPilotBlock(Block):
session_id=sid,
max_recursion_depth=input_data.max_recursion_depth,
user_id=execution_context.user_id,
permissions=permissions,
)
yield "response", response
@@ -460,78 +374,3 @@ def _reset_recursion(
"""Restore recursion depth and limit to their previous values."""
_autopilot_recursion_depth.reset(tokens[0])
_autopilot_recursion_limit.reset(tokens[1])
# ---------------------------------------------------------------------------
# Permission helpers
# ---------------------------------------------------------------------------
# Inherited permissions from a parent AutoPilotBlock execution.
# This acts as a ceiling: child executions can only be more restrictive.
_inherited_permissions: contextvars.ContextVar["CopilotPermissions | None"] = (
contextvars.ContextVar("_inherited_permissions", default=None)
)
async def _build_and_validate_permissions(
input_data: "AutoPilotBlock.Input",
) -> "CopilotPermissions | str":
"""Build a :class:`CopilotPermissions` from block input and validate it.
Returns a :class:`CopilotPermissions` on success or a human-readable
error string if validation fails.
"""
# Tool names are validated by Pydantic via the ToolName Literal type
# at model construction time — no runtime check needed here.
# Validate block identifiers against live block registry.
if input_data.blocks:
invalid_blocks = await validate_block_identifiers(input_data.blocks)
if invalid_blocks:
return (
f"Unknown block identifier(s) in 'blocks': {invalid_blocks}. "
"Use find_block to discover valid block names and IDs. "
"You may also use the first 8 characters of a block UUID."
)
return CopilotPermissions(
tools=list(input_data.tools),
tools_exclude=input_data.tools_exclude,
blocks=input_data.blocks,
blocks_exclude=input_data.blocks_exclude,
)
def _merge_inherited_permissions(
permissions: "CopilotPermissions | None",
) -> "tuple[CopilotPermissions | None, contextvars.Token[CopilotPermissions | None] | None]":
"""Merge *permissions* with any inherited parent permissions.
The merged result is stored back into the contextvar so that any nested
AutoPilotBlock invocation (sub-agent) inherits the merged ceiling.
Returns a tuple of (merged_permissions, reset_token). The caller MUST
reset the contextvar via ``_inherited_permissions.reset(token)`` in a
``finally`` block when ``reset_token`` is not None — this prevents
permission leakage between sequential independent executions in the same
asyncio task.
"""
parent = _inherited_permissions.get()
if permissions is None and parent is None:
return None, None
all_tools = all_known_tool_names()
if permissions is None:
permissions = CopilotPermissions() # allow-all; will be narrowed by parent
merged = (
permissions.merged_with_parent(parent, all_tools)
if parent is not None
else permissions
)
# Store merged permissions as the new inherited ceiling for nested calls.
# Return the token so the caller can restore the previous value in finally.
token = _inherited_permissions.set(merged)
return merged, token

View File

@@ -1,265 +0,0 @@
"""Tests for AutoPilotBlock permission fields and validation."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import ValidationError
from backend.blocks.autopilot import (
AutoPilotBlock,
_build_and_validate_permissions,
_inherited_permissions,
_merge_inherited_permissions,
)
from backend.copilot.permissions import CopilotPermissions, all_known_tool_names
from backend.data.execution import ExecutionContext
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_input(**kwargs) -> AutoPilotBlock.Input:
defaults = {
"prompt": "Do something",
"system_context": "",
"session_id": "",
"max_recursion_depth": 3,
"tools": [],
"tools_exclude": True,
"blocks": [],
"blocks_exclude": True,
}
defaults.update(kwargs)
return AutoPilotBlock.Input(**defaults)
# ---------------------------------------------------------------------------
# _build_and_validate_permissions
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestBuildAndValidatePermissions:
async def test_empty_inputs_returns_empty_permissions(self):
inp = _make_input()
result = await _build_and_validate_permissions(inp)
assert isinstance(result, CopilotPermissions)
assert result.is_empty()
async def test_valid_tool_names_accepted(self):
inp = _make_input(tools=["run_block", "web_fetch"], tools_exclude=True)
result = await _build_and_validate_permissions(inp)
assert isinstance(result, CopilotPermissions)
assert result.tools == ["run_block", "web_fetch"]
assert result.tools_exclude is True
async def test_invalid_tool_rejected_by_pydantic(self):
"""Invalid tool names are now caught at Pydantic validation time
(Literal type), before ``_build_and_validate_permissions`` is called."""
with pytest.raises(ValidationError, match="not_a_real_tool"):
_make_input(tools=["not_a_real_tool"])
async def test_valid_block_name_accepted(self):
mock_block_cls = MagicMock()
mock_block_cls.return_value.name = "HTTP Request"
with patch(
"backend.blocks.get_blocks",
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
):
inp = _make_input(blocks=["HTTP Request"], blocks_exclude=True)
result = await _build_and_validate_permissions(inp)
assert isinstance(result, CopilotPermissions)
assert result.blocks == ["HTTP Request"]
async def test_valid_partial_uuid_accepted(self):
mock_block_cls = MagicMock()
mock_block_cls.return_value.name = "HTTP Request"
with patch(
"backend.blocks.get_blocks",
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
):
inp = _make_input(blocks=["c069dc6b"], blocks_exclude=False)
result = await _build_and_validate_permissions(inp)
assert isinstance(result, CopilotPermissions)
async def test_invalid_block_identifier_returns_error(self):
mock_block_cls = MagicMock()
mock_block_cls.return_value.name = "HTTP Request"
with patch(
"backend.blocks.get_blocks",
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
):
inp = _make_input(blocks=["totally_fake_block"])
result = await _build_and_validate_permissions(inp)
assert isinstance(result, str)
assert "totally_fake_block" in result
assert "Unknown block identifier" in result
async def test_sdk_builtin_tool_names_accepted(self):
inp = _make_input(tools=["Read", "Task", "WebSearch"], tools_exclude=False)
result = await _build_and_validate_permissions(inp)
assert isinstance(result, CopilotPermissions)
assert not result.tools_exclude
async def test_empty_blocks_skips_validation(self):
# Should not call validate_block_identifiers at all when blocks=[].
with patch(
"backend.copilot.permissions.validate_block_identifiers"
) as mock_validate:
inp = _make_input(blocks=[])
await _build_and_validate_permissions(inp)
mock_validate.assert_not_called()
# ---------------------------------------------------------------------------
# _merge_inherited_permissions
# ---------------------------------------------------------------------------
class TestMergeInheritedPermissions:
def test_no_permissions_no_parent_returns_none(self):
merged, token = _merge_inherited_permissions(None)
assert merged is None
assert token is None
def test_permissions_no_parent_returned_unchanged(self):
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
merged, token = _merge_inherited_permissions(perms)
try:
assert merged is perms
assert token is not None
finally:
if token is not None:
_inherited_permissions.reset(token)
def test_child_narrows_parent(self):
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
# Set parent as inherited
outer_token = _inherited_permissions.set(parent)
try:
child = CopilotPermissions(tools=["web_fetch"], tools_exclude=True)
merged, inner_token = _merge_inherited_permissions(child)
try:
assert merged is not None
all_t = all_known_tool_names()
effective = merged.effective_allowed_tools(all_t)
assert "bash_exec" not in effective
assert "web_fetch" not in effective
finally:
if inner_token is not None:
_inherited_permissions.reset(inner_token)
finally:
_inherited_permissions.reset(outer_token)
def test_none_permissions_with_parent_uses_parent(self):
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
outer_token = _inherited_permissions.set(parent)
try:
merged, inner_token = _merge_inherited_permissions(None)
try:
assert merged is not None
# Merged should have parent's restrictions
effective = merged.effective_allowed_tools(all_known_tool_names())
assert "bash_exec" not in effective
finally:
if inner_token is not None:
_inherited_permissions.reset(inner_token)
finally:
_inherited_permissions.reset(outer_token)
def test_child_cannot_expand_parent_whitelist(self):
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
outer_token = _inherited_permissions.set(parent)
try:
# Child tries to allow more tools
child = CopilotPermissions(
tools=["run_block", "bash_exec"], tools_exclude=False
)
merged, inner_token = _merge_inherited_permissions(child)
try:
assert merged is not None
effective = merged.effective_allowed_tools(all_known_tool_names())
assert "bash_exec" not in effective
assert "run_block" in effective
finally:
if inner_token is not None:
_inherited_permissions.reset(inner_token)
finally:
_inherited_permissions.reset(outer_token)
# ---------------------------------------------------------------------------
# AutoPilotBlock.run — validation integration
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestAutoPilotBlockRunPermissions:
async def _collect_outputs(self, block, input_data, user_id="test-user"):
"""Helper to collect all yields from block.run()."""
ctx = ExecutionContext(
user_id=user_id,
graph_id="g1",
graph_exec_id="ge1",
node_exec_id="ne1",
node_id="n1",
)
outputs = {}
async for key, val in block.run(input_data, execution_context=ctx):
outputs[key] = val
return outputs
async def test_invalid_tool_rejected_by_pydantic(self):
"""Invalid tool names are caught at Pydantic validation (Literal type)."""
with pytest.raises(ValidationError, match="not_a_tool"):
_make_input(tools=["not_a_tool"])
async def test_invalid_block_yields_error(self):
mock_block_cls = MagicMock()
mock_block_cls.return_value.name = "HTTP Request"
with patch(
"backend.blocks.get_blocks",
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
):
block = AutoPilotBlock()
inp = _make_input(blocks=["nonexistent_block"])
outputs = await self._collect_outputs(block, inp)
assert "error" in outputs
assert "nonexistent_block" in outputs["error"]
async def test_empty_prompt_yields_error_before_permission_check(self):
block = AutoPilotBlock()
inp = _make_input(prompt=" ", tools=["run_block"])
outputs = await self._collect_outputs(block, inp)
assert "error" in outputs
assert "Prompt cannot be empty" in outputs["error"]
async def test_valid_permissions_passed_to_execute(self):
"""Permissions are forwarded to execute_copilot when valid."""
block = AutoPilotBlock()
captured: dict = {}
async def fake_execute_copilot(self_inner, **kwargs):
captured["permissions"] = kwargs.get("permissions")
return (
"ok",
[],
'[{"role":"user","content":"hi"}]',
"test-sid",
{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
)
with patch.object(
AutoPilotBlock, "create_session", new=AsyncMock(return_value="test-sid")
), patch.object(AutoPilotBlock, "execute_copilot", new=fake_execute_copilot):
inp = _make_input(tools=["run_block"], tools_exclude=False)
outputs = await self._collect_outputs(block, inp)
assert "error" not in outputs
perms = captured.get("permissions")
assert isinstance(perms, CopilotPermissions)
assert perms.tools == ["run_block"]
assert perms.tools_exclude is False

View File

@@ -1,712 +0,0 @@
"""Unit tests for merge_stats cost tracking in individual blocks.
Covers the exa code_context, exa contents, and apollo organization blocks
to verify provider cost is correctly extracted and reported.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, NodeExecutionStats
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
TEST_EXA_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="exa",
api_key=SecretStr("mock-exa-api-key"),
title="Mock Exa API key",
expires_at=None,
)
TEST_EXA_CREDENTIALS_INPUT = {
"provider": TEST_EXA_CREDENTIALS.provider,
"id": TEST_EXA_CREDENTIALS.id,
"type": TEST_EXA_CREDENTIALS.type,
"title": TEST_EXA_CREDENTIALS.title,
}
# ---------------------------------------------------------------------------
# ExaCodeContextBlock — cost_dollars is a string like "0.005"
# ---------------------------------------------------------------------------
class TestExaCodeContextBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_float_cost(self):
"""float(cost_dollars) parsed from API string and passed to merge_stats."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
api_response = {
"requestId": "req-1",
"query": "how to use hooks",
"response": "Here are some examples...",
"resultsCount": 3,
"costDollars": "0.005",
"searchTime": 1.2,
"outputTokens": 100,
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.code_context.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaCodeContextBlock.Input(
query="how to use hooks",
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
results = []
async for output in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
results.append(output)
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(0.005)
@pytest.mark.asyncio
async def test_invalid_cost_dollars_does_not_raise(self):
"""When cost_dollars cannot be parsed as float, merge_stats is not called."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
api_response = {
"requestId": "req-2",
"query": "query",
"response": "response",
"resultsCount": 0,
"costDollars": "N/A",
"searchTime": 0.5,
"outputTokens": 0,
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
merge_calls: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.code_context.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: merge_calls.append(s)
),
):
input_data = ExaCodeContextBlock.Input(
query="query",
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert merge_calls == []
@pytest.mark.asyncio
async def test_zero_cost_is_tracked(self):
"""A zero cost_dollars string '0.0' should still be recorded."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
api_response = {
"requestId": "req-3",
"query": "query",
"response": "...",
"resultsCount": 1,
"costDollars": "0.0",
"searchTime": 0.1,
"outputTokens": 10,
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.code_context.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaCodeContextBlock.Input(
query="query",
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
# ---------------------------------------------------------------------------
# ExaContentsBlock — response.cost_dollars.total (CostDollars model)
# ---------------------------------------------------------------------------
class TestExaContentsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_cost_dollars_total(self):
"""provider_cost equals response.cost_dollars.total when present."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
cost_dollars = CostDollars(total=0.012)
mock_response = MagicMock()
mock_response.results = []
mock_response.context = None
mock_response.statuses = None
mock_response.cost_dollars = cost_dollars
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.contents.AsyncExa",
return_value=MagicMock(
get_contents=AsyncMock(return_value=mock_response)
),
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaContentsBlock.Input(
urls=["https://example.com"],
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(0.012)
@pytest.mark.asyncio
async def test_no_merge_stats_when_cost_dollars_absent(self):
"""When response.cost_dollars is None, merge_stats is not called."""
from backend.blocks.exa.contents import ExaContentsBlock
block = ExaContentsBlock()
mock_response = MagicMock()
mock_response.results = []
mock_response.context = None
mock_response.statuses = None
mock_response.cost_dollars = None
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.contents.AsyncExa",
return_value=MagicMock(
get_contents=AsyncMock(return_value=mock_response)
),
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaContentsBlock.Input(
urls=["https://example.com"],
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert accumulated == []
# ---------------------------------------------------------------------------
# SearchOrganizationsBlock — provider_cost = float(len(organizations))
# ---------------------------------------------------------------------------
class TestSearchOrganizationsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_org_count(self):
"""provider_cost == number of returned organizations, type == 'items'."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.models import Organization
from backend.blocks.apollo.organization import SearchOrganizationsBlock
block = SearchOrganizationsBlock()
fake_orgs = [Organization(id=str(i), name=f"Org{i}") for i in range(3)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchOrganizationsBlock,
"search_organizations",
new_callable=AsyncMock,
return_value=fake_orgs,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchOrganizationsBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
results = []
async for output in block.run(
input_data,
credentials=APOLLO_CREDS,
):
results.append(output)
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(3.0)
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_org_list_tracks_zero(self):
"""An empty organization list results in provider_cost=0.0."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.organization import SearchOrganizationsBlock
block = SearchOrganizationsBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchOrganizationsBlock,
"search_organizations",
new_callable=AsyncMock,
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchOrganizationsBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=APOLLO_CREDS,
):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# JinaEmbeddingBlock — token count from usage.total_tokens
# ---------------------------------------------------------------------------
class TestJinaEmbeddingBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_token_count(self):
"""provider token count is recorded when API returns usage.total_tokens."""
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
block = JinaEmbeddingBlock()
api_response = {
"data": [{"embedding": [0.1, 0.2, 0.3]}],
"usage": {"total_tokens": 42},
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.jina.embeddings.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = JinaEmbeddingBlock.Input(
texts=["hello world"],
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=JINA_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].input_token_count == 42
@pytest.mark.asyncio
async def test_no_merge_stats_when_usage_absent(self):
"""When API response omits usage field, merge_stats is not called."""
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
block = JinaEmbeddingBlock()
api_response = {
"data": [{"embedding": [0.1, 0.2, 0.3]}],
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.jina.embeddings.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = JinaEmbeddingBlock.Input(
texts=["hello"],
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=JINA_CREDS):
pass
assert accumulated == []
# ---------------------------------------------------------------------------
# UnrealTextToSpeechBlock — character count from input text length
# ---------------------------------------------------------------------------
class TestUnrealTextToSpeechBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_character_count(self):
"""provider_cost equals len(text) with type='characters'."""
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
from backend.blocks.text_to_speech_block import (
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
)
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
block = UnrealTextToSpeechBlock()
test_text = "Hello, world!"
with (
patch.object(
UnrealTextToSpeechBlock,
"call_unreal_speech_api",
new_callable=AsyncMock,
return_value={"OutputUri": "https://example.com/audio.mp3"},
),
patch.object(block, "merge_stats") as mock_merge,
):
input_data = UnrealTextToSpeechBlock.Input(
text=test_text,
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=TTS_CREDS):
pass
mock_merge.assert_called_once()
stats = mock_merge.call_args[0][0]
assert stats.provider_cost == float(len(test_text))
assert stats.provider_cost_type == "characters"
@pytest.mark.asyncio
async def test_empty_text_gives_zero_characters(self):
"""An empty text string results in provider_cost=0.0."""
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
from backend.blocks.text_to_speech_block import (
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
)
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
block = UnrealTextToSpeechBlock()
with (
patch.object(
UnrealTextToSpeechBlock,
"call_unreal_speech_api",
new_callable=AsyncMock,
return_value={"OutputUri": "https://example.com/audio.mp3"},
),
patch.object(block, "merge_stats") as mock_merge,
):
input_data = UnrealTextToSpeechBlock.Input(
text="",
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=TTS_CREDS):
pass
mock_merge.assert_called_once()
stats = mock_merge.call_args[0][0]
assert stats.provider_cost == 0.0
assert stats.provider_cost_type == "characters"
# ---------------------------------------------------------------------------
# GoogleMapsSearchBlock — item count from search_places results
# ---------------------------------------------------------------------------
class TestGoogleMapsSearchBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_place_count(self):
"""provider_cost equals number of returned places, type == 'items'."""
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
from backend.blocks.google_maps import (
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
)
from backend.blocks.google_maps import GoogleMapsSearchBlock
block = GoogleMapsSearchBlock()
fake_places = [{"name": f"Place{i}", "address": f"Addr{i}"} for i in range(4)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
GoogleMapsSearchBlock,
"search_places",
return_value=fake_places,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = GoogleMapsSearchBlock.Input(
query="coffee shops",
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=MAPS_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 4.0
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_results_tracks_zero(self):
"""Zero places returned results in provider_cost=0.0."""
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
from backend.blocks.google_maps import (
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
)
from backend.blocks.google_maps import GoogleMapsSearchBlock
block = GoogleMapsSearchBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
GoogleMapsSearchBlock,
"search_places",
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = GoogleMapsSearchBlock.Input(
query="nothing here",
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=MAPS_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# SmartLeadAddLeadsBlock — item count from lead_list length
# ---------------------------------------------------------------------------
class TestSmartLeadAddLeadsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_lead_count(self):
"""provider_cost equals number of leads uploaded, type == 'items'."""
from backend.blocks.smartlead._auth import TEST_CREDENTIALS as SL_CREDS
from backend.blocks.smartlead._auth import (
TEST_CREDENTIALS_INPUT as SL_CREDS_INPUT,
)
from backend.blocks.smartlead.campaign import AddLeadToCampaignBlock
from backend.blocks.smartlead.models import (
AddLeadsToCampaignResponse,
LeadInput,
)
block = AddLeadToCampaignBlock()
fake_leads = [
LeadInput(first_name="Alice", last_name="A", email="alice@example.com"),
LeadInput(first_name="Bob", last_name="B", email="bob@example.com"),
]
fake_response = AddLeadsToCampaignResponse(
ok=True,
upload_count=2,
total_leads=2,
block_count=0,
duplicate_count=0,
invalid_email_count=0,
invalid_emails=[],
already_added_to_campaign=0,
unsubscribed_leads=[],
is_lead_limit_exhausted=False,
lead_import_stopped_count=0,
bounce_count=0,
)
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
AddLeadToCampaignBlock,
"add_leads_to_campaign",
new_callable=AsyncMock,
return_value=fake_response,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = AddLeadToCampaignBlock.Input(
campaign_id=123,
lead_list=fake_leads,
credentials=SL_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=SL_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 2.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# SearchPeopleBlock — item count from people list length
# ---------------------------------------------------------------------------
class TestSearchPeopleBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_people_count(self):
"""provider_cost equals number of returned people, type == 'items'."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.models import Contact
from backend.blocks.apollo.people import SearchPeopleBlock
block = SearchPeopleBlock()
fake_people = [Contact(id=str(i), first_name=f"Person{i}") for i in range(5)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchPeopleBlock,
"search_people",
new_callable=AsyncMock,
return_value=fake_people,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchPeopleBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(5.0)
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_people_list_tracks_zero(self):
"""An empty people list results in provider_cost=0.0."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.people import SearchPeopleBlock
block = SearchPeopleBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchPeopleBlock,
"search_people",
new_callable=AsyncMock,
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchPeopleBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"

View File

@@ -73,7 +73,7 @@ class ReadDiscordMessagesBlock(Block):
id="df06086a-d5ac-4abb-9996-2ad0acb2eff7",
input_schema=ReadDiscordMessagesBlock.Input, # Assign input schema
output_schema=ReadDiscordMessagesBlock.Output, # Assign output schema
description="Reads new messages from a Discord channel using a bot token and triggers when a new message is posted",
description="Reads messages from a Discord channel using a bot token.",
categories={BlockCategory.SOCIAL},
test_input={
"continuous_read": False,

View File

@@ -9,7 +9,6 @@ from typing import Union
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -117,10 +116,3 @@ class ExaCodeContextBlock(Block):
yield "cost_dollars", context.cost_dollars
yield "search_time", context.search_time
yield "output_tokens", context.output_tokens
# Parse cost_dollars (API returns as string, e.g. "0.005")
try:
cost_usd = float(context.cost_dollars)
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
except (ValueError, TypeError):
pass

View File

@@ -4,7 +4,6 @@ from typing import Optional
from exa_py import AsyncExa
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -224,6 +223,3 @@ class ExaContentsBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

View File

@@ -1,575 +0,0 @@
"""Tests for cost tracking in Exa blocks.
Covers the cost_dollars → provider_cost → merge_stats path for both
ExaContentsBlock and ExaCodeContextBlock.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.exa._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
from backend.data.model import NodeExecutionStats
class TestExaCodeContextCostTracking:
"""ExaCodeContextBlock parses cost_dollars (string) and calls merge_stats."""
@pytest.mark.asyncio
async def test_valid_cost_string_is_parsed_and_merged(self):
"""A numeric cost string like '0.005' is merged as provider_cost."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-1",
"query": "test query",
"response": "some code",
"resultsCount": 3,
"costDollars": "0.005",
"searchTime": 1.2,
"outputTokens": 100,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
outputs = []
async for key, value in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
outputs.append((key, value))
assert any(k == "cost_dollars" for k, _ in outputs)
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.005)
@pytest.mark.asyncio
async def test_invalid_cost_string_does_not_raise(self):
"""A non-numeric cost_dollars value is swallowed silently."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-2",
"query": "test",
"response": "code",
"resultsCount": 0,
"costDollars": "N/A",
"searchTime": 0.5,
"outputTokens": 0,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
outputs = []
async for key, value in block.run(
block.Input(query="test", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
outputs.append((key, value))
# No merge_stats call because float() raised ValueError
assert len(merged) == 0
@pytest.mark.asyncio
async def test_zero_cost_string_is_merged(self):
"""'0.0' is a valid cost — should still be tracked."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-3",
"query": "free query",
"response": "result",
"resultsCount": 1,
"costDollars": "0.0",
"searchTime": 0.1,
"outputTokens": 10,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
async for _ in block.run(
block.Input(query="free query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.0)
class TestExaContentsCostTracking:
"""ExaContentsBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = CostDollars(total=0.012)
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.012)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.contents import ExaContentsBlock
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
@pytest.mark.asyncio
async def test_zero_cost_dollars_is_merged(self):
"""A total of 0.0 (free tier) should still be merged."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = CostDollars(total=0.0)
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.0)
class TestExaSearchCostTracking:
"""ExaSearchBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.helpers import CostDollars
from backend.blocks.exa.search import ExaSearchBlock
block = ExaSearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.resolved_search_type = None
mock_sdk_response.cost_dollars = CostDollars(total=0.008)
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.008)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.search import ExaSearchBlock
block = ExaSearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.resolved_search_type = None
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
class TestExaSimilarCostTracking:
"""ExaFindSimilarBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.helpers import CostDollars
from backend.blocks.exa.similar import ExaFindSimilarBlock
block = ExaFindSimilarBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.request_id = "req-1"
mock_sdk_response.cost_dollars = CostDollars(total=0.015)
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.015)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.similar import ExaFindSimilarBlock
block = ExaFindSimilarBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.request_id = "req-2"
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
# ---------------------------------------------------------------------------
# ExaCreateResearchBlock — cost_dollars from completed poll response
# ---------------------------------------------------------------------------
COMPLETED_RESEARCH_RESPONSE = {
"researchId": "test-research-id",
"status": "completed",
"model": "exa-research",
"instructions": "test instructions",
"createdAt": 1700000000000,
"finishedAt": 1700000060000,
"costDollars": {
"total": 0.05,
"numSearches": 3,
"numPages": 10,
"reasoningTokens": 500,
},
"output": {"content": "Research findings...", "parsed": None},
}
PENDING_RESEARCH_RESPONSE = {
"researchId": "test-research-id",
"status": "pending",
"model": "exa-research",
"instructions": "test instructions",
"createdAt": 1700000000000,
}
class TestExaCreateResearchBlockCostTracking:
"""ExaCreateResearchBlock merges cost from completed poll response."""
@pytest.mark.asyncio
async def test_cost_merged_when_research_completes(self):
"""merge_stats called with provider_cost=total when poll returns completed."""
from backend.blocks.exa.research import ExaCreateResearchBlock
block = ExaCreateResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
create_resp = MagicMock()
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
poll_resp = MagicMock()
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
mock_instance = MagicMock()
mock_instance.post = AsyncMock(return_value=create_resp)
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
instructions="test instructions",
wait_for_completion=True,
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.05)
@pytest.mark.asyncio
async def test_no_merge_when_no_cost_dollars(self):
"""When completed response has no costDollars, merge_stats is not called."""
from backend.blocks.exa.research import ExaCreateResearchBlock
block = ExaCreateResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
create_resp = MagicMock()
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
poll_resp = MagicMock()
poll_resp.json.return_value = no_cost_response
mock_instance = MagicMock()
mock_instance.post = AsyncMock(return_value=create_resp)
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
instructions="test instructions",
wait_for_completion=True,
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert merged == []
# ---------------------------------------------------------------------------
# ExaGetResearchBlock — cost_dollars from single GET response
# ---------------------------------------------------------------------------
class TestExaGetResearchBlockCostTracking:
"""ExaGetResearchBlock merges cost when the fetched research has cost_dollars."""
@pytest.mark.asyncio
async def test_cost_merged_from_completed_research(self):
"""merge_stats called with provider_cost=total when research has costDollars."""
from backend.blocks.exa.research import ExaGetResearchBlock
block = ExaGetResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
get_resp = MagicMock()
get_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=get_resp)
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.05)
@pytest.mark.asyncio
async def test_no_merge_when_no_cost_dollars(self):
"""When research has no costDollars, merge_stats is not called."""
from backend.blocks.exa.research import ExaGetResearchBlock
block = ExaGetResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
get_resp = MagicMock()
get_resp.json.return_value = no_cost_response
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=get_resp)
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert merged == []
# ---------------------------------------------------------------------------
# ExaWaitForResearchBlock — cost_dollars from polling response
# ---------------------------------------------------------------------------
class TestExaWaitForResearchBlockCostTracking:
"""ExaWaitForResearchBlock merges cost when the polled research has cost_dollars."""
@pytest.mark.asyncio
async def test_cost_merged_when_research_completes(self):
"""merge_stats called with provider_cost=total once polling returns completed."""
from backend.blocks.exa.research import ExaWaitForResearchBlock
block = ExaWaitForResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
poll_resp = MagicMock()
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.05)
@pytest.mark.asyncio
async def test_no_merge_when_no_cost_dollars(self):
"""When completed research has no costDollars, merge_stats is not called."""
from backend.blocks.exa.research import ExaWaitForResearchBlock
block = ExaWaitForResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
poll_resp = MagicMock()
poll_resp.json.return_value = no_cost_response
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert merged == []

View File

@@ -12,7 +12,6 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -233,11 +232,6 @@ class ExaCreateResearchBlock(Block):
if research.cost_dollars:
yield "cost_total", research.cost_dollars.total
self.merge_stats(
NodeExecutionStats(
provider_cost=research.cost_dollars.total
)
)
return
await asyncio.sleep(check_interval)
@@ -352,9 +346,6 @@ class ExaGetResearchBlock(Block):
yield "cost_searches", research.cost_dollars.num_searches
yield "cost_pages", research.cost_dollars.num_pages
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
self.merge_stats(
NodeExecutionStats(provider_cost=research.cost_dollars.total)
)
yield "error_message", research.error
@@ -441,9 +432,6 @@ class ExaWaitForResearchBlock(Block):
if research.cost_dollars:
yield "cost_total", research.cost_dollars.total
self.merge_stats(
NodeExecutionStats(provider_cost=research.cost_dollars.total)
)
return

View File

@@ -4,7 +4,6 @@ from typing import Optional
from exa_py import AsyncExa
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -207,6 +206,3 @@ class ExaSearchBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

View File

@@ -3,7 +3,6 @@ from typing import Optional
from exa_py import AsyncExa
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -168,6 +167,3 @@ class ExaFindSimilarBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

View File

@@ -1,6 +1,5 @@
import asyncio
import base64
import re
from abc import ABC
from email import encoders
from email.mime.base import MIMEBase
@@ -9,7 +8,7 @@ from email.mime.text import MIMEText
from email.policy import SMTP
from email.utils import getaddresses, parseaddr
from pathlib import Path
from typing import List, Literal, Optional, Protocol, runtime_checkable
from typing import List, Literal, Optional
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
@@ -43,52 +42,8 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
def serialize_email_recipients(recipients: list[str]) -> str:
"""Serialize recipients list to comma-separated string.
Strips leading/trailing whitespace from each address to keep MIME
headers clean (mirrors the strip done in ``validate_email_recipients``).
"""
return ", ".join(addr.strip() for addr in recipients)
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
"""Validate that all recipients are plausible email addresses.
Raises ``ValueError`` with a user-friendly message listing every
invalid entry so the caller (or LLM) can correct them in one pass.
"""
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
if invalid:
formatted = ", ".join(f"'{a}'" for a in invalid)
raise ValueError(
f"Invalid email address(es) in '{field_name}': {formatted}. "
f"Each entry must be a valid email address (e.g. user@example.com)."
)
@runtime_checkable
class HasRecipients(Protocol):
to: list[str]
cc: list[str]
bcc: list[str]
def validate_all_recipients(input_data: HasRecipients) -> None:
"""Validate to/cc/bcc recipient fields on an input namespace.
Calls ``validate_email_recipients`` for ``to`` (required) and
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
first field that contains an invalid address.
"""
validate_email_recipients(input_data.to, "to")
if input_data.cc:
validate_email_recipients(input_data.cc, "cc")
if input_data.bcc:
validate_email_recipients(input_data.bcc, "bcc")
"""Serialize recipients list to comma-separated string."""
return ", ".join(recipients)
def _make_mime_text(
@@ -145,16 +100,14 @@ async def create_mime_message(
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
validate_all_recipients(input_data)
message = MIMEMultipart()
message["to"] = serialize_email_recipients(input_data.to)
message["subject"] = input_data.subject
if input_data.cc:
message["cc"] = serialize_email_recipients(input_data.cc)
message["cc"] = ", ".join(input_data.cc)
if input_data.bcc:
message["bcc"] = serialize_email_recipients(input_data.bcc)
message["bcc"] = ", ".join(input_data.bcc)
# Use the new helper function with content_type if available
content_type = getattr(input_data, "content_type", None)
@@ -1214,15 +1167,13 @@ async def _build_reply_message(
references.append(headers["message-id"])
# Create MIME message
validate_all_recipients(input_data)
msg = MIMEMultipart()
if input_data.to:
msg["To"] = serialize_email_recipients(input_data.to)
msg["To"] = ", ".join(input_data.to)
if input_data.cc:
msg["Cc"] = serialize_email_recipients(input_data.cc)
msg["Cc"] = ", ".join(input_data.cc)
if input_data.bcc:
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Subject"] = subject
if headers.get("message-id"):
msg["In-Reply-To"] = headers["message-id"]
@@ -1734,16 +1685,13 @@ To: {original_to}
else:
body = f"{forward_header}\n\n{original_body}"
# Validate all recipient lists before building the MIME message
validate_all_recipients(input_data)
# Create MIME message
msg = MIMEMultipart()
msg["To"] = serialize_email_recipients(input_data.to)
msg["To"] = ", ".join(input_data.to)
if input_data.cc:
msg["Cc"] = serialize_email_recipients(input_data.cc)
msg["Cc"] = ", ".join(input_data.cc)
if input_data.bcc:
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Subject"] = subject
# Add body with proper content type

View File

@@ -14,7 +14,6 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -118,11 +117,6 @@ class GoogleMapsSearchBlock(Block):
input_data.radius,
input_data.max_results,
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(places)), provider_cost_type="items"
)
)
for place in places:
yield "place", place

View File

@@ -2,8 +2,6 @@ import copy
from datetime import date, time
from typing import Any, Optional
from pydantic import AliasChoices, Field
from backend.blocks._base import (
Block,
BlockCategory,
@@ -30,9 +28,9 @@ class AgentInputBlock(Block):
"""
This block is used to provide input to the graph.
It takes in a value, name, and description.
It takes in a value, name, description, default values list and bool to limit selection to default values.
It outputs the value passed as input.
It Outputs the value passed as input.
"""
class Input(BlockSchemaInput):
@@ -49,6 +47,12 @@ class AgentInputBlock(Block):
default=None,
advanced=True,
)
placeholder_values: list = SchemaField(
description="The placeholder values to be passed as input.",
default_factory=list,
advanced=True,
hidden=True,
)
advanced: bool = SchemaField(
description="Whether to show the input in the advanced section, if the field is not required.",
default=False,
@@ -61,7 +65,10 @@ class AgentInputBlock(Block):
)
def generate_schema(self):
return copy.deepcopy(self.get_field_schema("value"))
schema = copy.deepcopy(self.get_field_schema("value"))
if possible_values := self.placeholder_values:
schema["enum"] = possible_values
return schema
class Output(BlockSchema):
# Use BlockSchema to avoid automatic error field for interface definition
@@ -79,16 +86,18 @@ class AgentInputBlock(Block):
"value": "Hello, World!",
"name": "input_1",
"description": "Example test input.",
"placeholder_values": [],
},
{
"value": 42,
"value": "Hello, World!",
"name": "input_2",
"description": "Example numeric input.",
"description": "Example test input with placeholders.",
"placeholder_values": ["Hello, World!"],
},
],
"test_output": [
("result", "Hello, World!"),
("result", 42),
("result", "Hello, World!"),
],
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
"block_type": BlockType.INPUT,
@@ -236,11 +245,13 @@ class AgentShortTextInputBlock(AgentInputBlock):
"value": "Hello",
"name": "short_text_1",
"description": "Short text example 1",
"placeholder_values": [],
},
{
"value": "Quick test",
"name": "short_text_2",
"description": "Short text example 2",
"placeholder_values": ["Quick test", "Another option"],
},
],
test_output=[
@@ -274,11 +285,13 @@ class AgentLongTextInputBlock(AgentInputBlock):
"value": "Lorem ipsum dolor sit amet...",
"name": "long_text_1",
"description": "Long text example 1",
"placeholder_values": [],
},
{
"value": "Another multiline text input.",
"name": "long_text_2",
"description": "Long text example 2",
"placeholder_values": ["Another multiline text input."],
},
],
test_output=[
@@ -312,11 +325,13 @@ class AgentNumberInputBlock(AgentInputBlock):
"value": 42,
"name": "number_input_1",
"description": "Number example 1",
"placeholder_values": [],
},
{
"value": 314,
"name": "number_input_2",
"description": "Number example 2",
"placeholder_values": [314, 2718],
},
],
test_output=[
@@ -469,8 +484,7 @@ class AgentFileInputBlock(AgentInputBlock):
class AgentDropdownInputBlock(AgentInputBlock):
"""
A specialized text input block that presents a dropdown selector
restricted to a fixed set of values.
A specialized text input block that relies on placeholder_values to present a dropdown.
"""
class Input(AgentInputBlock.Input):
@@ -480,26 +494,13 @@ class AgentDropdownInputBlock(AgentInputBlock):
advanced=False,
title="Default Value",
)
# Use Field() directly (not SchemaField) to pass validation_alias,
# which handles backward compat for legacy "placeholder_values" across
# all construction paths (model_construct, __init__, model_validate).
options: list = Field(
placeholder_values: list = SchemaField(
description="Possible values for the dropdown.",
default_factory=list,
advanced=False,
title="Dropdown Options",
description=(
"If provided, renders the input as a dropdown selector "
"restricted to these values. Leave empty for free-text input."
),
validation_alias=AliasChoices("options", "placeholder_values"),
json_schema_extra={"advanced": False, "secret": False},
)
def generate_schema(self):
schema = super().generate_schema()
if possible_values := self.options:
schema["enum"] = possible_values
return schema
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="Selected dropdown value.")
@@ -514,13 +515,13 @@ class AgentDropdownInputBlock(AgentInputBlock):
{
"value": "Option A",
"name": "dropdown_1",
"options": ["Option A", "Option B", "Option C"],
"placeholder_values": ["Option A", "Option B", "Option C"],
"description": "Dropdown example 1",
},
{
"value": "Option C",
"name": "dropdown_2",
"options": ["Option A", "Option B", "Option C"],
"placeholder_values": ["Option A", "Option B", "Option C"],
"description": "Dropdown example 2",
},
],

View File

@@ -10,7 +10,7 @@ from backend.blocks.jina._auth import (
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import NodeExecutionStats, SchemaField
from backend.data.model import SchemaField
from backend.util.request import Requests
@@ -45,13 +45,5 @@ class JinaEmbeddingBlock(Block):
}
data = {"input": input_data.texts, "model": input_data.model}
response = await Requests().post(url, headers=headers, json=data)
resp_json = response.json()
embeddings = [e["embedding"] for e in resp_json["data"]]
usage = resp_json.get("usage", {})
if usage.get("total_tokens"):
self.merge_stats(
NodeExecutionStats(
input_token_count=usage.get("total_tokens", 0),
)
)
embeddings = [e["embedding"] for e in response.json()["data"]]
yield "embeddings", embeddings

View File

@@ -1,7 +1,6 @@
# This file contains a lot of prompt block strings that would trigger "line too long"
# flake8: noqa: E501
import logging
import math
import re
import secrets
from abc import ABC
@@ -14,7 +13,6 @@ import ollama
import openai
from anthropic.types import ToolParam
from groq import AsyncGroq
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from pydantic import BaseModel, SecretStr
from backend.blocks._base import (
@@ -51,9 +49,6 @@ settings = Settings()
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
# HTTP status codes for user-caused errors that should not be reported to Sentry.
USER_ERROR_STATUS_CODES = (401, 403, 429)
LLMProviderName = Literal[
ProviderName.AIML_API,
ProviderName.ANTHROPIC,
@@ -106,18 +101,6 @@ class LlmModelMeta(EnumMeta):
class LlmModel(str, Enum, metaclass=LlmModelMeta):
@classmethod
def _missing_(cls, value: object) -> "LlmModel | None":
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
if isinstance(value, str) and "/" in value:
stripped = value.split("/", 1)[1]
try:
return cls(stripped)
except ValueError:
return None
return None
# OpenAI models
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"
@@ -207,19 +190,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
QWEN3_CODER = "qwen/qwen3-coder"
# Z.ai (Zhipu) models
ZAI_GLM_4_32B = "z-ai/glm-4-32b"
ZAI_GLM_4_5 = "z-ai/glm-4.5"
ZAI_GLM_4_5_AIR = "z-ai/glm-4.5-air"
ZAI_GLM_4_5_AIR_FREE = "z-ai/glm-4.5-air:free"
ZAI_GLM_4_5V = "z-ai/glm-4.5v"
ZAI_GLM_4_6 = "z-ai/glm-4.6"
ZAI_GLM_4_6V = "z-ai/glm-4.6v"
ZAI_GLM_4_7 = "z-ai/glm-4.7"
ZAI_GLM_4_7_FLASH = "z-ai/glm-4.7-flash"
ZAI_GLM_5 = "z-ai/glm-5"
ZAI_GLM_5_TURBO = "z-ai/glm-5-turbo"
ZAI_GLM_5V_TURBO = "z-ai/glm-5v-turbo"
# Llama API models
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
@@ -645,43 +615,6 @@ MODEL_METADATA = {
LlmModel.QWEN3_CODER: ModelMetadata(
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
),
# https://openrouter.ai/models?q=z-ai
LlmModel.ZAI_GLM_4_32B: ModelMetadata(
"open_router", 128000, 128000, "GLM 4 32B", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_5: ModelMetadata(
"open_router", 131072, 98304, "GLM 4.5", "OpenRouter", "Z.ai", 2
),
LlmModel.ZAI_GLM_4_5_AIR: ModelMetadata(
"open_router", 131072, 98304, "GLM 4.5 Air", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_5_AIR_FREE: ModelMetadata(
"open_router", 131072, 96000, "GLM 4.5 Air (Free)", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_5V: ModelMetadata(
"open_router", 65536, 16384, "GLM 4.5V", "OpenRouter", "Z.ai", 2
),
LlmModel.ZAI_GLM_4_6: ModelMetadata(
"open_router", 204800, 204800, "GLM 4.6", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_6V: ModelMetadata(
"open_router", 131072, 131072, "GLM 4.6V", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_7: ModelMetadata(
"open_router", 202752, 65535, "GLM 4.7", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_7_FLASH: ModelMetadata(
"open_router", 202752, 202752, "GLM 4.7 Flash", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_5: ModelMetadata(
"open_router", 80000, 80000, "GLM 5", "OpenRouter", "Z.ai", 2
),
LlmModel.ZAI_GLM_5_TURBO: ModelMetadata(
"open_router", 202752, 131072, "GLM 5 Turbo", "OpenRouter", "Z.ai", 3
),
LlmModel.ZAI_GLM_5V_TURBO: ModelMetadata(
"open_router", 202752, 131072, "GLM 5V Turbo", "OpenRouter", "Z.ai", 3
),
# Llama API models
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
"llama_api",
@@ -739,7 +672,6 @@ class LLMResponse(BaseModel):
prompt_tokens: int
completion_tokens: int
reasoning: Optional[str] = None
provider_cost: float | None = None
def convert_openai_tool_fmt_to_anthropic(
@@ -774,41 +706,9 @@ def convert_openai_tool_fmt_to_anthropic(
return anthropic_tools
def extract_openrouter_cost(response: OpenAIChatCompletion) -> float | None:
"""Extract OpenRouter's `x-total-cost` header from an OpenAI SDK response.
OpenRouter returns the per-request USD cost in a response header. The
OpenAI SDK exposes the raw httpx response via an undocumented `_response`
attribute. We use try/except AttributeError so that if the SDK ever drops
or renames that attribute, the warning is visible in logs rather than
silently degrading to no cost tracking.
"""
try:
raw_resp = response._response # type: ignore[attr-defined]
except AttributeError:
logger.warning(
"OpenAI SDK response missing _response attribute"
" — OpenRouter cost tracking unavailable"
)
return None
try:
cost_header = raw_resp.headers.get("x-total-cost")
if not cost_header:
return None
cost = float(cost_header)
if not math.isfinite(cost):
return None
return cost
except (ValueError, TypeError, AttributeError):
return None
def extract_openai_reasoning(response) -> str | None:
"""Extract reasoning from OpenAI-compatible response if available."""
"""Note: This will likely not working since the reasoning is not present in another Response API"""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_reasoning")
return None
reasoning = None
choice = response.choices[0]
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
@@ -824,9 +724,6 @@ def extract_openai_reasoning(response) -> str | None:
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
"""Extract tool calls from OpenAI-compatible response."""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
return None
if response.choices[0].message.tool_calls:
return [
ToolContentBlock(
@@ -899,19 +796,6 @@ async def llm_call(
)
prompt = result.messages
# Sanitize unpaired surrogates in message content to prevent
# UnicodeEncodeError when httpx encodes the JSON request body.
for msg in prompt:
content = msg.get("content")
if isinstance(content, str):
try:
content.encode("utf-8")
except UnicodeEncodeError:
logger.warning("Sanitized unpaired surrogates in LLM prompt content")
msg["content"] = content.encode("utf-8", errors="surrogatepass").decode(
"utf-8", errors="replace"
)
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
model_max_output = llm_model.max_output_tokens or int(2**15)
@@ -994,60 +878,65 @@ async def llm_call(
client = anthropic.AsyncAnthropic(
api_key=credentials.api_key.get_secret_value()
)
resp = await client.messages.create(
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
tools=an_tools,
timeout=600,
)
if not resp.content:
raise ValueError("No content returned from Anthropic.")
tool_calls = None
for content_block in resp.content:
# Antropic is different to openai, need to iterate through
# the content blocks to find the tool calls
if content_block.type == "tool_use":
if tool_calls is None:
tool_calls = []
tool_calls.append(
ToolContentBlock(
id=content_block.id,
type=content_block.type,
function=ToolCall(
name=content_block.name,
arguments=json.dumps(content_block.input),
),
)
)
if not tool_calls and resp.stop_reason == "tool_use":
logger.warning(
f"Tool use stop reason but no tool calls found in content. {resp}"
try:
resp = await client.messages.create(
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
tools=an_tools,
timeout=600,
)
reasoning = None
for content_block in resp.content:
if hasattr(content_block, "type") and content_block.type == "thinking":
reasoning = content_block.thinking
break
if not resp.content:
raise ValueError("No content returned from Anthropic.")
return LLMResponse(
raw_response=resp,
prompt=prompt,
response=(
resp.content[0].name
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
else getattr(resp.content[0], "text", "")
),
tool_calls=tool_calls,
prompt_tokens=resp.usage.input_tokens,
completion_tokens=resp.usage.output_tokens,
reasoning=reasoning,
)
tool_calls = None
for content_block in resp.content:
# Antropic is different to openai, need to iterate through
# the content blocks to find the tool calls
if content_block.type == "tool_use":
if tool_calls is None:
tool_calls = []
tool_calls.append(
ToolContentBlock(
id=content_block.id,
type=content_block.type,
function=ToolCall(
name=content_block.name,
arguments=json.dumps(content_block.input),
),
)
)
if not tool_calls and resp.stop_reason == "tool_use":
logger.warning(
f"Tool use stop reason but no tool calls found in content. {resp}"
)
reasoning = None
for content_block in resp.content:
if hasattr(content_block, "type") and content_block.type == "thinking":
reasoning = content_block.thinking
break
return LLMResponse(
raw_response=resp,
prompt=prompt,
response=(
resp.content[0].name
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
else getattr(resp.content[0], "text", "")
),
tool_calls=tool_calls,
prompt_tokens=resp.usage.input_tokens,
completion_tokens=resp.usage.output_tokens,
reasoning=reasoning,
)
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
logger.error(error_message)
raise ValueError(error_message)
elif provider == "groq":
if tools:
raise ValueError("Groq does not support tools.")
@@ -1060,8 +949,6 @@ async def llm_call(
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
if not response.choices:
raise ValueError("Groq returned empty choices in response")
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
@@ -1121,8 +1008,12 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
raise ValueError(f"OpenRouter returned empty choices: {response}")
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1135,7 +1026,6 @@ async def llm_call(
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
provider_cost=extract_openrouter_cost(response),
)
elif provider == "llama_api":
tools_param = tools if tools else openai.NOT_GIVEN
@@ -1160,8 +1050,12 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
raise ValueError(f"Llama API returned empty choices: {response}")
if response:
raise ValueError(f"Llama API error: {response}")
else:
raise ValueError("No response from Llama API.")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1191,8 +1085,6 @@ async def llm_call(
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
if not completion.choices:
raise ValueError("AI/ML API returned empty choices in response")
return LLMResponse(
raw_response=completion.choices[0].message,
@@ -1229,9 +1121,6 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
if not response.choices:
raise ValueError(f"v0 API returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1443,7 +1332,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
error_feedback_message = ""
llm_model = input_data.model
last_attempt_cost: float | None = None
for retry_count in range(input_data.retry):
logger.debug(f"LLM request: {prompt}")
@@ -1461,15 +1349,12 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
max_tokens=input_data.max_tokens,
)
response_text = llm_response.response
# Merge token counts for every attempt (each call costs tokens).
# provider_cost (actual USD) is tracked separately and only merged
# on success to avoid double-counting across retries.
token_stats = NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
self.merge_stats(
NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
)
)
self.merge_stats(token_stats)
last_attempt_cost = llm_response.provider_cost
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
if input_data.expected_format:
@@ -1538,7 +1423,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
provider_cost=last_attempt_cost,
)
)
yield "response", response_obj
@@ -1559,23 +1443,13 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
provider_cost=last_attempt_cost,
)
)
yield "response", {"response": response_text}
yield "prompt", self.prompt
return
except Exception as e:
is_user_error = (
isinstance(e, (anthropic.APIStatusError, openai.APIStatusError))
and e.status_code in USER_ERROR_STATUS_CODES
)
if is_user_error:
logger.warning(f"Error calling LLM: {e}")
error_feedback_message = f"Error calling LLM: {e}"
break
else:
logger.exception(f"Error calling LLM: {e}")
logger.exception(f"Error calling LLM: {e}")
if (
"maximum context length" in str(e).lower()
or "token limit" in str(e).lower()
@@ -2105,19 +1979,6 @@ class AIConversationBlock(AIBlockBase):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
has_messages = any(
isinstance(m, dict)
and isinstance(m.get("content"), str)
and bool(m["content"].strip())
for m in (input_data.messages or [])
)
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
if not has_messages and not has_prompt:
raise ValueError(
"Cannot call LLM with no messages and no prompt. "
"Provide at least one message or a non-empty prompt."
)
response = await self.llm_call(
AIStructuredResponseGeneratorBlock.Input(
prompt=input_data.prompt,

View File

@@ -89,12 +89,6 @@ class MCPToolBlock(Block):
default={},
hidden=True,
)
tool_description: str = SchemaField(
description="Description of the selected MCP tool. "
"Populated automatically when a tool is selected.",
default="",
hidden=True,
)
tool_arguments: dict[str, Any] = SchemaField(
description="Arguments to pass to the selected MCP tool. "

View File

@@ -23,7 +23,7 @@ from backend.blocks.smartlead.models import (
SaveSequencesResponse,
Sequence,
)
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
from backend.data.model import CredentialsField, SchemaField
class CreateCampaignBlock(Block):
@@ -226,12 +226,6 @@ class AddLeadToCampaignBlock(Block):
response = await self.add_leads_to_campaign(
input_data.campaign_id, input_data.lead_list, credentials
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(input_data.lead_list)),
provider_cost_type="items",
)
)
yield "campaign_id", input_data.campaign_id
yield "upload_count", response.upload_count

View File

@@ -1,323 +0,0 @@
import asyncio
from typing import Any, Literal
from pydantic import SecretStr
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import DBAPIError, OperationalError, ProgrammingError
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.sql_query_helpers import (
_DATABASE_TYPE_DEFAULT_PORT,
_DATABASE_TYPE_TO_DRIVER,
DatabaseType,
_execute_query,
_sanitize_error,
_validate_query_is_read_only,
_validate_single_statement,
)
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
SchemaField,
UserPasswordCredentials,
)
from backend.integrations.providers import ProviderName
from backend.util.request import resolve_and_check_blocked
TEST_CREDENTIALS = UserPasswordCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="database",
username=SecretStr("test_user"),
password=SecretStr("test_pass"),
title="Mock Database credentials",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
DatabaseCredentials = UserPasswordCredentials
DatabaseCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.DATABASE],
Literal["user_password"],
]
def DatabaseCredentialsField() -> DatabaseCredentialsInput:
return CredentialsField(
description="Database username and password",
)
class SQLQueryBlock(Block):
class Input(BlockSchemaInput):
database_type: DatabaseType = SchemaField(
default=DatabaseType.POSTGRES,
description="Database engine",
advanced=False,
)
host: SecretStr = SchemaField(
description=(
"Database hostname or IP address. "
"Treated as a secret to avoid leaking infrastructure details. "
"Private/internal IPs are blocked (SSRF protection)."
),
placeholder="db.example.com",
secret=True,
)
port: int | None = SchemaField(
default=None,
description=(
"Database port (leave empty for default: "
"PostgreSQL: 5432, MySQL: 3306, MSSQL: 1433)"
),
ge=1,
le=65535,
)
database: str = SchemaField(
description="Name of the database to connect to",
placeholder="my_database",
)
query: str = SchemaField(
description="SQL query to execute",
placeholder="SELECT * FROM analytics.daily_active_users LIMIT 10",
)
read_only: bool = SchemaField(
default=True,
description=(
"When enabled (default), only SELECT queries are allowed "
"and the database session is set to read-only mode. "
"Disable to allow write operations (INSERT, UPDATE, DELETE, etc.)."
),
)
timeout: int = SchemaField(
default=30,
description="Query timeout in seconds (max 120)",
ge=1,
le=120,
)
max_rows: int = SchemaField(
default=1000,
description="Maximum number of rows to return (max 10000)",
ge=1,
le=10000,
)
credentials: DatabaseCredentialsInput = DatabaseCredentialsField()
class Output(BlockSchemaOutput):
results: list[dict[str, Any]] = SchemaField(
description="Query results as a list of row dictionaries"
)
columns: list[str] = SchemaField(
description="Column names from the query result"
)
row_count: int = SchemaField(description="Number of rows returned")
truncated: bool = SchemaField(
description=(
"True when the result set was capped by max_rows, "
"indicating additional rows exist in the database"
)
)
affected_rows: int = SchemaField(
description="Number of rows affected by a write query (INSERT/UPDATE/DELETE)"
)
error: str = SchemaField(description="Error message if the query failed")
def __init__(self):
super().__init__(
id="4dc35c0f-4fd8-465e-9616-5a216f1ba2bc",
description=(
"Execute a SQL query. Read-only by default for safety "
"-- disable to allow write operations. "
"Supports PostgreSQL, MySQL, and MSSQL via SQLAlchemy."
),
categories={BlockCategory.DATA},
input_schema=SQLQueryBlock.Input,
output_schema=SQLQueryBlock.Output,
test_input={
"query": "SELECT 1 AS test_col",
"database_type": DatabaseType.POSTGRES,
"host": "localhost",
"database": "test_db",
"timeout": 30,
"max_rows": 1000,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("results", [{"test_col": 1}]),
("columns", ["test_col"]),
("row_count", 1),
("truncated", False),
],
test_mock={
"execute_query": lambda *_args, **_kwargs: (
[{"test_col": 1}],
["test_col"],
-1,
False,
),
"check_host_allowed": lambda *_args, **_kwargs: ["127.0.0.1"],
},
)
@staticmethod
async def check_host_allowed(host: str) -> list[str]:
"""Validate that the given host is not a private/blocked address.
Returns the list of resolved IP addresses so the caller can pin the
connection to the validated IP (preventing DNS rebinding / TOCTOU).
Raises ValueError or OSError if the host is blocked.
Extracted as a method so it can be mocked during block tests.
"""
return await resolve_and_check_blocked(host)
@staticmethod
def execute_query(
connection_url: URL | str,
query: str,
timeout: int,
max_rows: int,
read_only: bool = True,
database_type: DatabaseType = DatabaseType.POSTGRES,
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
Delegates to ``_execute_query`` in ``sql_query_helpers``.
Extracted as a method so it can be mocked during block tests.
"""
return _execute_query(
connection_url=connection_url,
query=query,
timeout=timeout,
max_rows=max_rows,
read_only=read_only,
database_type=database_type,
)
async def run(
self,
input_data: Input,
*,
credentials: DatabaseCredentials,
**_kwargs: Any,
) -> BlockOutput:
# Validate query structure and read-only constraints.
error = self._validate_query(input_data)
if error:
yield "error", error
return
# Validate host and resolve for SSRF protection.
host, pinned_host, error = await self._resolve_host(input_data)
if error:
yield "error", error
return
# Build connection URL and execute.
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT[input_data.database_type]
username = credentials.username.get_secret_value()
connection_url = URL.create(
drivername=_DATABASE_TYPE_TO_DRIVER[input_data.database_type],
username=username,
password=credentials.password.get_secret_value(),
host=pinned_host,
port=port,
database=input_data.database,
)
conn_str = connection_url.render_as_string(hide_password=True)
db_name = input_data.database
def _sanitize(err: Exception) -> str:
return _sanitize_error(
str(err).strip(),
conn_str,
host=pinned_host,
original_host=host,
username=username,
port=port,
database=db_name,
)
try:
results, columns, affected, truncated = await asyncio.to_thread(
self.execute_query,
connection_url=connection_url,
query=input_data.query,
timeout=input_data.timeout,
max_rows=input_data.max_rows,
read_only=input_data.read_only,
database_type=input_data.database_type,
)
yield "results", results
yield "columns", columns
yield "row_count", len(results)
yield "truncated", truncated
if affected >= 0:
yield "affected_rows", affected
except OperationalError as e:
yield (
"error",
self._classify_operational_error(
_sanitize(e),
input_data.timeout,
),
)
except ProgrammingError as e:
yield "error", f"SQL error: {_sanitize(e)}"
except DBAPIError as e:
yield "error", f"Database error: {_sanitize(e)}"
except ModuleNotFoundError:
yield (
"error",
(
f"Database driver not available for "
f"{input_data.database_type.value}. "
f"Please contact the platform administrator."
),
)
@staticmethod
def _validate_query(input_data: "SQLQueryBlock.Input") -> str | None:
"""Validate query structure and read-only constraints."""
stmt_error, parsed_stmt = _validate_single_statement(input_data.query)
if stmt_error:
return stmt_error
assert parsed_stmt is not None
if input_data.read_only:
return _validate_query_is_read_only(parsed_stmt)
return None
async def _resolve_host(
self, input_data: "SQLQueryBlock.Input"
) -> tuple[str, str, str | None]:
"""Validate and resolve the database host. Returns (host, pinned_ip, error)."""
host = input_data.host.get_secret_value().strip()
if not host:
return "", "", "Database host is required."
if host.startswith("/"):
return host, "", "Unix socket connections are not allowed."
try:
resolved_ips = await self.check_host_allowed(host)
except (ValueError, OSError) as e:
return host, "", f"Blocked host: {str(e).strip()}"
return host, resolved_ips[0], None
@staticmethod
def _classify_operational_error(sanitized_msg: str, timeout: int) -> str:
"""Classify an already-sanitized OperationalError for user display."""
lower = sanitized_msg.lower()
if "timeout" in lower or "cancel" in lower:
return f"Query timed out after {timeout}s."
if "connect" in lower:
return f"Failed to connect to database: {sanitized_msg}"
return f"Database error: {sanitized_msg}"

File diff suppressed because it is too large Load Diff

View File

@@ -1,430 +0,0 @@
import re
from datetime import date, datetime, time
from decimal import Decimal
from enum import Enum
from typing import Any
import sqlparse
from sqlalchemy import create_engine, text
from sqlalchemy.engine.url import URL
class DatabaseType(str, Enum):
POSTGRES = "postgres"
MYSQL = "mysql"
MSSQL = "mssql"
# Defense-in-depth: reject queries containing data-modifying keywords.
# These are checked against parsed SQL tokens (not raw text) so column names
# and string literals do not cause false positives.
_DISALLOWED_KEYWORDS = {
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"ALTER",
"CREATE",
"TRUNCATE",
"GRANT",
"REVOKE",
"COPY",
"EXECUTE",
"CALL",
"SET",
"RESET",
"DISCARD",
"NOTIFY",
"DO",
# MySQL file exfiltration: LOAD DATA LOCAL INFILE reads server/client files
"LOAD",
# MySQL REPLACE is INSERT-or-UPDATE; data modification
"REPLACE",
# ANSI MERGE (UPSERT) modifies data
"MERGE",
# MSSQL BULK INSERT loads external files into tables
"BULK",
# MSSQL EXEC / EXEC sp_name runs stored procedures (arbitrary code)
"EXEC",
}
# Map DatabaseType enum values to the expected SQLAlchemy driver prefix.
_DATABASE_TYPE_TO_DRIVER = {
DatabaseType.POSTGRES: "postgresql",
DatabaseType.MYSQL: "mysql+pymysql",
DatabaseType.MSSQL: "mssql+pymssql",
}
# Connection timeout in seconds passed to the DBAPI driver (connect_timeout /
# login_timeout). This bounds how long the driver waits to establish a TCP
# connection to the database server. It is separate from the per-statement
# timeout configured via SET commands inside _configure_session().
_CONNECT_TIMEOUT_SECONDS = 10
# Default ports for each database type.
_DATABASE_TYPE_DEFAULT_PORT = {
DatabaseType.POSTGRES: 5432,
DatabaseType.MYSQL: 3306,
DatabaseType.MSSQL: 1433,
}
def _sanitize_error(
error_msg: str,
connection_string: str,
*,
host: str = "",
original_host: str = "",
username: str = "",
port: int = 0,
database: str = "",
) -> str:
"""Remove connection string, credentials, and infrastructure details
from error messages so they are safe to expose to the LLM.
Scrubs:
- The full connection string
- URL-embedded credentials (``://user:pass@``)
- ``password=<value>`` key-value pairs
- The database hostname / IP used for the connection
- The original (pre-resolution) hostname provided by the user
- Any IPv4 addresses that appear in the message
- Any bracketed IPv6 addresses (e.g. ``[::1]``, ``[fe80::1%eth0]``)
- The database username
- The database port number
- The database name
"""
sanitized = error_msg.replace(connection_string, "<connection_string>")
sanitized = re.sub(r"password=[^\s&]+", "password=***", sanitized)
sanitized = re.sub(r"://[^@]+@", "://***:***@", sanitized)
# Replace the known host (may be an IP already) before the generic IP pass.
# Also replace the original (pre-DNS-resolution) hostname if it differs.
if original_host and original_host != host:
sanitized = sanitized.replace(original_host, "<host>")
if host:
sanitized = sanitized.replace(host, "<host>")
# Replace any remaining IPv4 addresses (e.g. resolved IPs the driver logs)
sanitized = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", "<ip>", sanitized)
# Replace bracketed IPv6 addresses (e.g. "[::1]", "[fe80::1%eth0]")
sanitized = re.sub(r"\[[0-9a-fA-F:]+(?:%[^\]]+)?\]", "<ip>", sanitized)
# Replace the database username (handles double-quoted, single-quoted,
# and unquoted formats across PostgreSQL, MySQL, and MSSQL error messages).
if username:
sanitized = re.sub(
r"""for user ["']?""" + re.escape(username) + r"""["']?""",
"for user <user>",
sanitized,
)
# Catch remaining bare occurrences in various quote styles:
# - PostgreSQL: "FATAL: role "myuser" does not exist"
# - MySQL: "Access denied for user 'myuser'@'host'"
# - MSSQL: "Login failed for user 'myuser'"
sanitized = sanitized.replace(f'"{username}"', "<user>")
sanitized = sanitized.replace(f"'{username}'", "<user>")
# Replace the port number (handles "port 5432" and ":5432" formats)
if port:
port_str = re.escape(str(port))
sanitized = re.sub(
r"(?:port |:)" + port_str + r"(?![0-9])",
lambda m: ("port " if m.group().startswith("p") else ":") + "<port>",
sanitized,
)
# Replace the database name to avoid leaking internal infrastructure names.
# Use word-boundary regex to prevent mangling when the database name is a
# common substring (e.g. "test", "data", "on").
if database:
sanitized = re.sub(r"\b" + re.escape(database) + r"\b", "<database>", sanitized)
return sanitized
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
"""Extract keyword tokens from a parsed SQL statement.
Uses sqlparse token type classification to collect Keyword/DML/DDL/DCL
tokens. String literals and identifiers have different token types, so
they are naturally excluded from the result.
"""
return [
token.normalized.upper()
for token in parsed.flatten()
if token.ttype
in (
sqlparse.tokens.Keyword,
sqlparse.tokens.Keyword.DML,
sqlparse.tokens.Keyword.DDL,
sqlparse.tokens.Keyword.DCL,
)
]
def _has_disallowed_into(stmt: sqlparse.sql.Statement) -> bool:
"""Check if a statement contains a disallowed ``INTO`` clause.
``SELECT ... INTO @variable`` is a valid read-only MySQL syntax that stores
a query result into a session-scoped user variable. All other forms of
``INTO`` are data-modifying or file-writing and must be blocked:
* ``SELECT ... INTO new_table`` (PostgreSQL / MSSQL creates a table)
* ``SELECT ... INTO OUTFILE`` (MySQL writes to the filesystem)
* ``SELECT ... INTO DUMPFILE`` (MySQL writes to the filesystem)
* ``INSERT INTO ...`` (already blocked by INSERT being in the
disallowed set, but we reject INTO as well for defense-in-depth)
Returns ``True`` if the statement contains a disallowed ``INTO``.
"""
flat = list(stmt.flatten())
for i, token in enumerate(flat):
if not (
token.ttype in (sqlparse.tokens.Keyword,)
and token.normalized.upper() == "INTO"
):
continue
# Look at the first non-whitespace token after INTO.
j = i + 1
while j < len(flat) and flat[j].ttype is sqlparse.tokens.Text.Whitespace:
j += 1
if j >= len(flat):
# INTO at the very end malformed, block it.
return True
next_token = flat[j]
# MySQL user variable: either a single Name starting with "@"
# (e.g. ``@total``) or a bare ``@`` Operator token followed by a Name.
if next_token.ttype is sqlparse.tokens.Name and next_token.value.startswith(
"@"
):
continue
if next_token.ttype is sqlparse.tokens.Operator and next_token.value == "@":
continue
# Everything else (table name, OUTFILE, DUMPFILE, etc.) is disallowed.
return True
return False
def _validate_query_is_read_only(stmt: sqlparse.sql.Statement) -> str | None:
"""Validate that a parsed SQL statement is read-only (SELECT/WITH only).
Accepts an already-parsed statement from ``_validate_single_statement``
to avoid re-parsing. Checks:
1. Statement type must be SELECT (sqlparse classifies WITH...SELECT as SELECT)
2. No disallowed keywords (INSERT, UPDATE, DELETE, DROP, etc.)
3. No disallowed INTO clauses (allows MySQL ``SELECT ... INTO @variable``)
Returns an error message if the query is not read-only, None otherwise.
"""
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
if stmt.get_type() != "SELECT":
return "Only SELECT queries are allowed."
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
for kw in _extract_keyword_tokens(stmt):
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
base_kw = kw.split()[0] if " " in kw else kw
if base_kw in _DISALLOWED_KEYWORDS:
return f"Disallowed SQL keyword: {kw}"
# Contextual check for INTO: allow MySQL @variable syntax, block everything else
if _has_disallowed_into(stmt):
return "Disallowed SQL keyword: INTO"
return None
def _validate_single_statement(
query: str,
) -> tuple[str | None, sqlparse.sql.Statement | None]:
"""Validate that the query contains exactly one non-empty SQL statement.
Returns (error_message, parsed_statement). If error_message is not None,
the query is invalid and parsed_statement will be None.
"""
stripped = query.strip().rstrip(";").strip()
if not stripped:
return "Query is empty.", None
# Parse the SQL using sqlparse for proper tokenization
statements = sqlparse.parse(stripped)
# Filter out empty statements and comment-only statements
statements = [
s
for s in statements
if s.tokens
and str(s).strip()
and not all(
t.is_whitespace or t.ttype in sqlparse.tokens.Comment for t in s.flatten()
)
]
if not statements:
return "Query is empty.", None
# Reject multiple statements -- prevents injection via semicolons
if len(statements) > 1:
return "Only single statements are allowed.", None
return None, statements[0]
def _serialize_value(value: Any) -> Any:
"""Convert database-specific types to JSON-serializable Python types."""
if isinstance(value, Decimal):
# NaN / Infinity are not valid JSON numbers; serialize as strings.
if value.is_nan() or value.is_infinite():
return str(value)
# Use int for whole numbers; use str for fractional to preserve exact
# precision (float would silently round high-precision analytics values).
if value == value.to_integral_value():
return int(value)
return str(value)
if isinstance(value, (datetime, date, time)):
return value.isoformat()
if isinstance(value, memoryview):
return bytes(value).hex()
if isinstance(value, bytes):
return value.hex()
return value
def _configure_session(
conn: Any,
dialect_name: str,
timeout_ms: str,
read_only: bool,
) -> None:
"""Set session-level timeout and read-only mode for the given dialect.
Timeout limitations by database:
* **PostgreSQL** ``statement_timeout`` reliably cancels any running
statement (SELECT or DML) after the configured duration.
* **MySQL** ``MAX_EXECUTION_TIME`` only applies to **read-only SELECT**
statements. DML (INSERT/UPDATE/DELETE) and DDL are *not* bounded by
this hint; they rely on the server's ``wait_timeout`` /
``interactive_timeout`` instead. There is no session-level setting in
MySQL that reliably cancels long-running writes.
* **MSSQL** ``SET LOCK_TIMEOUT`` only limits how long the server waits
to acquire a **lock**. CPU-bound queries (e.g. large scans, hash
joins) that do not block on locks will *not* be cancelled. MSSQL has
no session-level ``statement_timeout`` equivalent; the closest
mechanism is Resource Governor (requires sysadmin configuration) or
``CONTEXT_INFO``-based external monitoring.
Note: SQLite is not supported by this block. The ``_configure_session``
function is a no-op for unrecognised dialect names, so an SQLite engine
would skip all SET commands silently. The block's ``DatabaseType`` enum
intentionally excludes SQLite.
"""
if dialect_name == "postgresql":
conn.execute(text("SET statement_timeout = " + timeout_ms))
if read_only:
conn.execute(text("SET default_transaction_read_only = ON"))
elif dialect_name == "mysql":
# NOTE: MAX_EXECUTION_TIME only applies to SELECT statements.
# Write queries (INSERT/UPDATE/DELETE) are not bounded by this
# setting; they rely on the database's wait_timeout instead.
# See docstring above for full limitations.
conn.execute(text("SET SESSION MAX_EXECUTION_TIME = " + timeout_ms))
if read_only:
conn.execute(text("SET SESSION TRANSACTION READ ONLY"))
elif dialect_name == "mssql":
# MSSQL: SET LOCK_TIMEOUT limits lock-wait time (ms) only.
# CPU-bound queries without lock contention are NOT cancelled.
# See docstring above for full limitations.
conn.execute(text("SET LOCK_TIMEOUT " + timeout_ms))
# MSSQL lacks a session-level read-only mode like
# PostgreSQL/MySQL. Read-only enforcement is handled by
# the SQL validation layer (_validate_query_is_read_only)
# and the ROLLBACK in the finally block.
def _run_in_transaction(
conn: Any,
dialect_name: str,
query: str,
max_rows: int,
read_only: bool,
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
"""Execute a query inside an explicit transaction, returning results.
Returns ``(rows, columns, affected_rows, truncated)`` where *truncated*
is ``True`` when ``fetchmany`` returned exactly ``max_rows`` rows,
indicating that additional rows may exist in the result set.
"""
# MSSQL uses T-SQL "BEGIN TRANSACTION"; others use "BEGIN".
begin_stmt = "BEGIN TRANSACTION" if dialect_name == "mssql" else "BEGIN"
conn.execute(text(begin_stmt))
try:
result = conn.execute(text(query))
affected = result.rowcount if not result.returns_rows else -1
columns = list(result.keys()) if result.returns_rows else []
rows = result.fetchmany(max_rows) if result.returns_rows else []
truncated = len(rows) == max_rows
results = [
{col: _serialize_value(val) for col, val in zip(columns, row)}
for row in rows
]
except Exception:
try:
conn.execute(text("ROLLBACK"))
except Exception:
pass
raise
else:
conn.execute(text("ROLLBACK" if read_only else "COMMIT"))
return results, columns, affected, truncated
def _execute_query(
connection_url: URL | str,
query: str,
timeout: int,
max_rows: int,
read_only: bool = True,
database_type: DatabaseType = DatabaseType.POSTGRES,
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
Uses SQLAlchemy to connect to any supported database.
For SELECT queries, rows are limited to ``max_rows`` via DBAPI fetchmany.
``truncated`` is ``True`` when the result set was capped by ``max_rows``.
For write queries, affected_rows contains the rowcount from the driver.
When ``read_only`` is True, the database session is set to read-only
mode and the transaction is always rolled back.
"""
# Determine driver-specific connection timeout argument.
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
timeout_key = (
"login_timeout" if database_type == DatabaseType.MSSQL else "connect_timeout"
)
engine = create_engine(
connection_url, connect_args={timeout_key: _CONNECT_TIMEOUT_SECONDS}
)
try:
with engine.connect() as conn:
# Use AUTOCOMMIT so SET commands take effect immediately.
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
# Compute timeout in milliseconds. The value is Pydantic-validated
# (ge=1, le=120), but we use int() as defense-in-depth.
# NOTE: SET commands do not support bind parameters in most
# databases, so we use str(int(...)) for safe interpolation.
timeout_ms = str(int(timeout * 1000))
_configure_session(conn, engine.dialect.name, timeout_ms, read_only)
return _run_in_transaction(
conn, engine.dialect.name, query, max_rows, read_only
)
finally:
engine.dispose()

View File

@@ -1,8 +1,13 @@
import logging
import signal
import threading
import warnings
from contextlib import contextmanager
from enum import Enum
from stagehand import AsyncStagehand
from stagehand.types.session_act_params import Options as ActOptions
# Monkey patch Stagehands to prevent signal handling in worker threads
import stagehand.main
from stagehand import Stagehand
from backend.blocks.llm import (
MODEL_METADATA,
@@ -23,6 +28,46 @@ from backend.sdk import (
SchemaField,
)
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
warnings.filterwarnings("ignore", module="litellm.llms.custom_httpx")
# Store the original method
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
def safe_register_signal_handlers(self):
"""Only register signal handlers in the main thread"""
if threading.current_thread() is threading.main_thread():
original_register_signal_handlers(self)
else:
# Skip signal handling in worker threads
pass
# Replace the method
stagehand.main.Stagehand._register_signal_handlers = safe_register_signal_handlers
@contextmanager
def disable_signal_handling():
"""Context manager to temporarily disable signal handling"""
if threading.current_thread() is not threading.main_thread():
# In worker threads, temporarily replace signal.signal with a no-op
original_signal = signal.signal
def noop_signal(*args, **kwargs):
pass
signal.signal = noop_signal
try:
yield
finally:
signal.signal = original_signal
else:
# In main thread, don't modify anything
yield
logger = logging.getLogger(__name__)
@@ -103,10 +148,13 @@ class StagehandObserveBlock(Block):
instruction: str = SchemaField(
description="Natural language description of elements or actions to discover.",
)
dom_settle_timeout_ms: int = SchemaField(
description="Timeout in ms to wait for the DOM to settle after navigation.",
default=30000,
advanced=True,
iframes: bool = SchemaField(
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
default=True,
)
domSettleTimeoutMs: int = SchemaField(
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
default=45000,
)
class Output(BlockSchemaOutput):
@@ -137,28 +185,32 @@ class StagehandObserveBlock(Block):
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
async with AsyncStagehand(
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
browserbase_project_id=input_data.browserbase_project_id,
model_api_key=model_credentials.api_key.get_secret_value(),
) as client:
session = await client.sessions.start(
with disable_signal_handling():
stagehand = Stagehand(
api_key=stagehand_credentials.api_key.get_secret_value(),
project_id=input_data.browserbase_project_id,
model_name=input_data.model.provider_name,
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
model_api_key=model_credentials.api_key.get_secret_value(),
)
try:
await session.navigate(url=input_data.url)
observe_response = await session.observe(
instruction=input_data.instruction,
)
for result in observe_response.data.result:
yield "selector", result.selector
yield "description", result.description
yield "method", result.method
yield "arguments", result.arguments
finally:
await session.end()
await stagehand.init()
page = stagehand.page
assert page is not None, "Stagehand page is not initialized"
await page.goto(input_data.url)
observe_results = await page.observe(
input_data.instruction,
iframes=input_data.iframes,
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
)
for result in observe_results:
yield "selector", result.selector
yield "description", result.description
yield "method", result.method
yield "arguments", result.arguments
class StagehandActBlock(Block):
@@ -190,22 +242,24 @@ class StagehandActBlock(Block):
description="Variables to use in the action. Variables contains data you want the action to use.",
default_factory=dict,
)
dom_settle_timeout_ms: int = SchemaField(
description="Timeout in ms to wait for the DOM to settle after navigation.",
default=30000,
advanced=True,
iframes: bool = SchemaField(
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
default=True,
)
timeout_ms: int = SchemaField(
description="Timeout in ms for each action.",
default=30000,
advanced=True,
domSettleTimeoutMs: int = SchemaField(
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
default=45000,
)
timeoutMs: int = SchemaField(
description="Timeout in milliseconds for DOM ready. Extended timeout for slow-loading forms",
default=60000,
)
class Output(BlockSchemaOutput):
success: bool = SchemaField(
description="Whether the action was completed successfully"
)
message: str = SchemaField(description="Details about the action's execution.")
message: str = SchemaField(description="Details about the actions execution.")
action: str = SchemaField(description="Action performed")
def __init__(self):
@@ -228,33 +282,32 @@ class StagehandActBlock(Block):
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
async with AsyncStagehand(
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
browserbase_project_id=input_data.browserbase_project_id,
model_api_key=model_credentials.api_key.get_secret_value(),
) as client:
session = await client.sessions.start(
with disable_signal_handling():
stagehand = Stagehand(
api_key=stagehand_credentials.api_key.get_secret_value(),
project_id=input_data.browserbase_project_id,
model_name=input_data.model.provider_name,
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
model_api_key=model_credentials.api_key.get_secret_value(),
)
try:
await session.navigate(url=input_data.url)
for action in input_data.action:
act_options = ActOptions(
variables={k: v for k, v in input_data.variables.items()},
timeout=input_data.timeout_ms,
)
act_response = await session.act(
input=action,
options=act_options,
)
result = act_response.data.result
yield "success", result.success
yield "message", result.message
yield "action", result.action_description
finally:
await session.end()
await stagehand.init()
page = stagehand.page
assert page is not None, "Stagehand page is not initialized"
await page.goto(input_data.url)
for action in input_data.action:
action_results = await page.act(
action,
variables=input_data.variables,
iframes=input_data.iframes,
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
timeoutMs=input_data.timeoutMs,
)
yield "success", action_results.success
yield "message", action_results.message
yield "action", action_results.action
class StagehandExtractBlock(Block):
@@ -282,10 +335,13 @@ class StagehandExtractBlock(Block):
instruction: str = SchemaField(
description="Natural language description of elements or actions to discover.",
)
dom_settle_timeout_ms: int = SchemaField(
description="Timeout in ms to wait for the DOM to settle after navigation.",
default=30000,
advanced=True,
iframes: bool = SchemaField(
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
default=True,
)
domSettleTimeoutMs: int = SchemaField(
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
default=45000,
)
class Output(BlockSchemaOutput):
@@ -311,21 +367,24 @@ class StagehandExtractBlock(Block):
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
async with AsyncStagehand(
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
browserbase_project_id=input_data.browserbase_project_id,
model_api_key=model_credentials.api_key.get_secret_value(),
) as client:
session = await client.sessions.start(
with disable_signal_handling():
stagehand = Stagehand(
api_key=stagehand_credentials.api_key.get_secret_value(),
project_id=input_data.browserbase_project_id,
model_name=input_data.model.provider_name,
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
model_api_key=model_credentials.api_key.get_secret_value(),
)
try:
await session.navigate(url=input_data.url)
extract_response = await session.extract(
instruction=input_data.instruction,
)
yield "extraction", str(extract_response.data.result)
finally:
await session.end()
await stagehand.init()
page = stagehand.page
assert page is not None, "Stagehand page is not initialized"
await page.goto(input_data.url)
extraction = await page.extract(
input_data.instruction,
iframes=input_data.iframes,
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
)
yield "extraction", str(extraction.model_dump()["extraction"])

View File

@@ -4,8 +4,6 @@ import pytest
from backend.blocks import get_blocks
from backend.blocks._base import Block, BlockSchemaInput
from backend.blocks.io import AgentDropdownInputBlock, AgentInputBlock
from backend.data.graph import BaseGraph
from backend.data.model import SchemaField
from backend.util.test import execute_block_test
@@ -281,113 +279,3 @@ class TestAutoCredentialsFieldsValidation:
assert "Duplicate auto_credentials kwarg_name 'credentials'" in str(
exc_info.value
)
def test_agent_input_block_ignores_legacy_placeholder_values():
"""Verify AgentInputBlock.Input.model_construct tolerates extra placeholder_values
for backward compatibility with existing agent JSON."""
legacy_data = {
"name": "url",
"value": "",
"description": "Enter a URL",
"placeholder_values": ["https://example.com"],
}
instance = AgentInputBlock.Input.model_construct(**legacy_data)
schema = instance.generate_schema()
assert (
"enum" not in schema
), "AgentInputBlock should not produce enum from legacy placeholder_values"
def test_dropdown_input_block_produces_enum():
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum
using the canonical 'options' field name."""
opts = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_construct(
name="choice", value=None, options=opts
)
schema = instance.generate_schema()
assert schema.get("enum") == opts
def test_dropdown_input_block_legacy_placeholder_values_produces_enum():
"""Verify backward compat: passing legacy 'placeholder_values' to
AgentDropdownInputBlock still produces enum via model_construct remap."""
opts = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_construct(
name="choice", value=None, placeholder_values=opts
)
schema = instance.generate_schema()
assert (
schema.get("enum") == opts
), "Legacy placeholder_values should be remapped to options"
def test_generate_schema_integration_legacy_placeholder_values():
"""Test the full Graph._generate_schema path with legacy placeholder_values
on AgentInputBlock — verifies no enum leaks through the graph loading path."""
legacy_input_default = {
"name": "url",
"value": "",
"description": "Enter a URL",
"placeholder_values": ["https://example.com"],
}
result = BaseGraph._generate_schema(
(AgentInputBlock.Input, legacy_input_default),
)
url_props = result["properties"]["url"]
assert (
"enum" not in url_props
), "Graph schema should not contain enum from AgentInputBlock placeholder_values"
def test_generate_schema_integration_dropdown_produces_enum():
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
— verifies enum IS produced for dropdown blocks using canonical field name."""
dropdown_input_default = {
"name": "color",
"value": None,
"options": ["Red", "Green", "Blue"],
}
result = BaseGraph._generate_schema(
(AgentDropdownInputBlock.Input, dropdown_input_default),
)
color_props = result["properties"]["color"]
assert color_props.get("enum") == [
"Red",
"Green",
"Blue",
], "Graph schema should contain enum from AgentDropdownInputBlock"
def test_generate_schema_integration_dropdown_legacy_placeholder_values():
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
using legacy 'placeholder_values' — verifies backward compat produces enum."""
legacy_dropdown_input_default = {
"name": "color",
"value": None,
"placeholder_values": ["Red", "Green", "Blue"],
}
result = BaseGraph._generate_schema(
(AgentDropdownInputBlock.Input, legacy_dropdown_input_default),
)
color_props = result["properties"]["color"]
assert color_props.get("enum") == [
"Red",
"Green",
"Blue",
], "Legacy placeholder_values should still produce enum via model_construct remap"
def test_dropdown_input_block_init_legacy_placeholder_values():
"""Verify backward compat: constructing AgentDropdownInputBlock.Input via
model_validate with legacy 'placeholder_values' correctly maps to 'options'."""
opts = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_validate(
{"name": "choice", "value": None, "placeholder_values": opts}
)
assert (
instance.options == opts
), "Legacy placeholder_values should be remapped to options via model_validate"
schema = instance.generate_schema()
assert schema.get("enum") == opts

View File

@@ -207,51 +207,6 @@ class TestXMLParserBlockSecurity:
pass
class TestXMLParserBlockSyntaxErrors:
"""XML syntax errors should raise ValueError (not SyntaxError).
This ensures the base Block.execute() wraps them as BlockExecutionError
(expected / user-caused) instead of BlockUnknownError (unexpected / alerts
Sentry).
"""
async def test_unclosed_tag_raises_value_error(self):
"""Unclosed tags should raise ValueError, not SyntaxError."""
block = XMLParserBlock()
bad_xml = "<root><unclosed>"
with pytest.raises(ValueError, match="Unclosed tag"):
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
pass
async def test_unexpected_closing_tag_raises_value_error(self):
"""Extra closing tags should raise ValueError, not SyntaxError."""
block = XMLParserBlock()
bad_xml = "</unexpected>"
with pytest.raises(ValueError):
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
pass
async def test_empty_xml_raises_value_error(self):
"""Empty XML input should raise ValueError."""
block = XMLParserBlock()
with pytest.raises(ValueError, match="XML input is empty"):
async for _ in block.run(XMLParserBlock.Input(input_xml="")):
pass
async def test_syntax_error_from_parser_becomes_value_error(self):
"""SyntaxErrors from gravitasml library become ValueError (BlockExecutionError)."""
block = XMLParserBlock()
# Malformed XML that might trigger a SyntaxError from the parser
bad_xml = "<root><child>no closing"
with pytest.raises(ValueError):
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
pass
class TestStoreMediaFileSecurity:
"""Test file storage security limits."""

View File

@@ -1,18 +1,9 @@
from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
import anthropic
import httpx
import openai
import pytest
import backend.blocks.llm as llm
from backend.data.model import NodeExecutionStats
# TEST_CREDENTIALS_INPUT is a plain dict that satisfies AICredentials at runtime
# but not at the type level. Cast once here to avoid per-test suppressors.
_TEST_AI_CREDENTIALS = cast(llm.AICredentials, llm.TEST_CREDENTIALS_INPUT)
class TestLLMStatsTracking:
"""Test that LLM blocks correctly track token usage statistics."""
@@ -199,66 +190,6 @@ class TestLLMStatsTracking:
assert block.execution_stats.llm_call_count == 2 # retry_count + 1 = 1 + 1 = 2
assert block.execution_stats.llm_retry_count == 1
@pytest.mark.asyncio
async def test_retry_cost_uses_last_attempt_only(self):
"""provider_cost is only merged from the final successful attempt.
Intermediate retry costs are intentionally dropped to avoid
double-counting: the cost of failed attempts is captured in
last_attempt_cost only when the loop eventually succeeds.
"""
import backend.blocks.llm as llm
block = llm.AIStructuredResponseGeneratorBlock()
call_count = 0
async def mock_llm_call(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
# First attempt: fails validation, returns cost $0.01
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"wrong": "key"}</json_output>',
tool_calls=None,
prompt_tokens=10,
completion_tokens=5,
reasoning=None,
provider_cost=0.01,
)
# Second attempt: succeeds, returns cost $0.02
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
tool_calls=None,
prompt_tokens=20,
completion_tokens=10,
reasoning=None,
provider_cost=0.02,
)
block.llm_call = mock_llm_call # type: ignore
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
)
with patch("secrets.token_hex", return_value="test123456"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
# Only the final successful attempt's cost is merged
assert block.execution_stats.provider_cost == pytest.approx(0.02)
# Tokens from both attempts accumulate
assert block.execution_stats.input_token_count == 30
assert block.execution_stats.output_token_count == 15
@pytest.mark.asyncio
async def test_ai_text_summarizer_multiple_chunks(self):
"""Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks."""
@@ -548,154 +479,6 @@ class TestLLMStatsTracking:
assert outputs["response"] == {"result": "test"}
class TestAIConversationBlockValidation:
"""Test that AIConversationBlock validates inputs before calling the LLM."""
@pytest.mark.asyncio
async def test_empty_messages_and_empty_prompt_raises_error(self):
"""Empty messages with no prompt should raise ValueError, not a cryptic API error."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_empty_messages_with_prompt_succeeds(self):
"""Empty messages but a non-empty prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "OK"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="Hello, how are you?",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "OK"
@pytest.mark.asyncio
async def test_nonempty_messages_with_empty_prompt_succeeds(self):
"""Non-empty messages with no prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "response from conversation"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": "Hello"}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "response from conversation"
@pytest.mark.asyncio
async def test_messages_with_empty_content_raises_error(self):
"""Messages with empty content strings should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": ""}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_whitespace_content_raises_error(self):
"""Messages with whitespace-only content should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": " "}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_entry_raises_error(self):
"""Messages list containing None should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[None],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_empty_dict_raises_error(self):
"""Messages list containing empty dict should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_content_raises_error(self):
"""Messages with content=None should not crash with AttributeError."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": None}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
class TestAITextSummarizerValidation:
"""Test that AITextSummarizerBlock validates LLM responses are strings."""
@@ -872,238 +655,3 @@ class TestAITextSummarizerValidation:
error_message = str(exc_info.value)
assert "Expected a string summary" in error_message
assert "received dict" in error_message
def _make_anthropic_status_error(status_code: int) -> anthropic.APIStatusError:
"""Create an anthropic.APIStatusError with the given status code."""
request = httpx.Request("POST", "https://api.anthropic.com/v1/messages")
response = httpx.Response(status_code, request=request)
return anthropic.APIStatusError(
f"Error code: {status_code}", response=response, body=None
)
def _make_openai_status_error(status_code: int) -> openai.APIStatusError:
"""Create an openai.APIStatusError with the given status code."""
response = httpx.Response(
status_code, request=httpx.Request("POST", "https://api.openai.com/v1/chat")
)
return openai.APIStatusError(
f"Error code: {status_code}", response=response, body=None
)
class TestUserErrorStatusCodeHandling:
"""Test that user-caused LLM API errors (401/403/429) break the retry loop
and are logged as warnings, while server errors (500) trigger retries."""
@pytest.mark.asyncio
@pytest.mark.parametrize("status_code", [401, 403, 429])
async def test_anthropic_user_error_breaks_retry_loop(self, status_code: int):
"""401/403/429 Anthropic errors should break immediately, not retry."""
import backend.blocks.llm as llm
block = llm.AIStructuredResponseGeneratorBlock()
call_count = 0
async def mock_llm_call(*args, **kwargs):
nonlocal call_count
call_count += 1
raise _make_anthropic_status_error(status_code)
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test",
expected_format={"key": "desc"},
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
retry=3,
)
with pytest.raises(RuntimeError):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
assert (
call_count == 1
), f"Expected exactly 1 call for status {status_code}, got {call_count}"
@pytest.mark.asyncio
@pytest.mark.parametrize("status_code", [401, 403, 429])
async def test_openai_user_error_breaks_retry_loop(self, status_code: int):
"""401/403/429 OpenAI errors should break immediately, not retry."""
import backend.blocks.llm as llm
block = llm.AIStructuredResponseGeneratorBlock()
call_count = 0
async def mock_llm_call(*args, **kwargs):
nonlocal call_count
call_count += 1
raise _make_openai_status_error(status_code)
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test",
expected_format={"key": "desc"},
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
retry=3,
)
with pytest.raises(RuntimeError):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
assert (
call_count == 1
), f"Expected exactly 1 call for status {status_code}, got {call_count}"
@pytest.mark.asyncio
async def test_server_error_retries(self):
"""500 errors should be retried (not break immediately)."""
import backend.blocks.llm as llm
block = llm.AIStructuredResponseGeneratorBlock()
call_count = 0
async def mock_llm_call(*args, **kwargs):
nonlocal call_count
call_count += 1
raise _make_anthropic_status_error(500)
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test",
expected_format={"key": "desc"},
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
retry=3,
)
with pytest.raises(RuntimeError):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
assert (
call_count > 1
), f"Expected multiple retry attempts for 500, got {call_count}"
@pytest.mark.asyncio
async def test_user_error_logs_warning_not_exception(self):
"""User-caused errors should log with logger.warning, not logger.exception."""
import backend.blocks.llm as llm
block = llm.AIStructuredResponseGeneratorBlock()
async def mock_llm_call(*args, **kwargs):
raise _make_anthropic_status_error(401)
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test",
expected_format={"key": "desc"},
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with (
patch.object(llm.logger, "warning") as mock_warning,
patch.object(llm.logger, "exception") as mock_exception,
pytest.raises(RuntimeError),
):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
mock_warning.assert_called_once()
mock_exception.assert_not_called()
class TestLlmModelMissing:
"""Test that LlmModel handles provider-prefixed model names."""
def test_provider_prefixed_model_resolves(self):
"""Provider-prefixed model string should resolve to the correct enum member."""
assert (
llm.LlmModel("anthropic/claude-sonnet-4-6")
== llm.LlmModel.CLAUDE_4_6_SONNET
)
def test_bare_model_still_works(self):
"""Bare (non-prefixed) model string should still resolve correctly."""
assert llm.LlmModel("claude-sonnet-4-6") == llm.LlmModel.CLAUDE_4_6_SONNET
def test_invalid_prefixed_model_raises(self):
"""Unknown provider-prefixed model string should raise ValueError."""
with pytest.raises(ValueError):
llm.LlmModel("invalid/nonexistent-model")
def test_slash_containing_value_direct_lookup(self):
"""Enum values with '/' (e.g., OpenRouter models) should resolve via direct lookup, not _missing_."""
assert llm.LlmModel("google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
def test_double_prefixed_slash_model(self):
"""Double-prefixed value should still resolve by stripping first prefix."""
assert (
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
)
class TestExtractOpenRouterCost:
"""Tests for extract_openrouter_cost — the x-total-cost header parser."""
def _mk_response(self, headers: dict | None):
response = MagicMock()
if headers is None:
response._response = None
else:
raw = MagicMock()
raw.headers = headers
response._response = raw
return response
def test_extracts_numeric_cost(self):
response = self._mk_response({"x-total-cost": "0.0042"})
assert llm.extract_openrouter_cost(response) == 0.0042
def test_returns_none_when_header_missing(self):
response = self._mk_response({})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_header_empty_string(self):
response = self._mk_response({"x-total-cost": ""})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_header_non_numeric(self):
response = self._mk_response({"x-total-cost": "not-a-number"})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_no_response_attr(self):
response = MagicMock(spec=[]) # no _response attr
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_raw_is_none(self):
response = self._mk_response(None)
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_raw_has_no_headers(self):
response = MagicMock()
response._response = MagicMock(spec=[]) # no headers attr
assert llm.extract_openrouter_cost(response) is None
def test_returns_zero_for_zero_cost(self):
"""Zero-cost is a valid value (free tier) and must not become None."""
response = self._mk_response({"x-total-cost": "0"})
assert llm.extract_openrouter_cost(response) == 0.0
def test_returns_none_for_inf(self):
response = self._mk_response({"x-total-cost": "inf"})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_for_negative_inf(self):
response = self._mk_response({"x-total-cost": "-inf"})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_for_nan(self):
response = self._mk_response({"x-total-cost": "nan"})
assert llm.extract_openrouter_cost(response) is None

View File

@@ -1,87 +0,0 @@
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
from unittest.mock import MagicMock
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
class TestExtractOpenaiToolCallsEmptyChoices:
"""extract_openai_tool_calls() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_tool_calls(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_tool_calls(response) is None
def test_returns_tool_calls_when_choices_present(self):
tool = MagicMock()
tool.id = "call_1"
tool.type = "function"
tool.function.name = "my_func"
tool.function.arguments = '{"a": 1}'
message = MagicMock()
message.tool_calls = [tool]
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
result = extract_openai_tool_calls(response)
assert result is not None
assert len(result) == 1
assert result[0].function.name == "my_func"
def test_returns_none_when_no_tool_calls(self):
message = MagicMock()
message.tool_calls = None
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
assert extract_openai_tool_calls(response) is None
class TestExtractOpenaiReasoningEmptyChoices:
"""extract_openai_reasoning() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_reasoning(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_reasoning(response) is None
def test_returns_reasoning_from_choice(self):
choice = MagicMock()
choice.reasoning = "Step-by-step reasoning"
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
response = MagicMock(spec=[]) # no 'reasoning' attr on response
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result == "Step-by-step reasoning"
def test_returns_none_when_no_reasoning(self):
choice = MagicMock(spec=[]) # no 'reasoning' attr
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
response = MagicMock(spec=[]) # no 'reasoning' attr
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result is None

View File

@@ -1,202 +0,0 @@
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
Covers:
- ExecutionMode enum members exist and have stable values
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
- EXTENDED_THINKING model-name validation (must start with "claude")
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.llm import LlmModel
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
# ---------------------------------------------------------------------------
# ExecutionMode enum integrity
# ---------------------------------------------------------------------------
class TestExecutionModeEnum:
"""Guard against accidental renames or removals of enum members."""
def test_built_in_exists(self):
assert hasattr(ExecutionMode, "BUILT_IN")
assert ExecutionMode.BUILT_IN.value == "built_in"
def test_extended_thinking_exists(self):
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
def test_exactly_two_members(self):
"""If a new mode is added, this test should be updated intentionally."""
assert set(ExecutionMode.__members__.keys()) == {
"BUILT_IN",
"EXTENDED_THINKING",
}
def test_string_enum(self):
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
assert isinstance(ExecutionMode.BUILT_IN, str)
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
def test_round_trip_from_value(self):
"""Constructing from the string value should return the same member."""
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
# ---------------------------------------------------------------------------
# Provider validation (inline in OrchestratorBlock.run)
# ---------------------------------------------------------------------------
def _make_model_stub(provider: str, value: str):
"""Create a lightweight stub that behaves like LlmModel for validation."""
metadata = MagicMock()
metadata.provider = provider
stub = MagicMock()
stub.metadata = metadata
stub.value = value
return stub
class TestExtendedThinkingProviderValidation:
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
def test_anthropic_provider_accepted(self):
"""provider='anthropic' + claude model should not raise."""
model = _make_model_stub("anthropic", "claude-opus-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_open_router_provider_accepted(self):
"""provider='open_router' + claude model should not raise."""
model = _make_model_stub("open_router", "claude-sonnet-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_openai_provider_rejected(self):
"""provider='openai' should be rejected for EXTENDED_THINKING."""
model = _make_model_stub("openai", "gpt-4o")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_groq_provider_rejected(self):
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
model = _make_model_stub("anthropic", "not-a-claude-model")
model_name = model.value
assert not model_name.startswith("claude")
def test_real_gpt4o_model_rejected(self):
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
model = LlmModel.GPT4O
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_real_claude_model_passes(self):
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
model = LlmModel.CLAUDE_4_6_SONNET
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
# ---------------------------------------------------------------------------
# Integration-style: exercise the validation branch via OrchestratorBlock.run
# ---------------------------------------------------------------------------
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
inp = MagicMock()
inp.execution_mode = execution_mode
inp.model = model
inp.prompt = "test"
inp.sys_prompt = ""
inp.conversation_history = []
inp.last_tool_output = None
inp.prompt_values = {}
return inp
async def _collect_run_outputs(block, input_data, **kwargs):
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
outputs = []
async for item in block.run(input_data, **kwargs):
outputs.append(item)
return outputs
class TestExtendedThinkingValidationRaisesInBlock:
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
@pytest.mark.asyncio
async def test_non_anthropic_provider_raises_valueerror(self):
"""EXTENDED_THINKING + openai provider raises ValueError."""
block = OrchestratorBlock()
input_data = _make_input_data(model=LlmModel.GPT4O)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="Anthropic-compatible"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)
@pytest.mark.asyncio
async def test_non_claude_model_with_anthropic_provider_raises(self):
"""A model with anthropic provider but non-claude name raises ValueError."""
block = OrchestratorBlock()
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
input_data = _make_input_data(model=fake_model)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="only supports Claude models"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)

View File

@@ -57,7 +57,7 @@ async def execute_graph(
@pytest.mark.asyncio(loop_scope="session")
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data import graph
test_user = await create_test_user()
@@ -66,7 +66,7 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
nodes = [
graph.Node(
block_id=OrchestratorBlock().id,
block_id=SmartDecisionMakerBlock().id,
input_default={
"prompt": "Hello, World!",
"credentials": creds,
@@ -108,10 +108,10 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
@pytest.mark.asyncio(loop_scope="session")
async def test_orchestrator_function_signature(server: SpinTestServer):
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data import graph
test_user = await create_test_user()
@@ -120,7 +120,7 @@ async def test_orchestrator_function_signature(server: SpinTestServer):
nodes = [
graph.Node(
block_id=OrchestratorBlock().id,
block_id=SmartDecisionMakerBlock().id,
input_default={
"prompt": "Hello, World!",
"credentials": creds,
@@ -169,7 +169,7 @@ async def test_orchestrator_function_signature(server: SpinTestServer):
)
test_graph = await create_graph(server, test_graph, test_user)
tool_functions = await OrchestratorBlock._create_tool_node_signatures(
tool_functions = await SmartDecisionMakerBlock._create_tool_node_signatures(
test_graph.nodes[0].id
)
assert tool_functions is not None, "Tool functions should not be None"
@@ -198,12 +198,12 @@ async def test_orchestrator_function_signature(server: SpinTestServer):
@pytest.mark.asyncio
async def test_orchestrator_tracks_llm_stats():
"""Test that OrchestratorBlock correctly tracks LLM usage stats."""
async def test_smart_decision_maker_tracks_llm_stats():
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
import backend.blocks.llm as llm_module
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# Mock the llm.llm_call function to return controlled data
mock_response = MagicMock()
@@ -224,14 +224,14 @@ async def test_orchestrator_tracks_llm_stats():
new_callable=AsyncMock,
return_value=mock_response,
), patch.object(
OrchestratorBlock,
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
):
# Create test input
input_data = OrchestratorBlock.Input(
input_data = SmartDecisionMakerBlock.Input(
prompt="Should I continue with this task?",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
@@ -274,12 +274,12 @@ async def test_orchestrator_tracks_llm_stats():
@pytest.mark.asyncio
async def test_orchestrator_parameter_validation():
"""Test that OrchestratorBlock correctly validates tool call parameters."""
async def test_smart_decision_maker_parameter_validation():
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
import backend.blocks.llm as llm_module
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# Mock tool functions with specific parameter schema
mock_tool_functions = [
@@ -327,13 +327,13 @@ async def test_orchestrator_parameter_validation():
new_callable=AsyncMock,
return_value=mock_response_with_typo,
) as mock_llm_call, patch.object(
OrchestratorBlock,
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
input_data = OrchestratorBlock.Input(
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
@@ -394,13 +394,13 @@ async def test_orchestrator_parameter_validation():
new_callable=AsyncMock,
return_value=mock_response_missing_required,
), patch.object(
OrchestratorBlock,
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
input_data = OrchestratorBlock.Input(
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
@@ -454,13 +454,13 @@ async def test_orchestrator_parameter_validation():
new_callable=AsyncMock,
return_value=mock_response_valid,
), patch.object(
OrchestratorBlock,
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
input_data = OrchestratorBlock.Input(
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
@@ -518,13 +518,13 @@ async def test_orchestrator_parameter_validation():
new_callable=AsyncMock,
return_value=mock_response_all_params,
), patch.object(
OrchestratorBlock,
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
input_data = OrchestratorBlock.Input(
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
@@ -562,12 +562,12 @@ async def test_orchestrator_parameter_validation():
@pytest.mark.asyncio
async def test_orchestrator_raw_response_conversion():
"""Test that Orchestrator correctly handles different raw_response types with retry mechanism."""
async def test_smart_decision_maker_raw_response_conversion():
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
import backend.blocks.llm as llm_module
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# Mock tool functions
mock_tool_functions = [
@@ -637,7 +637,7 @@ async def test_orchestrator_raw_response_conversion():
with patch(
"backend.blocks.llm.llm_call", new_callable=AsyncMock
) as mock_llm_call, patch.object(
OrchestratorBlock,
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
@@ -646,7 +646,7 @@ async def test_orchestrator_raw_response_conversion():
# Second call returns successful response
mock_llm_call.side_effect = [mock_response_retry, mock_response_success]
input_data = OrchestratorBlock.Input(
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
@@ -715,12 +715,12 @@ async def test_orchestrator_raw_response_conversion():
new_callable=AsyncMock,
return_value=mock_response_ollama,
), patch.object(
OrchestratorBlock,
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[], # No tools for this test
):
input_data = OrchestratorBlock.Input(
input_data = SmartDecisionMakerBlock.Input(
prompt="Simple prompt",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
@@ -771,12 +771,12 @@ async def test_orchestrator_raw_response_conversion():
new_callable=AsyncMock,
return_value=mock_response_dict,
), patch.object(
OrchestratorBlock,
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
):
input_data = OrchestratorBlock.Input(
input_data = SmartDecisionMakerBlock.Input(
prompt="Another test",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
@@ -811,12 +811,12 @@ async def test_orchestrator_raw_response_conversion():
@pytest.mark.asyncio
async def test_orchestrator_agent_mode():
async def test_smart_decision_maker_agent_mode():
"""Test that agent mode executes tools directly and loops until finished."""
import backend.blocks.llm as llm_module
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# Mock tool call that requires multiple iterations
mock_tool_call_1 = MagicMock()
@@ -893,7 +893,7 @@ async def test_orchestrator_agent_mode():
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
), patch(
"backend.blocks.orchestrator.get_database_manager_async_client",
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
), patch(
"backend.executor.manager.async_update_node_execution_status",
@@ -929,7 +929,7 @@ async def test_orchestrator_agent_mode():
}
# Test agent mode with max_iterations = 3
input_data = OrchestratorBlock.Input(
input_data = SmartDecisionMakerBlock.Input(
prompt="Complete this task using tools",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
@@ -969,12 +969,12 @@ async def test_orchestrator_agent_mode():
@pytest.mark.asyncio
async def test_orchestrator_traditional_mode_default():
async def test_smart_decision_maker_traditional_mode_default():
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
import backend.blocks.llm as llm_module
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# Mock tool call
mock_tool_call = MagicMock()
@@ -1018,7 +1018,7 @@ async def test_orchestrator_traditional_mode_default():
):
# Test default behavior (traditional mode)
input_data = OrchestratorBlock.Input(
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
@@ -1060,12 +1060,12 @@ async def test_orchestrator_traditional_mode_default():
@pytest.mark.asyncio
async def test_orchestrator_uses_customized_name_for_blocks():
"""Test that OrchestratorBlock uses customized_name from node metadata for tool names."""
async def test_smart_decision_maker_uses_customized_name_for_blocks():
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
from unittest.mock import MagicMock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node with customized_name in metadata
@@ -1074,14 +1074,13 @@ async def test_orchestrator_uses_customized_name_for_blocks():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "input"
# Call the function directly
result = await OrchestratorBlock._create_block_function_signature(
result = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, [mock_link]
)
@@ -1092,12 +1091,12 @@ async def test_orchestrator_uses_customized_name_for_blocks():
@pytest.mark.asyncio
async def test_orchestrator_falls_back_to_block_name():
"""Test that OrchestratorBlock falls back to block.name when no customized_name."""
async def test_smart_decision_maker_falls_back_to_block_name():
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
from unittest.mock import MagicMock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node without customized_name
@@ -1106,14 +1105,13 @@ async def test_orchestrator_falls_back_to_block_name():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {} # No customized_name
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "input"
# Call the function directly
result = await OrchestratorBlock._create_block_function_signature(
result = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, [mock_link]
)
@@ -1124,11 +1122,11 @@ async def test_orchestrator_falls_back_to_block_name():
@pytest.mark.asyncio
async def test_orchestrator_uses_customized_name_for_agents():
"""Test that OrchestratorBlock uses customized_name from metadata for agent nodes."""
async def test_smart_decision_maker_uses_customized_name_for_agents():
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
from unittest.mock import AsyncMock, MagicMock, patch
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node with customized_name in metadata
@@ -1154,10 +1152,10 @@ async def test_orchestrator_uses_customized_name_for_agents():
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
with patch(
"backend.blocks.orchestrator.get_database_manager_async_client",
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
):
result = await OrchestratorBlock._create_agent_function_signature(
result = await SmartDecisionMakerBlock._create_agent_function_signature(
mock_node, [mock_link]
)
@@ -1168,11 +1166,11 @@ async def test_orchestrator_uses_customized_name_for_agents():
@pytest.mark.asyncio
async def test_orchestrator_agent_falls_back_to_graph_name():
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
"""Test that agent node falls back to graph name when no customized_name."""
from unittest.mock import AsyncMock, MagicMock, patch
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node without customized_name
@@ -1198,10 +1196,10 @@ async def test_orchestrator_agent_falls_back_to_graph_name():
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
with patch(
"backend.blocks.orchestrator.get_database_manager_async_client",
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
):
result = await OrchestratorBlock._create_agent_function_signature(
result = await SmartDecisionMakerBlock._create_agent_function_signature(
mock_node, [mock_link]
)

View File

@@ -3,12 +3,12 @@ from unittest.mock import Mock
import pytest
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@pytest.mark.asyncio
async def test_orchestrator_handles_dynamic_dict_fields():
"""Test Orchestrator can handle dynamic dictionary fields (_#_) for any block"""
async def test_smart_decision_maker_handles_dynamic_dict_fields():
"""Test Smart Decision Maker can handle dynamic dictionary fields (_#_) for any block"""
# Create a mock node for CreateDictionaryBlock
mock_node = Mock()
@@ -23,24 +23,24 @@ async def test_orchestrator_handles_dynamic_dict_fields():
source_name="tools_^_create_dict_~_name",
sink_name="values_#_name", # Dynamic dict field
sink_id="dict_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_create_dict_~_age",
sink_name="values_#_age", # Dynamic dict field
sink_id="dict_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_create_dict_~_city",
sink_name="values_#_city", # Dynamic dict field
sink_id="dict_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
]
# Generate function signature
signature = await OrchestratorBlock._create_block_function_signature(
signature = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, mock_links # type: ignore
)
@@ -70,8 +70,8 @@ async def test_orchestrator_handles_dynamic_dict_fields():
@pytest.mark.asyncio
async def test_orchestrator_handles_dynamic_list_fields():
"""Test Orchestrator can handle dynamic list fields (_$_) for any block"""
async def test_smart_decision_maker_handles_dynamic_list_fields():
"""Test Smart Decision Maker can handle dynamic list fields (_$_) for any block"""
# Create a mock node for AddToListBlock
mock_node = Mock()
@@ -86,18 +86,18 @@ async def test_orchestrator_handles_dynamic_list_fields():
source_name="tools_^_add_to_list_~_0",
sink_name="entries_$_0", # Dynamic list field
sink_id="list_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_add_to_list_~_1",
sink_name="entries_$_1", # Dynamic list field
sink_id="list_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
]
# Generate function signature
signature = await OrchestratorBlock._create_block_function_signature(
signature = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, mock_links # type: ignore
)

View File

@@ -1,4 +1,4 @@
"""Comprehensive tests for OrchestratorBlock dynamic field handling."""
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
import json
from unittest.mock import AsyncMock, MagicMock, Mock, patch
@@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
from backend.blocks.orchestrator import OrchestratorBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.text import MatchTextPatternBlock
from backend.data.dynamic_fields import get_dynamic_field_description
@@ -37,7 +37,7 @@ async def test_dynamic_field_description_generation():
@pytest.mark.asyncio
async def test_create_block_function_signature_with_dict_fields():
"""Test that function signatures are created correctly for dictionary dynamic fields."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# Create a mock node for CreateDictionaryBlock
mock_node = Mock()
@@ -52,19 +52,19 @@ async def test_create_block_function_signature_with_dict_fields():
source_name="tools_^_create_dict_~_values___name", # Sanitized source
sink_name="values_#_name", # Original sink
sink_id="dict_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_create_dict_~_values___age", # Sanitized source
sink_name="values_#_age", # Original sink
sink_id="dict_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_create_dict_~_values___email", # Sanitized source
sink_name="values_#_email", # Original sink
sink_id="dict_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
]
@@ -100,7 +100,7 @@ async def test_create_block_function_signature_with_dict_fields():
@pytest.mark.asyncio
async def test_create_block_function_signature_with_list_fields():
"""Test that function signatures are created correctly for list dynamic fields."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# Create a mock node for AddToListBlock
mock_node = Mock()
@@ -115,19 +115,19 @@ async def test_create_block_function_signature_with_list_fields():
source_name="tools_^_add_list_~_0",
sink_name="entries_$_0", # Dynamic list field
sink_id="list_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_add_list_~_1",
sink_name="entries_$_1", # Dynamic list field
sink_id="list_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_add_list_~_2",
sink_name="entries_$_2", # Dynamic list field
sink_id="list_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
]
@@ -154,7 +154,7 @@ async def test_create_block_function_signature_with_list_fields():
@pytest.mark.asyncio
async def test_create_block_function_signature_with_object_fields():
"""Test that function signatures are created correctly for object dynamic fields."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# Create a mock node for MatchTextPatternBlock (simulating object fields)
mock_node = Mock()
@@ -169,13 +169,13 @@ async def test_create_block_function_signature_with_object_fields():
source_name="tools_^_extract_~_user_name",
sink_name="data_@_user_name", # Dynamic object field
sink_id="extract_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_extract_~_user_email",
sink_name="data_@_user_email", # Dynamic object field
sink_id="extract_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
]
@@ -197,11 +197,11 @@ async def test_create_block_function_signature_with_object_fields():
@pytest.mark.asyncio
async def test_create_tool_node_signatures():
"""Test that the mapping between sanitized and original field names is built correctly."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# Mock the database client and connected nodes
with patch(
"backend.blocks.orchestrator.get_database_manager_async_client"
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db:
mock_client = AsyncMock()
mock_db.return_value = mock_client
@@ -281,7 +281,7 @@ async def test_create_tool_node_signatures():
@pytest.mark.asyncio
async def test_output_yielding_with_dynamic_fields():
"""Test that outputs are yielded correctly with dynamic field names mapped back."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# No more sanitized mapping needed since we removed sanitization
@@ -309,13 +309,13 @@ async def test_output_yielding_with_dynamic_fields():
# Mock the LLM call
with patch(
"backend.blocks.orchestrator.llm.llm_call", new_callable=AsyncMock
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
) as mock_llm:
mock_llm.return_value = mock_response
# Mock the database manager to avoid HTTP calls during tool execution
with patch(
"backend.blocks.orchestrator.get_database_manager_async_client"
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db_manager, patch.object(
block, "_create_tool_node_signatures", new_callable=AsyncMock
) as mock_sig:
@@ -420,7 +420,7 @@ async def test_output_yielding_with_dynamic_fields():
@pytest.mark.asyncio
async def test_mixed_regular_and_dynamic_fields():
"""Test handling of blocks with both regular and dynamic fields."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# Create a mock node
mock_node = Mock()
@@ -450,19 +450,19 @@ async def test_mixed_regular_and_dynamic_fields():
source_name="tools_^_test_~_regular",
sink_name="regular_field", # Regular field
sink_id="test_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_test_~_dict_key",
sink_name="values_#_key1", # Dynamic dict field
sink_id="test_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_test_~_dict_key2",
sink_name="values_#_key2", # Dynamic dict field
sink_id="test_node_id",
source_id="orchestrator_node_id",
source_id="smart_decision_node_id",
),
]
@@ -488,7 +488,7 @@ async def test_mixed_regular_and_dynamic_fields():
@pytest.mark.asyncio
async def test_validation_errors_dont_pollute_conversation():
"""Test that validation errors are only used during retries and don't pollute the conversation."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# Track conversation history changes
conversation_snapshots = []
@@ -535,7 +535,7 @@ async def test_validation_errors_dont_pollute_conversation():
# Mock the LLM call
with patch(
"backend.blocks.orchestrator.llm.llm_call", new_callable=AsyncMock
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
) as mock_llm:
mock_llm.side_effect = mock_llm_call
@@ -565,7 +565,7 @@ async def test_validation_errors_dont_pollute_conversation():
# Mock the database manager to avoid HTTP calls during tool execution
with patch(
"backend.blocks.orchestrator.get_database_manager_async_client"
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db_manager:
# Set up the mock database manager for agent mode
mock_db_client = AsyncMock()

View File

@@ -1,6 +1,6 @@
"""Tests for OrchestratorBlock compatibility with the OpenAI Responses API.
"""Tests for SmartDecisionMakerBlock compatibility with the OpenAI Responses API.
The OrchestratorBlock manages conversation history in the Chat Completions
The SmartDecisionMakerBlock manages conversation history in the Chat Completions
format, but OpenAI models now use the Responses API which has a fundamentally
different conversation structure. These tests document:
@@ -27,8 +27,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.orchestrator import (
OrchestratorBlock,
from backend.blocks.smart_decision_maker import (
SmartDecisionMakerBlock,
_combine_tool_responses,
_convert_raw_response_to_dict,
_create_tool_response,
@@ -733,7 +733,7 @@ class TestUpdateConversation:
def test_dict_raw_response_no_reasoning_no_tools(self):
"""Dict raw_response, no reasoning → appends assistant dict."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
prompt: list[dict] = []
resp = self._make_response({"role": "assistant", "content": "hi"})
block._update_conversation(prompt, resp)
@@ -741,7 +741,7 @@ class TestUpdateConversation:
def test_dict_raw_response_with_reasoning_no_tool_calls(self):
"""Reasoning present, no tool calls → reasoning prepended."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
prompt: list[dict] = []
resp = self._make_response(
{"role": "assistant", "content": "answer"},
@@ -757,7 +757,7 @@ class TestUpdateConversation:
def test_dict_raw_response_with_reasoning_and_anthropic_tool_calls(self):
"""Reasoning + Anthropic tool_use in content → reasoning skipped."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
prompt: list[dict] = []
raw = {
"role": "assistant",
@@ -772,7 +772,7 @@ class TestUpdateConversation:
def test_with_tool_outputs(self):
"""Tool outputs → extended onto prompt."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
prompt: list[dict] = []
resp = self._make_response({"role": "assistant", "content": None})
outputs = [{"role": "tool", "tool_call_id": "call_1", "content": "r"}]
@@ -782,7 +782,7 @@ class TestUpdateConversation:
def test_without_tool_outputs(self):
"""No tool outputs → only assistant message appended."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
prompt: list[dict] = []
resp = self._make_response({"role": "assistant", "content": "done"})
block._update_conversation(prompt, resp, None)
@@ -790,7 +790,7 @@ class TestUpdateConversation:
def test_string_raw_response(self):
"""Ollama string → wrapped as assistant dict."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
prompt: list[dict] = []
resp = self._make_response("hello from ollama")
block._update_conversation(prompt, resp)
@@ -800,7 +800,7 @@ class TestUpdateConversation:
def test_responses_api_text_response_produces_valid_items(self):
"""Responses API text response → conversation items must have valid role."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
prompt: list[dict] = [
{"role": "system", "content": "sys"},
{"role": "user", "content": "user"},
@@ -820,7 +820,7 @@ class TestUpdateConversation:
def test_responses_api_function_call_produces_valid_items(self):
"""Responses API function_call → conversation items must have valid type."""
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
prompt: list[dict] = []
resp = self._make_response(
_MockResponse(output=[_MockFunctionCall("tool", "{}", call_id="call_1")])
@@ -856,7 +856,7 @@ async def test_agent_mode_conversation_valid_for_responses_api():
"""
import backend.blocks.llm as llm_module
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
# First response: tool call
mock_tc = MagicMock()
@@ -936,7 +936,7 @@ async def test_agent_mode_conversation_valid_for_responses_api():
with patch("backend.blocks.llm.llm_call", llm_mock), patch.object(
block, "_create_tool_node_signatures", return_value=tool_sigs
), patch(
"backend.blocks.orchestrator.get_database_manager_async_client",
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db,
), patch(
"backend.executor.manager.async_update_node_execution_status",
@@ -945,7 +945,7 @@ async def test_agent_mode_conversation_valid_for_responses_api():
"backend.integrations.creds_manager.IntegrationCredentialsManager"
):
inp = OrchestratorBlock.Input(
inp = SmartDecisionMakerBlock.Input(
prompt="Improve this",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
@@ -992,7 +992,7 @@ async def test_traditional_mode_conversation_valid_for_responses_api():
"""Traditional mode: the yielded conversation must contain only valid items."""
import backend.blocks.llm as llm_module
block = OrchestratorBlock()
block = SmartDecisionMakerBlock()
mock_tc = MagicMock()
mock_tc.function.name = "my_tool"
@@ -1028,7 +1028,7 @@ async def test_traditional_mode_conversation_valid_for_responses_api():
"backend.blocks.llm.llm_call", new_callable=AsyncMock, return_value=resp
), patch.object(block, "_create_tool_node_signatures", return_value=tool_sigs):
inp = OrchestratorBlock.Input(
inp = SmartDecisionMakerBlock.Input(
prompt="Do it",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore

View File

@@ -13,7 +13,6 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -105,10 +104,4 @@ class UnrealTextToSpeechBlock(Block):
input_data.text,
input_data.voice_id,
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(input_data.text)),
provider_cost_type="characters",
)
)
yield "mp3_url", api_response["OutputUri"]

View File

@@ -44,7 +44,7 @@ class XMLParserBlock(Block):
elif token.type == "TAG_CLOSE":
depth -= 1
if depth < 0:
raise ValueError("Unexpected closing tag in XML input.")
raise SyntaxError("Unexpected closing tag in XML input.")
elif token.type in {"TEXT", "ESCAPE"}:
if depth == 0 and token.value:
raise ValueError(
@@ -53,7 +53,7 @@ class XMLParserBlock(Block):
)
if depth != 0:
raise ValueError("Unclosed tag detected in XML input.")
raise SyntaxError("Unclosed tag detected in XML input.")
if not root_seen:
raise ValueError("XML must include a root element.")
@@ -76,7 +76,4 @@ class XMLParserBlock(Block):
except ValueError as val_e:
raise ValueError(f"Validation error for dict:{val_e}") from val_e
except SyntaxError as syn_e:
# Raise as ValueError so the base Block.execute() wraps it as
# BlockExecutionError (expected user-caused failure) instead of
# BlockUnknownError (unexpected platform error that alerts Sentry).
raise ValueError(f"Error in input xml syntax: {syn_e}") from syn_e
raise SyntaxError(f"Error in input xml syntax: {syn_e}") from syn_e

Some files were not shown because too many files have changed in this diff Show More