mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
2 Commits
master
...
remove-cla
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f20693d02b | ||
|
|
a4188c5657 |
@@ -17,14 +17,6 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Read the PR description
|
||||
|
||||
Understand the **Why / What / How** before addressing comments — you need context to make good fixes:
|
||||
|
||||
```bash
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
## Fetch comments (all sources)
|
||||
|
||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||
@@ -113,9 +105,7 @@ kill $REST_PID 2>/dev/null; trap - EXIT
|
||||
```
|
||||
Never manually edit files in `src/app/api/__generated__/`.
|
||||
|
||||
Then commit and **push immediately** — never batch commits without pushing. Each fix should be visible on GitHub right away so CI can start and reviewers can see progress.
|
||||
|
||||
**Never push empty commits** (`git commit --allow-empty`) to re-trigger CI or bot checks. When a check fails, investigate the root cause (unchecked PR checklist, unaddressed review comments, code issues) and fix those directly. Empty commits add noise to git history.
|
||||
Then commit and **push immediately** — never batch commits without pushing.
|
||||
|
||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
|
||||
@@ -17,16 +17,6 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Read the PR description
|
||||
|
||||
Before reading code, understand the **why**, **what**, and **how** from the PR description:
|
||||
|
||||
```bash
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
Every PR should have a Why / What / How structure. If any of these are missing, note it as feedback.
|
||||
|
||||
## Read the diff
|
||||
|
||||
```bash
|
||||
@@ -44,8 +34,6 @@ gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
||||
|
||||
## What to check
|
||||
|
||||
**Description quality:** Does the PR description cover Why (motivation/problem), What (summary of changes), and How (approach/implementation details)? If any are missing, request them — you can't judge the approach without understanding the problem and intent.
|
||||
|
||||
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
||||
|
||||
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
||||
|
||||
@@ -1,754 +0,0 @@
|
||||
---
|
||||
name: pr-test
|
||||
description: "E2E manual testing of PRs/branches using docker compose, agent-browser, and API calls. TRIGGER when user asks to manually test a PR, test a feature end-to-end, or run integration tests against a running system."
|
||||
user-invocable: true
|
||||
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "2.0.0"
|
||||
---
|
||||
|
||||
# Manual E2E Test
|
||||
|
||||
Test a PR/branch end-to-end by building the full platform, interacting via browser and API, capturing screenshots, and reporting results.
|
||||
|
||||
## Critical Requirements
|
||||
|
||||
These are NON-NEGOTIABLE. Every test run MUST satisfy ALL the following:
|
||||
|
||||
### 1. Screenshots at Every Step
|
||||
- Take a screenshot at EVERY significant test step — not just at the end
|
||||
- Every test scenario MUST have at least one BEFORE and one AFTER screenshot
|
||||
- Name screenshots sequentially: `{NN}-{action}-{state}.png` (e.g., `01-credits-before.png`, `02-credits-after.png`)
|
||||
- If a screenshot is missing for a scenario, the test is INCOMPLETE — go back and take it
|
||||
|
||||
### 2. Screenshots MUST Be Posted to PR
|
||||
- Push ALL screenshots to a temp branch `test-screenshots/pr-{N}`
|
||||
- Post a PR comment with ALL screenshots embedded inline using GitHub raw URLs
|
||||
- This is NOT optional — every test run MUST end with a PR comment containing screenshots
|
||||
- If screenshot upload fails, retry. If it still fails, list failed files and require manual drag-and-drop/paste attachment in the PR comment
|
||||
|
||||
### 3. State Verification with Before/After Evidence
|
||||
- For EVERY state-changing operation (API call, user action), capture the state BEFORE and AFTER
|
||||
- Log the actual API response values (e.g., `credits_before=100, credits_after=95`)
|
||||
- Screenshot MUST show the relevant UI state change
|
||||
- Compare expected vs actual values explicitly — do not just eyeball it
|
||||
|
||||
### 4. Negative Test Cases Are Mandatory
|
||||
- Test at least ONE negative case per feature (e.g., insufficient credits, invalid input, unauthorized access)
|
||||
- Verify error messages are user-friendly and accurate
|
||||
- Verify the system state did NOT change after a rejected operation
|
||||
|
||||
### 5. Test Report Must Include Full Evidence
|
||||
Each test scenario in the report MUST have:
|
||||
- **Steps**: What was done (exact commands or UI actions)
|
||||
- **Expected**: What should happen
|
||||
- **Actual**: What actually happened
|
||||
- **API Evidence**: Before/after API response values for state-changing operations
|
||||
- **Screenshot Evidence**: Before/after screenshots with explanations
|
||||
|
||||
## State Manipulation for Realistic Testing
|
||||
|
||||
When testing features that depend on specific states (rate limits, credits, quotas):
|
||||
|
||||
1. **Use Redis CLI to set counters directly:**
|
||||
```bash
|
||||
# Find the Redis container
|
||||
REDIS_CONTAINER=$(docker ps --format '{{.Names}}' | grep redis | head -1)
|
||||
# Set a key with expiry
|
||||
docker exec $REDIS_CONTAINER redis-cli SET key value EX ttl
|
||||
# Example: Set rate limit counter to near-limit
|
||||
docker exec $REDIS_CONTAINER redis-cli SET "rate_limit:user:test@test.com" 99 EX 3600
|
||||
# Example: Check current value
|
||||
docker exec $REDIS_CONTAINER redis-cli GET "rate_limit:user:test@test.com"
|
||||
```
|
||||
|
||||
2. **Use API calls to check before/after state:**
|
||||
```bash
|
||||
# BEFORE: Record current state
|
||||
BEFORE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
echo "Credits BEFORE: $BEFORE"
|
||||
|
||||
# Perform the action...
|
||||
|
||||
# AFTER: Record new state and compare
|
||||
AFTER=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
echo "Credits AFTER: $AFTER"
|
||||
echo "Delta: $(( BEFORE - AFTER ))"
|
||||
```
|
||||
|
||||
3. **Take screenshots BEFORE and AFTER state changes** — the UI must reflect the backend state change
|
||||
|
||||
4. **Never rely on mocked/injected browser state** — always use real backend state. Do NOT use `agent-browser eval` to fake UI state. The backend must be the source of truth.
|
||||
|
||||
5. **Use direct DB queries when needed:**
|
||||
```bash
|
||||
# Query via Supabase's PostgREST or docker exec into the DB
|
||||
docker exec supabase-db psql -U supabase_admin -d postgres -c "SELECT credits FROM user_credits WHERE user_id = '...';"
|
||||
```
|
||||
|
||||
6. **After every API test, verify the state change actually persisted:**
|
||||
```bash
|
||||
# Example: After a credits purchase, verify DB matches API
|
||||
API_CREDITS=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
DB_CREDITS=$(docker exec supabase-db psql -U supabase_admin -d postgres -t -c "SELECT credits FROM user_credits WHERE user_id = '...';" | tr -d ' ')
|
||||
[ "$API_CREDITS" = "$DB_CREDITS" ] && echo "CONSISTENT" || echo "MISMATCH: API=$API_CREDITS DB=$DB_CREDITS"
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
- `$ARGUMENTS` — worktree path (e.g. `$REPO_ROOT`) or PR number
|
||||
- If `--fix` flag is present, auto-fix bugs found and push fixes (like pr-address loop)
|
||||
|
||||
## Step 0: Resolve the target
|
||||
|
||||
```bash
|
||||
# If argument is a PR number, find its worktree
|
||||
gh pr view {N} --json headRefName --jq '.headRefName'
|
||||
# If argument is a path, use it directly
|
||||
```
|
||||
|
||||
Determine:
|
||||
- `REPO_ROOT` — the root repo directory: `git -C "$WORKTREE_PATH" worktree list | head -1 | awk '{print $1}'` (or `git rev-parse --show-toplevel` if not a worktree)
|
||||
- `WORKTREE_PATH` — the worktree directory
|
||||
- `PLATFORM_DIR` — `$WORKTREE_PATH/autogpt_platform`
|
||||
- `BACKEND_DIR` — `$PLATFORM_DIR/backend`
|
||||
- `FRONTEND_DIR` — `$PLATFORM_DIR/frontend`
|
||||
- `PR_NUMBER` — the PR number (from `gh pr list --head $(git branch --show-current)`)
|
||||
- `PR_TITLE` — the PR title, slugified (e.g. "Add copilot permissions" → "add-copilot-permissions")
|
||||
- `RESULTS_DIR` — `$REPO_ROOT/test-results/PR-{PR_NUMBER}-{slugified-title}`
|
||||
|
||||
Create the results directory:
|
||||
```bash
|
||||
PR_NUMBER=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')
|
||||
PR_TITLE=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json title --jq '.[0].title' | tr '[:upper:]' '[:lower:]' | sed 's/[^a-z0-9]/-/g' | sed 's/--*/-/g' | sed 's/^-//;s/-$//' | head -c 50)
|
||||
RESULTS_DIR="$REPO_ROOT/test-results/PR-${PR_NUMBER}-${PR_TITLE}"
|
||||
mkdir -p $RESULTS_DIR
|
||||
```
|
||||
|
||||
**Test user credentials** (for logging into the UI or verifying results manually):
|
||||
- Email: `test@test.com`
|
||||
- Password: `testtest123`
|
||||
|
||||
## Step 1: Understand the PR
|
||||
|
||||
Before testing, understand what changed:
|
||||
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
|
||||
# Read PR description to understand the WHY
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
|
||||
git log --oneline dev..HEAD | head -20
|
||||
git diff dev --stat
|
||||
```
|
||||
|
||||
Read the PR description (Why / What / How) and changed files to understand:
|
||||
0. **Why** does this PR exist? What problem does it solve?
|
||||
1. **What** feature/fix does this PR implement?
|
||||
2. **How** does it work? What's the approach?
|
||||
3. What components are affected? (backend, frontend, copilot, executor, etc.)
|
||||
4. What are the key user-facing behaviors to test?
|
||||
|
||||
## Step 2: Write test scenarios
|
||||
|
||||
Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
|
||||
|
||||
```markdown
|
||||
# Test Plan: PR #{N} — {title}
|
||||
|
||||
## Scenarios
|
||||
1. [Scenario name] — [what to verify]
|
||||
2. ...
|
||||
|
||||
## API Tests (if applicable)
|
||||
1. [Endpoint] — [expected behavior]
|
||||
- Before state: [what to check before]
|
||||
- After state: [what to verify changed]
|
||||
|
||||
## UI Tests (if applicable)
|
||||
1. [Page/component] — [interaction to test]
|
||||
- Screenshot before: [what to capture]
|
||||
- Screenshot after: [what to capture]
|
||||
|
||||
## Negative Tests (REQUIRED — at least one per feature)
|
||||
1. [What should NOT happen] — [how to trigger it]
|
||||
- Expected error: [what error message/code]
|
||||
- State unchanged: [what to verify did NOT change]
|
||||
```
|
||||
|
||||
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
|
||||
|
||||
## Step 3: Environment setup
|
||||
|
||||
### 3a. Copy .env files from the root worktree
|
||||
|
||||
The root worktree (`$REPO_ROOT`) has the canonical `.env` files with all API keys. Copy them to the target worktree:
|
||||
|
||||
```bash
|
||||
# CRITICAL: .env files are NOT checked into git. They must be copied manually.
|
||||
cp $REPO_ROOT/autogpt_platform/.env $PLATFORM_DIR/.env
|
||||
cp $REPO_ROOT/autogpt_platform/backend/.env $BACKEND_DIR/.env
|
||||
cp $REPO_ROOT/autogpt_platform/frontend/.env $FRONTEND_DIR/.env
|
||||
```
|
||||
|
||||
### 3b. Configure copilot authentication
|
||||
|
||||
The copilot needs an LLM API to function. Two approaches (try subscription first):
|
||||
|
||||
#### Option 1: Subscription mode (preferred — uses your Claude Max/Pro subscription)
|
||||
|
||||
The `claude_agent_sdk` Python package **bundles its own Claude CLI binary** — no need to install `@anthropic-ai/claude-code` via npm. The backend auto-provisions credentials from environment variables on startup.
|
||||
|
||||
Run the helper script to extract tokens from your host and auto-update `backend/.env` (works on macOS, Linux, and Windows/WSL):
|
||||
|
||||
```bash
|
||||
# Extracts OAuth tokens and writes CLAUDE_CODE_OAUTH_TOKEN + CLAUDE_CODE_REFRESH_TOKEN into .env
|
||||
bash $BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env
|
||||
```
|
||||
|
||||
**How it works:** The script reads the OAuth token from:
|
||||
- **macOS**: system keychain (`"Claude Code-credentials"`)
|
||||
- **Linux/WSL**: `~/.claude/.credentials.json`
|
||||
- **Windows**: `%APPDATA%/claude/.credentials.json`
|
||||
|
||||
It sets `CLAUDE_CODE_OAUTH_TOKEN`, `CLAUDE_CODE_REFRESH_TOKEN`, and `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` in the `.env` file. On container startup, the backend auto-provisions `~/.claude/.credentials.json` inside the container from these env vars. The SDK's bundled CLI then authenticates using that file. No `claude login`, no npm install needed.
|
||||
|
||||
**Note:** The OAuth token expires (~24h). If copilot returns auth errors, re-run the script and restart: `$BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env && docker compose up -d copilot_executor`
|
||||
|
||||
#### Option 2: OpenRouter API key mode (fallback)
|
||||
|
||||
If subscription mode doesn't work, switch to API key mode using OpenRouter:
|
||||
|
||||
```bash
|
||||
# In $BACKEND_DIR/.env, ensure these are set:
|
||||
CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false
|
||||
CHAT_API_KEY=<value of OPEN_ROUTER_API_KEY from the same .env>
|
||||
CHAT_BASE_URL=https://openrouter.ai/api/v1
|
||||
CHAT_USE_CLAUDE_AGENT_SDK=true
|
||||
```
|
||||
|
||||
Use `sed` to update these values:
|
||||
```bash
|
||||
ORKEY=$(grep "^OPEN_ROUTER_API_KEY=" $BACKEND_DIR/.env | cut -d= -f2)
|
||||
[ -n "$ORKEY" ] || { echo "ERROR: OPEN_ROUTER_API_KEY is missing in $BACKEND_DIR/.env"; exit 1; }
|
||||
perl -i -pe 's/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false/' $BACKEND_DIR/.env
|
||||
# Add or update CHAT_API_KEY and CHAT_BASE_URL
|
||||
grep -q "^CHAT_API_KEY=" $BACKEND_DIR/.env && perl -i -pe "s|^CHAT_API_KEY=.*|CHAT_API_KEY=$ORKEY|" $BACKEND_DIR/.env || echo "CHAT_API_KEY=$ORKEY" >> $BACKEND_DIR/.env
|
||||
grep -q "^CHAT_BASE_URL=" $BACKEND_DIR/.env && perl -i -pe 's|^CHAT_BASE_URL=.*|CHAT_BASE_URL=https://openrouter.ai/api/v1|' $BACKEND_DIR/.env || echo "CHAT_BASE_URL=https://openrouter.ai/api/v1" >> $BACKEND_DIR/.env
|
||||
```
|
||||
|
||||
### 3c. Stop conflicting containers
|
||||
|
||||
```bash
|
||||
# Stop any running app containers (keep infra: supabase, redis, rabbitmq, clamav)
|
||||
docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocket|database_manager|scheduler|notification|frontend|migrate" | while read name; do
|
||||
docker stop "$name" 2>/dev/null
|
||||
done
|
||||
```
|
||||
|
||||
### 3e. Build and start
|
||||
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
|
||||
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker build failed"; exit 1; fi
|
||||
|
||||
cd $PLATFORM_DIR && docker compose up -d 2>&1 | tail -20
|
||||
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker compose up failed"; exit 1; fi
|
||||
```
|
||||
|
||||
**Note:** If the container appears to be running old code (e.g. missing PR changes), use `docker compose build --no-cache` to force a full rebuild. Docker BuildKit may sometimes reuse cached `COPY` layers from a previous build on a different branch.
|
||||
|
||||
**Expected time: 3-8 minutes** for build, 5-10 minutes with `--no-cache`.
|
||||
|
||||
### 3f. Wait for services to be ready
|
||||
|
||||
```bash
|
||||
# Poll until backend and frontend respond
|
||||
for i in $(seq 1 60); do
|
||||
BACKEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8006/docs 2>/dev/null)
|
||||
FRONTEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:3000 2>/dev/null)
|
||||
if [ "$BACKEND" = "200" ] && [ "$FRONTEND" = "200" ]; then
|
||||
echo "Services ready"
|
||||
break
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
```
|
||||
|
||||
|
||||
### 3h. Create test user and get auth token
|
||||
|
||||
```bash
|
||||
ANON_KEY=$(grep "NEXT_PUBLIC_SUPABASE_ANON_KEY=" $FRONTEND_DIR/.env | sed 's/.*NEXT_PUBLIC_SUPABASE_ANON_KEY=//' | tr -d '[:space:]')
|
||||
|
||||
# Signup (idempotent — returns "User already registered" if exists)
|
||||
RESULT=$(curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}')
|
||||
|
||||
# If "Database error finding user", restart supabase-auth and retry
|
||||
if echo "$RESULT" | grep -q "Database error"; then
|
||||
docker restart supabase-auth && sleep 5
|
||||
curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}'
|
||||
fi
|
||||
|
||||
# Get auth token
|
||||
TOKEN=$(curl -s -X POST 'http://localhost:8000/auth/v1/token?grant_type=password' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}' | jq -r '.access_token // ""')
|
||||
```
|
||||
|
||||
**Use this token for ALL API calls:**
|
||||
```bash
|
||||
curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/...
|
||||
```
|
||||
|
||||
## Step 4: Run tests
|
||||
|
||||
### Service ports reference
|
||||
|
||||
| Service | Port | URL |
|
||||
|---------|------|-----|
|
||||
| Frontend | 3000 | http://localhost:3000 |
|
||||
| Backend REST | 8006 | http://localhost:8006 |
|
||||
| Supabase Auth (via Kong) | 8000 | http://localhost:8000 |
|
||||
| Executor | 8002 | http://localhost:8002 |
|
||||
| Copilot Executor | 8008 | http://localhost:8008 |
|
||||
| WebSocket | 8001 | http://localhost:8001 |
|
||||
| Database Manager | 8005 | http://localhost:8005 |
|
||||
| Redis | 6379 | localhost:6379 |
|
||||
| RabbitMQ | 5672 | localhost:5672 |
|
||||
|
||||
### API testing
|
||||
|
||||
Use `curl` with the auth token for backend API tests. **For EVERY API call that changes state, record before/after values:**
|
||||
|
||||
```bash
|
||||
# Example: List agents
|
||||
curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/graphs | jq . | head -20
|
||||
|
||||
# Example: Create an agent
|
||||
curl -s -X POST http://localhost:8006/api/graphs \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{...}' | jq .
|
||||
|
||||
# Example: Run an agent
|
||||
curl -s -X POST "http://localhost:8006/api/graphs/{graph_id}/execute" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"data": {...}}'
|
||||
|
||||
# Example: Get execution results
|
||||
curl -s -H "Authorization: Bearer $TOKEN" \
|
||||
"http://localhost:8006/api/graphs/{graph_id}/executions/{exec_id}" | jq .
|
||||
```
|
||||
|
||||
**State verification pattern (use for EVERY state-changing API call):**
|
||||
```bash
|
||||
# 1. Record BEFORE state
|
||||
BEFORE_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
|
||||
echo "BEFORE: $BEFORE_STATE"
|
||||
|
||||
# 2. Perform the action
|
||||
ACTION_RESULT=$(curl -s -X POST ... | jq .)
|
||||
echo "ACTION RESULT: $ACTION_RESULT"
|
||||
|
||||
# 3. Record AFTER state
|
||||
AFTER_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
|
||||
echo "AFTER: $AFTER_STATE"
|
||||
|
||||
# 4. Log the comparison
|
||||
echo "=== STATE CHANGE VERIFICATION ==="
|
||||
echo "Before: $BEFORE_STATE"
|
||||
echo "After: $AFTER_STATE"
|
||||
echo "Expected change: {describe what should have changed}"
|
||||
```
|
||||
|
||||
### Browser testing with agent-browser
|
||||
|
||||
```bash
|
||||
# Close any existing session
|
||||
agent-browser close 2>/dev/null || true
|
||||
|
||||
# Use --session-name to persist cookies across navigations
|
||||
# This means login only needs to happen once per test session
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/login' --timeout 15000
|
||||
|
||||
# Get interactive elements
|
||||
agent-browser --session-name pr-test snapshot | grep "textbox\|button"
|
||||
|
||||
# Login
|
||||
agent-browser --session-name pr-test fill {email_ref} "test@test.com"
|
||||
agent-browser --session-name pr-test fill {password_ref} "testtest123"
|
||||
agent-browser --session-name pr-test click {login_button_ref}
|
||||
sleep 5
|
||||
|
||||
# Dismiss cookie banner if present
|
||||
agent-browser --session-name pr-test click 'text=Accept All' 2>/dev/null || true
|
||||
|
||||
# Navigate — cookies are preserved so login persists
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
|
||||
|
||||
# Take screenshot
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/01-page.png
|
||||
|
||||
# Interact with elements
|
||||
agent-browser --session-name pr-test fill {ref} "text"
|
||||
agent-browser --session-name pr-test press "Enter"
|
||||
agent-browser --session-name pr-test click {ref}
|
||||
agent-browser --session-name pr-test click 'text=Button Text'
|
||||
|
||||
# Read page content
|
||||
agent-browser --session-name pr-test snapshot | grep "text:"
|
||||
```
|
||||
|
||||
**Key pages:**
|
||||
- `/copilot` — CoPilot chat (for testing copilot features)
|
||||
- `/build` — Agent builder (for testing block/node features)
|
||||
- `/build?flowID={id}` — Specific agent in builder
|
||||
- `/library` — Agent library (for testing listing/import features)
|
||||
- `/library/agents/{id}` — Agent detail with run history
|
||||
- `/marketplace` — Marketplace
|
||||
|
||||
### Checking logs
|
||||
|
||||
```bash
|
||||
# Backend REST server
|
||||
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
|
||||
|
||||
# Executor (runs agent graphs)
|
||||
docker logs autogpt_platform-executor-1 2>&1 | tail -30
|
||||
|
||||
# Copilot executor (runs copilot chat sessions)
|
||||
docker logs autogpt_platform-copilot_executor-1 2>&1 | tail -30
|
||||
|
||||
# Frontend
|
||||
docker logs autogpt_platform-frontend-1 2>&1 | tail -30
|
||||
|
||||
# Filter for errors
|
||||
docker logs autogpt_platform-executor-1 2>&1 | grep -i "error\|exception\|traceback" | tail -20
|
||||
```
|
||||
|
||||
### Copilot chat testing
|
||||
|
||||
The copilot uses SSE streaming. To test via API:
|
||||
|
||||
```bash
|
||||
# Create a session
|
||||
SESSION_ID=$(curl -s -X POST 'http://localhost:8006/api/chat/sessions' \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{}' | jq -r '.id // .session_id // ""')
|
||||
|
||||
# Stream a message (SSE - will stream chunks)
|
||||
curl -N -X POST "http://localhost:8006/api/chat/sessions/$SESSION_ID/stream" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"message": "Hello, what can you help me with?"}' \
|
||||
--max-time 60 2>/dev/null | head -50
|
||||
```
|
||||
|
||||
Or test via browser (preferred for UI verification):
|
||||
```bash
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
|
||||
# ... fill chat input and press Enter, wait 20-30s for response
|
||||
```
|
||||
|
||||
## Step 5: Record results and take screenshots
|
||||
|
||||
**Take a screenshot at EVERY significant test step** — before and after interactions, on success, and on failure. This is NON-NEGOTIABLE.
|
||||
|
||||
**Required screenshot pattern for each test scenario:**
|
||||
```bash
|
||||
# BEFORE the action
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-before.png
|
||||
|
||||
# Perform the action...
|
||||
|
||||
# AFTER the action
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-after.png
|
||||
```
|
||||
|
||||
**Naming convention:**
|
||||
```bash
|
||||
# Examples:
|
||||
# $RESULTS_DIR/01-login-page-before.png
|
||||
# $RESULTS_DIR/02-login-page-after.png
|
||||
# $RESULTS_DIR/03-credits-page-before.png
|
||||
# $RESULTS_DIR/04-credits-purchase-after.png
|
||||
# $RESULTS_DIR/05-negative-insufficient-credits.png
|
||||
# $RESULTS_DIR/06-error-state.png
|
||||
```
|
||||
|
||||
**Minimum requirements:**
|
||||
- At least TWO screenshots per test scenario (before + after)
|
||||
- At least ONE screenshot for each negative test case showing the error state
|
||||
- If a test fails, screenshot the failure state AND any error logs visible in the UI
|
||||
|
||||
## Step 6: Show results to user with screenshots
|
||||
|
||||
**CRITICAL: After all tests complete, you MUST show every screenshot to the user using the Read tool, with an explanation of what each screenshot shows.** This is the most important part of the test report — the user needs to visually verify the results.
|
||||
|
||||
For each screenshot:
|
||||
1. Use the `Read` tool to display the PNG file (Claude can read images)
|
||||
2. Write a 1-2 sentence explanation below it describing:
|
||||
- What page/state is being shown
|
||||
- What the screenshot proves (which test scenario it validates)
|
||||
- Any notable details visible in the UI
|
||||
|
||||
Format the output like this:
|
||||
|
||||
```markdown
|
||||
### Screenshot 1: {descriptive title}
|
||||
[Read the PNG file here]
|
||||
|
||||
**What it shows:** {1-2 sentence explanation of what this screenshot proves}
|
||||
|
||||
---
|
||||
```
|
||||
|
||||
After showing all screenshots, output a **detailed** summary table:
|
||||
|
||||
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|
||||
|---|----------|--------|-------------|-------------------|
|
||||
| 1 | {name} | PASS/FAIL | Before: X, After: Y | 01-before.png, 02-after.png |
|
||||
| 2 | ... | ... | ... | ... |
|
||||
|
||||
**IMPORTANT:** As you show each screenshot and record test results, persist them in shell variables for Step 7:
|
||||
|
||||
```bash
|
||||
# Build these variables during Step 6 — they are required by Step 7's script
|
||||
# NOTE: declare -A requires Bash 4.0+. This is standard on modern systems (macOS ships zsh
|
||||
# but Homebrew bash is 5.x; Linux typically has bash 5.x). If running on Bash <4, use a
|
||||
# plain variable with a lookup function instead.
|
||||
declare -A SCREENSHOT_EXPLANATIONS=(
|
||||
["01-login-page.png"]="Shows the login page loaded successfully with SSO options visible."
|
||||
["02-builder-with-block.png"]="The builder canvas displays the newly added block connected to the trigger."
|
||||
# ... one entry per screenshot, using the same explanations you showed the user above
|
||||
)
|
||||
|
||||
TEST_RESULTS_TABLE="| 1 | Login flow | PASS | N/A | 01-login-before.png, 02-login-after.png |
|
||||
| 2 | Credits purchase | PASS | Before: 100, After: 95 | 03-credits-before.png, 04-credits-after.png |
|
||||
| 3 | Insufficient credits (negative) | PASS | Credits: 0, rejected | 05-insufficient-credits-error.png |"
|
||||
# ... one row per test scenario with actual results
|
||||
```
|
||||
|
||||
## Step 7: Post test report as PR comment with screenshots
|
||||
|
||||
Upload screenshots to the PR using the GitHub Git API (no local git operations — safe for worktrees), then post a comment with inline images and per-screenshot explanations.
|
||||
|
||||
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
REPO="Significant-Gravitas/AutoGPT"
|
||||
SCREENSHOTS_BRANCH="test-screenshots/pr-${PR_NUMBER}"
|
||||
SCREENSHOTS_DIR="test-screenshots/PR-${PR_NUMBER}"
|
||||
|
||||
# Step 1: Create blobs for each screenshot and build tree JSON
|
||||
# Retry each blob upload up to 3 times. If still failing, list them at end of report.
|
||||
shopt -s nullglob
|
||||
SCREENSHOT_FILES=("$RESULTS_DIR"/*.png)
|
||||
if [ ${#SCREENSHOT_FILES[@]} -eq 0 ]; then
|
||||
echo "ERROR: No screenshots found in $RESULTS_DIR. Test run is incomplete."
|
||||
exit 1
|
||||
fi
|
||||
TREE_JSON='['
|
||||
FIRST=true
|
||||
FAILED_UPLOADS=()
|
||||
for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
BASENAME=$(basename "$img")
|
||||
B64=$(base64 < "$img")
|
||||
BLOB_SHA=""
|
||||
for attempt in 1 2 3; do
|
||||
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha' 2>/dev/null || true)
|
||||
[ -n "$BLOB_SHA" ] && break
|
||||
sleep 1
|
||||
done
|
||||
if [ -z "$BLOB_SHA" ]; then
|
||||
FAILED_UPLOADS+=("$img")
|
||||
continue
|
||||
fi
|
||||
if [ "$FIRST" = true ]; then FIRST=false; else TREE_JSON+=','; fi
|
||||
TREE_JSON+="{\"path\":\"${SCREENSHOTS_DIR}/${BASENAME}\",\"mode\":\"100644\",\"type\":\"blob\",\"sha\":\"${BLOB_SHA}\"}"
|
||||
done
|
||||
TREE_JSON+=']'
|
||||
|
||||
# Step 2: Create tree, commit, and branch ref
|
||||
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
gh api "repos/${REPO}/git/refs" \
|
||||
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-f sha="$COMMIT_SHA" 2>/dev/null \
|
||||
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-X PATCH -f sha="$COMMIT_SHA" -f force=true
|
||||
```
|
||||
|
||||
Then post the comment with **inline images AND explanations for each screenshot**:
|
||||
|
||||
```bash
|
||||
REPO_URL="https://raw.githubusercontent.com/${REPO}/${SCREENSHOTS_BRANCH}"
|
||||
|
||||
# Build image markdown using uploaded image URLs; skip FAILED_UPLOADS (listed separately)
|
||||
|
||||
IMAGE_MARKDOWN=""
|
||||
for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
BASENAME=$(basename "$img")
|
||||
TITLE=$(echo "${BASENAME%.png}" | sed 's/^[0-9]*-//' | sed 's/-/ /g' | awk '{for(i=1;i<=NF;i++) $i=toupper(substr($i,1,1)) tolower(substr($i,2))}1')
|
||||
# Skip images that failed to upload — they will be listed at the end
|
||||
IS_FAILED=false
|
||||
for failed in "${FAILED_UPLOADS[@]}"; do
|
||||
[ "$(basename "$failed")" = "$BASENAME" ] && IS_FAILED=true && break
|
||||
done
|
||||
if [ "$IS_FAILED" = true ]; then
|
||||
continue
|
||||
fi
|
||||
EXPLANATION="${SCREENSHOT_EXPLANATIONS[$BASENAME]}"
|
||||
if [ -z "$EXPLANATION" ]; then
|
||||
echo "ERROR: Missing screenshot explanation for $BASENAME. Add it to SCREENSHOT_EXPLANATIONS in Step 6."
|
||||
exit 1
|
||||
fi
|
||||
IMAGE_MARKDOWN="${IMAGE_MARKDOWN}
|
||||
### ${TITLE}
|
||||

|
||||
${EXPLANATION}
|
||||
"
|
||||
done
|
||||
|
||||
# Write comment body to file to avoid shell interpretation issues with special characters
|
||||
COMMENT_FILE=$(mktemp)
|
||||
# If any uploads failed, append a section listing them with instructions
|
||||
FAILED_SECTION=""
|
||||
if [ ${#FAILED_UPLOADS[@]} -gt 0 ]; then
|
||||
FAILED_SECTION="
|
||||
## ⚠️ Failed Screenshot Uploads
|
||||
The following screenshots could not be uploaded via the GitHub API after 3 retries.
|
||||
**To add them:** drag-and-drop or paste these files into a PR comment manually:
|
||||
"
|
||||
for failed in "${FAILED_UPLOADS[@]}"; do
|
||||
FAILED_SECTION="${FAILED_SECTION}
|
||||
- \`$(basename "$failed")\` (local path: \`$failed\`)"
|
||||
done
|
||||
FAILED_SECTION="${FAILED_SECTION}
|
||||
|
||||
**Run status:** INCOMPLETE until the files above are manually attached and visible inline in the PR."
|
||||
fi
|
||||
|
||||
cat > "$COMMENT_FILE" <<INNEREOF
|
||||
## E2E Test Report
|
||||
|
||||
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|
||||
|---|----------|--------|-------------|-------------------|
|
||||
${TEST_RESULTS_TABLE}
|
||||
|
||||
${IMAGE_MARKDOWN}
|
||||
${FAILED_SECTION}
|
||||
INNEREOF
|
||||
|
||||
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
|
||||
rm -f "$COMMENT_FILE"
|
||||
```
|
||||
|
||||
**The PR comment MUST include:**
|
||||
1. A summary table of all scenarios with PASS/FAIL and before/after API evidence
|
||||
2. Every successfully uploaded screenshot rendered inline; any failed uploads listed with manual attachment instructions
|
||||
3. A 1-2 sentence explanation below each screenshot describing what it proves
|
||||
|
||||
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
|
||||
|
||||
## Fix mode (--fix flag)
|
||||
|
||||
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.
|
||||
|
||||
### Fix protocol for EVERY issue found (including UX issues):
|
||||
|
||||
1. **Identify** the root cause in the code — read the relevant source files
|
||||
2. **Write a failing test first** (TDD): For backend bugs, write a test marked with `pytest.mark.xfail(reason="...")`. For frontend/Playwright bugs, write a test with `.fixme` annotation. Run it to confirm it fails as expected.
|
||||
3. **Screenshot** the broken state: `agent-browser screenshot $RESULTS_DIR/{NN}-broken-{description}.png`
|
||||
4. **Fix** the code in the worktree
|
||||
5. **Rebuild** ONLY the affected service (not the whole stack):
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose up --build -d {service_name}
|
||||
# e.g., docker compose up --build -d rest_server
|
||||
# e.g., docker compose up --build -d frontend
|
||||
```
|
||||
6. **Wait** for the service to be ready (poll health endpoint)
|
||||
7. **Re-test** the same scenario
|
||||
8. **Screenshot** the fixed state: `agent-browser screenshot $RESULTS_DIR/{NN}-fixed-{description}.png`
|
||||
9. **Remove the xfail/fixme marker** from the test written in step 2, and verify it passes
|
||||
10. **Verify** the fix did not break other scenarios (run a quick smoke test)
|
||||
11. **Commit and push** immediately:
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
git add -A
|
||||
git commit -m "fix: {description of fix}"
|
||||
git push
|
||||
```
|
||||
12. **Continue** to the next test scenario
|
||||
|
||||
### Fix loop (like pr-address)
|
||||
|
||||
```text
|
||||
test scenario → find issue (bug OR UX problem) → screenshot broken state
|
||||
→ fix code → rebuild affected service only → re-test → screenshot fixed state
|
||||
→ verify no regressions → commit + push
|
||||
→ repeat for next scenario
|
||||
→ after ALL scenarios pass, run full re-test to verify everything together
|
||||
```
|
||||
|
||||
**Key differences from non-fix mode:**
|
||||
- UX issues count as bugs — fix them (bad alignment, confusing labels, missing loading states)
|
||||
- Every fix MUST have a before/after screenshot pair proving it works
|
||||
- Commit after EACH fix, not in a batch at the end
|
||||
- The final re-test must produce a clean set of all-passing screenshots
|
||||
|
||||
## Known issues and workarounds
|
||||
|
||||
### Problem: "Database error finding user" on signup
|
||||
**Cause:** Supabase auth service schema cache is stale after migration.
|
||||
**Fix:** `docker restart supabase-auth && sleep 5` then retry signup.
|
||||
|
||||
### Problem: Copilot returns auth errors in subscription mode
|
||||
**Cause:** `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` but `CLAUDE_CODE_OAUTH_TOKEN` is not set or expired.
|
||||
**Fix:** Re-extract the OAuth token from macOS keychain (see step 3b, Option 1) and recreate the container (`docker compose up -d copilot_executor`). The backend auto-provisions `~/.claude/.credentials.json` from the env var on startup. No `npm install` or `claude login` needed — the SDK bundles its own CLI binary.
|
||||
|
||||
### Problem: agent-browser can't find chromium
|
||||
**Cause:** The Dockerfile auto-provisions system chromium on all architectures (including ARM64). If your branch is behind `dev`, this may not be present yet.
|
||||
**Fix:** Check if chromium exists: `which chromium || which chromium-browser`. If missing, install it: `apt-get install -y chromium` and set `AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium` in the container environment.
|
||||
|
||||
### Problem: agent-browser selector matches multiple elements
|
||||
**Cause:** `text=X` matches all elements containing that text.
|
||||
**Fix:** Use `agent-browser snapshot` to get specific `ref=eNN` references, then use those: `agent-browser click eNN`.
|
||||
|
||||
### Problem: Frontend shows cookie banner blocking interaction
|
||||
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
|
||||
|
||||
### Problem: Container loses npm packages after rebuild
|
||||
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
|
||||
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
|
||||
|
||||
### Problem: Services not starting after `docker compose up`
|
||||
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.
|
||||
|
||||
### Problem: Docker uses cached layers with old code (PR changes not visible)
|
||||
**Cause:** `docker compose up --build` reuses cached `COPY` layers from previous builds. If the PR branch changes Python files but the previous build already cached that layer from `dev`, the container runs `dev` code.
|
||||
**Fix:** Always use `docker compose build --no-cache` for the first build of a PR branch. Subsequent rebuilds within the same branch can use `--build`.
|
||||
|
||||
### Problem: `agent-browser open` loses login session
|
||||
**Cause:** Without session persistence, `agent-browser open` starts fresh.
|
||||
**Fix:** Use `--session-name pr-test` on ALL agent-browser commands. This auto-saves/restores cookies and localStorage across navigations. Alternatively, use `agent-browser eval "window.location.href = '...'"` to navigate within the same context.
|
||||
|
||||
### Problem: Supabase auth returns "Database error querying schema"
|
||||
**Cause:** The database schema changed (migration ran) but supabase-auth has a stale schema cache.
|
||||
**Fix:** `docker restart supabase-db && sleep 10 && docker restart supabase-auth && sleep 8`. If user data was lost, re-signup.
|
||||
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,12 +1,8 @@
|
||||
### Why / What / How
|
||||
|
||||
<!-- Why: Why does this PR exist? What problem does it solve, or what's broken/missing without it? -->
|
||||
<!-- What: What does this PR change? Summarize the changes at a high level. -->
|
||||
<!-- How: How does it work? Describe the approach, key implementation details, or architecture decisions. -->
|
||||
<!-- Clearly explain the need for these changes: -->
|
||||
|
||||
### Changes 🏗️
|
||||
|
||||
<!-- List the key changes. Keep it higher level than the diff but specific enough to highlight what's new/modified. -->
|
||||
<!-- Concisely describe all of the changes made in this pull request: -->
|
||||
|
||||
### Checklist 📋
|
||||
|
||||
|
||||
@@ -53,10 +53,8 @@ AutoGPT Platform is a monorepo containing:
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
|
||||
54
autogpt_platform/autogpt_libs/poetry.lock
generated
54
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
@@ -67,7 +67,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
|
||||
optional = false
|
||||
python-versions = "<3.11,>=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version == \"3.10\""
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||
@@ -541,7 +541,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version == \"3.10\""
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
|
||||
{file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"},
|
||||
@@ -2181,14 +2181,14 @@ testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "7.1.0"
|
||||
version = "7.0.0"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678"},
|
||||
{file = "pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2"},
|
||||
{file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"},
|
||||
{file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2342,30 +2342,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.15.7"
|
||||
version = "0.15.0"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580"},
|
||||
{file = "ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de"},
|
||||
{file = "ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1"},
|
||||
{file = "ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2"},
|
||||
{file = "ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac"},
|
||||
{file = "ruff-0.15.0-py3-none-linux_armv6l.whl", hash = "sha256:aac4ebaa612a82b23d45964586f24ae9bc23ca101919f5590bdb368d74ad5455"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:dcd4be7cc75cfbbca24a98d04d0b9b36a270d0833241f776b788d59f4142b14d"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d747e3319b2bce179c7c1eaad3d884dc0a199b5f4d5187620530adf9105268ce"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:650bd9c56ae03102c51a5e4b554d74d825ff3abe4db22b90fd32d816c2e90621"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a6664b7eac559e3048223a2da77769c2f92b43a6dfd4720cef42654299a599c9"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f811f97b0f092b35320d1556f3353bf238763420ade5d9e62ebd2b73f2ff179"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:761ec0a66680fab6454236635a39abaf14198818c8cdf691e036f4bc0f406b2d"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:940f11c2604d317e797b289f4f9f3fa5555ffe4fb574b55ed006c3d9b6f0eb78"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbca3d40558789126da91d7ef9a7c87772ee107033db7191edefa34e2c7f1b4"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9a121a96db1d75fa3eb39c4539e607f628920dd72ff1f7c5ee4f1b768ac62d6e"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5298d518e493061f2eabd4abd067c7e4fb89e2f63291c94332e35631c07c3662"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:afb6e603d6375ff0d6b0cee563fa21ab570fd15e65c852cb24922cef25050cf1"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:77e515f6b15f828b94dc17d2b4ace334c9ddb7d9468c54b2f9ed2b9c1593ef16"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6f6e80850a01eb13b3e42ee0ebdf6e4497151b48c35051aab51c101266d187a3"},
|
||||
{file = "ruff-0.15.0-py3-none-win32.whl", hash = "sha256:238a717ef803e501b6d51e0bdd0d2c6e8513fe9eec14002445134d3907cd46c3"},
|
||||
{file = "ruff-0.15.0-py3-none-win_amd64.whl", hash = "sha256:dd5e4d3301dc01de614da3cdffc33d4b1b96fb89e45721f1598e5532ccf78b18"},
|
||||
{file = "ruff-0.15.0-py3-none-win_arm64.whl", hash = "sha256:c480d632cc0ca3f0727acac8b7d053542d9e114a462a145d0b00e7cd658c515a"},
|
||||
{file = "ruff-0.15.0.tar.gz", hash = "sha256:6bdea47cdbea30d40f8f8d7d69c0854ba7c15420ec75a26f463290949d7f7e9a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2564,7 +2564,7 @@ description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version == \"3.10\""
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
||||
@@ -2912,4 +2912,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "e0936a065565550afed18f6298b7e04e814b44100def7049f1a0d68662624a39"
|
||||
content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb"
|
||||
|
||||
@@ -26,8 +26,8 @@ pyright = "^1.1.408"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-cov = "^7.1.0"
|
||||
ruff = "^0.15.7"
|
||||
pytest-cov = "^7.0.0"
|
||||
ruff = "^0.15.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -61,7 +61,6 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
|
||||
@@ -121,20 +121,36 @@ RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) using the system chromium package.
|
||||
# Chrome for Testing (the binary agent-browser downloads via `agent-browser install`)
|
||||
# has no ARM64 builds, so we use the distro-packaged chromium instead — verified to
|
||||
# work with agent-browser via Docker tests on arm64; amd64 is validated in CI.
|
||||
# Note: system chromium tracks the Debian package schedule rather than a pinned
|
||||
# Chrome for Testing release. If agent-browser requires a specific Chrome version,
|
||||
# verify compatibility against the chromium package version in the base image.
|
||||
# Install agent-browser (Copilot browser tool) + Chromium.
|
||||
# On amd64: install runtime libs + run `agent-browser install` to download
|
||||
# Chrome for Testing (pinned version, tested with Playwright).
|
||||
# On arm64: install system chromium package — Chrome for Testing has no ARM64
|
||||
# binary. AGENT_BROWSER_EXECUTABLE_PATH is set at runtime by the entrypoint
|
||||
# script (below) to redirect agent-browser to the system binary.
|
||||
ARG TARGETARCH
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends chromium fonts-liberation \
|
||||
&& if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
apt-get install -y --no-install-recommends chromium fonts-liberation; \
|
||||
else \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1; \
|
||||
fi \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& ([ "$TARGETARCH" = "arm64" ] || agent-browser install) \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
ENV AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium
|
||||
# On arm64 the system chromium is at /usr/bin/chromium; set
|
||||
# AGENT_BROWSER_EXECUTABLE_PATH so agent-browser's daemon uses it instead of
|
||||
# Chrome for Testing (which has no ARM64 binary). On amd64 the variable is left
|
||||
# unset so agent-browser uses the Chrome for Testing binary it downloaded above.
|
||||
RUN printf '#!/bin/sh\n[ -x /usr/bin/chromium ] && export AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium\nexec "$@"\n' \
|
||||
> /usr/local/bin/entrypoint.sh \
|
||||
&& chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
@@ -157,4 +173,5 @@ RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["rest"]
|
||||
|
||||
@@ -18,20 +18,15 @@ from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.integrations.models import get_all_provider_names
|
||||
from backend.api.features.integrations.router import (
|
||||
CredentialsMetaResponse,
|
||||
to_meta_response,
|
||||
)
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -96,6 +91,18 @@ class OAuthCompleteResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CredentialSummary(BaseModel):
|
||||
"""Summary of a credential without sensitive data."""
|
||||
|
||||
id: str
|
||||
provider: str
|
||||
type: CredentialsType
|
||||
title: Optional[str] = None
|
||||
scopes: Optional[list[str]] = None
|
||||
username: Optional[str] = None
|
||||
host: Optional[str] = None
|
||||
|
||||
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about an integration provider."""
|
||||
|
||||
@@ -466,12 +473,12 @@ async def complete_oauth(
|
||||
)
|
||||
|
||||
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialsMetaResponse])
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||
async def list_credentials(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
) -> list[CredentialSummary]:
|
||||
"""
|
||||
List all credentials for the authenticated user.
|
||||
|
||||
@@ -479,19 +486,28 @@ async def list_credentials(
|
||||
"""
|
||||
credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
return [
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
|
||||
@integrations_router.get(
|
||||
"/{provider}/credentials", response_model=list[CredentialsMetaResponse]
|
||||
"/{provider}/credentials", response_model=list[CredentialSummary]
|
||||
)
|
||||
async def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
) -> list[CredentialSummary]:
|
||||
"""
|
||||
List credentials for a specific provider.
|
||||
"""
|
||||
@@ -499,7 +515,16 @@ async def list_credentials_by_provider(
|
||||
auth.user_id, provider
|
||||
)
|
||||
return [
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
|
||||
@@ -572,11 +597,11 @@ async def create_credential(
|
||||
# Store credentials
|
||||
try:
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
except Exception:
|
||||
logger.exception("Failed to store credentials")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to store credentials",
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
)
|
||||
|
||||
logger.info(f"Created {request.type} credentials for provider {provider}")
|
||||
@@ -614,18 +639,15 @@ async def delete_credential(
|
||||
use the main API's delete endpoint which handles webhook cleanup and
|
||||
token revocation.
|
||||
"""
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if not provider_matches(creds.provider, provider):
|
||||
if creds.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
|
||||
await creds_manager.delete(auth.user_id, cred_id)
|
||||
|
||||
@@ -7,8 +7,6 @@ import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
@@ -134,40 +132,3 @@ async def admin_download_agent_file(
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/submissions/{store_listing_version_id}/preview",
|
||||
summary="Admin Preview Submission Listing",
|
||||
)
|
||||
async def admin_preview_submission(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""
|
||||
Preview a marketplace submission as it would appear on the listing page.
|
||||
Bypasses the APPROVED-only StoreAgent view so admins can preview pending
|
||||
submissions before approving.
|
||||
"""
|
||||
return await store_db.get_store_agent_details_as_admin(store_listing_version_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/add-to-library",
|
||||
summary="Admin Add Pending Agent to Library",
|
||||
status_code=201,
|
||||
)
|
||||
async def admin_add_agent_to_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Add a pending marketplace agent to the admin's library for review.
|
||||
Uses admin-level access to bypass marketplace APPROVED-only checks.
|
||||
|
||||
The builder can load the graph because get_graph() checks library
|
||||
membership as a fallback: "you added it, you keep it."
|
||||
"""
|
||||
return await library_db.add_store_agent_to_library_as_admin(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -1,335 +0,0 @@
|
||||
"""Tests for admin store routes and the bypass logic they depend on.
|
||||
|
||||
Tests are organized by what they protect:
|
||||
- SECRT-2162: get_graph_as_admin bypasses ownership/marketplace checks
|
||||
- SECRT-2167 security: admin endpoints reject non-admin users
|
||||
- SECRT-2167 bypass: preview queries StoreListingVersion (not StoreAgent view),
|
||||
and add-to-library uses get_graph_as_admin (not get_graph)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.graph import get_graph_as_admin
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .store_admin_routes import router as store_admin_router
|
||||
|
||||
# Shared constants
|
||||
ADMIN_USER_ID = "admin-user-id"
|
||||
CREATOR_USER_ID = "other-creator-id"
|
||||
GRAPH_ID = "test-graph-id"
|
||||
GRAPH_VERSION = 3
|
||||
SLV_ID = "test-store-listing-version-id"
|
||||
|
||||
|
||||
def _make_mock_graph(user_id: str = CREATOR_USER_ID) -> MagicMock:
|
||||
graph = MagicMock()
|
||||
graph.userId = user_id
|
||||
graph.id = GRAPH_ID
|
||||
graph.version = GRAPH_VERSION
|
||||
graph.Nodes = []
|
||||
return graph
|
||||
|
||||
|
||||
# ---- SECRT-2162: get_graph_as_admin bypasses ownership checks ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_access_pending_agent_not_owned() -> None:
|
||||
"""get_graph_as_admin must return a graph even when the admin doesn't own
|
||||
it and it's not APPROVED in the marketplace."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=False,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_download_pending_agent_with_subagents() -> None:
|
||||
"""get_graph_as_admin with for_export=True must call get_sub_graphs
|
||||
and pass sub_graphs to GraphModel.from_db."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_sub_graph = MagicMock(name="SubGraph")
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.get_sub_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_sub_graph],
|
||||
) as mock_get_sub,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_get_sub.assert_awaited_once_with(mock_graph)
|
||||
mock_from_db.assert_called_once_with(
|
||||
graph=mock_graph,
|
||||
sub_graphs=[mock_sub_graph],
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
|
||||
# ---- SECRT-2167 security: admin endpoints reject non-admin users ---- #
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(store_admin_router)
|
||||
|
||||
|
||||
@app.exception_handler(NotFoundError)
|
||||
async def _not_found_handler(
|
||||
request: fastapi.Request, exc: NotFoundError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
return fastapi.responses.JSONResponse(status_code=404, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all route tests in this module."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_preview_requires_admin(mock_jwt_user) -> None:
|
||||
"""Non-admin users must get 403 on the preview endpoint."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_add_to_library_requires_admin(mock_jwt_user) -> None:
|
||||
"""Non-admin users must get 403 on the add-to-library endpoint."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
response = client.post(f"/admin/submissions/{SLV_ID}/add-to-library")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_preview_nonexistent_submission(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Preview of a nonexistent submission returns 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.store_admin_routes.store_db"
|
||||
".get_store_agent_details_as_admin",
|
||||
side_effect=NotFoundError("not found"),
|
||||
)
|
||||
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- SECRT-2167 bypass: verify the right data sources are used ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_queries_store_listing_version_not_store_agent() -> None:
|
||||
"""get_store_agent_details_as_admin must query StoreListingVersion
|
||||
directly (not the APPROVED-only StoreAgent view). This is THE test that
|
||||
prevents the bypass from being accidentally reverted."""
|
||||
from backend.api.features.store.db import get_store_agent_details_as_admin
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.id = SLV_ID
|
||||
mock_slv.name = "Test Agent"
|
||||
mock_slv.subHeading = "Short desc"
|
||||
mock_slv.description = "Long desc"
|
||||
mock_slv.videoUrl = None
|
||||
mock_slv.agentOutputDemoUrl = None
|
||||
mock_slv.imageUrls = ["https://example.com/img.png"]
|
||||
mock_slv.instructions = None
|
||||
mock_slv.categories = ["productivity"]
|
||||
mock_slv.version = 1
|
||||
mock_slv.agentGraphId = GRAPH_ID
|
||||
mock_slv.agentGraphVersion = GRAPH_VERSION
|
||||
mock_slv.updatedAt = datetime(2026, 3, 24, tzinfo=timezone.utc)
|
||||
mock_slv.recommendedScheduleCron = "0 9 * * *"
|
||||
|
||||
mock_listing = MagicMock()
|
||||
mock_listing.id = "listing-id"
|
||||
mock_listing.slug = "test-agent"
|
||||
mock_listing.activeVersionId = SLV_ID
|
||||
mock_listing.hasApprovedVersion = False
|
||||
mock_listing.CreatorProfile = MagicMock(username="creator", avatarUrl="")
|
||||
mock_slv.StoreListing = mock_listing
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.store.db.prisma.models" ".StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch(
|
||||
"backend.api.features.store.db.prisma.models.StoreAgent.prisma",
|
||||
) as mock_store_agent_prisma,
|
||||
):
|
||||
mock_slv_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await get_store_agent_details_as_admin(SLV_ID)
|
||||
|
||||
# Verify it queried StoreListingVersion (not the APPROVED-only StoreAgent)
|
||||
mock_slv_prisma.return_value.find_unique.assert_awaited_once()
|
||||
await_args = mock_slv_prisma.return_value.find_unique.await_args
|
||||
assert await_args is not None
|
||||
assert await_args.kwargs["where"] == {"id": SLV_ID}
|
||||
|
||||
# Verify the APPROVED-only StoreAgent view was NOT touched
|
||||
mock_store_agent_prisma.assert_not_called()
|
||||
|
||||
# Verify the result has the right data
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.agent_image == ["https://example.com/img.png"]
|
||||
assert result.has_approved_version is False
|
||||
assert result.runs == 0
|
||||
assert result.rating == 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_graph_admin_uses_get_graph_as_admin() -> None:
|
||||
"""resolve_graph_for_library(admin=True) must call get_graph_as_admin,
|
||||
not get_graph. This is THE test that prevents the add-to-library bypass
|
||||
from being accidentally reverted."""
|
||||
from backend.api.features.library._add_to_library import resolve_graph_for_library
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models"
|
||||
".StoreListingVersion.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
".get_graph_as_admin",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_graph_model,
|
||||
) as mock_admin,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db.get_graph",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_regular,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await resolve_graph_for_library(SLV_ID, ADMIN_USER_ID, admin=True)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_admin.assert_awaited_once_with(
|
||||
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id=ADMIN_USER_ID
|
||||
)
|
||||
mock_regular.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_graph_regular_uses_get_graph() -> None:
|
||||
"""resolve_graph_for_library(admin=False) must call get_graph,
|
||||
not get_graph_as_admin. Ensures the non-admin path is preserved."""
|
||||
from backend.api.features.library._add_to_library import resolve_graph_for_library
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models"
|
||||
".StoreListingVersion.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
".get_graph_as_admin",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_admin,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db.get_graph",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_graph_model,
|
||||
) as mock_regular,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await resolve_graph_for_library(SLV_ID, "regular-user-id", admin=False)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_regular.assert_awaited_once_with(
|
||||
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id="regular-user-id"
|
||||
)
|
||||
mock_admin.assert_not_awaited()
|
||||
|
||||
|
||||
# ---- Library membership grants graph access (product decision) ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_member_can_view_pending_agent_in_builder() -> None:
|
||||
"""After adding a pending agent to their library, the user should be
|
||||
able to load the graph in the builder via get_graph()."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
mock_library_agent = MagicMock()
|
||||
mock_library_agent.AgentGraph = mock_graph
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_library_agent
|
||||
)
|
||||
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model, "Library membership should grant graph access"
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Override session-scoped fixtures so unit tests run without the server."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def graph_cleanup():
|
||||
yield
|
||||
@@ -34,7 +34,6 @@ from backend.data.model import (
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
@@ -139,18 +138,6 @@ class CredentialsMetaResponse(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
|
||||
return CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||
async def callback(
|
||||
provider: Annotated[
|
||||
@@ -217,7 +204,15 @@ async def callback(
|
||||
f"and provider {provider.value}"
|
||||
)
|
||||
|
||||
return to_meta_response(credentials)
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/credentials", summary="List Credentials")
|
||||
@@ -227,7 +222,16 @@ async def list_credentials(
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
|
||||
@@ -241,7 +245,16 @@ async def list_credentials_by_provider(
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
|
||||
@@ -254,21 +267,18 @@ async def get_credential(
|
||||
],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> CredentialsMetaResponse:
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
) -> Credentials:
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if not provider_matches(credential.provider, provider):
|
||||
if credential.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
return to_meta_response(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
||||
@@ -278,22 +288,16 @@ async def create_credentials(
|
||||
ProviderName, Path(title="The provider to create credentials for")
|
||||
],
|
||||
credentials: Credentials,
|
||||
) -> CredentialsMetaResponse:
|
||||
if is_sdk_default(credentials.id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot create credentials with a reserved ID",
|
||||
)
|
||||
) -> Credentials:
|
||||
credentials.provider = provider
|
||||
try:
|
||||
await creds_manager.create(user_id, credentials)
|
||||
except Exception:
|
||||
logger.exception("Failed to store credentials")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to store credentials",
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
)
|
||||
return to_meta_response(credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
class CredentialsDeletionResponse(BaseModel):
|
||||
@@ -328,19 +332,15 @@ async def delete_credentials(
|
||||
bool, Query(title="Whether to proceed if any linked webhooks are still in use")
|
||||
] = False,
|
||||
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if not provider_matches(creds.provider, provider):
|
||||
if creds.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials not found",
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,278 +0,0 @@
|
||||
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.integrations.router import router
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "test-user-id"
|
||||
|
||||
|
||||
def _make_api_key_cred(cred_id: str = "cred-123", provider: str = "openai"):
|
||||
return APIKeyCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My API Key",
|
||||
api_key=SecretStr("sk-secret-key-value"),
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_cred(cred_id: str = "cred-456", provider: str = "github"):
|
||||
return OAuth2Credentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My OAuth",
|
||||
access_token=SecretStr("ghp_secret_token"),
|
||||
refresh_token=SecretStr("ghp_refresh_secret"),
|
||||
scopes=["repo", "user"],
|
||||
username="testuser",
|
||||
)
|
||||
|
||||
|
||||
def _make_user_password_cred(cred_id: str = "cred-789", provider: str = "openai"):
|
||||
return UserPasswordCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My Login",
|
||||
username=SecretStr("admin"),
|
||||
password=SecretStr("s3cret-pass"),
|
||||
)
|
||||
|
||||
|
||||
def _make_host_scoped_cred(cred_id: str = "cred-host", provider: str = "openai"):
|
||||
return HostScopedCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="Host Cred",
|
||||
host="https://api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer top-secret")},
|
||||
)
|
||||
|
||||
|
||||
def _make_sdk_default_cred(provider: str = "openai"):
|
||||
return APIKeyCredentials(
|
||||
id=f"{provider}-default",
|
||||
provider=provider,
|
||||
title=f"{provider} (default)",
|
||||
api_key=SecretStr("sk-platform-secret-key"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestGetCredentialReturnsMetaOnly:
|
||||
"""GET /{provider}/credentials/{cred_id} must not return secrets."""
|
||||
|
||||
def test_api_key_credential_no_secret(self):
|
||||
cred = _make_api_key_cred()
|
||||
with (
|
||||
patch.object(router, "dependencies", []),
|
||||
patch("backend.api.features.integrations.router.creds_manager") as mock_mgr,
|
||||
):
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-123")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-123"
|
||||
assert data["provider"] == "openai"
|
||||
assert data["type"] == "api_key"
|
||||
assert "api_key" not in data
|
||||
assert "sk-secret-key-value" not in str(data)
|
||||
|
||||
def test_oauth2_credential_no_secret(self):
|
||||
cred = _make_oauth2_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-456")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-456"
|
||||
assert data["scopes"] == ["repo", "user"]
|
||||
assert data["username"] == "testuser"
|
||||
assert "access_token" not in data
|
||||
assert "refresh_token" not in data
|
||||
assert "ghp_" not in str(data)
|
||||
|
||||
def test_user_password_credential_no_secret(self):
|
||||
cred = _make_user_password_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-789")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-789"
|
||||
assert "password" not in data
|
||||
assert "username" not in data or data["username"] is None
|
||||
assert "s3cret-pass" not in str(data)
|
||||
assert "admin" not in str(data)
|
||||
|
||||
def test_host_scoped_credential_no_secret(self):
|
||||
cred = _make_host_scoped_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-host")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-host"
|
||||
assert data["host"] == "https://api.example.com"
|
||||
assert "headers" not in data
|
||||
assert "top-secret" not in str(data)
|
||||
|
||||
def test_get_credential_wrong_provider_returns_404(self):
|
||||
"""Provider mismatch should return generic 404, not leak credential existence."""
|
||||
cred = _make_api_key_cred(provider="openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-123")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["detail"] == "Credentials not found"
|
||||
|
||||
def test_list_credentials_no_secrets(self):
|
||||
"""List endpoint must not leak secrets in any credential."""
|
||||
creds = [_make_api_key_cred(), _make_oauth2_cred()]
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=creds)
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
raw = str(resp.json())
|
||||
assert "sk-secret-key-value" not in raw
|
||||
assert "ghp_secret_token" not in raw
|
||||
assert "ghp_refresh_secret" not in raw
|
||||
|
||||
|
||||
class TestSdkDefaultCredentialsNotAccessible:
|
||||
"""SDK default credentials (ID ending in '-default') must be hidden."""
|
||||
|
||||
def test_get_sdk_default_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock()
|
||||
resp = client.get("/openai/credentials/openai-default")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.get.assert_not_called()
|
||||
|
||||
def test_list_credentials_excludes_sdk_defaults(self):
|
||||
user_cred = _make_api_key_cred()
|
||||
sdk_cred = _make_sdk_default_cred("openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[user_cred, sdk_cred])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
ids = [c["id"] for c in data]
|
||||
assert "cred-123" in ids
|
||||
assert "openai-default" not in ids
|
||||
|
||||
def test_list_by_provider_excludes_sdk_defaults(self):
|
||||
user_cred = _make_api_key_cred()
|
||||
sdk_cred = _make_sdk_default_cred("openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[user_cred, sdk_cred]
|
||||
)
|
||||
resp = client.get("/openai/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
ids = [c["id"] for c in data]
|
||||
assert "cred-123" in ids
|
||||
assert "openai-default" not in ids
|
||||
|
||||
def test_delete_sdk_default_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock()
|
||||
resp = client.request("DELETE", "/openai/credentials/openai-default")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.store.get_creds_by_id.assert_not_called()
|
||||
|
||||
|
||||
class TestCreateCredentialNoSecretInResponse:
|
||||
"""POST /{provider}/credentials must not return secrets."""
|
||||
|
||||
def test_create_api_key_no_secret_in_response(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.create = AsyncMock()
|
||||
resp = client.post(
|
||||
"/openai/credentials",
|
||||
json={
|
||||
"id": "new-cred",
|
||||
"provider": "openai",
|
||||
"type": "api_key",
|
||||
"title": "New Key",
|
||||
"api_key": "sk-newsecret",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["id"] == "new-cred"
|
||||
assert "api_key" not in data
|
||||
assert "sk-newsecret" not in str(data)
|
||||
|
||||
def test_create_with_sdk_default_id_rejected(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.create = AsyncMock()
|
||||
resp = client.post(
|
||||
"/openai/credentials",
|
||||
json={
|
||||
"id": "openai-default",
|
||||
"provider": "openai",
|
||||
"type": "api_key",
|
||||
"title": "Sneaky",
|
||||
"api_key": "sk-evil",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 403
|
||||
mock_mgr.create.assert_not_called()
|
||||
@@ -1,124 +0,0 @@
|
||||
"""Shared logic for adding store agents to a user's library.
|
||||
|
||||
Both `add_store_agent_to_library` and `add_store_agent_to_library_as_admin`
|
||||
delegate to these helpers so the duplication-prone create/restore/dedup
|
||||
logic lives in exactly one place.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.data.graph as graph_db
|
||||
from backend.data.graph import GraphModel, GraphSettings
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
from .db import get_library_agent_by_graph_id, update_library_agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def resolve_graph_for_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
admin: bool,
|
||||
) -> GraphModel:
|
||||
"""Look up a StoreListingVersion and resolve its graph.
|
||||
|
||||
When ``admin=True``, uses ``get_graph_as_admin`` to bypass the marketplace
|
||||
APPROVED-only check. Otherwise uses the regular ``get_graph``.
|
||||
"""
|
||||
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
if not slv or not slv.AgentGraph:
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
ag = slv.AgentGraph
|
||||
if admin:
|
||||
graph_model = await graph_db.get_graph_as_admin(
|
||||
graph_id=ag.id, version=ag.version, user_id=user_id
|
||||
)
|
||||
else:
|
||||
graph_model = await graph_db.get_graph(
|
||||
graph_id=ag.id, version=ag.version, user_id=user_id
|
||||
)
|
||||
|
||||
if not graph_model:
|
||||
raise NotFoundError(f"Graph #{ag.id} v{ag.version} not found or accessible")
|
||||
return graph_model
|
||||
|
||||
|
||||
async def add_graph_to_library(
|
||||
store_listing_version_id: str,
|
||||
graph_model: GraphModel,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Check existing / restore soft-deleted / create new LibraryAgent."""
|
||||
if existing := await get_library_agent_by_graph_id(
|
||||
user_id, graph_model.id, graph_model.version
|
||||
):
|
||||
return existing
|
||||
|
||||
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_model.id,
|
||||
"agentGraphVersion": graph_model.version,
|
||||
}
|
||||
},
|
||||
)
|
||||
if deleted_agent and (deleted_agent.isDeleted or deleted_agent.isArchived):
|
||||
return await update_library_agent(
|
||||
deleted_agent.id,
|
||||
user_id,
|
||||
is_deleted=False,
|
||||
is_archived=False,
|
||||
)
|
||||
|
||||
try:
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_model.id,
|
||||
"version": graph_model.version,
|
||||
}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(
|
||||
GraphSettings.from_graph(graph_model).model_dump()
|
||||
),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
except prisma.errors.UniqueViolationError:
|
||||
# Race condition: concurrent request created the row between our
|
||||
# check and create. Re-read instead of crashing.
|
||||
existing = await get_library_agent_by_graph_id(
|
||||
user_id, graph_model.id, graph_model.version
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
raise # Shouldn't happen, but don't swallow unexpected errors
|
||||
|
||||
logger.debug(
|
||||
f"Added graph #{graph_model.id} v{graph_model.version} "
|
||||
f"for store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
@@ -1,71 +0,0 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ._add_to_library import add_graph_to_library
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_restores_archived_agent() -> None:
|
||||
graph_model = MagicMock(id="graph-id", version=2)
|
||||
archived_agent = MagicMock(id="library-agent-id", isDeleted=False, isArchived=True)
|
||||
restored_agent = MagicMock(name="LibraryAgentModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.get_library_agent_by_graph_id",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.update_library_agent",
|
||||
new=AsyncMock(return_value=restored_agent),
|
||||
) as mock_update,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=archived_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is restored_agent
|
||||
mock_update.assert_awaited_once_with(
|
||||
"library-agent-id",
|
||||
"user-id",
|
||||
is_deleted=False,
|
||||
is_archived=False,
|
||||
)
|
||||
mock_prisma.return_value.create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_restores_deleted_agent() -> None:
|
||||
graph_model = MagicMock(id="graph-id", version=2)
|
||||
deleted_agent = MagicMock(id="library-agent-id", isDeleted=True, isArchived=False)
|
||||
restored_agent = MagicMock(name="LibraryAgentModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.get_library_agent_by_graph_id",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.update_library_agent",
|
||||
new=AsyncMock(return_value=restored_agent),
|
||||
) as mock_update,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=deleted_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is restored_agent
|
||||
mock_update.assert_awaited_once_with(
|
||||
"library-agent-id",
|
||||
"user-id",
|
||||
is_deleted=False,
|
||||
is_archived=False,
|
||||
)
|
||||
mock_prisma.return_value.create.assert_not_called()
|
||||
@@ -336,15 +336,12 @@ async def get_library_agent_by_graph_id(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
include_archived: bool = False,
|
||||
) -> library_model.LibraryAgent | None:
|
||||
filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"agentGraphId": graph_id,
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
if not include_archived:
|
||||
filter["isArchived"] = False
|
||||
if graph_version is not None:
|
||||
filter["agentGraphVersion"] = graph_version
|
||||
|
||||
@@ -585,9 +582,7 @@ async def update_graph_in_library(
|
||||
|
||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||
|
||||
library_agent = await get_library_agent_by_graph_id(
|
||||
user_id, created_graph.id, include_archived=True
|
||||
)
|
||||
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||
if not library_agent:
|
||||
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||
|
||||
@@ -823,38 +818,92 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
async def add_store_agent_to_library(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Adds a marketplace agent to the user’s library.
|
||||
|
||||
See also: `add_store_agent_to_library_as_admin()` which uses
|
||||
`get_graph_as_admin` to bypass marketplace status checks for admin review.
|
||||
"""
|
||||
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
|
||||
Adds an agent from a store listing version to the user's library if they don't already have it.
|
||||
|
||||
Args:
|
||||
store_listing_version_id: The ID of the store listing version containing the agent.
|
||||
user_id: The user’s library to which the agent is being added.
|
||||
|
||||
Returns:
|
||||
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the store listing or associated agent is not found.
|
||||
DatabaseError: If there's an issue creating the LibraryAgent record.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Adding agent from store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
graph_model = await resolve_graph_for_library(
|
||||
store_listing_version_id, user_id, admin=False
|
||||
)
|
||||
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
|
||||
|
||||
|
||||
async def add_store_agent_to_library_as_admin(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Admin variant that uses `get_graph_as_admin` to bypass marketplace
|
||||
APPROVED-only checks, allowing admins to add pending agents for review."""
|
||||
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
|
||||
|
||||
logger.warning(
|
||||
f"ADMIN adding agent from store listing version "
|
||||
f"#{store_listing_version_id} to library for user #{user_id}"
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
)
|
||||
graph_model = await resolve_graph_for_library(
|
||||
store_listing_version_id, user_id, admin=True
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
graph = store_listing_version.AgentGraph
|
||||
|
||||
# Convert to GraphModel to check for HITL blocks
|
||||
graph_model = await graph_db.get_graph(
|
||||
graph_id=graph.id,
|
||||
version=graph.version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=False,
|
||||
)
|
||||
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
|
||||
if not graph_model:
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph.id} v{graph.version} not found or accessible"
|
||||
)
|
||||
|
||||
# Check if user already has this agent (non-deleted)
|
||||
if existing := await get_library_agent_by_graph_id(
|
||||
user_id, graph.id, graph.version
|
||||
):
|
||||
return existing
|
||||
|
||||
# Check for soft-deleted version and restore it
|
||||
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
}
|
||||
},
|
||||
)
|
||||
if deleted_agent and deleted_agent.isDeleted:
|
||||
return await update_library_agent(deleted_agent.id, user_id, is_deleted=False)
|
||||
|
||||
# Create LibraryAgent entry
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(GraphSettings.from_graph(graph_model).model_dump()),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
f"for store listing version #{store_listing_version.id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -150,13 +150,8 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
# Mock graph_db.get_graph function that's called to check for HITL blocks
|
||||
# (lives in _add_to_library.py after refactor, not db.py)
|
||||
mock_graph_db = mocker.patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
)
|
||||
mock_graph_db = mocker.patch("backend.api.features.library.db.graph_db")
|
||||
mock_graph_model = mocker.Mock()
|
||||
mock_graph_model.id = "agent1"
|
||||
mock_graph_model.version = 1
|
||||
mock_graph_model.nodes = (
|
||||
[]
|
||||
) # Empty list so _has_human_in_the_loop_blocks returns False
|
||||
@@ -229,94 +224,3 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agent_by_graph_id_excludes_archived(mocker):
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
|
||||
result = await db.get_library_agent_by_graph_id("test-user", "agent1", 7)
|
||||
|
||||
assert result is None
|
||||
mock_library_agent.return_value.find_first.assert_called_once()
|
||||
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
|
||||
assert where == {
|
||||
"agentGraphId": "agent1",
|
||||
"userId": "test-user",
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"agentGraphVersion": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agent_by_graph_id_can_include_archived(mocker):
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
|
||||
result = await db.get_library_agent_by_graph_id(
|
||||
"test-user",
|
||||
"agent1",
|
||||
7,
|
||||
include_archived=True,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
mock_library_agent.return_value.find_first.assert_called_once()
|
||||
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
|
||||
assert where == {
|
||||
"agentGraphId": "agent1",
|
||||
"userId": "test-user",
|
||||
"isDeleted": False,
|
||||
"agentGraphVersion": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_graph_in_library_allows_archived_library_agent(mocker):
|
||||
graph = mocker.Mock(id="graph-id")
|
||||
existing_version = mocker.Mock(version=1, is_active=True)
|
||||
graph_model = mocker.Mock()
|
||||
created_graph = mocker.Mock(id="graph-id", version=2, is_active=False)
|
||||
current_library_agent = mocker.Mock()
|
||||
updated_library_agent = mocker.Mock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.get_graph_all_versions",
|
||||
new=mocker.AsyncMock(return_value=[existing_version]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.make_graph_model",
|
||||
return_value=graph_model,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.create_graph",
|
||||
new=mocker.AsyncMock(return_value=created_graph),
|
||||
)
|
||||
mock_get_library_agent = mocker.patch(
|
||||
"backend.api.features.library.db.get_library_agent_by_graph_id",
|
||||
new=mocker.AsyncMock(return_value=current_library_agent),
|
||||
)
|
||||
mock_update_library_agent = mocker.patch(
|
||||
"backend.api.features.library.db.update_library_agent_version_and_settings",
|
||||
new=mocker.AsyncMock(return_value=updated_library_agent),
|
||||
)
|
||||
|
||||
result_graph, result_library_agent = await db.update_graph_in_library(
|
||||
graph,
|
||||
"test-user",
|
||||
)
|
||||
|
||||
assert result_graph is created_graph
|
||||
assert result_library_agent is updated_library_agent
|
||||
assert graph.version == 2
|
||||
graph_model.reassign_ids.assert_called_once_with(
|
||||
user_id="test-user", reassign_graph_id=False
|
||||
)
|
||||
mock_get_library_agent.assert_awaited_once_with(
|
||||
"test-user",
|
||||
"graph-id",
|
||||
include_archived=True,
|
||||
)
|
||||
mock_update_library_agent.assert_awaited_once_with("test-user", created_graph)
|
||||
|
||||
@@ -391,11 +391,6 @@ async def get_available_graph(
|
||||
async def get_store_agent_by_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get agent details from the StoreAgent view (APPROVED agents only).
|
||||
|
||||
See also: `get_store_agent_details_as_admin()` which bypasses the
|
||||
APPROVED-only StoreAgent view for admin preview of pending submissions.
|
||||
"""
|
||||
logger.debug(f"Getting store agent details for {store_listing_version_id}")
|
||||
|
||||
try:
|
||||
@@ -416,57 +411,6 @@ async def get_store_agent_by_version_id(
|
||||
raise DatabaseError("Failed to fetch agent details") from e
|
||||
|
||||
|
||||
async def get_store_agent_details_as_admin(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get agent details for admin preview, bypassing the APPROVED-only
|
||||
StoreAgent view. Queries StoreListingVersion directly so pending
|
||||
submissions are visible."""
|
||||
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id},
|
||||
include={
|
||||
"StoreListing": {"include": {"CreatorProfile": True}},
|
||||
},
|
||||
)
|
||||
if not slv or not slv.StoreListing:
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found"
|
||||
)
|
||||
|
||||
listing = slv.StoreListing
|
||||
# CreatorProfile is a required FK relation — should always exist.
|
||||
# If it's None, the DB is in a bad state.
|
||||
profile = listing.CreatorProfile
|
||||
if not profile:
|
||||
raise DatabaseError(
|
||||
f"StoreListing {listing.id} has no CreatorProfile — FK violated"
|
||||
)
|
||||
|
||||
return store_model.StoreAgentDetails(
|
||||
store_listing_version_id=slv.id,
|
||||
slug=listing.slug,
|
||||
agent_name=slv.name,
|
||||
agent_video=slv.videoUrl or "",
|
||||
agent_output_demo=slv.agentOutputDemoUrl or "",
|
||||
agent_image=slv.imageUrls,
|
||||
creator=profile.username,
|
||||
creator_avatar=profile.avatarUrl or "",
|
||||
sub_heading=slv.subHeading,
|
||||
description=slv.description,
|
||||
instructions=slv.instructions,
|
||||
categories=slv.categories,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
versions=[str(slv.version)],
|
||||
graph_id=slv.agentGraphId,
|
||||
graph_versions=[str(slv.agentGraphVersion)],
|
||||
last_updated=slv.updatedAt,
|
||||
recommended_schedule_cron=slv.recommendedScheduleCron,
|
||||
active_version_id=listing.activeVersionId or slv.id,
|
||||
has_approved_version=listing.hasApprovedVersion,
|
||||
)
|
||||
|
||||
|
||||
class StoreCreatorsSortOptions(Enum):
|
||||
# NOTE: values correspond 1:1 to columns of the Creator view
|
||||
AGENT_RATING = "agent_rating"
|
||||
|
||||
@@ -592,11 +592,6 @@ async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
|
||||
async def configure_user_auto_top_up(
|
||||
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> str:
|
||||
"""Configure auto top-up settings and perform an immediate top-up if needed.
|
||||
|
||||
Raises HTTPException(422) if the request parameters are invalid or if
|
||||
the credit top-up fails.
|
||||
"""
|
||||
if request.threshold < 0:
|
||||
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
|
||||
if request.amount < 500 and request.amount != 0:
|
||||
@@ -611,20 +606,10 @@ async def configure_user_auto_top_up(
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
|
||||
try:
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
except ValueError as e:
|
||||
known_messages = (
|
||||
"must not be negative",
|
||||
"already exists for user",
|
||||
"No payment method found",
|
||||
)
|
||||
if any(msg in str(e) for msg in known_messages):
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
raise
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
|
||||
await set_auto_top_up(
|
||||
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
|
||||
@@ -980,16 +965,14 @@ async def execute_graph(
|
||||
source: Annotated[GraphExecutionSource | None, Body(embed=True)] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
dry_run: Annotated[bool, Body(embed=True)] = False,
|
||||
) -> execution_db.GraphExecutionMeta:
|
||||
if not dry_run:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await execution_utils.add_graph_execution(
|
||||
@@ -999,7 +982,6 @@ async def execute_graph(
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
|
||||
@@ -188,7 +188,6 @@ async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file: UploadFile,
|
||||
session_id: str | None = Query(default=None),
|
||||
overwrite: bool = Query(default=False),
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to the user's workspace.
|
||||
@@ -249,9 +248,7 @@ async def upload_file(
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(
|
||||
content, filename, overwrite=overwrite
|
||||
)
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
|
||||
|
||||
@@ -210,22 +210,13 @@ instrument_fastapi(
|
||||
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: fastapi.Request, exc: Exception):
|
||||
if log_error:
|
||||
if status_code >= 500:
|
||||
logger.exception(
|
||||
"%s %s failed. Investigate and resolve the underlying issue: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s %s failed with %d: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
status_code,
|
||||
exc,
|
||||
)
|
||||
logger.exception(
|
||||
"%s %s failed. Investigate and resolve the underlying issue: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
|
||||
hint = (
|
||||
"Adjust the request and retry."
|
||||
@@ -275,10 +266,12 @@ async def validation_error_handler(
|
||||
|
||||
|
||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||
app.add_exception_handler(FolderAlreadyExistsError, handle_internal_http_error(409))
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403))
|
||||
app.add_exception_handler(
|
||||
FolderAlreadyExistsError, handle_internal_http_error(409, False)
|
||||
)
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||
@@ -528,11 +521,8 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
):
|
||||
from backend.api.features.integrations.router import (
|
||||
create_credentials,
|
||||
get_credential,
|
||||
)
|
||||
) -> Credentials:
|
||||
from .features.integrations.router import create_credentials, get_credential
|
||||
|
||||
try:
|
||||
return await create_credentials(
|
||||
|
||||
@@ -15,12 +15,6 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.copilot.permissions import (
|
||||
CopilotPermissions,
|
||||
ToolName,
|
||||
all_known_tool_names,
|
||||
validate_block_identifiers,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -102,50 +96,6 @@ class AutoPilotBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools: list[ToolName] = SchemaField(
|
||||
description=(
|
||||
"Tool names to filter. Works with tools_exclude to form an "
|
||||
"allow-list or deny-list. "
|
||||
"Leave empty to apply no tool filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'tools' list is interpreted. "
|
||||
"True (default): 'tools' is a deny-list — listed tools are blocked, "
|
||||
"all others are allowed. An empty 'tools' list means allow everything. "
|
||||
"False: 'tools' is an allow-list — only listed tools are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks: list[str] = SchemaField(
|
||||
description=(
|
||||
"Block identifiers to filter when the copilot uses run_block. "
|
||||
"Each entry can be: a block name (e.g. 'HTTP Request'), "
|
||||
"a full block UUID, or the first 8 hex characters of the UUID "
|
||||
"(e.g. 'c069dc6b'). Works with blocks_exclude. "
|
||||
"Leave empty to apply no block filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'blocks' list is interpreted. "
|
||||
"True (default): 'blocks' is a deny-list — listed blocks are blocked, "
|
||||
"all others are allowed. An empty 'blocks' list means allow everything. "
|
||||
"False: 'blocks' is an allow-list — only listed blocks are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
@@ -234,7 +184,7 @@ class AutoPilotBlock(Block):
|
||||
|
||||
async def create_session(self, user_id: str) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session # avoid circular import
|
||||
from backend.copilot.model import create_chat_session
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
return session.session_id
|
||||
@@ -246,7 +196,6 @@ class AutoPilotBlock(Block):
|
||||
session_id: str,
|
||||
max_recursion_depth: int,
|
||||
user_id: str,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
||||
"""Invoke the copilot and collect all stream results.
|
||||
|
||||
@@ -260,21 +209,14 @@ class AutoPilotBlock(Block):
|
||||
session_id: Chat session to use.
|
||||
max_recursion_depth: Maximum allowed recursion nesting.
|
||||
user_id: Authenticated user ID.
|
||||
permissions: Optional capability filter restricting tools/blocks.
|
||||
|
||||
Returns:
|
||||
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
||||
"""
|
||||
from backend.copilot.sdk.collect import (
|
||||
collect_copilot_response, # avoid circular import
|
||||
)
|
||||
from backend.copilot.sdk.collect import collect_copilot_response
|
||||
|
||||
tokens = _check_recursion(max_recursion_depth)
|
||||
perm_token = None
|
||||
try:
|
||||
effective_permissions, perm_token = _merge_inherited_permissions(
|
||||
permissions
|
||||
)
|
||||
effective_prompt = prompt
|
||||
if system_context:
|
||||
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
||||
@@ -283,7 +225,6 @@ class AutoPilotBlock(Block):
|
||||
session_id=session_id,
|
||||
message=effective_prompt,
|
||||
user_id=user_id,
|
||||
permissions=effective_permissions,
|
||||
)
|
||||
|
||||
# Build a lightweight conversation summary from streamed data.
|
||||
@@ -330,8 +271,6 @@ class AutoPilotBlock(Block):
|
||||
)
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
if perm_token is not None:
|
||||
_inherited_permissions.reset(perm_token)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -356,13 +295,6 @@ class AutoPilotBlock(Block):
|
||||
yield "error", "max_recursion_depth must be at least 1."
|
||||
return
|
||||
|
||||
# Validate and build permissions eagerly — fail before creating a session.
|
||||
permissions = await _build_and_validate_permissions(input_data)
|
||||
if isinstance(permissions, str):
|
||||
# Validation error returned as a string message.
|
||||
yield "error", permissions
|
||||
return
|
||||
|
||||
# Create session eagerly so the user always gets the session_id,
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
@@ -380,7 +312,6 @@ class AutoPilotBlock(Block):
|
||||
session_id=sid,
|
||||
max_recursion_depth=input_data.max_recursion_depth,
|
||||
user_id=execution_context.user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
@@ -443,78 +374,3 @@ def _reset_recursion(
|
||||
"""Restore recursion depth and limit to their previous values."""
|
||||
_autopilot_recursion_depth.reset(tokens[0])
|
||||
_autopilot_recursion_limit.reset(tokens[1])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Permission helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Inherited permissions from a parent AutoPilotBlock execution.
|
||||
# This acts as a ceiling: child executions can only be more restrictive.
|
||||
_inherited_permissions: contextvars.ContextVar["CopilotPermissions | None"] = (
|
||||
contextvars.ContextVar("_inherited_permissions", default=None)
|
||||
)
|
||||
|
||||
|
||||
async def _build_and_validate_permissions(
|
||||
input_data: "AutoPilotBlock.Input",
|
||||
) -> "CopilotPermissions | str":
|
||||
"""Build a :class:`CopilotPermissions` from block input and validate it.
|
||||
|
||||
Returns a :class:`CopilotPermissions` on success or a human-readable
|
||||
error string if validation fails.
|
||||
"""
|
||||
# Tool names are validated by Pydantic via the ToolName Literal type
|
||||
# at model construction time — no runtime check needed here.
|
||||
# Validate block identifiers against live block registry.
|
||||
if input_data.blocks:
|
||||
invalid_blocks = await validate_block_identifiers(input_data.blocks)
|
||||
if invalid_blocks:
|
||||
return (
|
||||
f"Unknown block identifier(s) in 'blocks': {invalid_blocks}. "
|
||||
"Use find_block to discover valid block names and IDs. "
|
||||
"You may also use the first 8 characters of a block UUID."
|
||||
)
|
||||
|
||||
return CopilotPermissions(
|
||||
tools=list(input_data.tools),
|
||||
tools_exclude=input_data.tools_exclude,
|
||||
blocks=input_data.blocks,
|
||||
blocks_exclude=input_data.blocks_exclude,
|
||||
)
|
||||
|
||||
|
||||
def _merge_inherited_permissions(
|
||||
permissions: "CopilotPermissions | None",
|
||||
) -> "tuple[CopilotPermissions | None, contextvars.Token[CopilotPermissions | None] | None]":
|
||||
"""Merge *permissions* with any inherited parent permissions.
|
||||
|
||||
The merged result is stored back into the contextvar so that any nested
|
||||
AutoPilotBlock invocation (sub-agent) inherits the merged ceiling.
|
||||
|
||||
Returns a tuple of (merged_permissions, reset_token). The caller MUST
|
||||
reset the contextvar via ``_inherited_permissions.reset(token)`` in a
|
||||
``finally`` block when ``reset_token`` is not None — this prevents
|
||||
permission leakage between sequential independent executions in the same
|
||||
asyncio task.
|
||||
"""
|
||||
parent = _inherited_permissions.get()
|
||||
|
||||
if permissions is None and parent is None:
|
||||
return None, None
|
||||
|
||||
all_tools = all_known_tool_names()
|
||||
|
||||
if permissions is None:
|
||||
permissions = CopilotPermissions() # allow-all; will be narrowed by parent
|
||||
|
||||
merged = (
|
||||
permissions.merged_with_parent(parent, all_tools)
|
||||
if parent is not None
|
||||
else permissions
|
||||
)
|
||||
|
||||
# Store merged permissions as the new inherited ceiling for nested calls.
|
||||
# Return the token so the caller can restore the previous value in finally.
|
||||
token = _inherited_permissions.set(merged)
|
||||
return merged, token
|
||||
|
||||
@@ -1,265 +0,0 @@
|
||||
"""Tests for AutoPilotBlock permission fields and validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AutoPilotBlock,
|
||||
_build_and_validate_permissions,
|
||||
_inherited_permissions,
|
||||
_merge_inherited_permissions,
|
||||
)
|
||||
from backend.copilot.permissions import CopilotPermissions, all_known_tool_names
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input(**kwargs) -> AutoPilotBlock.Input:
|
||||
defaults = {
|
||||
"prompt": "Do something",
|
||||
"system_context": "",
|
||||
"session_id": "",
|
||||
"max_recursion_depth": 3,
|
||||
"tools": [],
|
||||
"tools_exclude": True,
|
||||
"blocks": [],
|
||||
"blocks_exclude": True,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return AutoPilotBlock.Input(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_and_validate_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestBuildAndValidatePermissions:
|
||||
async def test_empty_inputs_returns_empty_permissions(self):
|
||||
inp = _make_input()
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.is_empty()
|
||||
|
||||
async def test_valid_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["run_block", "web_fetch"], tools_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.tools == ["run_block", "web_fetch"]
|
||||
assert result.tools_exclude is True
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are now caught at Pydantic validation time
|
||||
(Literal type), before ``_build_and_validate_permissions`` is called."""
|
||||
with pytest.raises(ValidationError, match="not_a_real_tool"):
|
||||
_make_input(tools=["not_a_real_tool"])
|
||||
|
||||
async def test_valid_block_name_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.blocks == ["HTTP Request"]
|
||||
|
||||
async def test_valid_partial_uuid_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["c069dc6b"], blocks_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
|
||||
async def test_invalid_block_identifier_returns_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["totally_fake_block"])
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, str)
|
||||
assert "totally_fake_block" in result
|
||||
assert "Unknown block identifier" in result
|
||||
|
||||
async def test_sdk_builtin_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["Read", "Task", "WebSearch"], tools_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert not result.tools_exclude
|
||||
|
||||
async def test_empty_blocks_skips_validation(self):
|
||||
# Should not call validate_block_identifiers at all when blocks=[].
|
||||
with patch(
|
||||
"backend.copilot.permissions.validate_block_identifiers"
|
||||
) as mock_validate:
|
||||
inp = _make_input(blocks=[])
|
||||
await _build_and_validate_permissions(inp)
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _merge_inherited_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeInheritedPermissions:
|
||||
def test_no_permissions_no_parent_returns_none(self):
|
||||
merged, token = _merge_inherited_permissions(None)
|
||||
assert merged is None
|
||||
assert token is None
|
||||
|
||||
def test_permissions_no_parent_returned_unchanged(self):
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
merged, token = _merge_inherited_permissions(perms)
|
||||
try:
|
||||
assert merged is perms
|
||||
assert token is not None
|
||||
finally:
|
||||
if token is not None:
|
||||
_inherited_permissions.reset(token)
|
||||
|
||||
def test_child_narrows_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
# Set parent as inherited
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
child = CopilotPermissions(tools=["web_fetch"], tools_exclude=True)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
all_t = all_known_tool_names()
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert "bash_exec" not in effective
|
||||
assert "web_fetch" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_none_permissions_with_parent_uses_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
merged, inner_token = _merge_inherited_permissions(None)
|
||||
try:
|
||||
assert merged is not None
|
||||
# Merged should have parent's restrictions
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_child_cannot_expand_parent_whitelist(self):
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
# Child tries to allow more tools
|
||||
child = CopilotPermissions(
|
||||
tools=["run_block", "bash_exec"], tools_exclude=False
|
||||
)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
assert "run_block" in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AutoPilotBlock.run — validation integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAutoPilotBlockRunPermissions:
|
||||
async def _collect_outputs(self, block, input_data, user_id="test-user"):
|
||||
"""Helper to collect all yields from block.run()."""
|
||||
ctx = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id="g1",
|
||||
graph_exec_id="ge1",
|
||||
node_exec_id="ne1",
|
||||
node_id="n1",
|
||||
)
|
||||
outputs = {}
|
||||
async for key, val in block.run(input_data, execution_context=ctx):
|
||||
outputs[key] = val
|
||||
return outputs
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are caught at Pydantic validation (Literal type)."""
|
||||
with pytest.raises(ValidationError, match="not_a_tool"):
|
||||
_make_input(tools=["not_a_tool"])
|
||||
|
||||
async def test_invalid_block_yields_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(blocks=["nonexistent_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "nonexistent_block" in outputs["error"]
|
||||
|
||||
async def test_empty_prompt_yields_error_before_permission_check(self):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(prompt=" ", tools=["run_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "Prompt cannot be empty" in outputs["error"]
|
||||
|
||||
async def test_valid_permissions_passed_to_execute(self):
|
||||
"""Permissions are forwarded to execute_copilot when valid."""
|
||||
block = AutoPilotBlock()
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_execute_copilot(self_inner, **kwargs):
|
||||
captured["permissions"] = kwargs.get("permissions")
|
||||
return (
|
||||
"ok",
|
||||
[],
|
||||
'[{"role":"user","content":"hi"}]',
|
||||
"test-sid",
|
||||
{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
AutoPilotBlock, "create_session", new=AsyncMock(return_value="test-sid")
|
||||
), patch.object(AutoPilotBlock, "execute_copilot", new=fake_execute_copilot):
|
||||
inp = _make_input(tools=["run_block"], tools_exclude=False)
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
|
||||
assert "error" not in outputs
|
||||
perms = captured.get("permissions")
|
||||
assert isinstance(perms, CopilotPermissions)
|
||||
assert perms.tools == ["run_block"]
|
||||
assert perms.tools_exclude is False
|
||||
@@ -49,9 +49,6 @@ settings = Settings()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
# HTTP status codes for user-caused errors that should not be reported to Sentry.
|
||||
USER_ERROR_STATUS_CODES = (401, 403, 429)
|
||||
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.AIML_API,
|
||||
ProviderName.ANTHROPIC,
|
||||
@@ -799,19 +796,6 @@ async def llm_call(
|
||||
)
|
||||
prompt = result.messages
|
||||
|
||||
# Sanitize unpaired surrogates in message content to prevent
|
||||
# UnicodeEncodeError when httpx encodes the JSON request body.
|
||||
for msg in prompt:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content.encode("utf-8")
|
||||
except UnicodeEncodeError:
|
||||
logger.warning("Sanitized unpaired surrogates in LLM prompt content")
|
||||
msg["content"] = content.encode("utf-8", errors="surrogatepass").decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
@@ -894,60 +878,65 @@ async def llm_call(
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=credentials.api_key.get_secret_value()
|
||||
)
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
elif provider == "groq":
|
||||
if tools:
|
||||
raise ValueError("Groq does not support tools.")
|
||||
@@ -1460,16 +1449,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
except Exception as e:
|
||||
is_user_error = (
|
||||
isinstance(e, (anthropic.APIStatusError, openai.APIStatusError))
|
||||
and e.status_code in USER_ERROR_STATUS_CODES
|
||||
)
|
||||
if is_user_error:
|
||||
logger.warning(f"Error calling LLM: {e}")
|
||||
error_feedback_message = f"Error calling LLM: {e}"
|
||||
break
|
||||
else:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
"maximum context length" in str(e).lower()
|
||||
or "token limit" in str(e).lower()
|
||||
|
||||
@@ -258,10 +258,9 @@ def get_pending_tool_calls(conversation_history: list[Any] | None) -> dict[str,
|
||||
return {call_id: count for call_id, count in pending_calls.items() if count > 0}
|
||||
|
||||
|
||||
class OrchestratorBlock(Block):
|
||||
class SmartDecisionMakerBlock(Block):
|
||||
"""
|
||||
A block that uses a language model to orchestrate tool calls, supporting both
|
||||
single-shot and iterative agent mode execution.
|
||||
A block that uses a language model to make smart decisions based on a given prompt.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -402,8 +401,8 @@ class OrchestratorBlock(Block):
|
||||
description="Uses AI to intelligently decide what tool to use.",
|
||||
categories={BlockCategory.AI},
|
||||
block_type=BlockType.AI,
|
||||
input_schema=OrchestratorBlock.Input,
|
||||
output_schema=OrchestratorBlock.Output,
|
||||
input_schema=SmartDecisionMakerBlock.Input,
|
||||
output_schema=SmartDecisionMakerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
@@ -441,7 +440,7 @@ class OrchestratorBlock(Block):
|
||||
tool_name = custom_name if custom_name else block.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": OrchestratorBlock.cleanup(tool_name),
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"description": block.description,
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
@@ -452,7 +451,7 @@ class OrchestratorBlock(Block):
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Clean property key to ensure Anthropic API compatibility for ALL fields
|
||||
clean_field_name = OrchestratorBlock.cleanup(field_name)
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
|
||||
if is_dynamic:
|
||||
@@ -486,7 +485,7 @@ class OrchestratorBlock(Block):
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Always use cleaned field name for property key (Anthropic API compliance)
|
||||
clean_field_name = OrchestratorBlock.cleanup(field_name)
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
|
||||
if is_dynamic:
|
||||
base_name = extract_base_field_name(field_name)
|
||||
@@ -543,7 +542,7 @@ class OrchestratorBlock(Block):
|
||||
tool_name = custom_name if custom_name else sink_graph_meta.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": OrchestratorBlock.cleanup(tool_name),
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
@@ -553,7 +552,7 @@ class OrchestratorBlock(Block):
|
||||
for link in links:
|
||||
field_name = link.sink_name
|
||||
|
||||
clean_field_name = OrchestratorBlock.cleanup(field_name)
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
|
||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||
@@ -619,13 +618,17 @@ class OrchestratorBlock(Block):
|
||||
raise ValueError(f"Sink node not found: {links[0].sink_id}")
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
tool_func = await OrchestratorBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
tool_func = (
|
||||
await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
else:
|
||||
tool_func = await OrchestratorBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
tool_func = (
|
||||
await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
|
||||
@@ -905,7 +908,7 @@ class OrchestratorBlock(Block):
|
||||
task=node_exec_future,
|
||||
)
|
||||
|
||||
# Execute the node directly since we're in the Orchestrator context
|
||||
# Execute the node directly since we're in the SmartDecisionMaker context
|
||||
node_exec_future.set_result(
|
||||
await execution_processor.on_node_execution(
|
||||
node_exec=node_exec_entry,
|
||||
@@ -931,7 +934,7 @@ class OrchestratorBlock(Block):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Tool execution with manager failed: {e}")
|
||||
logger.error(f"Tool execution with manager failed: {e}")
|
||||
# Return error response
|
||||
return _create_tool_response(
|
||||
tool_call.id,
|
||||
@@ -1109,7 +1112,7 @@ class OrchestratorBlock(Block):
|
||||
return
|
||||
elif input_data.last_tool_output:
|
||||
logger.error(
|
||||
f"[OrchestratorBlock-node_exec_id={node_exec_id}] "
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
f"No pending tool calls found. This may indicate an issue with the "
|
||||
f"conversation history, or the tool giving response more than once."
|
||||
f"This should not happen! Please check the conversation history for any inconsistencies."
|
||||
@@ -1246,7 +1249,7 @@ class OrchestratorBlock(Block):
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{original_field_name}"
|
||||
|
||||
logger.debug(
|
||||
"[OrchestratorBlock|geid:%s|neid:%s] emit %s",
|
||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||
graph_exec_id,
|
||||
node_exec_id,
|
||||
emit_key,
|
||||
@@ -1,8 +1,13 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
from stagehand import AsyncStagehand
|
||||
from stagehand.types.session_act_params import Options as ActOptions
|
||||
# Monkey patch Stagehands to prevent signal handling in worker threads
|
||||
import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
@@ -23,6 +28,46 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
|
||||
warnings.filterwarnings("ignore", module="litellm.llms.custom_httpx")
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
def safe_register_signal_handlers(self):
|
||||
"""Only register signal handlers in the main thread"""
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
original_register_signal_handlers(self)
|
||||
else:
|
||||
# Skip signal handling in worker threads
|
||||
pass
|
||||
|
||||
|
||||
# Replace the method
|
||||
stagehand.main.Stagehand._register_signal_handlers = safe_register_signal_handlers
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_signal_handling():
|
||||
"""Context manager to temporarily disable signal handling"""
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# In worker threads, temporarily replace signal.signal with a no-op
|
||||
original_signal = signal.signal
|
||||
|
||||
def noop_signal(*args, **kwargs):
|
||||
pass
|
||||
|
||||
signal.signal = noop_signal
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal = original_signal
|
||||
else:
|
||||
# In main thread, don't modify anything
|
||||
yield
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -103,10 +148,13 @@ class StagehandObserveBlock(Block):
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
@@ -137,28 +185,32 @@ class StagehandObserveBlock(Block):
|
||||
|
||||
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
|
||||
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
model_name=input_data.model.provider_name,
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
observe_response = await session.observe(
|
||||
instruction=input_data.instruction,
|
||||
)
|
||||
for result in observe_response.data.result:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
finally:
|
||||
await session.end()
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
|
||||
observe_results = await page.observe(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
for result in observe_results:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
|
||||
|
||||
class StagehandActBlock(Block):
|
||||
@@ -190,22 +242,24 @@ class StagehandActBlock(Block):
|
||||
description="Variables to use in the action. Variables contains data you want the action to use.",
|
||||
default_factory=dict,
|
||||
)
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms for each action.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
timeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM ready. Extended timeout for slow-loading forms",
|
||||
default=60000,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the action was completed successfully"
|
||||
)
|
||||
message: str = SchemaField(description="Details about the action's execution.")
|
||||
message: str = SchemaField(description="Details about the action’s execution.")
|
||||
action: str = SchemaField(description="Action performed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -228,33 +282,32 @@ class StagehandActBlock(Block):
|
||||
|
||||
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
model_name=input_data.model.provider_name,
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
for action in input_data.action:
|
||||
act_options = ActOptions(
|
||||
variables={k: v for k, v in input_data.variables.items()},
|
||||
timeout=input_data.timeout_ms,
|
||||
)
|
||||
act_response = await session.act(
|
||||
input=action,
|
||||
options=act_options,
|
||||
)
|
||||
result = act_response.data.result
|
||||
yield "success", result.success
|
||||
yield "message", result.message
|
||||
yield "action", result.action_description
|
||||
finally:
|
||||
await session.end()
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
for action in input_data.action:
|
||||
action_results = await page.act(
|
||||
action,
|
||||
variables=input_data.variables,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
timeoutMs=input_data.timeoutMs,
|
||||
)
|
||||
yield "success", action_results.success
|
||||
yield "message", action_results.message
|
||||
yield "action", action_results.action
|
||||
|
||||
|
||||
class StagehandExtractBlock(Block):
|
||||
@@ -282,10 +335,13 @@ class StagehandExtractBlock(Block):
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
@@ -311,21 +367,24 @@ class StagehandExtractBlock(Block):
|
||||
|
||||
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
model_name=input_data.model.provider_name,
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
extract_response = await session.extract(
|
||||
instruction=input_data.instruction,
|
||||
)
|
||||
yield "extraction", str(extract_response.data.result)
|
||||
finally:
|
||||
await session.end()
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
extraction = await page.extract(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
yield "extraction", str(extraction.model_dump()["extraction"])
|
||||
|
||||
@@ -1,18 +1,9 @@
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import anthropic
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
# TEST_CREDENTIALS_INPUT is a plain dict that satisfies AICredentials at runtime
|
||||
# but not at the type level. Cast once here to avoid per-test suppressors.
|
||||
_TEST_AI_CREDENTIALS = cast(llm.AICredentials, llm.TEST_CREDENTIALS_INPUT)
|
||||
|
||||
|
||||
class TestLLMStatsTracking:
|
||||
"""Test that LLM blocks correctly track token usage statistics."""
|
||||
@@ -664,148 +655,3 @@ class TestAITextSummarizerValidation:
|
||||
error_message = str(exc_info.value)
|
||||
assert "Expected a string summary" in error_message
|
||||
assert "received dict" in error_message
|
||||
|
||||
|
||||
def _make_anthropic_status_error(status_code: int) -> anthropic.APIStatusError:
|
||||
"""Create an anthropic.APIStatusError with the given status code."""
|
||||
request = httpx.Request("POST", "https://api.anthropic.com/v1/messages")
|
||||
response = httpx.Response(status_code, request=request)
|
||||
return anthropic.APIStatusError(
|
||||
f"Error code: {status_code}", response=response, body=None
|
||||
)
|
||||
|
||||
|
||||
def _make_openai_status_error(status_code: int) -> openai.APIStatusError:
|
||||
"""Create an openai.APIStatusError with the given status code."""
|
||||
response = httpx.Response(
|
||||
status_code, request=httpx.Request("POST", "https://api.openai.com/v1/chat")
|
||||
)
|
||||
return openai.APIStatusError(
|
||||
f"Error code: {status_code}", response=response, body=None
|
||||
)
|
||||
|
||||
|
||||
class TestUserErrorStatusCodeHandling:
|
||||
"""Test that user-caused LLM API errors (401/403/429) break the retry loop
|
||||
and are logged as warnings, while server errors (500) trigger retries."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("status_code", [401, 403, 429])
|
||||
async def test_anthropic_user_error_breaks_retry_loop(self, status_code: int):
|
||||
"""401/403/429 Anthropic errors should break immediately, not retry."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_anthropic_status_error(status_code)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count == 1
|
||||
), f"Expected exactly 1 call for status {status_code}, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("status_code", [401, 403, 429])
|
||||
async def test_openai_user_error_breaks_retry_loop(self, status_code: int):
|
||||
"""401/403/429 OpenAI errors should break immediately, not retry."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_openai_status_error(status_code)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count == 1
|
||||
), f"Expected exactly 1 call for status {status_code}, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_error_retries(self):
|
||||
"""500 errors should be retried (not break immediately)."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_anthropic_status_error(500)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count > 1
|
||||
), f"Expected multiple retry attempts for 500, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_error_logs_warning_not_exception(self):
|
||||
"""User-caused errors should log with logger.warning, not logger.exception."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
raise _make_anthropic_status_error(401)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(llm.logger, "warning") as mock_warning,
|
||||
patch.object(llm.logger, "exception") as mock_exception,
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
mock_warning.assert_called_once()
|
||||
mock_exception.assert_not_called()
|
||||
|
||||
@@ -57,7 +57,7 @@ async def execute_graph(
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data import graph
|
||||
|
||||
test_user = await create_test_user()
|
||||
@@ -66,7 +66,7 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=OrchestratorBlock().id,
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": creds,
|
||||
@@ -108,10 +108,10 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_orchestrator_function_signature(server: SpinTestServer):
|
||||
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data import graph
|
||||
|
||||
test_user = await create_test_user()
|
||||
@@ -120,7 +120,7 @@ async def test_orchestrator_function_signature(server: SpinTestServer):
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=OrchestratorBlock().id,
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": creds,
|
||||
@@ -169,7 +169,7 @@ async def test_orchestrator_function_signature(server: SpinTestServer):
|
||||
)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
tool_functions = await OrchestratorBlock._create_tool_node_signatures(
|
||||
tool_functions = await SmartDecisionMakerBlock._create_tool_node_signatures(
|
||||
test_graph.nodes[0].id
|
||||
)
|
||||
assert tool_functions is not None, "Tool functions should not be None"
|
||||
@@ -198,12 +198,12 @@ async def test_orchestrator_function_signature(server: SpinTestServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_tracks_llm_stats():
|
||||
"""Test that OrchestratorBlock correctly tracks LLM usage stats."""
|
||||
async def test_smart_decision_maker_tracks_llm_stats():
|
||||
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock the llm.llm_call function to return controlled data
|
||||
mock_response = MagicMock()
|
||||
@@ -224,14 +224,14 @@ async def test_orchestrator_tracks_llm_stats():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
OrchestratorBlock,
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
|
||||
# Create test input
|
||||
input_data = OrchestratorBlock.Input(
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Should I continue with this task?",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -274,12 +274,12 @@ async def test_orchestrator_tracks_llm_stats():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_parameter_validation():
|
||||
"""Test that OrchestratorBlock correctly validates tool call parameters."""
|
||||
async def test_smart_decision_maker_parameter_validation():
|
||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool functions with specific parameter schema
|
||||
mock_tool_functions = [
|
||||
@@ -327,13 +327,13 @@ async def test_orchestrator_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_with_typo,
|
||||
) as mock_llm_call, patch.object(
|
||||
OrchestratorBlock,
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = OrchestratorBlock.Input(
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -394,13 +394,13 @@ async def test_orchestrator_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_missing_required,
|
||||
), patch.object(
|
||||
OrchestratorBlock,
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = OrchestratorBlock.Input(
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -454,13 +454,13 @@ async def test_orchestrator_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_valid,
|
||||
), patch.object(
|
||||
OrchestratorBlock,
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = OrchestratorBlock.Input(
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -518,13 +518,13 @@ async def test_orchestrator_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_all_params,
|
||||
), patch.object(
|
||||
OrchestratorBlock,
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = OrchestratorBlock.Input(
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -562,12 +562,12 @@ async def test_orchestrator_parameter_validation():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_raw_response_conversion():
|
||||
"""Test that Orchestrator correctly handles different raw_response types with retry mechanism."""
|
||||
async def test_smart_decision_maker_raw_response_conversion():
|
||||
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool functions
|
||||
mock_tool_functions = [
|
||||
@@ -637,7 +637,7 @@ async def test_orchestrator_raw_response_conversion():
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm_call, patch.object(
|
||||
OrchestratorBlock,
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
@@ -646,7 +646,7 @@ async def test_orchestrator_raw_response_conversion():
|
||||
# Second call returns successful response
|
||||
mock_llm_call.side_effect = [mock_response_retry, mock_response_success]
|
||||
|
||||
input_data = OrchestratorBlock.Input(
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -715,12 +715,12 @@ async def test_orchestrator_raw_response_conversion():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_ollama,
|
||||
), patch.object(
|
||||
OrchestratorBlock,
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[], # No tools for this test
|
||||
):
|
||||
input_data = OrchestratorBlock.Input(
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -771,12 +771,12 @@ async def test_orchestrator_raw_response_conversion():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_dict,
|
||||
), patch.object(
|
||||
OrchestratorBlock,
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
input_data = OrchestratorBlock.Input(
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -811,12 +811,12 @@ async def test_orchestrator_raw_response_conversion():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_agent_mode():
|
||||
async def test_smart_decision_maker_agent_mode():
|
||||
"""Test that agent mode executes tools directly and loops until finished."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool call that requires multiple iterations
|
||||
mock_tool_call_1 = MagicMock()
|
||||
@@ -893,7 +893,7 @@ async def test_orchestrator_agent_mode():
|
||||
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||
), patch(
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
), patch(
|
||||
"backend.executor.manager.async_update_node_execution_status",
|
||||
@@ -929,7 +929,7 @@ async def test_orchestrator_agent_mode():
|
||||
}
|
||||
|
||||
# Test agent mode with max_iterations = 3
|
||||
input_data = OrchestratorBlock.Input(
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Complete this task using tools",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -969,12 +969,12 @@ async def test_orchestrator_agent_mode():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_traditional_mode_default():
|
||||
async def test_smart_decision_maker_traditional_mode_default():
|
||||
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool call
|
||||
mock_tool_call = MagicMock()
|
||||
@@ -1018,7 +1018,7 @@ async def test_orchestrator_traditional_mode_default():
|
||||
):
|
||||
|
||||
# Test default behavior (traditional mode)
|
||||
input_data = OrchestratorBlock.Input(
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -1060,12 +1060,12 @@ async def test_orchestrator_traditional_mode_default():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_uses_customized_name_for_blocks():
|
||||
"""Test that OrchestratorBlock uses customized_name from node metadata for tool names."""
|
||||
async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
@@ -1080,7 +1080,7 @@ async def test_orchestrator_uses_customized_name_for_blocks():
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await OrchestratorBlock._create_block_function_signature(
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1091,12 +1091,12 @@ async def test_orchestrator_uses_customized_name_for_blocks():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_falls_back_to_block_name():
|
||||
"""Test that OrchestratorBlock falls back to block.name when no customized_name."""
|
||||
async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
@@ -1111,7 +1111,7 @@ async def test_orchestrator_falls_back_to_block_name():
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await OrchestratorBlock._create_block_function_signature(
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1122,11 +1122,11 @@ async def test_orchestrator_falls_back_to_block_name():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_uses_customized_name_for_agents():
|
||||
"""Test that OrchestratorBlock uses customized_name from metadata for agent nodes."""
|
||||
async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
@@ -1152,10 +1152,10 @@ async def test_orchestrator_uses_customized_name_for_agents():
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await OrchestratorBlock._create_agent_function_signature(
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1166,11 +1166,11 @@ async def test_orchestrator_uses_customized_name_for_agents():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_agent_falls_back_to_graph_name():
|
||||
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
"""Test that agent node falls back to graph name when no customized_name."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
@@ -1196,10 +1196,10 @@ async def test_orchestrator_agent_falls_back_to_graph_name():
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await OrchestratorBlock._create_agent_function_signature(
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -3,12 +3,12 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_handles_dynamic_dict_fields():
|
||||
"""Test Orchestrator can handle dynamic dictionary fields (_#_) for any block"""
|
||||
async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
"""Test Smart Decision Maker can handle dynamic dictionary fields (_#_) for any block"""
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
@@ -23,24 +23,24 @@ async def test_orchestrator_handles_dynamic_dict_fields():
|
||||
source_name="tools_^_create_dict_~_name",
|
||||
sink_name="values_#_name", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_age",
|
||||
sink_name="values_#_age", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_city",
|
||||
sink_name="values_#_city", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await OrchestratorBlock._create_block_function_signature(
|
||||
signature = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, mock_links # type: ignore
|
||||
)
|
||||
|
||||
@@ -70,8 +70,8 @@ async def test_orchestrator_handles_dynamic_dict_fields():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_handles_dynamic_list_fields():
|
||||
"""Test Orchestrator can handle dynamic list fields (_$_) for any block"""
|
||||
async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
"""Test Smart Decision Maker can handle dynamic list fields (_$_) for any block"""
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
@@ -86,18 +86,18 @@ async def test_orchestrator_handles_dynamic_list_fields():
|
||||
source_name="tools_^_add_to_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_to_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await OrchestratorBlock._create_block_function_signature(
|
||||
signature = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, mock_links # type: ignore
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Comprehensive tests for OrchestratorBlock dynamic field handling."""
|
||||
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
@@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.text import MatchTextPatternBlock
|
||||
from backend.data.dynamic_fields import get_dynamic_field_description
|
||||
|
||||
@@ -37,7 +37,7 @@ async def test_dynamic_field_description_generation():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_dict_fields():
|
||||
"""Test that function signatures are created correctly for dictionary dynamic fields."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
@@ -52,19 +52,19 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
source_name="tools_^_create_dict_~_values___name", # Sanitized source
|
||||
sink_name="values_#_name", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___age", # Sanitized source
|
||||
sink_name="values_#_age", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___email", # Sanitized source
|
||||
sink_name="values_#_email", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -100,7 +100,7 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_list_fields():
|
||||
"""Test that function signatures are created correctly for list dynamic fields."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
@@ -115,19 +115,19 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
source_name="tools_^_add_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_2",
|
||||
sink_name="entries_$_2", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -154,7 +154,7 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_object_fields():
|
||||
"""Test that function signatures are created correctly for object dynamic fields."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for MatchTextPatternBlock (simulating object fields)
|
||||
mock_node = Mock()
|
||||
@@ -169,13 +169,13 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
source_name="tools_^_extract_~_user_name",
|
||||
sink_name="data_@_user_name", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_email",
|
||||
sink_name="data_@_user_email", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -197,11 +197,11 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_tool_node_signatures():
|
||||
"""Test that the mapping between sanitized and original field names is built correctly."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock the database client and connected nodes
|
||||
with patch(
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
) as mock_db:
|
||||
mock_client = AsyncMock()
|
||||
mock_db.return_value = mock_client
|
||||
@@ -281,7 +281,7 @@ async def test_create_tool_node_signatures():
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_yielding_with_dynamic_fields():
|
||||
"""Test that outputs are yielded correctly with dynamic field names mapped back."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# No more sanitized mapping needed since we removed sanitization
|
||||
|
||||
@@ -309,13 +309,13 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.orchestrator.llm.llm_call", new_callable=AsyncMock
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
# Mock the database manager to avoid HTTP calls during tool execution
|
||||
with patch(
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
) as mock_db_manager, patch.object(
|
||||
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
@@ -420,7 +420,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_regular_and_dynamic_fields():
|
||||
"""Test handling of blocks with both regular and dynamic fields."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node
|
||||
mock_node = Mock()
|
||||
@@ -450,19 +450,19 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
source_name="tools_^_test_~_regular",
|
||||
sink_name="regular_field", # Regular field
|
||||
sink_id="test_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key",
|
||||
sink_name="values_#_key1", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key2",
|
||||
sink_name="values_#_key2", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -488,7 +488,7 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_errors_dont_pollute_conversation():
|
||||
"""Test that validation errors are only used during retries and don't pollute the conversation."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Track conversation history changes
|
||||
conversation_snapshots = []
|
||||
@@ -535,7 +535,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.orchestrator.llm.llm_call", new_callable=AsyncMock
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.side_effect = mock_llm_call
|
||||
|
||||
@@ -565,7 +565,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
|
||||
# Mock the database manager to avoid HTTP calls during tool execution
|
||||
with patch(
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
) as mock_db_manager:
|
||||
# Set up the mock database manager for agent mode
|
||||
mock_db_client = AsyncMock()
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for OrchestratorBlock compatibility with the OpenAI Responses API.
|
||||
"""Tests for SmartDecisionMakerBlock compatibility with the OpenAI Responses API.
|
||||
|
||||
The OrchestratorBlock manages conversation history in the Chat Completions
|
||||
The SmartDecisionMakerBlock manages conversation history in the Chat Completions
|
||||
format, but OpenAI models now use the Responses API which has a fundamentally
|
||||
different conversation structure. These tests document:
|
||||
|
||||
@@ -27,8 +27,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.orchestrator import (
|
||||
OrchestratorBlock,
|
||||
from backend.blocks.smart_decision_maker import (
|
||||
SmartDecisionMakerBlock,
|
||||
_combine_tool_responses,
|
||||
_convert_raw_response_to_dict,
|
||||
_create_tool_response,
|
||||
@@ -733,7 +733,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_dict_raw_response_no_reasoning_no_tools(self):
|
||||
"""Dict raw_response, no reasoning → appends assistant dict."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response({"role": "assistant", "content": "hi"})
|
||||
block._update_conversation(prompt, resp)
|
||||
@@ -741,7 +741,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_dict_raw_response_with_reasoning_no_tool_calls(self):
|
||||
"""Reasoning present, no tool calls → reasoning prepended."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response(
|
||||
{"role": "assistant", "content": "answer"},
|
||||
@@ -757,7 +757,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_dict_raw_response_with_reasoning_and_anthropic_tool_calls(self):
|
||||
"""Reasoning + Anthropic tool_use in content → reasoning skipped."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
prompt: list[dict] = []
|
||||
raw = {
|
||||
"role": "assistant",
|
||||
@@ -772,7 +772,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_with_tool_outputs(self):
|
||||
"""Tool outputs → extended onto prompt."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response({"role": "assistant", "content": None})
|
||||
outputs = [{"role": "tool", "tool_call_id": "call_1", "content": "r"}]
|
||||
@@ -782,7 +782,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_without_tool_outputs(self):
|
||||
"""No tool outputs → only assistant message appended."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response({"role": "assistant", "content": "done"})
|
||||
block._update_conversation(prompt, resp, None)
|
||||
@@ -790,7 +790,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_string_raw_response(self):
|
||||
"""Ollama string → wrapped as assistant dict."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response("hello from ollama")
|
||||
block._update_conversation(prompt, resp)
|
||||
@@ -800,7 +800,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_responses_api_text_response_produces_valid_items(self):
|
||||
"""Responses API text response → conversation items must have valid role."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
prompt: list[dict] = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "user"},
|
||||
@@ -820,7 +820,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_responses_api_function_call_produces_valid_items(self):
|
||||
"""Responses API function_call → conversation items must have valid type."""
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response(
|
||||
_MockResponse(output=[_MockFunctionCall("tool", "{}", call_id="call_1")])
|
||||
@@ -856,7 +856,7 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# First response: tool call
|
||||
mock_tc = MagicMock()
|
||||
@@ -936,7 +936,7 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
with patch("backend.blocks.llm.llm_call", llm_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=tool_sigs
|
||||
), patch(
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db,
|
||||
), patch(
|
||||
"backend.executor.manager.async_update_node_execution_status",
|
||||
@@ -945,7 +945,7 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
"backend.integrations.creds_manager.IntegrationCredentialsManager"
|
||||
):
|
||||
|
||||
inp = OrchestratorBlock.Input(
|
||||
inp = SmartDecisionMakerBlock.Input(
|
||||
prompt="Improve this",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -992,7 +992,7 @@ async def test_traditional_mode_conversation_valid_for_responses_api():
|
||||
"""Traditional mode: the yielded conversation must contain only valid items."""
|
||||
import backend.blocks.llm as llm_module
|
||||
|
||||
block = OrchestratorBlock()
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_tc = MagicMock()
|
||||
mock_tc.function.name = "my_tool"
|
||||
@@ -1028,7 +1028,7 @@ async def test_traditional_mode_conversation_valid_for_responses_api():
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock, return_value=resp
|
||||
), patch.object(block, "_create_tool_node_signatures", return_value=tool_sigs):
|
||||
|
||||
inp = OrchestratorBlock.Input(
|
||||
inp = SmartDecisionMakerBlock.Input(
|
||||
prompt="Do it",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -17,9 +17,6 @@ from backend.util.workspace import WorkspaceManager
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
|
||||
# Allowed base directory for the Read tool. Public so service.py can use it
|
||||
# for sweep operations without depending on a private implementation detail.
|
||||
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
|
||||
@@ -46,12 +43,6 @@ _current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
)
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
# Current execution's capability filter. None means "no restrictions".
|
||||
# Set by set_execution_context(); read by run_block and service.py.
|
||||
_current_permissions: "ContextVar[CopilotPermissions | None]" = ContextVar(
|
||||
"_current_permissions", default=None
|
||||
)
|
||||
|
||||
|
||||
def encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
@@ -72,7 +63,6 @@ def set_execution_context(
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> None:
|
||||
"""Set per-turn context variables used by file-resolution tool handlers."""
|
||||
_current_user_id.set(user_id)
|
||||
@@ -80,7 +70,6 @@ def set_execution_context(
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
_current_permissions.set(permissions)
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
@@ -88,11 +77,6 @@ def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
return _current_user_id.get(), _current_session.get()
|
||||
|
||||
|
||||
def get_current_permissions() -> "CopilotPermissions | None":
|
||||
"""Return the capability filter for the current execution, or None if unrestricted."""
|
||||
return _current_permissions.get()
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current session, or None if not active."""
|
||||
return _current_sandbox.get()
|
||||
@@ -104,32 +88,17 @@ def get_sdk_cwd() -> str:
|
||||
|
||||
|
||||
E2B_WORKDIR = "/home/user"
|
||||
E2B_ALLOWED_DIRS: tuple[str, ...] = (E2B_WORKDIR, "/tmp")
|
||||
E2B_ALLOWED_DIRS_STR: str = " or ".join(E2B_ALLOWED_DIRS)
|
||||
|
||||
|
||||
def is_within_allowed_dirs(path: str) -> bool:
|
||||
"""Return True if *path* is within one of the allowed sandbox directories."""
|
||||
for allowed in E2B_ALLOWED_DIRS:
|
||||
if path == allowed or path.startswith(allowed + "/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def resolve_sandbox_path(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under an allowed directory.
|
||||
|
||||
Allowed directories: ``/home/user`` and ``/tmp``.
|
||||
Relative paths are resolved against ``/home/user``.
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if not is_within_allowed_dirs(normalized):
|
||||
raise ValueError(
|
||||
f"Path must be within {E2B_ALLOWED_DIRS_STR}: {os.path.basename(path)}"
|
||||
)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
return normalized
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ import pytest
|
||||
from backend.copilot.context import (
|
||||
SDK_PROJECTS_DIR,
|
||||
_current_project_dir,
|
||||
get_current_permissions,
|
||||
get_current_sandbox,
|
||||
get_execution_context,
|
||||
get_sdk_cwd,
|
||||
@@ -19,7 +18,6 @@ from backend.copilot.context import (
|
||||
resolve_sandbox_path,
|
||||
set_execution_context,
|
||||
)
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
|
||||
def _make_session() -> MagicMock:
|
||||
@@ -63,19 +61,6 @@ def test_get_current_sandbox_returns_set_value():
|
||||
assert get_current_sandbox() is mock_sandbox
|
||||
|
||||
|
||||
def test_set_and_get_current_permissions():
|
||||
"""set_execution_context stores permissions; get_current_permissions returns it."""
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
set_execution_context("u1", _make_session(), permissions=perms)
|
||||
assert get_current_permissions() is perms
|
||||
|
||||
|
||||
def test_get_current_permissions_defaults_to_none():
|
||||
"""get_current_permissions returns None when no permissions have been set."""
|
||||
set_execution_context("u1", _make_session())
|
||||
assert get_current_permissions() is None
|
||||
|
||||
|
||||
def test_get_sdk_cwd_empty_when_not_set():
|
||||
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
|
||||
set_execution_context("u1", _make_session(), sdk_cwd=None)
|
||||
@@ -198,32 +183,10 @@ def test_resolve_sandbox_path_normalizes_dots():
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_escape_raises():
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_absolute_outside_raises():
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_allowed():
|
||||
assert resolve_sandbox_path("/tmp/data.txt") == "/tmp/data.txt"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_nested():
|
||||
assert resolve_sandbox_path("/tmp/a/b/c.txt") == "/tmp/a/b/c.txt"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_itself():
|
||||
assert resolve_sandbox_path("/tmp") == "/tmp"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_escape_raises():
|
||||
with pytest.raises(ValueError):
|
||||
resolve_sandbox_path("/tmp/../etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_prefix_collision_raises():
|
||||
with pytest.raises(ValueError):
|
||||
resolve_sandbox_path("/tmp_evil/malicious.txt")
|
||||
|
||||
@@ -14,7 +14,7 @@ import time
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamError
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
@@ -23,7 +23,6 @@ from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
|
||||
@@ -154,6 +153,8 @@ class CoPilotProcessor:
|
||||
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
|
||||
runs on the same loop that created the session.
|
||||
"""
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
coro = shutdown_workspace_storage()
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
|
||||
@@ -267,37 +268,35 @@ class CoPilotProcessor:
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
# stream_and_publish wraps the raw stream with registry
|
||||
# publishing (shared with collect_copilot_response).
|
||||
raw_stream = stream_fn(
|
||||
async for chunk in stream_fn(
|
||||
session_id=entry.session_id,
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
)
|
||||
async for chunk in stream_registry.stream_and_publish(
|
||||
session_id=entry.session_id,
|
||||
turn_id=entry.turn_id,
|
||||
stream=raw_stream,
|
||||
):
|
||||
if cancel.is_set():
|
||||
log.info("Cancel requested, breaking stream")
|
||||
break
|
||||
|
||||
# Capture StreamError so mark_session_completed receives
|
||||
# the error message (stream_and_publish yields but does
|
||||
# not publish StreamError — that's done by mark_session_completed).
|
||||
if isinstance(chunk, StreamError):
|
||||
error_msg = chunk.errorText
|
||||
break
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Skip StreamFinish — mark_session_completed publishes it.
|
||||
if isinstance(chunk, StreamFinish):
|
||||
continue
|
||||
|
||||
try:
|
||||
await stream_registry.publish_chunk(entry.turn_id, chunk)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error publishing chunk {type(chunk).__name__}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Stream loop completed
|
||||
if cancel.is_set():
|
||||
log.info("Stream cancelled by user")
|
||||
|
||||
@@ -1,430 +0,0 @@
|
||||
"""Copilot execution permissions — tool and block allow/deny filtering.
|
||||
|
||||
:class:`CopilotPermissions` is the single model used everywhere:
|
||||
|
||||
- ``AutoPilotBlock`` reads four block-input fields and builds one instance.
|
||||
- ``stream_chat_completion_sdk`` applies it when constructing
|
||||
``ClaudeAgentOptions.allowed_tools`` / ``disallowed_tools``.
|
||||
- ``run_block`` reads it from the contextvar to gate block execution.
|
||||
- Recursive (sub-agent) invocations merge parent and child so children
|
||||
can only be *more* restrictive, never more permissive.
|
||||
|
||||
Tool names
|
||||
----------
|
||||
Users specify the **short name** as it appears in ``TOOL_REGISTRY`` (e.g.
|
||||
``run_block``, ``web_fetch``) or as an SDK built-in (e.g. ``Read``,
|
||||
``Task``, ``WebSearch``). Internally these are mapped to the full SDK
|
||||
format (``mcp__copilot__run_block``, ``Read``, …) by
|
||||
:func:`apply_tool_permissions`.
|
||||
|
||||
Block identifiers
|
||||
-----------------
|
||||
Each entry in ``blocks`` may be one of:
|
||||
|
||||
- A **full UUID** (``c069dc6b-c3ed-4c12-b6e5-d47361e64ce6``)
|
||||
- A **partial UUID** — the first 8-character hex segment (``c069dc6b``)
|
||||
- A **block name** (case-insensitive, e.g. ``"HTTP Request"``)
|
||||
|
||||
:func:`validate_block_identifiers` resolves all entries against the live
|
||||
block registry and returns any that could not be matched.
|
||||
|
||||
Semantics
|
||||
---------
|
||||
``tools_exclude=True`` (default) — ``tools`` is a **blacklist**; listed
|
||||
tools are denied and everything else is allowed. An empty list means
|
||||
"allow all" (no filtering).
|
||||
|
||||
``tools_exclude=False`` — ``tools`` is a **whitelist**; only listed tools
|
||||
are allowed.
|
||||
|
||||
``blocks_exclude`` follows the same pattern for ``blocks``.
|
||||
|
||||
Recursion inheritance
|
||||
---------------------
|
||||
:meth:`CopilotPermissions.merged_with_parent` produces a new instance that
|
||||
is at most as permissive as the parent:
|
||||
|
||||
- Tools: effective-allowed sets are intersected then stored as a whitelist.
|
||||
- Blocks: the parent is stored in ``_parent`` and consulted during every
|
||||
:meth:`is_block_allowed` call so both constraints must pass.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Literal, get_args
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants — single source of truth for all accepted tool names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Literal type combining all valid tool names — used by AutoPilotBlock.Input
|
||||
# so the frontend renders a multi-select dropdown.
|
||||
# This is the SINGLE SOURCE OF TRUTH. All other name sets are derived from it.
|
||||
ToolName = Literal[
|
||||
# Platform tools (must match keys in TOOL_REGISTRY)
|
||||
"add_understanding",
|
||||
"bash_exec",
|
||||
"browser_act",
|
||||
"browser_navigate",
|
||||
"browser_screenshot",
|
||||
"connect_integration",
|
||||
"continue_run_block",
|
||||
"create_agent",
|
||||
"create_feature_request",
|
||||
"create_folder",
|
||||
"customize_agent",
|
||||
"delete_folder",
|
||||
"delete_workspace_file",
|
||||
"edit_agent",
|
||||
"find_agent",
|
||||
"find_block",
|
||||
"find_library_agent",
|
||||
"fix_agent_graph",
|
||||
"get_agent_building_guide",
|
||||
"get_doc_page",
|
||||
"get_mcp_guide",
|
||||
"list_folders",
|
||||
"list_workspace_files",
|
||||
"move_agents_to_folder",
|
||||
"move_folder",
|
||||
"read_workspace_file",
|
||||
"run_agent",
|
||||
"run_block",
|
||||
"run_mcp_tool",
|
||||
"search_docs",
|
||||
"search_feature_requests",
|
||||
"update_folder",
|
||||
"validate_agent_graph",
|
||||
"view_agent_output",
|
||||
"web_fetch",
|
||||
"write_workspace_file",
|
||||
# SDK built-ins
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Read",
|
||||
"Task",
|
||||
"TodoWrite",
|
||||
"WebSearch",
|
||||
"Write",
|
||||
]
|
||||
|
||||
# Frozen set of all valid tool names — derived from the Literal.
|
||||
ALL_TOOL_NAMES: frozenset[str] = frozenset(get_args(ToolName))
|
||||
|
||||
# SDK built-in tool names — uppercase-initial names are SDK built-ins.
|
||||
SDK_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
n for n in ALL_TOOL_NAMES if n[0].isupper()
|
||||
)
|
||||
|
||||
# Platform tool names — everything that isn't an SDK built-in.
|
||||
PLATFORM_TOOL_NAMES: frozenset[str] = ALL_TOOL_NAMES - SDK_BUILTIN_TOOL_NAMES
|
||||
|
||||
# Compiled regex patterns for block identifier classification.
|
||||
_FULL_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_PARTIAL_UUID_RE = re.compile(r"^[0-9a-f]{8}$", re.IGNORECASE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper — block identifier matching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _block_matches(identifier: str, block_id: str, block_name: str) -> bool:
|
||||
"""Return True if *identifier* resolves to the given block.
|
||||
|
||||
Resolution order:
|
||||
1. Full UUID — exact case-insensitive match against *block_id*.
|
||||
2. Partial UUID (8 hex chars, first segment) — prefix match.
|
||||
3. Name — case-insensitive equality against *block_name*.
|
||||
"""
|
||||
ident = identifier.strip()
|
||||
if _FULL_UUID_RE.match(ident):
|
||||
return ident.lower() == block_id.lower()
|
||||
if _PARTIAL_UUID_RE.match(ident):
|
||||
return block_id.lower().startswith(ident.lower())
|
||||
return ident.lower() == block_name.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CopilotPermissions(BaseModel):
|
||||
"""Capability filter for a single copilot execution.
|
||||
|
||||
Attributes:
|
||||
tools: Tool names to filter (short names, e.g. ``run_block``).
|
||||
tools_exclude: When True (default) ``tools`` is a blacklist;
|
||||
when False it is a whitelist. Ignored when *tools* is empty.
|
||||
blocks: Block identifiers (name, full UUID, or 8-char partial UUID).
|
||||
blocks_exclude: Same semantics as *tools_exclude* but for blocks.
|
||||
"""
|
||||
|
||||
tools: list[str] = []
|
||||
tools_exclude: bool = True
|
||||
blocks: list[str] = []
|
||||
blocks_exclude: bool = True
|
||||
|
||||
# Private: parent permissions for recursion inheritance.
|
||||
# Set only by merged_with_parent(); never exposed in block input schema.
|
||||
_parent: CopilotPermissions | None = PrivateAttr(default=None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def effective_allowed_tools(self, all_tools: frozenset[str]) -> frozenset[str]:
|
||||
"""Compute the set of short tool names that are permitted.
|
||||
|
||||
Args:
|
||||
all_tools: Universe of valid short tool names.
|
||||
|
||||
Returns:
|
||||
Subset of *all_tools* that pass the filter.
|
||||
"""
|
||||
if not self.tools:
|
||||
return frozenset(all_tools)
|
||||
tool_set = frozenset(self.tools)
|
||||
if self.tools_exclude:
|
||||
return all_tools - tool_set
|
||||
return all_tools & tool_set
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Block helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_block_allowed(self, block_id: str, block_name: str) -> bool:
|
||||
"""Return True if the block may be executed under these permissions.
|
||||
|
||||
Checks this instance first, then consults the parent (if any) so
|
||||
the entire inheritance chain is respected.
|
||||
"""
|
||||
if not self._check_block_locally(block_id, block_name):
|
||||
return False
|
||||
if self._parent is not None:
|
||||
return self._parent.is_block_allowed(block_id, block_name)
|
||||
return True
|
||||
|
||||
def _check_block_locally(self, block_id: str, block_name: str) -> bool:
|
||||
"""Check *only* this instance's block filter (ignores parent)."""
|
||||
if not self.blocks:
|
||||
return True # No filter → allow all
|
||||
matched = any(
|
||||
_block_matches(identifier, block_id, block_name)
|
||||
for identifier in self.blocks
|
||||
)
|
||||
return not matched if self.blocks_exclude else matched
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recursion / merging
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def merged_with_parent(
|
||||
self,
|
||||
parent: CopilotPermissions,
|
||||
all_tools: frozenset[str],
|
||||
) -> CopilotPermissions:
|
||||
"""Return a new instance that is at most as permissive as *parent*.
|
||||
|
||||
- Tools: intersection of effective-allowed sets, stored as a whitelist.
|
||||
- Blocks: parent is stored internally; both constraints are applied
|
||||
during :meth:`is_block_allowed`.
|
||||
"""
|
||||
merged_tools = self.effective_allowed_tools(
|
||||
all_tools
|
||||
) & parent.effective_allowed_tools(all_tools)
|
||||
result = CopilotPermissions(
|
||||
tools=sorted(merged_tools),
|
||||
tools_exclude=False,
|
||||
blocks=self.blocks,
|
||||
blocks_exclude=self.blocks_exclude,
|
||||
)
|
||||
result._parent = parent
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Convenience
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True when no filtering is configured (allow-all passthrough)."""
|
||||
return not self.tools and not self.blocks and self._parent is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def all_known_tool_names() -> frozenset[str]:
|
||||
"""Return all short tool names accepted in *tools*.
|
||||
|
||||
Returns the pre-computed ``ALL_TOOL_NAMES`` set (derived from the
|
||||
``ToolName`` Literal). On first call, also verifies consistency with
|
||||
the live ``TOOL_REGISTRY``.
|
||||
"""
|
||||
_assert_tool_names_consistent()
|
||||
return ALL_TOOL_NAMES
|
||||
|
||||
|
||||
def validate_tool_names(tools: list[str]) -> list[str]:
|
||||
"""Return entries in *tools* that are not valid tool names.
|
||||
|
||||
Args:
|
||||
tools: List of short tool name strings to validate.
|
||||
|
||||
Returns:
|
||||
List of invalid names (empty if all are valid).
|
||||
"""
|
||||
return [t for t in tools if t not in ALL_TOOL_NAMES]
|
||||
|
||||
|
||||
_tool_names_checked = False
|
||||
|
||||
|
||||
def _assert_tool_names_consistent() -> None:
|
||||
"""Verify that ``PLATFORM_TOOL_NAMES`` matches ``TOOL_REGISTRY`` keys.
|
||||
|
||||
Called once lazily (TOOL_REGISTRY has heavy imports). Raises
|
||||
``AssertionError`` with a helpful diff if they diverge.
|
||||
"""
|
||||
global _tool_names_checked
|
||||
if _tool_names_checked:
|
||||
return
|
||||
_tool_names_checked = True
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
registry_keys: frozenset[str] = frozenset(TOOL_REGISTRY.keys())
|
||||
declared: frozenset[str] = PLATFORM_TOOL_NAMES
|
||||
if registry_keys != declared:
|
||||
missing = registry_keys - declared
|
||||
extra = declared - registry_keys
|
||||
parts: list[str] = [
|
||||
"PLATFORM_TOOL_NAMES in permissions.py is out of sync with TOOL_REGISTRY."
|
||||
]
|
||||
if missing:
|
||||
parts.append(f" Missing from PLATFORM_TOOL_NAMES: {sorted(missing)}")
|
||||
if extra:
|
||||
parts.append(f" Extra in PLATFORM_TOOL_NAMES: {sorted(extra)}")
|
||||
parts.append(" Update the ToolName Literal to match.")
|
||||
raise AssertionError("\n".join(parts))
|
||||
|
||||
|
||||
async def validate_block_identifiers(
|
||||
identifiers: list[str],
|
||||
) -> list[str]:
|
||||
"""Resolve each block identifier and return those that could not be matched.
|
||||
|
||||
Args:
|
||||
identifiers: List of block identifiers (name, full UUID, or partial UUID).
|
||||
|
||||
Returns:
|
||||
List of identifiers that matched no known block.
|
||||
"""
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
# get_blocks() returns dict[block_id_str, BlockClass]; instantiate once to get names.
|
||||
block_registry = get_blocks()
|
||||
block_info = {bid: cls().name for bid, cls in block_registry.items()}
|
||||
invalid: list[str] = []
|
||||
for ident in identifiers:
|
||||
matched = any(
|
||||
_block_matches(ident, bid, bname) for bid, bname in block_info.items()
|
||||
)
|
||||
if not matched:
|
||||
invalid.append(ident)
|
||||
return invalid
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDK tool-list application
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_tool_permissions(
|
||||
permissions: CopilotPermissions,
|
||||
*,
|
||||
use_e2b: bool = False,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Compute (allowed_tools, extra_disallowed) for :class:`ClaudeAgentOptions`.
|
||||
|
||||
Takes the base allowed/disallowed lists from
|
||||
:func:`~backend.copilot.sdk.tool_adapter.get_copilot_tool_names` /
|
||||
:func:`~backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools` and
|
||||
applies *permissions* on top.
|
||||
|
||||
Returns:
|
||||
``(allowed_tools, extra_disallowed)`` where *allowed_tools* is the
|
||||
possibly-narrowed list to pass to ``ClaudeAgentOptions.allowed_tools``
|
||||
and *extra_disallowed* is the list to pass to
|
||||
``ClaudeAgentOptions.disallowed_tools``.
|
||||
"""
|
||||
from backend.copilot.sdk.tool_adapter import (
|
||||
_READ_TOOL_NAME,
|
||||
MCP_TOOL_PREFIX,
|
||||
get_copilot_tool_names,
|
||||
get_sdk_disallowed_tools,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
base_allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
base_disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
if permissions.is_empty():
|
||||
return base_allowed, base_disallowed
|
||||
|
||||
all_tools = all_known_tool_names()
|
||||
effective = permissions.effective_allowed_tools(all_tools)
|
||||
|
||||
# In E2B mode, SDK built-in file tools (Read, Write, Edit, Glob, Grep)
|
||||
# are replaced by MCP equivalents (read_file, write_file, ...).
|
||||
# Map each SDK built-in name to its E2B MCP name so users can use the
|
||||
# familiar names in their permissions and the E2B tools are included.
|
||||
_SDK_TO_E2B: dict[str, str] = {}
|
||||
if use_e2b:
|
||||
from backend.copilot.sdk.e2b_file_tools import E2B_FILE_TOOL_NAMES
|
||||
|
||||
_SDK_TO_E2B = dict(
|
||||
zip(
|
||||
["Read", "Write", "Edit", "Glob", "Grep"],
|
||||
E2B_FILE_TOOL_NAMES,
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
|
||||
# Build an updated allowed list by mapping short names → SDK names and
|
||||
# keeping only those present in the original base_allowed list.
|
||||
def to_sdk_names(short: str) -> list[str]:
|
||||
names: list[str] = []
|
||||
if short in TOOL_REGISTRY:
|
||||
names.append(f"{MCP_TOOL_PREFIX}{short}")
|
||||
elif short in _SDK_TO_E2B:
|
||||
# E2B mode: map SDK built-in file tool to its MCP equivalent.
|
||||
names.append(f"{MCP_TOOL_PREFIX}{_SDK_TO_E2B[short]}")
|
||||
else:
|
||||
names.append(short) # SDK built-in — used as-is
|
||||
return names
|
||||
|
||||
# short names permitted by permissions
|
||||
permitted_sdk: set[str] = set()
|
||||
for s in effective:
|
||||
permitted_sdk.update(to_sdk_names(s))
|
||||
# Always include the internal Read tool (used by SDK for large/truncated outputs)
|
||||
permitted_sdk.add(f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}")
|
||||
|
||||
filtered_allowed = [t for t in base_allowed if t in permitted_sdk]
|
||||
|
||||
# Extra disallowed = tools that were in base_allowed but are now removed
|
||||
removed = set(base_allowed) - set(filtered_allowed)
|
||||
extra_disallowed = list(set(base_disallowed) | removed)
|
||||
|
||||
return filtered_allowed, extra_disallowed
|
||||
@@ -1,579 +0,0 @@
|
||||
"""Tests for CopilotPermissions — tool/block capability filtering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.permissions import (
|
||||
ALL_TOOL_NAMES,
|
||||
PLATFORM_TOOL_NAMES,
|
||||
SDK_BUILTIN_TOOL_NAMES,
|
||||
CopilotPermissions,
|
||||
_block_matches,
|
||||
all_known_tool_names,
|
||||
apply_tool_permissions,
|
||||
validate_block_identifiers,
|
||||
validate_tool_names,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _block_matches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBlockMatches:
|
||||
BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
BLOCK_NAME = "HTTP Request"
|
||||
|
||||
def test_full_uuid_match(self):
|
||||
assert _block_matches(self.BLOCK_ID, self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_full_uuid_case_insensitive(self):
|
||||
assert _block_matches(self.BLOCK_ID.upper(), self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_full_uuid_no_match(self):
|
||||
other = "aaaaaaaa-0000-0000-0000-000000000000"
|
||||
assert not _block_matches(other, self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_match(self):
|
||||
assert _block_matches("c069dc6b", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_case_insensitive(self):
|
||||
assert _block_matches("C069DC6B", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_no_match(self):
|
||||
assert not _block_matches("deadbeef", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_name_match(self):
|
||||
assert _block_matches("HTTP Request", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_name_case_insensitive(self):
|
||||
assert _block_matches("http request", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
assert _block_matches("HTTP REQUEST", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_name_no_match(self):
|
||||
assert not _block_matches("Unknown Block", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_not_matching_as_name(self):
|
||||
# "c069dc6b" is 8 hex chars → treated as partial UUID, NOT name match
|
||||
assert not _block_matches(
|
||||
"c069dc6b", "ffffffff-0000-0000-0000-000000000000", "c069dc6b"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.effective_allowed_tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
ALL_TOOLS = frozenset(
|
||||
["run_block", "web_fetch", "bash_exec", "find_agent", "Task", "Read"]
|
||||
)
|
||||
|
||||
|
||||
class TestEffectiveAllowedTools:
|
||||
def test_empty_list_allows_all(self):
|
||||
perms = CopilotPermissions(tools=[], tools_exclude=True)
|
||||
assert perms.effective_allowed_tools(ALL_TOOLS) == ALL_TOOLS
|
||||
|
||||
def test_empty_whitelist_allows_all(self):
|
||||
# edge: tools_exclude=False but empty list → allow all
|
||||
perms = CopilotPermissions(tools=[], tools_exclude=False)
|
||||
assert perms.effective_allowed_tools(ALL_TOOLS) == ALL_TOOLS
|
||||
|
||||
def test_blacklist_removes_listed(self):
|
||||
perms = CopilotPermissions(tools=["bash_exec", "web_fetch"], tools_exclude=True)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert "bash_exec" not in result
|
||||
assert "web_fetch" not in result
|
||||
assert "run_block" in result
|
||||
assert "Task" in result
|
||||
|
||||
def test_whitelist_keeps_only_listed(self):
|
||||
perms = CopilotPermissions(tools=["run_block", "Task"], tools_exclude=False)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert result == frozenset(["run_block", "Task"])
|
||||
|
||||
def test_whitelist_unknown_tool_yields_empty(self):
|
||||
perms = CopilotPermissions(tools=["nonexistent"], tools_exclude=False)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert result == frozenset()
|
||||
|
||||
def test_blacklist_unknown_tool_ignored(self):
|
||||
perms = CopilotPermissions(tools=["nonexistent"], tools_exclude=True)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert result == ALL_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.is_block_allowed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
BLOCK_NAME = "HTTP Request"
|
||||
|
||||
|
||||
class TestIsBlockAllowed:
|
||||
def test_empty_allows_everything(self):
|
||||
perms = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_blacklist_blocks_listed(self):
|
||||
perms = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
assert not perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_blacklist_allows_unlisted(self):
|
||||
perms = CopilotPermissions(blocks=["Other Block"], blocks_exclude=True)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_whitelist_allows_listed(self):
|
||||
perms = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=False)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_whitelist_blocks_unlisted(self):
|
||||
perms = CopilotPermissions(blocks=["Other Block"], blocks_exclude=False)
|
||||
assert not perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_blacklist(self):
|
||||
perms = CopilotPermissions(blocks=["c069dc6b"], blocks_exclude=True)
|
||||
assert not perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_full_uuid_whitelist(self):
|
||||
perms = CopilotPermissions(blocks=[BLOCK_ID], blocks_exclude=False)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_parent_blocks_when_child_allows(self):
|
||||
parent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
child = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
child._parent = parent
|
||||
assert not child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_parent_allows_when_child_blocks(self):
|
||||
parent = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
child = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
child._parent = parent
|
||||
assert not child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_both_must_allow(self):
|
||||
parent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=False)
|
||||
child = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=False)
|
||||
child._parent = parent
|
||||
assert child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_grandparent_blocks_propagate(self):
|
||||
grandparent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
parent = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
parent._parent = grandparent
|
||||
child = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
child._parent = parent
|
||||
assert not child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.merged_with_parent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergedWithParent:
|
||||
def test_tool_intersection(self):
|
||||
all_t = frozenset(["run_block", "web_fetch", "bash_exec"])
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
child = CopilotPermissions(tools=["web_fetch"], tools_exclude=True)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert "bash_exec" not in effective
|
||||
assert "web_fetch" not in effective
|
||||
assert "run_block" in effective
|
||||
|
||||
def test_parent_whitelist_narrows_child(self):
|
||||
all_t = frozenset(["run_block", "web_fetch", "bash_exec"])
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
child = CopilotPermissions(tools=[], tools_exclude=True) # allow all
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert effective == frozenset(["run_block"])
|
||||
|
||||
def test_child_cannot_expand_parent_whitelist(self):
|
||||
all_t = frozenset(["run_block", "web_fetch", "bash_exec"])
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
child = CopilotPermissions(
|
||||
tools=["run_block", "bash_exec"], tools_exclude=False
|
||||
)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
# bash_exec was not in parent's whitelist → must not appear
|
||||
assert "bash_exec" not in effective
|
||||
assert "run_block" in effective
|
||||
|
||||
def test_merged_stored_as_whitelist(self):
|
||||
all_t = frozenset(["run_block", "web_fetch"])
|
||||
parent = CopilotPermissions(tools=[], tools_exclude=True)
|
||||
child = CopilotPermissions(tools=[], tools_exclude=True)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
assert not merged.tools_exclude # stored as whitelist
|
||||
assert set(merged.tools) == {"run_block", "web_fetch"}
|
||||
|
||||
def test_block_parent_stored(self):
|
||||
all_t = frozenset(["run_block"])
|
||||
parent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
child = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
# Parent restriction is preserved via _parent
|
||||
assert not merged.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.is_empty
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsEmpty:
|
||||
def test_default_is_empty(self):
|
||||
assert CopilotPermissions().is_empty()
|
||||
|
||||
def test_with_tools_not_empty(self):
|
||||
assert not CopilotPermissions(tools=["bash_exec"]).is_empty()
|
||||
|
||||
def test_with_blocks_not_empty(self):
|
||||
assert not CopilotPermissions(blocks=["HTTP Request"]).is_empty()
|
||||
|
||||
def test_with_parent_not_empty(self):
|
||||
perms = CopilotPermissions()
|
||||
perms._parent = CopilotPermissions(tools=["bash_exec"])
|
||||
assert not perms.is_empty()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_tool_names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateToolNames:
|
||||
def test_valid_registry_tool(self):
|
||||
assert validate_tool_names(["run_block", "web_fetch"]) == []
|
||||
|
||||
def test_valid_sdk_builtin(self):
|
||||
assert validate_tool_names(["Read", "Task", "WebSearch"]) == []
|
||||
|
||||
def test_invalid_tool(self):
|
||||
result = validate_tool_names(["nonexistent_tool"])
|
||||
assert "nonexistent_tool" in result
|
||||
|
||||
def test_mixed(self):
|
||||
result = validate_tool_names(["run_block", "fake_tool"])
|
||||
assert "fake_tool" in result
|
||||
assert "run_block" not in result
|
||||
|
||||
def test_empty_list(self):
|
||||
assert validate_tool_names([]) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_block_identifiers (async)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestValidateBlockIdentifiers:
|
||||
async def test_empty_list(self):
|
||||
result = await validate_block_identifiers([])
|
||||
assert result == []
|
||||
|
||||
async def test_valid_full_uuid(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(
|
||||
["c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"]
|
||||
)
|
||||
assert result == []
|
||||
|
||||
async def test_invalid_identifier(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(["totally_unknown"])
|
||||
assert "totally_unknown" in result
|
||||
|
||||
async def test_partial_uuid_match(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(["c069dc6b"])
|
||||
assert result == []
|
||||
|
||||
async def test_name_match(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(["http request"])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# apply_tool_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApplyToolPermissions:
|
||||
def test_empty_permissions_returns_base_unchanged(self, mocker):
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=["mcp__copilot__run_block", "mcp__copilot__web_fetch", "Task"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object(), "web_fetch": object()},
|
||||
)
|
||||
perms = CopilotPermissions()
|
||||
allowed, disallowed = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
assert "mcp__copilot__web_fetch" in allowed
|
||||
|
||||
def test_blacklist_removes_tool(self, mocker):
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__web_fetch",
|
||||
"mcp__copilot__bash_exec",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{
|
||||
"run_block": object(),
|
||||
"web_fetch": object(),
|
||||
"bash_exec": object(),
|
||||
},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "web_fetch", "bash_exec", "Task"]),
|
||||
)
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__bash_exec" not in allowed
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
|
||||
def test_whitelist_keeps_only_listed(self, mocker):
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__web_fetch",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object(), "web_fetch": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "web_fetch", "Task", "WebSearch"]),
|
||||
)
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
assert "mcp__copilot__web_fetch" not in allowed
|
||||
assert "Task" not in allowed
|
||||
|
||||
def test_read_tool_always_included_even_when_blacklisted(self, mocker):
|
||||
"""mcp__copilot__Read must stay in allowed even if Read is explicitly blacklisted."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=[],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Task"]),
|
||||
)
|
||||
# Explicitly blacklist Read
|
||||
perms = CopilotPermissions(tools=["Read"], tools_exclude=True)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__Read" in allowed # always preserved for SDK internals
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
assert "Task" in allowed
|
||||
|
||||
def test_read_tool_always_included_with_narrow_whitelist(self, mocker):
|
||||
"""mcp__copilot__Read must stay in allowed even when not in a whitelist."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=[],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Task"]),
|
||||
)
|
||||
# Whitelist only run_block — Read not listed
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__Read" in allowed # always preserved for SDK internals
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
|
||||
def test_e2b_file_tools_included_when_sdk_builtin_whitelisted(self, mocker):
|
||||
"""In E2B mode, whitelisting 'Read' must include mcp__copilot__read_file."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"mcp__copilot__read_file",
|
||||
"mcp__copilot__write_file",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash", "Read", "Write", "Edit", "Glob", "Grep"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Write", "Task"]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.e2b_file_tools.E2B_FILE_TOOL_NAMES",
|
||||
["read_file", "write_file", "edit_file", "glob", "grep"],
|
||||
)
|
||||
# Whitelist Read and run_block — E2B read_file should be included
|
||||
perms = CopilotPermissions(tools=["Read", "run_block"], tools_exclude=False)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=True)
|
||||
assert "mcp__copilot__read_file" in allowed
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
# Write not whitelisted — write_file should NOT be included
|
||||
assert "mcp__copilot__write_file" not in allowed
|
||||
|
||||
def test_e2b_file_tools_excluded_when_sdk_builtin_blacklisted(self, mocker):
|
||||
"""In E2B mode, blacklisting 'Read' must also remove mcp__copilot__read_file."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"mcp__copilot__read_file",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash", "Read", "Write", "Edit", "Glob", "Grep"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Task"]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.e2b_file_tools.E2B_FILE_TOOL_NAMES",
|
||||
["read_file", "write_file", "edit_file", "glob", "grep"],
|
||||
)
|
||||
# Blacklist Read — E2B read_file should also be removed
|
||||
perms = CopilotPermissions(tools=["Read"], tools_exclude=True)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=True)
|
||||
assert "mcp__copilot__read_file" not in allowed
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
# mcp__copilot__Read is always preserved for SDK internals
|
||||
assert "mcp__copilot__Read" in allowed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDK_BUILTIN_TOOL_NAMES sanity check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSdkBuiltinToolNames:
|
||||
def test_expected_builtins_present(self):
|
||||
expected = {
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
}
|
||||
assert expected.issubset(SDK_BUILTIN_TOOL_NAMES)
|
||||
|
||||
def test_platform_names_match_tool_registry(self):
|
||||
"""PLATFORM_TOOL_NAMES (derived from ToolName Literal) must match TOOL_REGISTRY keys."""
|
||||
registry_keys = frozenset(TOOL_REGISTRY.keys())
|
||||
assert PLATFORM_TOOL_NAMES == registry_keys, (
|
||||
f"ToolName Literal is out of sync with TOOL_REGISTRY. "
|
||||
f"Missing: {registry_keys - PLATFORM_TOOL_NAMES}, "
|
||||
f"Extra: {PLATFORM_TOOL_NAMES - registry_keys}"
|
||||
)
|
||||
|
||||
def test_all_tool_names_is_union(self):
|
||||
"""ALL_TOOL_NAMES must equal PLATFORM_TOOL_NAMES | SDK_BUILTIN_TOOL_NAMES."""
|
||||
assert ALL_TOOL_NAMES == PLATFORM_TOOL_NAMES | SDK_BUILTIN_TOOL_NAMES
|
||||
|
||||
def test_no_overlap_between_platform_and_sdk(self):
|
||||
"""Platform and SDK built-in names must not overlap."""
|
||||
assert PLATFORM_TOOL_NAMES.isdisjoint(SDK_BUILTIN_TOOL_NAMES)
|
||||
|
||||
def test_known_tools_includes_registry_and_builtins(self):
|
||||
known = all_known_tool_names()
|
||||
assert "run_block" in known
|
||||
assert "Read" in known
|
||||
assert "Task" in known
|
||||
@@ -12,18 +12,34 @@ from backend.copilot.tools import TOOL_REGISTRY
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = f"""\
|
||||
|
||||
### Sharing files
|
||||
After `write_workspace_file`, embed the `download_url` in Markdown:
|
||||
- File: `[report.csv](workspace://file_id#text/csv)`
|
||||
- Image: ``
|
||||
- Video: ``
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
|
||||
### File references — @@agptfile:
|
||||
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
|
||||
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
|
||||
- `/absolute/path` — local/sandbox files
|
||||
- `[start-end]` — optional 1-indexed line range
|
||||
- Multiple refs per argument supported. Only `workspace://` and absolute paths are expanded.
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Passing file content to tools — @@agptfile: references
|
||||
Instead of copying large file contents into a tool argument, pass a file
|
||||
reference and the platform will load the content for you.
|
||||
|
||||
Syntax: `@@agptfile:<uri>[<start>-<end>]`
|
||||
|
||||
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
|
||||
- `workspace://<file_id>` — workspace file by ID
|
||||
- `workspace:///<path>` — workspace file by virtual path
|
||||
- `/absolute/local/path` — ephemeral or sdk_cwd file
|
||||
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
|
||||
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
|
||||
- URIs that do not start with `workspace://` or `/` are **not** expanded.
|
||||
|
||||
Examples:
|
||||
```
|
||||
@@ -34,9 +50,21 @@ Examples:
|
||||
@@agptfile:/home/user/script.py
|
||||
```
|
||||
|
||||
**Structured data**: When the entire argument is a single file reference, the platform auto-parses by extension/MIME. Supported: JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel (.xlsx only; legacy `.xls` is NOT supported). Unrecognised formats return plain string.
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
**Type coercion**: The platform auto-coerces expanded string values to match block input types (e.g. JSON string → `list[list[str]]`).
|
||||
**Structured data**: When the **entire** argument value is a single file
|
||||
reference (no surrounding text), the platform automatically parses the file
|
||||
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||
unrecognised or parsing fails, the content is returned as a plain string.
|
||||
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||
|
||||
**Type coercion**: The platform also coerces expanded values to match the
|
||||
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||
|
||||
### Media file inputs (format: "file")
|
||||
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||
@@ -63,50 +91,6 @@ Example — committing an image file to GitHub:
|
||||
}}
|
||||
```
|
||||
|
||||
### Writing large files — CRITICAL
|
||||
**Never write an entire large document in a single tool call.** When the
|
||||
content you want to write exceeds ~2000 words the tool call's output token
|
||||
limit will silently truncate the arguments, producing an empty `{{}}` input
|
||||
that fails repeatedly.
|
||||
|
||||
**Preferred: compose from file references.** If the data is already in
|
||||
files (tool outputs, workspace files), compose the report in one call
|
||||
using `@@agptfile:` references — the system expands them inline:
|
||||
|
||||
```bash
|
||||
cat > report.md << 'EOF'
|
||||
# Research Report
|
||||
## Data from web research
|
||||
@@agptfile:/home/user/web_results.txt
|
||||
## Block execution output
|
||||
@@agptfile:workspace://<file_id>
|
||||
## Conclusion
|
||||
<brief synthesis>
|
||||
EOF
|
||||
```
|
||||
|
||||
**Fallback: write section-by-section.** When you must generate content
|
||||
from conversation context (no files to reference), split into multiple
|
||||
`bash_exec` calls — one section per call:
|
||||
|
||||
```bash
|
||||
cat > report.md << 'EOF'
|
||||
# Section 1
|
||||
<content from your earlier tool call results>
|
||||
EOF
|
||||
```
|
||||
```bash
|
||||
cat >> report.md << 'EOF'
|
||||
# Section 2
|
||||
<content from your earlier tool call results>
|
||||
EOF
|
||||
```
|
||||
Use `cat >` for the first chunk and `cat >>` to append subsequent chunks.
|
||||
Do not re-fetch or re-generate data you already have from prior tool calls.
|
||||
|
||||
After building the file, reference it with `@@agptfile:` in other tools:
|
||||
`@@agptfile:/home/user/report.md`
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
@@ -182,12 +166,17 @@ def _build_storage_supplement(
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell & filesystem
|
||||
- The SDK built-in Bash tool is NOT available. Use `bash_exec` for shell commands ({sandbox_type}). Working dir: `{working_dir}`
|
||||
- SDK file tools (Read/Write/Edit/Glob/Grep) and `bash_exec` share one filesystem — use relative or absolute paths under this dir.
|
||||
- `read_workspace_file`/`write_workspace_file` operate on **persistent cloud workspace storage** (separate from the working dir).
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs {sandbox_type}.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `{working_dir}`
|
||||
- All SDK file tools AND `bash_exec` operate on the same filesystem
|
||||
- Use relative paths or absolute paths under `{working_dir}` for all file operations
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **{storage_system_1_name}** (`{working_dir}`):
|
||||
{characteristics}
|
||||
{persistence}
|
||||
|
||||
@@ -143,11 +143,11 @@ To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||
tool_arguments.
|
||||
6. Output: `result` (the tool's return value) and `error` (error message)
|
||||
|
||||
### Using OrchestratorBlock (AI Orchestrator with Agent Mode)
|
||||
### Using SmartDecisionMakerBlock (AI Orchestrator with Agent Mode)
|
||||
|
||||
To create an agent where AI autonomously decides which tools or sub-agents to
|
||||
call in a loop until the task is complete:
|
||||
1. Create a `OrchestratorBlock` node
|
||||
1. Create a `SmartDecisionMakerBlock` node
|
||||
(ID: `3b191d9f-356f-482d-8238-ba04b6d18381`)
|
||||
2. Set `input_default`:
|
||||
- `agent_mode_max_iterations`: Choose based on task complexity:
|
||||
@@ -169,8 +169,8 @@ call in a loop until the task is complete:
|
||||
3. Wire the `prompt` input from an `AgentInputBlock` (the user's task)
|
||||
4. Create downstream tool blocks — regular blocks **or** `AgentExecutorBlock`
|
||||
nodes that call sub-agents
|
||||
5. Link each tool to the Orchestrator: set `source_name: "tools"` on
|
||||
the Orchestrator side and `sink_name: <input_field>` on each tool
|
||||
5. Link each tool to the SmartDecisionMaker: set `source_name: "tools"` on
|
||||
the SmartDecisionMaker side and `sink_name: <input_field>` on each tool
|
||||
block's input. Create one link per input field the tool needs.
|
||||
6. Wire the `finished` output to an `AgentOutputBlock` for the final result
|
||||
7. Credentials (LLM API key) are configured by the user in the platform UI
|
||||
@@ -178,49 +178,35 @@ call in a loop until the task is complete:
|
||||
|
||||
**Example — Orchestrator calling two sub-agents:**
|
||||
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
|
||||
- Node 2: `OrchestratorBlock` (input_default:
|
||||
- Node 2: `SmartDecisionMakerBlock` (input_default:
|
||||
`{"agent_mode_max_iterations": 10, "conversation_compaction": true}`)
|
||||
- Node 3: `AgentExecutorBlock` (sub-agent A — set `graph_id`, `graph_version`,
|
||||
`input_schema`, `output_schema` from library agent)
|
||||
- Node 4: `AgentExecutorBlock` (sub-agent B — same pattern)
|
||||
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
|
||||
- Links:
|
||||
- Input→Orchestrator: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- Orchestrator→Agent A (per input field): `source_name: "tools"`,
|
||||
- Input→SDM: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- SDM→Agent A (per input field): `source_name: "tools"`,
|
||||
`sink_name: "<agent_a_input_field>"`
|
||||
- Orchestrator→Agent B (per input field): `source_name: "tools"`,
|
||||
- SDM→Agent B (per input field): `source_name: "tools"`,
|
||||
`sink_name: "<agent_b_input_field>"`
|
||||
- Orchestrator→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
- SDM→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
|
||||
**Example — Orchestrator calling regular blocks as tools:**
|
||||
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
|
||||
- Node 2: `OrchestratorBlock` (input_default:
|
||||
- Node 2: `SmartDecisionMakerBlock` (input_default:
|
||||
`{"agent_mode_max_iterations": 5, "conversation_compaction": true}`)
|
||||
- Node 3: `GetWebpageBlock` (regular block — the AI calls it as a tool)
|
||||
- Node 4: `AITextGeneratorBlock` (another regular block as a tool)
|
||||
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
|
||||
- Links:
|
||||
- Input→Orchestrator: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- Orchestrator→GetWebpage: `source_name: "tools"`, `sink_name: "url"`
|
||||
- Orchestrator→AITextGenerator: `source_name: "tools"`, `sink_name: "prompt"`
|
||||
- Orchestrator→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
- Input→SDM: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- SDM→GetWebpage: `source_name: "tools"`, `sink_name: "url"`
|
||||
- SDM→AITextGenerator: `source_name: "tools"`, `sink_name: "prompt"`
|
||||
- SDM→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
|
||||
Regular blocks work exactly like sub-agents as tools — wire each input
|
||||
field from `source_name: "tools"` on the Orchestrator side.
|
||||
|
||||
### Testing with Dry Run
|
||||
|
||||
After saving an agent, suggest a dry run to validate wiring without consuming
|
||||
real API calls, credentials, or credits:
|
||||
|
||||
1. **Run**: Call `run_agent` or `run_block` with `dry_run=True` and provide
|
||||
sample inputs. This executes the graph with mock outputs, verifying that
|
||||
links resolve correctly and required inputs are satisfied.
|
||||
2. **Check results**: Call `view_agent_output` with `show_execution_details=True`
|
||||
to inspect the full node-by-node execution trace. This shows what each node
|
||||
received as input and produced as output, making it easy to spot wiring issues.
|
||||
3. **Iterate**: If the dry run reveals wiring issues or missing inputs, fix
|
||||
the agent JSON and re-save before suggesting a real execution.
|
||||
field from `source_name: "tools"` on the SmartDecisionMaker side.
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
|
||||
@@ -7,35 +7,7 @@ without implementing their own event loop.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .. import stream_registry
|
||||
from ..response_model import (
|
||||
StreamError,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from .service import stream_chat_completion_sdk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Identifiers used when registering AutoPilot-originated streams in the
|
||||
# stream registry. Distinct from "chat_stream"/"chat" used by the HTTP SSE
|
||||
# endpoint, making it easy to filter AutoPilot streams in logs/observability.
|
||||
AUTOPILOT_TOOL_CALL_ID = "autopilot_stream"
|
||||
AUTOPILOT_TOOL_NAME = "autopilot"
|
||||
from typing import Any
|
||||
|
||||
|
||||
class CopilotResult:
|
||||
@@ -61,131 +33,26 @@ class CopilotResult:
|
||||
self.total_tokens: int = 0
|
||||
|
||||
|
||||
class _RegistryHandle(BaseModel):
|
||||
"""Tracks stream registry session state for cleanup."""
|
||||
|
||||
publish_turn_id: str = ""
|
||||
error_msg: str | None = None
|
||||
error_already_published: bool = False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _registry_session(
|
||||
session_id: str, user_id: str, turn_id: str
|
||||
) -> AsyncIterator[_RegistryHandle]:
|
||||
"""Create a stream registry session and ensure it is finalized."""
|
||||
handle = _RegistryHandle(publish_turn_id=turn_id)
|
||||
try:
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=AUTOPILOT_TOOL_CALL_ID,
|
||||
tool_name=AUTOPILOT_TOOL_NAME,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"[collect] Failed to create stream registry session for %s, "
|
||||
"frontend will not receive real-time updates",
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
# Disable chunk publishing but keep finalization enabled so
|
||||
# mark_session_completed can clean up any partial registry state.
|
||||
handle.publish_turn_id = ""
|
||||
|
||||
try:
|
||||
yield handle
|
||||
finally:
|
||||
try:
|
||||
await stream_registry.mark_session_completed(
|
||||
session_id,
|
||||
error_message=handle.error_msg,
|
||||
skip_error_publish=handle.error_already_published,
|
||||
)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"[collect] Failed to mark stream completed for %s",
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class _ToolCallEntry(BaseModel):
|
||||
"""A single tool call observed during stream consumption."""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
input: Any
|
||||
output: Any = None
|
||||
success: bool | None = None
|
||||
|
||||
|
||||
class _EventAccumulator(BaseModel):
|
||||
"""Mutable accumulator for stream events."""
|
||||
|
||||
response_parts: list[str] = Field(default_factory=list)
|
||||
tool_calls: list[_ToolCallEntry] = Field(default_factory=list)
|
||||
tool_calls_by_id: dict[str, _ToolCallEntry] = Field(default_factory=dict)
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
def _process_event(event: object, acc: _EventAccumulator) -> str | None:
|
||||
"""Process a single stream event and return error_msg if StreamError.
|
||||
|
||||
Uses structural pattern matching for dispatch per project guidelines.
|
||||
"""
|
||||
match event:
|
||||
case StreamTextDelta(delta=delta):
|
||||
acc.response_parts.append(delta)
|
||||
case StreamToolInputAvailable() as e:
|
||||
entry = _ToolCallEntry(
|
||||
tool_call_id=e.toolCallId,
|
||||
tool_name=e.toolName,
|
||||
input=e.input,
|
||||
)
|
||||
acc.tool_calls.append(entry)
|
||||
acc.tool_calls_by_id[e.toolCallId] = entry
|
||||
case StreamToolOutputAvailable() as e:
|
||||
if tc := acc.tool_calls_by_id.get(e.toolCallId):
|
||||
tc.output = e.output
|
||||
tc.success = e.success
|
||||
else:
|
||||
logger.debug(
|
||||
"Received tool output for unknown tool_call_id: %s",
|
||||
e.toolCallId,
|
||||
)
|
||||
case StreamUsage() as e:
|
||||
acc.prompt_tokens += e.prompt_tokens
|
||||
acc.completion_tokens += e.completion_tokens
|
||||
acc.total_tokens += e.total_tokens
|
||||
case StreamError(errorText=err):
|
||||
return err
|
||||
return None
|
||||
|
||||
|
||||
async def collect_copilot_response(
|
||||
*,
|
||||
session_id: str,
|
||||
message: str,
|
||||
user_id: str,
|
||||
is_user_message: bool = True,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> CopilotResult:
|
||||
"""Consume :func:`stream_chat_completion_sdk` and return aggregated results.
|
||||
|
||||
Registers with the stream registry so the frontend can connect via SSE
|
||||
and receive real-time updates while the AutoPilot block is executing.
|
||||
This is the recommended entry-point for callers that need a simple
|
||||
request-response interface (e.g. the AutoPilot block) rather than
|
||||
streaming individual events. It avoids duplicating the event-collection
|
||||
logic and does NOT wrap the stream in ``asyncio.timeout`` — the SDK
|
||||
manages its own heartbeat-based timeouts internally.
|
||||
|
||||
Args:
|
||||
session_id: Chat session to use.
|
||||
message: The user message / prompt.
|
||||
user_id: Authenticated user ID.
|
||||
is_user_message: Whether this is a user-initiated message.
|
||||
permissions: Optional capability filter. When provided, restricts
|
||||
which tools and blocks the copilot may use during this execution.
|
||||
|
||||
Returns:
|
||||
A :class:`CopilotResult` with the aggregated response text,
|
||||
@@ -194,39 +61,48 @@ async def collect_copilot_response(
|
||||
Raises:
|
||||
RuntimeError: If the stream yields a ``StreamError`` event.
|
||||
"""
|
||||
turn_id = str(uuid.uuid4())
|
||||
async with _registry_session(session_id, user_id, turn_id) as handle:
|
||||
try:
|
||||
raw_stream = stream_chat_completion_sdk(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
published_stream = stream_registry.stream_and_publish(
|
||||
session_id=session_id,
|
||||
turn_id=handle.publish_turn_id,
|
||||
stream=raw_stream,
|
||||
)
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
acc = _EventAccumulator()
|
||||
async for event in published_stream:
|
||||
if err := _process_event(event, acc):
|
||||
handle.error_msg = err
|
||||
# stream_and_publish skips StreamError events, so
|
||||
# mark_session_completed must publish the error to Redis.
|
||||
handle.error_already_published = False
|
||||
raise RuntimeError(f"Copilot error: {err}")
|
||||
except Exception:
|
||||
if handle.error_msg is None:
|
||||
handle.error_msg = "AutoPilot execution failed"
|
||||
raise
|
||||
from .service import stream_chat_completion_sdk
|
||||
|
||||
result = CopilotResult()
|
||||
result.response_text = "".join(acc.response_parts)
|
||||
result.tool_calls = [tc.model_dump() for tc in acc.tool_calls]
|
||||
result.prompt_tokens = acc.prompt_tokens
|
||||
result.completion_tokens = acc.completion_tokens
|
||||
result.total_tokens = acc.total_tokens
|
||||
response_parts: list[str] = []
|
||||
tool_calls_by_id: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
):
|
||||
if isinstance(event, StreamTextDelta):
|
||||
response_parts.append(event.delta)
|
||||
elif isinstance(event, StreamToolInputAvailable):
|
||||
entry: dict[str, Any] = {
|
||||
"tool_call_id": event.toolCallId,
|
||||
"tool_name": event.toolName,
|
||||
"input": event.input,
|
||||
"output": None,
|
||||
"success": None,
|
||||
}
|
||||
result.tool_calls.append(entry)
|
||||
tool_calls_by_id[event.toolCallId] = entry
|
||||
elif isinstance(event, StreamToolOutputAvailable):
|
||||
if tc := tool_calls_by_id.get(event.toolCallId):
|
||||
tc["output"] = event.output
|
||||
tc["success"] = event.success
|
||||
elif isinstance(event, StreamUsage):
|
||||
result.prompt_tokens += event.prompt_tokens
|
||||
result.completion_tokens += event.completion_tokens
|
||||
result.total_tokens += event.total_tokens
|
||||
elif isinstance(event, StreamError):
|
||||
raise RuntimeError(f"Copilot error: {event.errorText}")
|
||||
|
||||
result.response_text = "".join(response_parts)
|
||||
return result
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
"""Tests for collect_copilot_response stream registry integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.copilot.sdk.collect import collect_copilot_response
|
||||
|
||||
|
||||
def _mock_stream_fn(*events):
|
||||
"""Return a callable that returns an async generator."""
|
||||
|
||||
async def _gen(**_kwargs):
|
||||
for e in events:
|
||||
yield e
|
||||
|
||||
return _gen
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry():
|
||||
"""Patch stream_registry module used by collect."""
|
||||
with patch("backend.copilot.sdk.collect.stream_registry") as m:
|
||||
m.create_session = AsyncMock()
|
||||
m.publish_chunk = AsyncMock()
|
||||
m.mark_session_completed = AsyncMock()
|
||||
|
||||
# stream_and_publish: pass-through that also publishes (real logic)
|
||||
# We re-implement the pass-through here so the event loop works,
|
||||
# but still track publish_chunk calls via the mock.
|
||||
async def _stream_and_publish(session_id, turn_id, stream):
|
||||
async for event in stream:
|
||||
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
|
||||
await m.publish_chunk(turn_id, event)
|
||||
yield event
|
||||
|
||||
m.stream_and_publish = _stream_and_publish
|
||||
yield m
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stream_fn_patch():
|
||||
"""Helper to patch stream_chat_completion_sdk."""
|
||||
|
||||
def _patch(events):
|
||||
return patch(
|
||||
"backend.copilot.sdk.collect.stream_chat_completion_sdk",
|
||||
new=_mock_stream_fn(*events),
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_registry_called_on_success(mock_registry, stream_fn_patch):
|
||||
"""Stream registry create/publish/complete are called correctly on success."""
|
||||
events = [
|
||||
StreamTextDelta(id="t1", delta="Hello "),
|
||||
StreamTextDelta(id="t1", delta="world"),
|
||||
StreamUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
StreamFinish(),
|
||||
]
|
||||
|
||||
with stream_fn_patch(events):
|
||||
result = await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result.response_text == "Hello world"
|
||||
assert result.total_tokens == 15
|
||||
|
||||
mock_registry.create_session.assert_awaited_once()
|
||||
# StreamFinish should NOT be published (mark_session_completed does it)
|
||||
published_types = [
|
||||
type(call.args[1]).__name__
|
||||
for call in mock_registry.publish_chunk.call_args_list
|
||||
]
|
||||
assert "StreamFinish" not in published_types
|
||||
assert "StreamTextDelta" in published_types
|
||||
|
||||
mock_registry.mark_session_completed.assert_awaited_once()
|
||||
_, kwargs = mock_registry.mark_session_completed.call_args
|
||||
assert kwargs.get("error_message") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_registry_error_on_stream_error(mock_registry, stream_fn_patch):
|
||||
"""mark_session_completed receives error message when StreamError occurs."""
|
||||
events = [
|
||||
StreamTextDelta(id="t1", delta="partial"),
|
||||
StreamError(errorText="something broke"),
|
||||
]
|
||||
|
||||
with stream_fn_patch(events):
|
||||
with pytest.raises(RuntimeError, match="something broke"):
|
||||
await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
_, kwargs = mock_registry.mark_session_completed.call_args
|
||||
assert kwargs.get("error_message") == "something broke"
|
||||
# stream_and_publish skips StreamError, so mark_session_completed must
|
||||
# publish it (skip_error_publish=False).
|
||||
assert kwargs.get("skip_error_publish") is False
|
||||
|
||||
# StreamError should NOT be published via publish_chunk — mark_session_completed
|
||||
# handles it to avoid double-publication.
|
||||
published_types = [
|
||||
type(call.args[1]).__name__
|
||||
for call in mock_registry.publish_chunk.call_args_list
|
||||
]
|
||||
assert "StreamError" not in published_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_degradation_when_create_session_fails(
|
||||
mock_registry, stream_fn_patch
|
||||
):
|
||||
"""AutoPilot still works when stream registry create_session raises."""
|
||||
events = [
|
||||
StreamTextDelta(id="t1", delta="works"),
|
||||
StreamFinish(),
|
||||
]
|
||||
mock_registry.create_session = AsyncMock(side_effect=ConnectionError("Redis down"))
|
||||
|
||||
with stream_fn_patch(events):
|
||||
result = await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result.response_text == "works"
|
||||
# publish_chunk should NOT be called because turn_id was cleared
|
||||
mock_registry.publish_chunk.assert_not_awaited()
|
||||
# mark_session_completed IS still called to clean up any partial state
|
||||
mock_registry.mark_session_completed.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_calls_published_and_collected(mock_registry, stream_fn_patch):
|
||||
"""Tool call events are both published to registry and collected in result."""
|
||||
events = [
|
||||
StreamToolInputAvailable(
|
||||
toolCallId="tc-1", toolName="read_file", input={"path": "/tmp"}
|
||||
),
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId="tc-1", output="file contents", success=True
|
||||
),
|
||||
StreamTextDelta(id="t1", delta="done"),
|
||||
StreamFinish(),
|
||||
]
|
||||
|
||||
with stream_fn_patch(events):
|
||||
result = await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0]["tool_name"] == "read_file"
|
||||
assert result.tool_calls[0]["output"] == "file contents"
|
||||
assert result.tool_calls[0]["success"] is True
|
||||
assert result.response_text == "done"
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
When E2B is active, these tools replace the SDK built-in Read/Write/Edit/
|
||||
Glob/Grep so that all file operations share the same ``/home/user``
|
||||
and ``/tmp`` filesystems as ``bash_exec``.
|
||||
filesystem as ``bash_exec``.
|
||||
|
||||
SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
@@ -16,13 +16,10 @@ import shlex
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.copilot.context import (
|
||||
E2B_ALLOWED_DIRS,
|
||||
E2B_ALLOWED_DIRS_STR,
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
is_within_allowed_dirs,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
@@ -39,7 +36,7 @@ async def _check_sandbox_symlink_escape(
|
||||
``readlink -f`` follows actual symlinks on the sandbox filesystem.
|
||||
|
||||
Returns the canonical parent path, or ``None`` if the path escapes
|
||||
the allowed sandbox directories.
|
||||
``E2B_WORKDIR``.
|
||||
|
||||
Note: There is an inherent TOCTOU window between this check and the
|
||||
subsequent ``sandbox.files.write()``. A symlink could theoretically be
|
||||
@@ -55,7 +52,10 @@ async def _check_sandbox_symlink_escape(
|
||||
if (
|
||||
canonical_res.exit_code != 0
|
||||
or not canonical_parent
|
||||
or not is_within_allowed_dirs(canonical_parent)
|
||||
or (
|
||||
canonical_parent != E2B_WORKDIR
|
||||
and not canonical_parent.startswith(E2B_WORKDIR + "/")
|
||||
)
|
||||
):
|
||||
return None
|
||||
return canonical_parent
|
||||
@@ -89,38 +89,6 @@ def _get_sandbox_and_path(
|
||||
return sandbox, remote
|
||||
|
||||
|
||||
async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
|
||||
"""Write *content* to *path* inside the sandbox.
|
||||
|
||||
The E2B filesystem API (``sandbox.files.write``) and the command API
|
||||
(``sandbox.commands.run``) run as **different users**. On ``/tmp``
|
||||
(which has the sticky bit set) this means ``sandbox.files.write`` can
|
||||
create new files but cannot overwrite files previously created by
|
||||
``sandbox.commands.run`` (or itself), because the sticky bit restricts
|
||||
deletion/rename to the file owner.
|
||||
|
||||
To work around this, writes targeting ``/tmp`` are performed via
|
||||
``tee`` through the command API, which runs as the sandbox ``user``
|
||||
and can therefore always overwrite user-owned files.
|
||||
"""
|
||||
if path == "/tmp" or path.startswith("/tmp/"):
|
||||
import base64 as _b64
|
||||
|
||||
encoded = _b64.b64encode(content.encode()).decode()
|
||||
result = await sandbox.commands.run(
|
||||
f"echo {shlex.quote(encoded)} | base64 -d > {shlex.quote(path)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=10,
|
||||
)
|
||||
if result.exit_code != 0:
|
||||
raise RuntimeError(
|
||||
f"shell write failed (exit {result.exit_code}): "
|
||||
+ (result.stderr or "").strip()
|
||||
)
|
||||
else:
|
||||
await sandbox.files.write(path, content)
|
||||
|
||||
|
||||
# Tool handlers
|
||||
|
||||
|
||||
@@ -171,16 +139,13 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
try:
|
||||
parent = os.path.dirname(remote)
|
||||
if parent and parent not in E2B_ALLOWED_DIRS:
|
||||
if parent and parent != E2B_WORKDIR:
|
||||
await sandbox.files.make_dir(parent)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(
|
||||
f"Path must be within {E2B_ALLOWED_DIRS_STR}: {os.path.basename(parent)}",
|
||||
error=True,
|
||||
)
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
await _sandbox_write(sandbox, remote, content)
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
@@ -207,10 +172,7 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
parent = os.path.dirname(remote)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(
|
||||
f"Path must be within {E2B_ALLOWED_DIRS_STR}: {os.path.basename(parent)}",
|
||||
error=True,
|
||||
)
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
|
||||
try:
|
||||
@@ -235,7 +197,7 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
else content.replace(old_string, new_string, 1)
|
||||
)
|
||||
try:
|
||||
await _sandbox_write(sandbox, remote, updated)
|
||||
await sandbox.files.write(remote, updated)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
@@ -328,14 +290,14 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
|
||||
E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
(
|
||||
"read_file",
|
||||
"Read a file from the cloud sandbox (/home/user or /tmp). "
|
||||
"Read a file from the cloud sandbox (/home/user). "
|
||||
"Use offset and limit for large files.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute under /home/user or /tmp).",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
@@ -352,7 +314,7 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
),
|
||||
(
|
||||
"write_file",
|
||||
"Write or create a file in the cloud sandbox (/home/user or /tmp). "
|
||||
"Write or create a file in the cloud sandbox (/home/user). "
|
||||
"Parent directories are created automatically. "
|
||||
"To copy a workspace file into the sandbox, use "
|
||||
"read_workspace_file with save_to_path instead.",
|
||||
@@ -361,7 +323,7 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute under /home/user or /tmp).",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"content": {"type": "string", "description": "Content to write."},
|
||||
},
|
||||
@@ -378,7 +340,7 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute under /home/user or /tmp).",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"old_string": {"type": "string", "description": "Text to find."},
|
||||
"new_string": {"type": "string", "description": "Replacement text."},
|
||||
|
||||
@@ -15,7 +15,6 @@ from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_proj
|
||||
from .e2b_file_tools import (
|
||||
_check_sandbox_symlink_escape,
|
||||
_read_local,
|
||||
_sandbox_write,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
@@ -40,23 +39,23 @@ class TestResolveSandboxPath:
|
||||
assert resolve_sandbox_path("./README.md") == f"{E2B_WORKDIR}/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
@@ -69,24 +68,6 @@ class TestResolveSandboxPath:
|
||||
"""Path that resolves back within E2B_WORKDIR is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == f"{E2B_WORKDIR}/a/c.txt"
|
||||
|
||||
def test_tmp_absolute_allowed(self):
|
||||
assert resolve_sandbox_path("/tmp/data.txt") == "/tmp/data.txt"
|
||||
|
||||
def test_tmp_nested_allowed(self):
|
||||
assert resolve_sandbox_path("/tmp/a/b/c.txt") == "/tmp/a/b/c.txt"
|
||||
|
||||
def test_tmp_itself_allowed(self):
|
||||
assert resolve_sandbox_path("/tmp") == "/tmp"
|
||||
|
||||
def test_tmp_escape_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/tmp/../etc/passwd")
|
||||
|
||||
def test_tmp_prefix_collision_blocked(self):
|
||||
"""A path like /tmp_evil should be blocked (not a prefix match)."""
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/tmp_evil/malicious.txt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local — host filesystem reads with allowlist enforcement
|
||||
@@ -246,92 +227,3 @@ class TestCheckSandboxSymlinkEscape:
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/a/b/c/d\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/a/b/c/d")
|
||||
assert result == f"{E2B_WORKDIR}/a/b/c/d"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_path_allowed(self):
|
||||
"""Paths resolving to /tmp are allowed."""
|
||||
sandbox = _make_sandbox(stdout="/tmp/workdir\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, "/tmp/workdir")
|
||||
assert result == "/tmp/workdir"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_itself_allowed(self):
|
||||
"""The /tmp directory itself is allowed."""
|
||||
sandbox = _make_sandbox(stdout="/tmp\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, "/tmp")
|
||||
assert result == "/tmp"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sandbox_write — routing writes through shell for /tmp paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSandboxWrite:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_path_uses_shell_command(self):
|
||||
"""Writes to /tmp should use commands.run (shell) instead of files.write."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
sandbox = SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
await _sandbox_write(sandbox, "/tmp/test.py", "print('hello')")
|
||||
|
||||
commands.run.assert_called_once()
|
||||
files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_user_path_uses_files_api(self):
|
||||
"""Writes to /home/user should use sandbox.files.write."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
sandbox = SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
await _sandbox_write(sandbox, "/home/user/test.py", "print('hello')")
|
||||
|
||||
files.write.assert_called_once_with("/home/user/test.py", "print('hello')")
|
||||
commands.run.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_nested_path_uses_shell_command(self):
|
||||
"""Writes to nested /tmp paths should use commands.run."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
sandbox = SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
await _sandbox_write(sandbox, "/tmp/subdir/file.txt", "content")
|
||||
|
||||
commands.run.assert_called_once()
|
||||
files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_write_shell_failure_raises(self):
|
||||
"""Shell write failure should raise RuntimeError."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="No space left", exit_code=1)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
sandbox = SimpleNamespace(commands=commands)
|
||||
|
||||
with pytest.raises(RuntimeError, match="shell write failed"):
|
||||
await _sandbox_write(sandbox, "/tmp/test.txt", "content")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_write_preserves_content_with_special_chars(self):
|
||||
"""Content with special shell characters should be preserved via base64."""
|
||||
import base64
|
||||
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
sandbox = SimpleNamespace(commands=commands)
|
||||
|
||||
content = "print(\"Hello $USER\")\n# a `backtick` and 'quotes'\n"
|
||||
await _sandbox_write(sandbox, "/tmp/special.py", content)
|
||||
|
||||
# Verify the command contains base64-encoded content
|
||||
call_args = commands.run.call_args[0][0]
|
||||
# Extract the base64 string from the command
|
||||
encoded_in_cmd = call_args.split("echo ")[1].split(" |")[0].strip("'")
|
||||
decoded = base64.b64decode(encoded_in_cmd).decode()
|
||||
assert decoded == content
|
||||
|
||||
@@ -2,20 +2,19 @@
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
from typing import Any, NamedTuple, cast
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
@@ -32,7 +31,6 @@ from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.permissions import apply_tool_permissions
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -79,15 +77,10 @@ from ..tracking import track_user_message
|
||||
from .compaction import CompactionTracker, filter_compaction_messages
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .subscription import validate_subscription as _validate_claude_code_subscription
|
||||
from .tool_adapter import (
|
||||
cancel_pending_tool_tasks,
|
||||
create_copilot_mcp_server,
|
||||
get_copilot_tool_names,
|
||||
get_sdk_disallowed_tools,
|
||||
pre_launch_tool_call,
|
||||
reset_stash_event,
|
||||
reset_tool_failure_counters,
|
||||
set_execution_context,
|
||||
wait_for_stash,
|
||||
)
|
||||
@@ -113,20 +106,6 @@ config = ChatConfig()
|
||||
# Non-context errors (network, auth, rate-limit) are NOT retried.
|
||||
_MAX_STREAM_ATTEMPTS = 3
|
||||
|
||||
# Hard circuit breaker: abort the stream if the model sends this many
|
||||
# consecutive tool calls with empty parameters (a sign of context
|
||||
# saturation or serialization failure). Empty input ({}) is never
|
||||
# legitimate — even one is suspicious, three is conclusive.
|
||||
_EMPTY_TOOL_CALL_LIMIT = 3
|
||||
|
||||
# User-facing error shown when the empty-tool-call circuit breaker trips.
|
||||
_CIRCUIT_BREAKER_ERROR_MSG = (
|
||||
"AutoPilot was unable to complete the tool call "
|
||||
"— this usually happens when the response is "
|
||||
"too large to fit in a single tool call. "
|
||||
"Try breaking your request into smaller parts."
|
||||
)
|
||||
|
||||
# Patterns that indicate the prompt/request exceeds the model's context limit.
|
||||
# Matched case-insensitively against the full exception chain.
|
||||
_PROMPT_TOO_LONG_PATTERNS: tuple[str, ...] = (
|
||||
@@ -185,19 +164,6 @@ def _is_prompt_too_long(err: BaseException) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_tool_only_message(sdk_msg: object) -> bool:
|
||||
"""Return True if *sdk_msg* is an AssistantMessage containing only ToolUseBlocks.
|
||||
|
||||
Such a message represents a parallel tool-call batch (no text output yet).
|
||||
The ``bool(…content)`` guard prevents vacuous-truth evaluation on an empty list.
|
||||
"""
|
||||
return (
|
||||
isinstance(sdk_msg, AssistantMessage)
|
||||
and bool(sdk_msg.content)
|
||||
and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
|
||||
)
|
||||
|
||||
|
||||
class ReducedContext(NamedTuple):
|
||||
builder: TranscriptBuilder
|
||||
use_resume: bool
|
||||
@@ -492,6 +458,37 @@ def _resolve_sdk_model() -> str | None:
|
||||
return model
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _validate_claude_code_subscription() -> None:
|
||||
"""Validate Claude CLI is installed and responds to `--version`.
|
||||
|
||||
Cached so the blocking subprocess check runs at most once per process
|
||||
lifetime. A failure (CLI not installed) is a config error that requires
|
||||
a process restart anyway.
|
||||
"""
|
||||
claude_path = shutil.which("claude")
|
||||
if not claude_path:
|
||||
raise RuntimeError(
|
||||
"Claude Code CLI not found. Install it with: "
|
||||
"npm install -g @anthropic-ai/claude-code"
|
||||
)
|
||||
result = subprocess.run(
|
||||
[claude_path, "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Claude CLI check failed (exit {result.returncode}): "
|
||||
f"{result.stderr.strip()}"
|
||||
)
|
||||
logger.info(
|
||||
"Claude Code subscription mode: CLI version %s",
|
||||
result.stdout.strip(),
|
||||
)
|
||||
|
||||
|
||||
def _build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
@@ -1031,122 +1028,15 @@ def _dispatch_response(
|
||||
return response
|
||||
|
||||
|
||||
class _HandledStreamError(Exception):
|
||||
class _TransientErrorHandled(Exception):
|
||||
"""Raised by `_run_stream_attempt` after it has already yielded a
|
||||
`StreamError` to the client (e.g. transient API error, circuit breaker).
|
||||
`StreamError` for a transient API error.
|
||||
|
||||
This signals the outer retry loop that the attempt failed so it can
|
||||
perform session-message rollback and set the `ended_with_stream_error`
|
||||
flag, **without** yielding a duplicate `StreamError` to the client.
|
||||
|
||||
Attributes:
|
||||
error_msg: The user-facing error message to persist.
|
||||
code: Machine-readable error code (e.g. ``circuit_breaker_empty_tool_calls``).
|
||||
retryable: Whether the frontend should offer a retry button.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
error_msg: str | None = None,
|
||||
code: str | None = None,
|
||||
retryable: bool = True,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.error_msg = error_msg
|
||||
self.code = code
|
||||
self.retryable = retryable
|
||||
|
||||
|
||||
@dataclass
|
||||
class _EmptyToolBreakResult:
|
||||
"""Result of checking for empty tool calls in a single AssistantMessage."""
|
||||
|
||||
count: int # Updated consecutive counter
|
||||
tripped: bool # Whether the circuit breaker fired
|
||||
error: StreamError | None # StreamError to yield (if tripped)
|
||||
error_msg: str | None # Error message (if tripped)
|
||||
error_code: str | None # Error code (if tripped)
|
||||
|
||||
|
||||
def _check_empty_tool_breaker(
|
||||
sdk_msg: object,
|
||||
consecutive: int,
|
||||
ctx: _StreamContext,
|
||||
state: _RetryState,
|
||||
) -> _EmptyToolBreakResult:
|
||||
"""Detect consecutive empty tool calls and trip the circuit breaker.
|
||||
|
||||
Returns an ``_EmptyToolBreakResult`` with the updated counter and, if the
|
||||
breaker tripped, the ``StreamError`` to yield plus the error metadata.
|
||||
"""
|
||||
if not isinstance(sdk_msg, AssistantMessage):
|
||||
return _EmptyToolBreakResult(consecutive, False, None, None, None)
|
||||
|
||||
empty_tools = [
|
||||
b.name for b in sdk_msg.content if isinstance(b, ToolUseBlock) and not b.input
|
||||
]
|
||||
if not empty_tools:
|
||||
# Reset on any non-empty-tool AssistantMessage (including text-only
|
||||
# messages — any() over empty content is False).
|
||||
return _EmptyToolBreakResult(0, False, None, None, None)
|
||||
|
||||
consecutive += 1
|
||||
|
||||
# Log full diagnostics on first occurrence only; subsequent hits just
|
||||
# log the counter to reduce noise.
|
||||
if consecutive == 1:
|
||||
logger.warning(
|
||||
"%s Empty tool call detected (%d/%d): "
|
||||
"tools=%s, model=%s, error=%s, "
|
||||
"block_types=%s, cumulative_usage=%s",
|
||||
ctx.log_prefix,
|
||||
consecutive,
|
||||
_EMPTY_TOOL_CALL_LIMIT,
|
||||
empty_tools,
|
||||
sdk_msg.model,
|
||||
sdk_msg.error,
|
||||
[type(b).__name__ for b in sdk_msg.content],
|
||||
{
|
||||
"prompt": state.usage.prompt_tokens,
|
||||
"completion": state.usage.completion_tokens,
|
||||
"cache_read": state.usage.cache_read_tokens,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s Empty tool call detected (%d/%d): tools=%s",
|
||||
ctx.log_prefix,
|
||||
consecutive,
|
||||
_EMPTY_TOOL_CALL_LIMIT,
|
||||
empty_tools,
|
||||
)
|
||||
|
||||
if consecutive < _EMPTY_TOOL_CALL_LIMIT:
|
||||
return _EmptyToolBreakResult(consecutive, False, None, None, None)
|
||||
|
||||
logger.error(
|
||||
"%s Circuit breaker: aborting stream after %d "
|
||||
"consecutive empty tool calls. "
|
||||
"This is likely caused by the model attempting "
|
||||
"to write content too large for a single tool "
|
||||
"call's output token limit. The model should "
|
||||
"write large files in chunks using bash_exec "
|
||||
"with cat >> (append).",
|
||||
ctx.log_prefix,
|
||||
consecutive,
|
||||
)
|
||||
error_msg = _CIRCUIT_BREAKER_ERROR_MSG
|
||||
error_code = "circuit_breaker_empty_tool_calls"
|
||||
_append_error_marker(ctx.session, error_msg, retryable=True)
|
||||
return _EmptyToolBreakResult(
|
||||
count=consecutive,
|
||||
tripped=True,
|
||||
error=StreamError(errorText=error_msg, code=error_code),
|
||||
error_msg=error_msg,
|
||||
error_code=error_code,
|
||||
)
|
||||
|
||||
|
||||
async def _run_stream_attempt(
|
||||
ctx: _StreamContext,
|
||||
@@ -1181,12 +1071,6 @@ async def _run_stream_attempt(
|
||||
accumulated_tool_calls=[],
|
||||
)
|
||||
ended_with_stream_error = False
|
||||
# Stores the error message used by _append_error_marker so the outer
|
||||
# retry loop can re-append the correct message after session rollback.
|
||||
stream_error_msg: str | None = None
|
||||
stream_error_code: str | None = None
|
||||
|
||||
consecutive_empty_tool_calls = 0
|
||||
|
||||
async with ClaudeSDKClient(options=state.options) as client:
|
||||
logger.info(
|
||||
@@ -1277,43 +1161,18 @@ async def _run_stream_attempt(
|
||||
"suppressing raw error text",
|
||||
ctx.log_prefix,
|
||||
)
|
||||
stream_error_msg = FRIENDLY_TRANSIENT_MSG
|
||||
stream_error_code = "transient_api_error"
|
||||
_append_error_marker(
|
||||
ctx.session,
|
||||
stream_error_msg,
|
||||
FRIENDLY_TRANSIENT_MSG,
|
||||
retryable=True,
|
||||
)
|
||||
yield StreamError(
|
||||
errorText=stream_error_msg,
|
||||
code=stream_error_code,
|
||||
errorText=FRIENDLY_TRANSIENT_MSG,
|
||||
code="transient_api_error",
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
|
||||
# Parallel tool execution: pre-launch every ToolUseBlock as an
|
||||
# asyncio.Task the moment its AssistantMessage arrives. The SDK
|
||||
# sends one AssistantMessage per tool call when issuing parallel
|
||||
# calls, so each message is pre-launched independently. The MCP
|
||||
# handlers will await the already-running task instead of executing
|
||||
# fresh, making all concurrent tool calls run in parallel.
|
||||
#
|
||||
# Also determine if the message is a tool-only batch (all content
|
||||
# items are ToolUseBlocks) — such messages have no text output yet,
|
||||
# so we skip the wait_for_stash flush below.
|
||||
is_tool_only = False
|
||||
if isinstance(sdk_msg, AssistantMessage) and sdk_msg.content:
|
||||
is_tool_only = True
|
||||
# NOTE: Pre-launches are sequential (each await completes
|
||||
# file-ref expansion before the next starts). This is fine
|
||||
# since expansion is typically sub-ms; a future optimisation
|
||||
# could gather all pre-launches concurrently.
|
||||
for tool_use in sdk_msg.content:
|
||||
if isinstance(tool_use, ToolUseBlock):
|
||||
await pre_launch_tool_call(tool_use.name, tool_use.input)
|
||||
else:
|
||||
is_tool_only = False
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
# message can arrive before the hook stashes output.
|
||||
@@ -1327,12 +1186,15 @@ async def _run_stream_attempt(
|
||||
# AssistantMessages (each containing only
|
||||
# ToolUseBlocks), we must NOT wait/flush — the prior
|
||||
# tools are still executing concurrently.
|
||||
is_parallel_continuation = isinstance(sdk_msg, AssistantMessage) and all(
|
||||
isinstance(b, ToolUseBlock) for b in sdk_msg.content
|
||||
)
|
||||
if (
|
||||
state.adapter.has_unresolved_tool_calls
|
||||
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
|
||||
and not is_tool_only
|
||||
and not is_parallel_continuation
|
||||
):
|
||||
if await wait_for_stash():
|
||||
if await wait_for_stash(timeout=0.5):
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -1347,17 +1209,13 @@ async def _run_stream_attempt(
|
||||
if isinstance(sdk_msg, ResultMessage):
|
||||
logger.info(
|
||||
"%s Received: ResultMessage %s "
|
||||
"(unresolved=%d, current=%d, resolved=%d, "
|
||||
"num_turns=%d, cost_usd=%s, result=%s)",
|
||||
"(unresolved=%d, current=%d, resolved=%d)",
|
||||
ctx.log_prefix,
|
||||
sdk_msg.subtype,
|
||||
len(state.adapter.current_tool_calls)
|
||||
- len(state.adapter.resolved_tool_calls),
|
||||
len(state.adapter.current_tool_calls),
|
||||
len(state.adapter.resolved_tool_calls),
|
||||
sdk_msg.num_turns,
|
||||
sdk_msg.total_cost_usd,
|
||||
(sdk_msg.result or "")[:200],
|
||||
)
|
||||
if sdk_msg.subtype in (
|
||||
"error",
|
||||
@@ -1414,18 +1272,6 @@ async def _run_stream_attempt(
|
||||
)
|
||||
entries_replaced = True
|
||||
|
||||
# --- Hard circuit breaker for empty tool calls ---
|
||||
breaker = _check_empty_tool_breaker(
|
||||
sdk_msg, consecutive_empty_tool_calls, ctx, state
|
||||
)
|
||||
consecutive_empty_tool_calls = breaker.count
|
||||
if breaker.tripped and breaker.error is not None:
|
||||
stream_error_msg = breaker.error_msg
|
||||
stream_error_code = breaker.error_code
|
||||
yield breaker.error
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
|
||||
# --- Dispatch adapter responses ---
|
||||
for response in state.adapter.convert_message(sdk_msg):
|
||||
dispatched = _dispatch_response(
|
||||
@@ -1506,10 +1352,8 @@ async def _run_stream_attempt(
|
||||
# to the client (StreamError yielded above), raise so the outer retry
|
||||
# loop can rollback session messages and set its error flags properly.
|
||||
if ended_with_stream_error:
|
||||
raise _HandledStreamError(
|
||||
"Stream error handled — StreamError already yielded",
|
||||
error_msg=stream_error_msg,
|
||||
code=stream_error_code,
|
||||
raise _TransientErrorHandled(
|
||||
"Transient API error handled — StreamError already yielded"
|
||||
)
|
||||
|
||||
|
||||
@@ -1520,7 +1364,6 @@ async def stream_chat_completion_sdk(
|
||||
user_id: str | None = None,
|
||||
session: ChatSession | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncIterator[StreamBaseResponse]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
@@ -1766,13 +1609,7 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
set_execution_context(
|
||||
user_id,
|
||||
session,
|
||||
sandbox=e2b_sandbox,
|
||||
sdk_cwd=sdk_cwd,
|
||||
permissions=permissions,
|
||||
)
|
||||
set_execution_context(user_id, session, sandbox=e2b_sandbox, sdk_cwd=sdk_cwd)
|
||||
|
||||
# Fail fast when no API credentials are available at all.
|
||||
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
@@ -1798,11 +1635,8 @@ async def stream_chat_completion_sdk(
|
||||
on_compact=compaction.on_compact,
|
||||
)
|
||||
|
||||
if permissions is not None:
|
||||
allowed, disallowed = apply_tool_permissions(permissions, use_e2b=use_e2b)
|
||||
else:
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
def _on_stderr(line: str) -> None:
|
||||
"""Log a stderr line emitted by the Claude CLI subprocess."""
|
||||
@@ -1912,12 +1746,6 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
for attempt in range(_MAX_STREAM_ATTEMPTS):
|
||||
# Clear any stale stash signal from the previous attempt so
|
||||
# wait_for_stash() doesn't fire prematurely on a leftover event.
|
||||
reset_stash_event()
|
||||
# Reset tool-level circuit breaker so failures from a previous
|
||||
# (rolled-back) attempt don't carry over to the fresh attempt.
|
||||
reset_tool_failure_counters()
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
"%s Retrying with reduced context (%d/%d)",
|
||||
@@ -1973,10 +1801,6 @@ async def stream_chat_completion_sdk(
|
||||
if not isinstance(event, StreamHeartbeat):
|
||||
events_yielded += 1
|
||||
yield event
|
||||
# Cancel any pre-launched tasks that were never dispatched
|
||||
# by the SDK (e.g. edge-case SDK behaviour changes). Symmetric
|
||||
# with the three error-path await cancel_pending_tool_tasks() calls.
|
||||
await cancel_pending_tool_tasks()
|
||||
break # Stream completed — exit retry loop
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(
|
||||
@@ -1985,42 +1809,26 @@ async def stream_chat_completion_sdk(
|
||||
attempt + 1,
|
||||
_MAX_STREAM_ATTEMPTS,
|
||||
)
|
||||
# Cancel any pre-launched tasks so they don't continue executing
|
||||
# against a rolled-back or abandoned session.
|
||||
await cancel_pending_tool_tasks()
|
||||
raise
|
||||
except _HandledStreamError as exc:
|
||||
except _TransientErrorHandled:
|
||||
# _run_stream_attempt already yielded a StreamError and
|
||||
# appended an error marker. We only need to rollback
|
||||
# session messages and set the error flag — do NOT set
|
||||
# stream_err so the post-loop code won't emit a
|
||||
# duplicate StreamError.
|
||||
logger.warning(
|
||||
"%s Stream error handled in attempt "
|
||||
"(attempt %d/%d, code=%s, events_yielded=%d)",
|
||||
"%s Transient error handled in stream attempt "
|
||||
"(attempt %d/%d, events_yielded=%d)",
|
||||
log_prefix,
|
||||
attempt + 1,
|
||||
_MAX_STREAM_ATTEMPTS,
|
||||
exc.code or "transient",
|
||||
events_yielded,
|
||||
)
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
# transcript_builder still contains entries from the aborted
|
||||
# attempt that no longer match session.messages. Skip upload
|
||||
# so a future --resume doesn't replay rolled-back content.
|
||||
skip_transcript_upload = True
|
||||
# Re-append the error marker so it survives the rollback
|
||||
# and is persisted by the finally block (see #2947655365).
|
||||
# Use the specific error message from the attempt (e.g.
|
||||
# circuit breaker msg) rather than always the generic one.
|
||||
_append_error_marker(
|
||||
session,
|
||||
exc.error_msg or FRIENDLY_TRANSIENT_MSG,
|
||||
retryable=True,
|
||||
)
|
||||
_append_error_marker(session, FRIENDLY_TRANSIENT_MSG, retryable=True)
|
||||
ended_with_stream_error = True
|
||||
# Cancel any pre-launched tasks from the failed attempt.
|
||||
await cancel_pending_tool_tasks()
|
||||
break
|
||||
except Exception as e:
|
||||
stream_err = e
|
||||
@@ -2037,9 +1845,6 @@ async def stream_chat_completion_sdk(
|
||||
exc_info=True,
|
||||
)
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
# Cancel any pre-launched tasks from the failed attempt so they
|
||||
# don't continue executing against the rolled-back session.
|
||||
await cancel_pending_tool_tasks()
|
||||
if events_yielded > 0:
|
||||
# Events were already sent to the frontend and cannot be
|
||||
# unsent. Retrying would produce duplicate/inconsistent
|
||||
@@ -2049,13 +1854,11 @@ async def stream_chat_completion_sdk(
|
||||
log_prefix,
|
||||
events_yielded,
|
||||
)
|
||||
skip_transcript_upload = True
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
if not is_context_error:
|
||||
# Non-context errors (network, auth, rate-limit) should
|
||||
# not trigger compaction — surface the error immediately.
|
||||
skip_transcript_upload = True
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
continue
|
||||
@@ -2151,16 +1954,6 @@ async def stream_chat_completion_sdk(
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
)
|
||||
except GeneratorExit:
|
||||
# GeneratorExit is raised when the async generator is closed by the
|
||||
# caller (e.g. client disconnect, page refresh). We MUST release the
|
||||
# stream lock here because the ``finally`` block at the end of this
|
||||
# function may not execute when GeneratorExit propagates through nested
|
||||
# async generators. Without this, the lock stays held for its full TTL
|
||||
# and the user sees "Another stream is already active" on every retry.
|
||||
logger.warning("%s GeneratorExit — releasing stream lock", log_prefix)
|
||||
await lock.release()
|
||||
raise
|
||||
except BaseException as e:
|
||||
# Catch BaseException to handle both Exception and CancelledError
|
||||
# (CancelledError inherits from BaseException in Python 3.8+)
|
||||
|
||||
@@ -1,23 +1,21 @@
|
||||
"""Unit tests for extracted service helpers.
|
||||
|
||||
Covers ``_is_prompt_too_long``, ``_reduce_context``, ``_iter_sdk_messages``,
|
||||
``ReducedContext``, and the ``is_parallel_continuation`` logic.
|
||||
and the ``ReducedContext`` named tuple.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from claude_agent_sdk import AssistantMessage, TextBlock, ToolUseBlock
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
ReducedContext,
|
||||
_is_prompt_too_long,
|
||||
_is_tool_only_message,
|
||||
_iter_sdk_messages,
|
||||
_reduce_context,
|
||||
)
|
||||
@@ -283,55 +281,3 @@ class TestIterSdkMessages:
|
||||
first = await gen.__anext__()
|
||||
assert first == "first"
|
||||
await gen.aclose() # should cancel pending task cleanly
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_parallel_continuation logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsParallelContinuation:
|
||||
"""Unit tests for the is_parallel_continuation expression in the streaming loop.
|
||||
|
||||
Verifies the vacuous-truth guard (empty content must return False) and the
|
||||
boundary cases for mixed TextBlock+ToolUseBlock messages.
|
||||
"""
|
||||
|
||||
def _make_tool_block(self) -> MagicMock:
|
||||
block = MagicMock(spec=ToolUseBlock)
|
||||
return block
|
||||
|
||||
def test_all_tool_use_blocks_is_parallel(self):
|
||||
"""AssistantMessage with only ToolUseBlocks is a parallel continuation."""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = [self._make_tool_block(), self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is True
|
||||
|
||||
def test_empty_content_is_not_parallel(self):
|
||||
"""AssistantMessage with empty content must NOT be treated as parallel.
|
||||
|
||||
Without the bool(sdk_msg.content) guard, all() on an empty iterable
|
||||
returns True via vacuous truth — this test ensures the guard is present.
|
||||
"""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = []
|
||||
assert _is_tool_only_message(msg) is False
|
||||
|
||||
def test_mixed_text_and_tool_blocks_not_parallel(self):
|
||||
"""AssistantMessage with text + tool blocks is NOT a parallel continuation."""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
text_block = MagicMock(spec=TextBlock)
|
||||
msg.content = [text_block, self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is False
|
||||
|
||||
def test_non_assistant_message_not_parallel(self):
|
||||
"""Non-AssistantMessage types are never parallel continuations."""
|
||||
assert _is_tool_only_message("not a message") is False
|
||||
assert _is_tool_only_message(None) is False
|
||||
assert _is_tool_only_message(42) is False
|
||||
|
||||
def test_single_tool_block_is_parallel(self):
|
||||
"""Single ToolUseBlock AssistantMessage is a parallel continuation."""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = [self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is True
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
"""Claude Code subscription auth helpers.
|
||||
|
||||
Handles locating the SDK-bundled CLI binary, provisioning credentials from
|
||||
environment variables, and validating that subscription auth is functional.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_bundled_cli() -> str:
|
||||
"""Locate the Claude CLI binary bundled inside ``claude_agent_sdk``.
|
||||
|
||||
Falls back to ``shutil.which("claude")`` if the SDK bundle is absent.
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk._internal.transport.subprocess_cli import (
|
||||
SubprocessCLITransport,
|
||||
)
|
||||
|
||||
path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
|
||||
if path:
|
||||
return str(path)
|
||||
except Exception:
|
||||
pass
|
||||
system_path = shutil.which("claude")
|
||||
if system_path:
|
||||
return system_path
|
||||
raise RuntimeError(
|
||||
"Claude CLI not found — neither the SDK-bundled binary nor a "
|
||||
"system-installed `claude` could be located."
|
||||
)
|
||||
|
||||
|
||||
def provision_credentials_file() -> None:
|
||||
"""Write ``~/.claude/.credentials.json`` from env when running headless.
|
||||
|
||||
If ``CLAUDE_CODE_OAUTH_TOKEN`` is set (an OAuth *access* token obtained
|
||||
from ``claude auth status`` or extracted from the macOS keychain), this
|
||||
helper writes a minimal credentials file so the bundled CLI can
|
||||
authenticate without an interactive ``claude login``.
|
||||
|
||||
A ``CLAUDE_CODE_REFRESH_TOKEN`` env var is optional but recommended —
|
||||
it lets the CLI silently refresh an expired access token.
|
||||
"""
|
||||
access_token = os.environ.get("CLAUDE_CODE_OAUTH_TOKEN", "").strip()
|
||||
if not access_token:
|
||||
return
|
||||
|
||||
creds_dir = os.path.expanduser("~/.claude")
|
||||
creds_path = os.path.join(creds_dir, ".credentials.json")
|
||||
|
||||
# Don't overwrite an existing credentials file (e.g. from a volume mount).
|
||||
if os.path.exists(creds_path):
|
||||
logger.debug("Credentials file already exists at %s — skipping", creds_path)
|
||||
return
|
||||
|
||||
os.makedirs(creds_dir, exist_ok=True)
|
||||
|
||||
creds = {
|
||||
"claudeAiOauth": {
|
||||
"accessToken": access_token,
|
||||
"refreshToken": os.environ.get("CLAUDE_CODE_REFRESH_TOKEN", "").strip(),
|
||||
"expiresAt": 0,
|
||||
"scopes": [
|
||||
"user:inference",
|
||||
"user:profile",
|
||||
"user:sessions:claude_code",
|
||||
],
|
||||
}
|
||||
}
|
||||
with open(creds_path, "w") as f:
|
||||
json.dump(creds, f)
|
||||
logger.info("Provisioned Claude credentials file at %s", creds_path)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def validate_subscription() -> None:
|
||||
"""Validate the bundled Claude CLI is reachable and authenticated.
|
||||
|
||||
Cached so the blocking subprocess check runs at most once per process
|
||||
lifetime. On first call, also provisions ``~/.claude/.credentials.json``
|
||||
from the ``CLAUDE_CODE_OAUTH_TOKEN`` env var when available.
|
||||
"""
|
||||
provision_credentials_file()
|
||||
|
||||
cli = find_bundled_cli()
|
||||
result = subprocess.run(
|
||||
[cli, "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Claude CLI check failed (exit {result.returncode}): "
|
||||
f"{result.stderr.strip()}"
|
||||
)
|
||||
logger.info(
|
||||
"Claude Code subscription mode: CLI version %s",
|
||||
result.stdout.strip(),
|
||||
)
|
||||
|
||||
# Verify the CLI is actually authenticated.
|
||||
auth_result = subprocess.run(
|
||||
[cli, "auth", "status"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
env={
|
||||
**os.environ,
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
},
|
||||
)
|
||||
if auth_result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
"Claude CLI is not authenticated. Either:\n"
|
||||
" • Set CLAUDE_CODE_OAUTH_TOKEN env var (from `claude auth status` "
|
||||
"or macOS keychain), or\n"
|
||||
" • Mount ~/.claude/.credentials.json into the container, or\n"
|
||||
" • Run `claude login` inside the container."
|
||||
)
|
||||
try:
|
||||
status = json.loads(auth_result.stdout)
|
||||
if not status.get("loggedIn"):
|
||||
raise RuntimeError(
|
||||
"Claude CLI reports loggedIn=false. Set CLAUDE_CODE_OAUTH_TOKEN "
|
||||
"or run `claude login`."
|
||||
)
|
||||
logger.info(
|
||||
"Claude subscription auth: method=%s, email=%s",
|
||||
status.get("authMethod"),
|
||||
status.get("email"),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Could not parse `claude auth status` output")
|
||||
@@ -1,96 +0,0 @@
|
||||
"""Tests for the tool call circuit breaker in tool_adapter.py."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.tool_adapter import (
|
||||
_MAX_CONSECUTIVE_TOOL_FAILURES,
|
||||
_check_circuit_breaker,
|
||||
_clear_tool_failures,
|
||||
_consecutive_tool_failures,
|
||||
_record_tool_failure,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_tracker():
|
||||
"""Reset the circuit breaker tracker for each test."""
|
||||
token = _consecutive_tool_failures.set({})
|
||||
yield
|
||||
_consecutive_tool_failures.reset(token)
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
def test_no_trip_below_threshold(self):
|
||||
"""Circuit breaker should not trip before reaching the limit."""
|
||||
args = {"file_path": "/tmp/test.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES - 1):
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
_record_tool_failure("write_file", args)
|
||||
# Still under the limit
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
|
||||
def test_trips_at_threshold(self):
|
||||
"""Circuit breaker should trip after reaching the failure limit."""
|
||||
args = {"file_path": "/tmp/test.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
_record_tool_failure("write_file", args)
|
||||
# Now it should trip
|
||||
result = _check_circuit_breaker("write_file", args)
|
||||
assert result is not None
|
||||
assert "STOP" in result
|
||||
assert "write_file" in result
|
||||
|
||||
def test_different_args_tracked_separately(self):
|
||||
"""Different args should have separate failure counters."""
|
||||
args_a = {"file_path": "/tmp/a.txt"}
|
||||
args_b = {"file_path": "/tmp/b.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("write_file", args_a)
|
||||
# args_a should trip
|
||||
assert _check_circuit_breaker("write_file", args_a) is not None
|
||||
# args_b should NOT trip
|
||||
assert _check_circuit_breaker("write_file", args_b) is None
|
||||
|
||||
def test_different_tools_tracked_separately(self):
|
||||
"""Different tools should have separate failure counters."""
|
||||
args = {"file_path": "/tmp/test.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("tool_a", args)
|
||||
# tool_a should trip
|
||||
assert _check_circuit_breaker("tool_a", args) is not None
|
||||
# tool_b with same args should NOT trip
|
||||
assert _check_circuit_breaker("tool_b", args) is None
|
||||
|
||||
def test_empty_args_tracked(self):
|
||||
"""Empty args ({}) — the exact failure pattern from the bug — should be tracked."""
|
||||
args = {}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("write_file", args)
|
||||
assert _check_circuit_breaker("write_file", args) is not None
|
||||
|
||||
def test_clear_resets_counter(self):
|
||||
"""Clearing failures should reset the counter."""
|
||||
args = {}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("write_file", args)
|
||||
_clear_tool_failures("write_file")
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
|
||||
def test_success_clears_failures(self):
|
||||
"""A successful call should reset the failure counter."""
|
||||
args = {}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES - 1):
|
||||
_record_tool_failure("write_file", args)
|
||||
# Success clears failures
|
||||
_clear_tool_failures("write_file")
|
||||
# Should be able to fail again without tripping
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES - 1):
|
||||
_record_tool_failure("write_file", args)
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
|
||||
def test_no_tracker_returns_none(self):
|
||||
"""If tracker is not initialized, circuit breaker should not trip."""
|
||||
_consecutive_tool_failures.set(None) # type: ignore[arg-type]
|
||||
_record_tool_failure("write_file", {}) # should not raise
|
||||
assert _check_circuit_breaker("write_file", {}) is None
|
||||
@@ -16,7 +16,6 @@ from typing import TYPE_CHECKING, Any
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
from backend.copilot.context import (
|
||||
_current_permissions,
|
||||
_current_project_dir,
|
||||
_current_sandbox,
|
||||
_current_sdk_cwd,
|
||||
@@ -42,8 +41,6 @@ from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
|
||||
@@ -53,14 +50,6 @@ _MCP_MAX_CHARS = 500_000
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||
|
||||
# Map from tool_name -> Queue of pre-launched (task, args) pairs.
|
||||
# Initialised per-session in set_execution_context() so concurrent sessions
|
||||
# never share the same dict.
|
||||
_TaskQueueItem = tuple[asyncio.Task[dict[str, Any]], dict[str, Any]]
|
||||
_tool_task_queues: ContextVar[dict[str, asyncio.Queue[_TaskQueueItem]] | None] = (
|
||||
ContextVar("_tool_task_queues", default=None)
|
||||
)
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
@@ -77,23 +66,12 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
|
||||
"_stash_event", default=None
|
||||
)
|
||||
|
||||
# Circuit breaker: tracks consecutive tool failures to detect infinite retry loops.
|
||||
# When a tool is called repeatedly with empty/identical args and keeps failing,
|
||||
# this counter is incremented. After _MAX_CONSECUTIVE_TOOL_FAILURES identical
|
||||
# failures the tool handler returns a hard-stop message instead of the raw error.
|
||||
_MAX_CONSECUTIVE_TOOL_FAILURES = 3
|
||||
_consecutive_tool_failures: ContextVar[dict[str, int]] = ContextVar(
|
||||
"_consecutive_tool_failures",
|
||||
default=None, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
@@ -105,83 +83,14 @@ def set_execution_context(
|
||||
session: Current chat session.
|
||||
sandbox: Optional E2B sandbox; when set, bash_exec routes commands there.
|
||||
sdk_cwd: SDK working directory; used to scope tool-results reads.
|
||||
permissions: Optional capability filter restricting tools/blocks.
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
_current_permissions.set(permissions)
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
_tool_task_queues.set({})
|
||||
_consecutive_tool_failures.set({})
|
||||
|
||||
|
||||
def reset_stash_event() -> None:
|
||||
"""Clear any stale stash signal left over from a previous stream attempt.
|
||||
|
||||
``_stash_event`` is set once per session in ``set_execution_context`` and
|
||||
reused across retry attempts. A PostToolUse hook from a failed attempt may
|
||||
leave the event set; calling this at the start of each retry prevents
|
||||
``wait_for_stash`` from returning prematurely on a stale signal.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is not None:
|
||||
event.clear()
|
||||
|
||||
|
||||
async def cancel_pending_tool_tasks() -> None:
|
||||
"""Cancel all queued pre-launched tasks for the current execution context.
|
||||
|
||||
Call this when a stream attempt aborts (error, cancellation) to prevent
|
||||
pre-launched tasks from continuing to execute against a rolled-back session.
|
||||
Tasks that are already done are skipped; in-flight tasks are cancelled and
|
||||
awaited so that any cleanup (``finally`` blocks, DB rollbacks) completes
|
||||
before the next retry starts.
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if not queues:
|
||||
return
|
||||
cancelled_tasks: list[asyncio.Task] = []
|
||||
for tool_name, queue in list(queues.items()):
|
||||
cancelled = 0
|
||||
while not queue.empty():
|
||||
task, _args = queue.get_nowait()
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
cancelled_tasks.append(task)
|
||||
cancelled += 1
|
||||
if cancelled:
|
||||
logger.debug(
|
||||
"Cancelled %d pre-launched task(s) for tool '%s'", cancelled, tool_name
|
||||
)
|
||||
queues.clear()
|
||||
# Await all cancelled tasks so their cleanup (finally blocks, DB rollbacks)
|
||||
# completes before the next retry attempt starts new pre-launches.
|
||||
# Use a timeout to prevent hanging indefinitely if a task's cleanup is stuck.
|
||||
if cancelled_tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*cancelled_tasks, return_exceptions=True),
|
||||
timeout=5.0,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Timed out waiting for %d cancelled task(s) to clean up",
|
||||
len(cancelled_tasks),
|
||||
)
|
||||
|
||||
|
||||
def reset_tool_failure_counters() -> None:
|
||||
"""Reset all tool-level circuit breaker counters.
|
||||
|
||||
Called at the start of each SDK retry attempt so that failure counts
|
||||
from a previous (rolled-back) attempt do not carry over and prematurely
|
||||
trip the breaker on a fresh attempt with different context.
|
||||
"""
|
||||
_consecutive_tool_failures.set({})
|
||||
|
||||
|
||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
@@ -246,13 +155,12 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
by waiting on the ``_stash_event``, which is signaled by
|
||||
:func:`stash_pending_tool_output`.
|
||||
|
||||
Uses ``asyncio.Event.wait()`` so it returns the instant the hook signals —
|
||||
the timeout is purely a safety net for the case where the hook never fires.
|
||||
Returns ``True`` if the stash signal was received, ``False`` on timeout.
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
|
||||
The 2.0 s default was chosen to accommodate slower tool startup in cloud
|
||||
sandboxes while still failing fast when the hook genuinely will not fire.
|
||||
With the parallel pre-launch path, hooks typically fire well under 1 ms.
|
||||
The 2.0 s default was chosen based on production metrics: the original
|
||||
0.5 s caused frequent timeouts under load (parallel tool calls, large
|
||||
outputs). 2.0 s gives a comfortable margin while still failing fast
|
||||
when the hook genuinely will not fire.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
@@ -261,7 +169,7 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
if event.is_set():
|
||||
event.clear()
|
||||
return True
|
||||
# Slow path: block until the hook signals or the safety timeout expires.
|
||||
# Slow path: wait for the hook to signal.
|
||||
try:
|
||||
async with asyncio.timeout(timeout):
|
||||
await event.wait()
|
||||
@@ -271,82 +179,6 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def pre_launch_tool_call(tool_name: str, args: dict[str, Any]) -> None:
|
||||
"""Pre-launch a tool as a background task so parallel calls run concurrently.
|
||||
|
||||
Called when an AssistantMessage with ToolUseBlocks is received, before the
|
||||
SDK dispatches the MCP tool/call requests. The tool_handler will await the
|
||||
pre-launched task instead of executing fresh.
|
||||
|
||||
The tool_name may include an MCP prefix (e.g. ``mcp__copilot__run_block``);
|
||||
the prefix is stripped automatically before looking up the tool.
|
||||
|
||||
Ordering guarantee: the Claude Agent SDK dispatches MCP ``tools/call`` requests
|
||||
in the same order as the ToolUseBlocks appear in the AssistantMessage.
|
||||
Pre-launched tasks are queued FIFO per tool name, so the N-th handler for a
|
||||
given tool name dequeues the N-th pre-launched task — result and args always
|
||||
correspond when the SDK preserves order (which it does in the current SDK).
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if queues is None:
|
||||
return
|
||||
|
||||
# Strip the MCP server prefix (e.g. "mcp__copilot__") to get the bare tool name.
|
||||
# Use removeprefix so tool names that themselves contain "__" are handled correctly.
|
||||
bare_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
base_tool = TOOL_REGISTRY.get(bare_name)
|
||||
if base_tool is None:
|
||||
return
|
||||
|
||||
user_id, session = get_execution_context()
|
||||
if session is None:
|
||||
return
|
||||
|
||||
# Expand @@agptfile: references before launching the task.
|
||||
# The _truncating wrapper (which normally handles expansion) runs AFTER
|
||||
# pre_launch_tool_call — the pre-launched task would otherwise receive raw
|
||||
# @@agptfile: tokens and fail to resolve them inside _execute_tool_sync.
|
||||
# Use _build_input_schema (same path as _truncating) for schema-aware expansion.
|
||||
input_schema: dict[str, Any] | None
|
||||
try:
|
||||
input_schema = _build_input_schema(base_tool)
|
||||
except Exception:
|
||||
input_schema = None # schema unavailable — skip schema-aware expansion
|
||||
try:
|
||||
args = await expand_file_refs_in_args(
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
logger.warning(
|
||||
"pre_launch_tool_call: @@agptfile expansion failed for %s: %s — skipping pre-launch",
|
||||
bare_name,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
task = asyncio.create_task(_execute_tool_sync(base_tool, user_id, session, args))
|
||||
# Log unhandled exceptions so "Task exception was never retrieved" warnings
|
||||
# do not pollute stderr when a task is pre-launched but never dequeued.
|
||||
task.add_done_callback(
|
||||
lambda t, name=bare_name: (
|
||||
logger.warning(
|
||||
"Pre-launched task for %s raised unhandled: %s",
|
||||
name,
|
||||
t.exception(),
|
||||
)
|
||||
if not t.cancelled() and t.exception()
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
if bare_name not in queues:
|
||||
queues[bare_name] = asyncio.Queue[_TaskQueueItem]()
|
||||
# Store (task, args) so the handler can log a warning if the SDK dispatches
|
||||
# calls in a different order than the ToolUseBlocks appeared in the message.
|
||||
queues[bare_name].put_nowait((task, args))
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
base_tool: BaseTool,
|
||||
user_id: str | None,
|
||||
@@ -355,10 +187,8 @@ async def _execute_tool_sync(
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a tool synchronously and return MCP-formatted response.
|
||||
|
||||
Note: ``@@agptfile:`` expansion should be performed by the caller before
|
||||
invoking this function. For the normal (non-parallel) path it is handled
|
||||
by the ``_truncating`` wrapper; for the pre-launched parallel path it is
|
||||
handled in :func:`pre_launch_tool_call` before the task is created.
|
||||
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
|
||||
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
|
||||
"""
|
||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
@@ -387,66 +217,6 @@ def _mcp_error(message: str) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _failure_key(tool_name: str, args: dict[str, Any]) -> str:
|
||||
"""Compute a stable fingerprint for (tool_name, args) used by the circuit breaker."""
|
||||
args_key = json.dumps(args, sort_keys=True, default=str)
|
||||
return f"{tool_name}:{args_key}"
|
||||
|
||||
|
||||
def _check_circuit_breaker(tool_name: str, args: dict[str, Any]) -> str | None:
|
||||
"""Check if a tool has hit the consecutive failure limit.
|
||||
|
||||
Tracks failures keyed by (tool_name, args_fingerprint). Returns an error
|
||||
message if the circuit breaker has tripped, or None if the call should proceed.
|
||||
"""
|
||||
tracker = _consecutive_tool_failures.get(None)
|
||||
if tracker is None:
|
||||
return None
|
||||
|
||||
key = _failure_key(tool_name, args)
|
||||
count = tracker.get(key, 0)
|
||||
if count >= _MAX_CONSECUTIVE_TOOL_FAILURES:
|
||||
logger.warning(
|
||||
"Circuit breaker tripped for tool %s after %d consecutive "
|
||||
"identical failures (args=%s)",
|
||||
tool_name,
|
||||
count,
|
||||
key[len(tool_name) + 1 :][:200],
|
||||
)
|
||||
return (
|
||||
f"STOP: Tool '{tool_name}' has failed {count} consecutive times with "
|
||||
f"the same arguments. Do NOT retry this tool call. "
|
||||
f"If you were trying to write content to a file, instead respond with "
|
||||
f"the content directly as a text message to the user."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _record_tool_failure(tool_name: str, args: dict[str, Any]) -> None:
|
||||
"""Record a tool failure for circuit breaker tracking."""
|
||||
tracker = _consecutive_tool_failures.get(None)
|
||||
if tracker is None:
|
||||
return
|
||||
key = _failure_key(tool_name, args)
|
||||
tracker[key] = tracker.get(key, 0) + 1
|
||||
|
||||
|
||||
def _clear_tool_failures(tool_name: str) -> None:
|
||||
"""Clear failure tracking for a tool on success.
|
||||
|
||||
Clears ALL args variants for the tool, not just the successful call's args.
|
||||
This gives the tool a "fresh start" on any success, which is appropriate for
|
||||
the primary use case (detecting infinite loops with identical failing args).
|
||||
"""
|
||||
tracker = _consecutive_tool_failures.get(None)
|
||||
if tracker is None:
|
||||
return
|
||||
# Clear all entries for this tool name
|
||||
keys_to_remove = [k for k in tracker if k.startswith(f"{tool_name}:")]
|
||||
for k in keys_to_remove:
|
||||
del tracker[k]
|
||||
|
||||
|
||||
def create_tool_handler(base_tool: BaseTool):
|
||||
"""Create an async handler function for a BaseTool.
|
||||
|
||||
@@ -455,83 +225,7 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the wrapped tool and return MCP-formatted response.
|
||||
|
||||
If a pre-launched task exists (from parallel tool pre-launch in the
|
||||
message loop), await it instead of executing fresh.
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if queues and base_tool.name in queues:
|
||||
queue = queues[base_tool.name]
|
||||
if not queue.empty():
|
||||
task, launch_args = queue.get_nowait()
|
||||
# Sanity-check: warn if the args don't match — this can happen
|
||||
# if the SDK dispatches tool calls in a different order than the
|
||||
# ToolUseBlocks appeared in the AssistantMessage (unlikely but
|
||||
# could occur in future SDK versions or with SDK bugs).
|
||||
# We compare full values (not just keys) so that two run_block
|
||||
# calls with different block_id values are caught even though
|
||||
# both have the same key set.
|
||||
if launch_args != args:
|
||||
logger.warning(
|
||||
"Pre-launched task for %s: arg mismatch "
|
||||
"(launch_keys=%s, call_keys=%s) — cancelling "
|
||||
"pre-launched task and falling back to direct execution",
|
||||
base_tool.name,
|
||||
(
|
||||
sorted(launch_args.keys())
|
||||
if isinstance(launch_args, dict)
|
||||
else type(launch_args).__name__
|
||||
),
|
||||
(
|
||||
sorted(args.keys())
|
||||
if isinstance(args, dict)
|
||||
else type(args).__name__
|
||||
),
|
||||
)
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
# Await cancellation to prevent duplicate concurrent
|
||||
# execution for blocks with side effects.
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
# Fall through to the direct-execution path below.
|
||||
else:
|
||||
# Args match — await the pre-launched task.
|
||||
try:
|
||||
result = await task
|
||||
except asyncio.CancelledError:
|
||||
# Re-raise: CancelledError may be propagating from the
|
||||
# outer streaming loop being cancelled — swallowing it
|
||||
# would mask the cancellation and prevent proper cleanup.
|
||||
logger.warning(
|
||||
"Pre-launched tool %s was cancelled — re-raising",
|
||||
base_tool.name,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Pre-launched tool %s failed: %s",
|
||||
base_tool.name,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return _mcp_error(
|
||||
f"Failed to execute {base_tool.name}. "
|
||||
"Check server logs for details."
|
||||
)
|
||||
|
||||
# Pre-truncate the result so the _truncating wrapper (which
|
||||
# wraps this handler) receives an already-within-budget
|
||||
# value. _truncating handles stashing — we must NOT stash
|
||||
# here or the output will be appended twice to the FIFO
|
||||
# queue and pop_pending_tool_output would return a duplicate
|
||||
# entry on the second call for the same tool.
|
||||
return truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# No pre-launched task — execute directly (fallback for non-parallel calls).
|
||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||
user_id, session = get_execution_context()
|
||||
|
||||
if session is None:
|
||||
@@ -540,12 +234,8 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
|
||||
)
|
||||
return _mcp_error(
|
||||
f"Failed to execute {base_tool.name}. Check server logs for details."
|
||||
)
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||
|
||||
return tool_handler
|
||||
|
||||
@@ -668,15 +358,6 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
Applied once to every registered tool."""
|
||||
|
||||
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
||||
# Circuit breaker: stop infinite retry loops with identical args.
|
||||
# Use the original (pre-expansion) args for fingerprinting so
|
||||
# check and record always use the same key — @@agptfile:
|
||||
# expansion mutates args, which would cause a key mismatch.
|
||||
original_args = args
|
||||
stop_msg = _check_circuit_breaker(tool_name, original_args)
|
||||
if stop_msg:
|
||||
return _mcp_error(stop_msg)
|
||||
|
||||
user_id, session = get_execution_context()
|
||||
if session is not None:
|
||||
try:
|
||||
@@ -684,7 +365,6 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
_record_tool_failure(tool_name, original_args)
|
||||
return _mcp_error(
|
||||
f"@@agptfile: reference could not be resolved: {exc}. "
|
||||
"Ensure the file exists before referencing it. "
|
||||
@@ -694,12 +374,6 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
result = await fn(args)
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# Track consecutive failures for circuit breaker
|
||||
if truncated.get("isError"):
|
||||
_record_tool_failure(tool_name, original_args)
|
||||
else:
|
||||
_clear_tool_failures(tool_name)
|
||||
|
||||
# Stash the text so the response adapter can forward our
|
||||
# middle-out truncated version to the frontend instead of the
|
||||
# SDK's head-truncated version (for outputs >~100 KB the SDK
|
||||
|
||||
@@ -1,26 +1,16 @@
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars, parallel pre-launch."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import get_sdk_cwd
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
_text_from_mcp_result,
|
||||
cancel_pending_tool_tasks,
|
||||
create_tool_handler,
|
||||
pop_pending_tool_output,
|
||||
pre_launch_tool_call,
|
||||
reset_stash_event,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
wait_for_stash,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -130,69 +120,6 @@ class TestToolOutputStash:
|
||||
assert pop_pending_tool_output("a") == "alpha"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset_stash_event / wait_for_stash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetStashEvent:
|
||||
"""Tests for reset_stash_event — the stale-signal fix for retry attempts."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_clears_stale_signal(self):
|
||||
"""After reset, wait_for_stash does NOT return immediately (blocks until timeout)."""
|
||||
# Simulate a stale signal left by a failed attempt's PostToolUse hook.
|
||||
stash_pending_tool_output("some_tool", "stale output")
|
||||
# The stash_pending_tool_output call sets the event.
|
||||
# Now reset it — simulating start of a new retry attempt.
|
||||
reset_stash_event()
|
||||
# wait_for_stash should block and time out since the event was cleared.
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is False, (
|
||||
"wait_for_stash should have timed out after reset_stash_event, "
|
||||
"but it returned True — stale signal was not cleared"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_returns_true_when_signaled_after_reset(self):
|
||||
"""After reset, a new stash signal is correctly detected."""
|
||||
reset_stash_event()
|
||||
|
||||
async def _signal_after_delay():
|
||||
await asyncio.sleep(0.01)
|
||||
stash_pending_tool_output("tool", "fresh output")
|
||||
|
||||
asyncio.create_task(_signal_after_delay())
|
||||
result = await wait_for_stash(timeout=1.0)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_scenario_stale_event_does_not_fire_prematurely(self):
|
||||
"""Simulates: attempt 1 leaves event set → reset → attempt 2 waits correctly."""
|
||||
# Attempt 1: hook fires and sets the event
|
||||
stash_pending_tool_output("t", "attempt-1-output")
|
||||
# Pop it so the stash is empty (simulating normal consumption)
|
||||
pop_pending_tool_output("t")
|
||||
|
||||
# Between attempts: reset (as service.py does before each retry)
|
||||
reset_stash_event()
|
||||
|
||||
# Attempt 2: wait_for_stash should NOT return True immediately
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is False, (
|
||||
"Stale event from attempt 1 caused wait_for_stash to return "
|
||||
"prematurely in attempt 2"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncating wrapper (integration via create_copilot_mcp_server)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -241,534 +168,3 @@ class TestTruncationAndStashIntegration:
|
||||
text = _text_from_mcp_result(truncated)
|
||||
assert len(text) < len(big_text)
|
||||
assert len(str(truncated)) <= _MCP_MAX_CHARS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parallel pre-launch infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_tool(name: str, output: str = "result") -> MagicMock:
|
||||
"""Return a BaseTool mock that returns a successful StreamToolOutputAvailable."""
|
||||
tool = MagicMock()
|
||||
tool.name = name
|
||||
tool.parameters = {"properties": {}, "required": []}
|
||||
tool.execute = AsyncMock(
|
||||
return_value=StreamToolOutputAvailable(
|
||||
toolCallId="test-id",
|
||||
output=output,
|
||||
toolName=name,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
def _make_mock_session() -> MagicMock:
|
||||
"""Return a minimal ChatSession mock."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
def _init_ctx(session=None):
|
||||
set_execution_context(
|
||||
user_id="user-1",
|
||||
session=session, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
)
|
||||
|
||||
|
||||
class TestPreLaunchToolCall:
|
||||
"""Tests for pre_launch_tool_call and the queue-based parallel dispatch."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_is_silently_ignored(self):
|
||||
"""pre_launch_tool_call does nothing for tools not in TOOL_REGISTRY."""
|
||||
# Should not raise even if the tool name is completely unknown
|
||||
await pre_launch_tool_call("nonexistent_tool", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_prefix_stripped_before_registry_lookup(self):
|
||||
"""mcp__copilot__run_block is looked up as 'run_block'."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("mcp__copilot__run_block", {"block_id": "b1"})
|
||||
|
||||
# The task was enqueued — mock_tool.execute should be called once
|
||||
# (may not complete immediately but should start)
|
||||
await asyncio.sleep(0) # yield to event loop
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_tool_name_without_prefix(self):
|
||||
"""Tool names without __ separator are looked up as-is."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
|
||||
await asyncio.sleep(0)
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_enqueued_fifo_for_same_tool(self):
|
||||
"""Two pre-launched calls for the same tool name are enqueued FIFO."""
|
||||
results = []
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
results.append(len(results))
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=str(len(results) - 1),
|
||||
toolName="t",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("t")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"t": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("t", {"n": 1})
|
||||
await pre_launch_tool_call("t", {"n": 2})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_ref_expansion_failure_skips_pre_launch(self):
|
||||
"""When @@agptfile: expansion fails, pre_launch_tool_call skips the task.
|
||||
|
||||
The handler should then fall back to direct execution (which will also
|
||||
fail with a proper MCP error via _truncating's own expansion).
|
||||
"""
|
||||
mock_tool = _make_mock_tool("run_block", output="should-not-execute")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.expand_file_refs_in_args",
|
||||
AsyncMock(side_effect=FileRefExpansionError("@@agptfile:missing.txt")),
|
||||
),
|
||||
):
|
||||
# Should not raise — expansion failure is handled gracefully
|
||||
await pre_launch_tool_call("run_block", {"text": "@@agptfile:missing.txt"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# No task was pre-launched — execute was not called
|
||||
mock_tool.execute.assert_not_awaited()
|
||||
|
||||
|
||||
class TestCreateToolHandlerParallel:
|
||||
"""Tests for create_tool_handler using pre-launched tasks."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_uses_prelaunched_task(self):
|
||||
"""Handler pops and awaits the pre-launched task rather than re-executing."""
|
||||
mock_tool = _make_mock_tool("run_block", output="pre-launched result")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0) # let task start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "pre-launched result" in text
|
||||
# Should only have been called once (the pre-launched task), not twice
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_does_not_double_stash_for_prelaunched_task(self):
|
||||
"""Pre-launched task result must NOT be stashed by tool_handler directly.
|
||||
|
||||
The _truncating wrapper wraps tool_handler and handles stashing after
|
||||
tool_handler returns. If tool_handler also stashed, the output would be
|
||||
appended twice to the FIFO queue and pop_pending_tool_output would return
|
||||
a duplicate on the second call.
|
||||
|
||||
This test calls tool_handler directly (without _truncating) and asserts
|
||||
that nothing was stashed — confirming stashing is deferred to _truncating.
|
||||
"""
|
||||
mock_tool = _make_mock_tool("run_block", output="stash-me")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "stash-me" in result["content"][0]["text"]
|
||||
# tool_handler must NOT stash — _truncating (which wraps handler) does it.
|
||||
# Calling pop here (without going through _truncating) should return None.
|
||||
not_stashed = pop_pending_tool_output("run_block")
|
||||
assert not_stashed is None, (
|
||||
"tool_handler must not stash directly — _truncating handles stashing "
|
||||
"to prevent double-stash in the FIFO queue"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_falls_back_when_queue_empty(self):
|
||||
"""When no pre-launched task exists, handler executes directly."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct result")
|
||||
|
||||
# Don't call pre_launch_tool_call — queue is empty
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "direct result" in text
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_cancelled_error_propagates(self):
|
||||
"""CancelledError from a pre-launched task is re-raised to preserve cancellation semantics."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await handler({"block_id": "b1"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_exception_returns_mcp_error(self):
|
||||
"""Exception from a pre-launched task is caught and returned as MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=RuntimeError("block exploded"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is True
|
||||
assert "Failed to execute run_block" in result["content"][0]["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_same_tool_calls_dispatched_in_order(self):
|
||||
"""Two pre-launched tasks for the same tool are consumed in FIFO order."""
|
||||
call_order = []
|
||||
|
||||
async def execute_with_tag(*args, **kwargs):
|
||||
tag = kwargs.get("block_id", "?")
|
||||
call_order.append(tag)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id", output=f"out-{tag}", toolName="run_block", success=True
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_tag)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "first"})
|
||||
await pre_launch_tool_call("run_block", {"block_id": "second"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
r1 = await handler({"block_id": "first"})
|
||||
r2 = await handler({"block_id": "second"})
|
||||
|
||||
assert "out-first" in r1["content"][0]["text"]
|
||||
assert "out-second" in r2["content"][0]["text"]
|
||||
assert call_order == [
|
||||
"first",
|
||||
"second",
|
||||
], f"Expected FIFO dispatch order but got {call_order}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arg_mismatch_falls_back_to_direct_execution(self):
|
||||
"""When pre-launched args differ from SDK args, handler cancels pre-launched
|
||||
task and falls back to direct execution with the correct args."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct-result")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
# Pre-launch with args {"block_id": "wrong"}
|
||||
await pre_launch_tool_call("run_block", {"block_id": "wrong"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# SDK dispatches with different args
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "correct"})
|
||||
|
||||
assert result["isError"] is False
|
||||
# The tool was called twice: once by pre-launch (wrong args), once by
|
||||
# direct fallback (correct args). The result should come from the
|
||||
# direct execution path.
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_falls_back_gracefully(self):
|
||||
"""When session is None and no pre-launched task, handler returns MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
# session=None means get_execution_context returns (user_id, None)
|
||||
set_execution_context(user_id="u", session=None, sandbox=None) # type: ignore[arg-type]
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is True
|
||||
assert "session" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cancel_pending_tool_tasks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCancelPendingToolTasks:
|
||||
"""Tests for cancel_pending_tool_tasks — the stream-abort cleanup helper."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancels_queued_tasks(self):
|
||||
"""Queued tasks are cancelled and the queue is cleared."""
|
||||
ran = False
|
||||
|
||||
async def never_run(*_args, **_kwargs):
|
||||
nonlocal ran
|
||||
await asyncio.sleep(10) # long enough to still be pending
|
||||
ran = True
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=never_run)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0) # let task start
|
||||
await cancel_pending_tool_tasks()
|
||||
await asyncio.sleep(0) # let cancellation propagate
|
||||
|
||||
assert not ran, "Task should have been cancelled before completing"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_when_no_tasks_queued(self):
|
||||
"""cancel_pending_tool_tasks does not raise when queues are empty."""
|
||||
await cancel_pending_tool_tasks() # should not raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_does_not_find_cancelled_task(self):
|
||||
"""After cancel, tool_handler falls back to direct execution."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct-fallback")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
await cancel_pending_tool_tasks()
|
||||
|
||||
# Queue is now empty — handler should execute directly
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "direct-fallback" in result["content"][0]["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concurrent / parallel pre-launch scenarios
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAllParallelToolsPrelaunchedIndependently:
|
||||
"""Simulate SDK sending N separate AssistantMessages for the same tool concurrently."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_parallel_tools_prelaunched_independently(self):
|
||||
"""5 pre-launches for the same tool all enqueue independently and run concurrently.
|
||||
|
||||
Each task sleeps for PER_TASK_S seconds. If they ran sequentially the total
|
||||
wall time would be ~5*PER_TASK_S. Running concurrently it should finish in
|
||||
roughly PER_TASK_S (plus scheduling overhead).
|
||||
"""
|
||||
PER_TASK_S = 0.05
|
||||
N = 5
|
||||
started: list[int] = []
|
||||
finished: list[int] = []
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
idx = len(started)
|
||||
started.append(idx)
|
||||
await asyncio.sleep(PER_TASK_S)
|
||||
finished.append(idx)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{idx}",
|
||||
output=f"result-{idx}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
for i in range(N):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": f"echo {i}"})
|
||||
|
||||
# Measure only the concurrent execution window, not pre-launch overhead.
|
||||
# Starting the timer here avoids false failures on slow CI runners where
|
||||
# the pre_launch_tool_call setup takes longer than the concurrent sleep.
|
||||
t0 = asyncio.get_running_loop().time()
|
||||
await asyncio.sleep(PER_TASK_S * 2)
|
||||
elapsed = asyncio.get_running_loop().time() - t0
|
||||
|
||||
assert mock_tool.execute.await_count == N
|
||||
assert len(finished) == N
|
||||
# Wall time of the sleep window should be well under N * PER_TASK_S
|
||||
# (sequential would be ~0.25s; concurrent finishes in ~PER_TASK_S = 0.05s)
|
||||
assert elapsed < N * PER_TASK_S, (
|
||||
f"Expected concurrent execution (<{N * PER_TASK_S:.2f}s) "
|
||||
f"but sleep window took {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
class TestHandlerReturnsResultFromCorrectPrelaunchedTask:
|
||||
"""Pop pre-launched tasks in order and verify each returns its own result."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_returns_result_from_correct_prelaunched_task(self):
|
||||
"""Two pre-launches for the same tool: first handler gets first result, second gets second."""
|
||||
|
||||
async def execute_with_cmd(*args, **kwargs):
|
||||
cmd = kwargs.get("cmd", "?")
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=f"output-for-{cmd}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_cmd)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": "alpha"})
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": "beta"})
|
||||
await asyncio.sleep(0) # let both tasks start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
r1 = await handler({"cmd": "alpha"})
|
||||
r2 = await handler({"cmd": "beta"})
|
||||
|
||||
text1 = r1["content"][0]["text"]
|
||||
text2 = r2["content"][0]["text"]
|
||||
assert "output-for-alpha" in text1, f"Expected alpha result, got: {text1}"
|
||||
assert "output-for-beta" in text2, f"Expected beta result, got: {text2}"
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
|
||||
class TestFiveConcurrentPrelaunchAllComplete:
|
||||
"""Pre-launch 5 tasks; consume all 5 via handlers; assert all succeed."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_five_concurrent_prelaunch_all_complete(self):
|
||||
"""All 5 pre-launched tasks complete and return successful results."""
|
||||
N = 5
|
||||
call_count = 0
|
||||
|
||||
async def counting_execute(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
n = call_count
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{n}",
|
||||
output=f"done-{n}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=counting_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
for i in range(N):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": f"task-{i}"})
|
||||
|
||||
await asyncio.sleep(0) # let all tasks start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
results = []
|
||||
for i in range(N):
|
||||
results.append(await handler({"cmd": f"task-{i}"}))
|
||||
|
||||
assert (
|
||||
mock_tool.execute.await_count == N
|
||||
), f"Expected {N} execute calls, got {mock_tool.execute.await_count}"
|
||||
for i, result in enumerate(results):
|
||||
assert result["isError"] is False, f"Result {i} should not be an error"
|
||||
text = result["content"][0]["text"]
|
||||
assert "done-" in text, f"Result {i} missing expected output: {text}"
|
||||
|
||||
@@ -17,13 +17,11 @@ Subscribers:
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.api.model import CopilotCompletionPayload
|
||||
from backend.data.notification_bus import (
|
||||
@@ -35,21 +33,12 @@ from backend.data.redis_client import get_redis_async
|
||||
from .config import ChatConfig
|
||||
from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -291,56 +280,6 @@ async def publish_chunk(
|
||||
return message_id
|
||||
|
||||
|
||||
async def stream_and_publish(
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
stream: AsyncIterator[StreamBaseResponse],
|
||||
) -> AsyncIterator[StreamBaseResponse]:
|
||||
"""Wrap an async stream iterator with registry publishing.
|
||||
|
||||
Publishes each chunk to the stream registry for frontend SSE consumption,
|
||||
skipping ``StreamFinish`` and ``StreamError`` (which are published by
|
||||
:func:`mark_session_completed`).
|
||||
|
||||
This is a pass-through: every event from *stream* is yielded unchanged so
|
||||
the caller can still consume and aggregate them. The caller is responsible
|
||||
for calling :func:`create_session` before and :func:`mark_session_completed`
|
||||
after iterating.
|
||||
|
||||
Args:
|
||||
session_id: Chat session ID (for logging only).
|
||||
turn_id: Turn UUID that identifies the Redis stream to publish to.
|
||||
If empty, publishing is silently skipped (graceful degradation).
|
||||
stream: The underlying async iterator of stream events.
|
||||
|
||||
Yields:
|
||||
Every event from *stream*, unchanged.
|
||||
"""
|
||||
publish_failed_once = False
|
||||
|
||||
async for event in stream:
|
||||
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
|
||||
try:
|
||||
await publish_chunk(turn_id, event)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
if not publish_failed_once:
|
||||
publish_failed_once = True
|
||||
logger.warning(
|
||||
"[stream_and_publish] Failed to publish chunk %s for %s "
|
||||
"(further failures logged at DEBUG)",
|
||||
type(event).__name__,
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"[stream_and_publish] Failed to publish chunk %s",
|
||||
type(event).__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
yield event
|
||||
|
||||
|
||||
async def subscribe_to_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
@@ -754,8 +693,6 @@ async def _stream_listener(
|
||||
async def mark_session_completed(
|
||||
session_id: str,
|
||||
error_message: str | None = None,
|
||||
*,
|
||||
skip_error_publish: bool = False,
|
||||
) -> bool:
|
||||
"""Mark a session as completed, then publish StreamFinish.
|
||||
|
||||
@@ -771,10 +708,6 @@ async def mark_session_completed(
|
||||
session_id: Session ID to mark as completed
|
||||
error_message: If provided, marks as "failed" and publishes a
|
||||
StreamError before StreamFinish. Otherwise marks as "completed".
|
||||
skip_error_publish: If True, still marks the session as "failed" but
|
||||
does NOT publish a StreamError event. Use this when the error has
|
||||
already been published to the stream (e.g. via stream_and_publish)
|
||||
to avoid duplicate error delivery to the frontend.
|
||||
|
||||
Returns:
|
||||
True if session was newly marked completed, False if already completed/failed
|
||||
@@ -794,7 +727,7 @@ async def mark_session_completed(
|
||||
logger.debug(f"Session {session_id} already completed/failed, skipping")
|
||||
return False
|
||||
|
||||
if error_message and not skip_error_publish:
|
||||
if error_message:
|
||||
try:
|
||||
await publish_chunk(turn_id, StreamError(errorText=error_message))
|
||||
except Exception as e:
|
||||
@@ -980,6 +913,21 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
Returns:
|
||||
Reconstructed response object, or None if unknown type
|
||||
"""
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextEnd,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
# Map response types to their corresponding classes
|
||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||
ResponseType.START.value: StreamStart,
|
||||
|
||||
@@ -22,12 +22,13 @@ class AddUnderstandingTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Store user's business context, workflows, pain points, and automation goals. "
|
||||
"Call whenever the user shares business info. Each call incrementally merges "
|
||||
"with existing data — provide only the fields you have. "
|
||||
"Builds a profile that helps recommend better agents for the user's needs."
|
||||
)
|
||||
return """Capture and store information about the user's business context,
|
||||
workflows, pain points, and automation goals. Call this tool whenever the user
|
||||
shares information about their business. Each call incrementally adds to the
|
||||
existing understanding - you don't need to provide all fields at once.
|
||||
|
||||
Use this to build a comprehensive profile that helps recommend better agents
|
||||
and automations for the user's specific needs."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
|
||||
@@ -20,9 +20,9 @@ SSRF protection:
|
||||
|
||||
Requires:
|
||||
npm install -g agent-browser
|
||||
In Docker: system chromium package with AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium
|
||||
(set automatically — no `agent-browser install` needed).
|
||||
Locally: run `agent-browser install` to download Chromium.
|
||||
agent-browser install (downloads Chromium, one-time — skipped in Docker
|
||||
where system chromium is pre-installed and
|
||||
AGENT_BROWSER_EXECUTABLE_PATH is set)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -410,11 +410,18 @@ class BrowserNavigateTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Navigate to a URL in a real browser. Returns accessibility tree with @ref IDs "
|
||||
"for browser_act. Session persists (cookies/auth carry over). "
|
||||
"For static pages, prefer web_fetch. "
|
||||
"For SPAs, elements may load late — use browser_act with wait + browser_screenshot to verify. "
|
||||
"For auth: navigate to login, fill creds and submit with browser_act, then navigate to target."
|
||||
"Navigate to a URL using a real browser. Returns an accessibility "
|
||||
"tree snapshot listing the page's interactive elements with @ref IDs "
|
||||
"(e.g. @e3) that can be used with browser_act. "
|
||||
"Session persists — cookies and login state carry over between calls. "
|
||||
"Use this (with browser_act) for multi-step interaction: login flows, "
|
||||
"form filling, button clicks, or anything requiring page interaction. "
|
||||
"For plain static pages, prefer web_fetch — no browser overhead. "
|
||||
"For authenticated pages: navigate to the login page first, use browser_act "
|
||||
"to fill credentials and submit, then navigate to the target page. "
|
||||
"Note: for slow SPAs, the returned snapshot may reflect a partially-loaded "
|
||||
"state. If elements seem missing, use browser_act with action='wait' and a "
|
||||
"CSS selector or millisecond delay, then take a browser_screenshot to verify."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -424,13 +431,13 @@ class BrowserNavigateTool(BaseTool):
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "HTTP/HTTPS URL to navigate to.",
|
||||
"description": "The HTTP/HTTPS URL to navigate to.",
|
||||
},
|
||||
"wait_for": {
|
||||
"type": "string",
|
||||
"enum": ["networkidle", "load", "domcontentloaded"],
|
||||
"default": "networkidle",
|
||||
"description": "Navigation completion strategy (default: networkidle).",
|
||||
"description": "When to consider navigation complete. Use 'networkidle' for SPAs (default).",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
@@ -549,12 +556,14 @@ class BrowserActTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Interact with the current browser page using @ref IDs from the snapshot. "
|
||||
"Actions: click, dblclick, fill, type, scroll, hover, press, "
|
||||
"Interact with the current browser page. Use @ref IDs from the "
|
||||
"snapshot (e.g. '@e3') to target elements. Returns an updated snapshot. "
|
||||
"Supported actions: click, dblclick, fill, type, scroll, hover, press, "
|
||||
"check, uncheck, select, wait, back, forward, reload. "
|
||||
"fill clears field first; type appends. "
|
||||
"wait accepts CSS selector or milliseconds (e.g. '1000'). "
|
||||
"Returns updated snapshot."
|
||||
"fill clears the field before typing; type appends without clearing. "
|
||||
"wait accepts a CSS selector (waits for element) or milliseconds string (e.g. '1000'). "
|
||||
"Example login flow: fill @e1 with email → fill @e2 with password → "
|
||||
"click @e3 (submit) → browser_navigate to the target page."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -580,21 +589,30 @@ class BrowserActTool(BaseTool):
|
||||
"forward",
|
||||
"reload",
|
||||
],
|
||||
"description": "Action to perform.",
|
||||
"description": "The action to perform.",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": "@ref ID (e.g. '@e3'), CSS selector, or text. Required for: click, dblclick, fill, type, hover, check, uncheck, select. For wait: CSS selector or milliseconds string (e.g. '1000').",
|
||||
"description": (
|
||||
"Element to target. Use @ref from snapshot (e.g. '@e3'), "
|
||||
"a CSS selector, or a text description. "
|
||||
"Required for: click, dblclick, fill, type, hover, check, uncheck, select. "
|
||||
"For wait: a CSS selector to wait for, or milliseconds as a string (e.g. '1000')."
|
||||
),
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Text for fill/type, key for press (e.g. 'Enter'), option for select.",
|
||||
"description": (
|
||||
"For fill/type: the text to enter. "
|
||||
"For press: key name (e.g. 'Enter', 'Tab', 'Control+a'). "
|
||||
"For select: the option value to select."
|
||||
),
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"enum": ["up", "down", "left", "right"],
|
||||
"default": "down",
|
||||
"description": "Scroll direction (default: down).",
|
||||
"description": "For scroll: direction to scroll.",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
@@ -741,10 +759,12 @@ class BrowserScreenshotTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Screenshot the current browser page and save to workspace. "
|
||||
"annotate=true overlays @ref labels on elements. "
|
||||
"IMPORTANT: After calling, you MUST immediately call read_workspace_file with the "
|
||||
"returned file_id to display the image inline."
|
||||
"Take a screenshot of the current browser page and save it to the workspace. "
|
||||
"IMPORTANT: After calling this tool, immediately call read_workspace_file "
|
||||
"with the returned file_id to display the image inline to the user — "
|
||||
"the screenshot is not visible until you do this. "
|
||||
"With annotate=true (default), @ref labels are overlaid on interactive "
|
||||
"elements, making it easy to see which @ref ID maps to which element on screen."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -755,12 +775,12 @@ class BrowserScreenshotTool(BaseTool):
|
||||
"annotate": {
|
||||
"type": "boolean",
|
||||
"default": True,
|
||||
"description": "Overlay @ref labels (default: true).",
|
||||
"description": "Overlay @ref labels on interactive elements (default: true).",
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"default": "screenshot.png",
|
||||
"description": "Workspace filename (default: screenshot.png).",
|
||||
"description": "Filename to save in the workspace.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
"""Integration tests for agent-browser + system chromium.
|
||||
|
||||
These tests actually invoke the agent-browser binary via subprocess and require:
|
||||
- agent-browser installed (npm install -g agent-browser)
|
||||
- AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium (set in Docker)
|
||||
|
||||
Run with:
|
||||
poetry run test
|
||||
|
||||
Or to run only this file:
|
||||
poetry run pytest backend/copilot/tools/agent_browser_integration_test.py -v -p no:autogpt_platform
|
||||
|
||||
Skipped automatically when agent-browser binary is not found.
|
||||
Tests that hit external sites are marked ``integration`` and skipped by default
|
||||
in CI (use ``-m integration`` to include them).
|
||||
|
||||
Two test tiers:
|
||||
- CLI tests: call agent-browser subprocess directly (no backend imports needed)
|
||||
- Tool class tests: call BrowserNavigateTool/BrowserActTool._execute() directly
|
||||
with user_id=None (skips workspace/DB interactions — no Postgres/RabbitMQ needed)
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.agent_browser import BrowserActTool, BrowserNavigateTool
|
||||
from backend.copilot.tools.models import (
|
||||
BrowserActResponse,
|
||||
BrowserNavigateResponse,
|
||||
ErrorResponse,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
shutil.which("agent-browser") is None,
|
||||
reason="agent-browser binary not found",
|
||||
)
|
||||
|
||||
_SESSION = "integration-test-session"
|
||||
|
||||
|
||||
def _agent_browser(
|
||||
*args: str, session: str = _SESSION, timeout: int = 30
|
||||
) -> tuple[int, str, str]:
|
||||
"""Run agent-browser for the given session, return (rc, stdout, stderr)."""
|
||||
result = subprocess.run(
|
||||
["agent-browser", "--session", session, "--session-name", session, *args],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
return result.returncode, result.stdout, result.stderr
|
||||
|
||||
|
||||
def _close_session(session: str, timeout: int = 5) -> None:
|
||||
"""Best-effort close for a browser session; never raises on failure."""
|
||||
try:
|
||||
subprocess.run(
|
||||
["agent-browser", "--session", session, "--session-name", session, "close"],
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _teardown():
|
||||
"""Close the shared test session after each test (best-effort)."""
|
||||
yield
|
||||
_close_session(_SESSION)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_chromium_executable_env_is_set():
|
||||
"""AGENT_BROWSER_EXECUTABLE_PATH must be set and point to an executable binary."""
|
||||
exe = os.environ.get("AGENT_BROWSER_EXECUTABLE_PATH", "")
|
||||
assert exe, "AGENT_BROWSER_EXECUTABLE_PATH is not set"
|
||||
assert os.path.isfile(exe), f"Chromium binary not found at {exe}"
|
||||
assert os.access(exe, os.X_OK), f"Chromium binary at {exe} is not executable"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_navigate_returns_success():
|
||||
"""agent-browser can open a public URL using system chromium."""
|
||||
rc, _, stderr = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0, f"open failed (rc={rc}): {stderr}"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_title_after_navigate():
|
||||
"""get title returns the page title after navigation."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, stdout, stderr = _agent_browser("get", "title", timeout=10)
|
||||
assert rc == 0, f"get title failed: {stderr}"
|
||||
assert "example" in stdout.lower()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_url_after_navigate():
|
||||
"""get url returns the navigated URL."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, stdout, stderr = _agent_browser("get", "url", timeout=10)
|
||||
assert rc == 0, f"get url failed: {stderr}"
|
||||
assert urlparse(stdout.strip()).netloc == "example.com"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_snapshot_returns_interactive_elements():
|
||||
"""snapshot -i -c lists interactive elements on the page."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, stdout, stderr = _agent_browser("snapshot", "-i", "-c", timeout=15)
|
||||
assert rc == 0, f"snapshot failed: {stderr}"
|
||||
assert len(stdout.strip()) > 0, "snapshot returned empty output"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_screenshot_produces_valid_png():
|
||||
"""screenshot saves a non-empty, valid PNG file."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
tmp = f.name
|
||||
try:
|
||||
rc, _, stderr = _agent_browser("screenshot", tmp, timeout=15)
|
||||
assert rc == 0, f"screenshot failed: {stderr}"
|
||||
size = os.path.getsize(tmp)
|
||||
assert size > 1000, f"PNG too small ({size} bytes) — likely blank or corrupt"
|
||||
with open(tmp, "rb") as f:
|
||||
assert f.read(4) == b"\x89PNG", "Output is not a valid PNG"
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_scroll_down():
|
||||
"""scroll down succeeds without error."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, _, stderr = _agent_browser("scroll", "down", timeout=10)
|
||||
assert rc == 0, f"scroll failed: {stderr}"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_fill_form_field():
|
||||
"""fill writes text into an input field."""
|
||||
rc, _, _ = _agent_browser("open", "https://httpbin.org/forms/post")
|
||||
assert rc == 0
|
||||
|
||||
rc, _, stderr = _agent_browser(
|
||||
"fill", "input[name=custname]", "IntegrationTestUser", timeout=10
|
||||
)
|
||||
assert rc == 0, f"fill failed: {stderr}"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_concurrent_independent_sessions():
|
||||
"""Two independent sessions can navigate in parallel without interference."""
|
||||
session_a = "integration-concurrent-a"
|
||||
session_b = "integration-concurrent-b"
|
||||
|
||||
try:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
||||
fut_a = pool.submit(
|
||||
_agent_browser, "open", "https://example.com", session=session_a
|
||||
)
|
||||
fut_b = pool.submit(
|
||||
_agent_browser, "open", "https://httpbin.org/html", session=session_b
|
||||
)
|
||||
rc_a, _, err_a = fut_a.result(timeout=40)
|
||||
rc_b, _, err_b = fut_b.result(timeout=40)
|
||||
assert rc_a == 0, f"session_a open failed: {err_a}"
|
||||
assert rc_b == 0, f"session_b open failed: {err_b}"
|
||||
|
||||
rc_ua, url_a, err_ua = _agent_browser(
|
||||
"get", "url", session=session_a, timeout=10
|
||||
)
|
||||
rc_ub, url_b, err_ub = _agent_browser(
|
||||
"get", "url", session=session_b, timeout=10
|
||||
)
|
||||
assert rc_ua == 0, f"session_a get url failed: {err_ua}"
|
||||
assert rc_ub == 0, f"session_b get url failed: {err_ub}"
|
||||
assert urlparse(url_a.strip()).netloc == "example.com"
|
||||
assert urlparse(url_b.strip()).netloc == "httpbin.org"
|
||||
finally:
|
||||
_close_session(session_a)
|
||||
_close_session(session_b)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_close_session():
|
||||
"""close shuts down the browser daemon cleanly."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, _, stderr = _agent_browser("close", timeout=10)
|
||||
assert rc == 0, f"close failed: {stderr}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Python tool class integration tests
|
||||
#
|
||||
# These tests exercise the actual BrowserNavigateTool / BrowserActTool Python
|
||||
# classes (not just the CLI binary) to verify the full call path — URL
|
||||
# validation, subprocess dispatch, response parsing — works with system
|
||||
# chromium. user_id=None skips workspace/DB interactions so no Postgres or
|
||||
# RabbitMQ is needed.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TOOL_SESSION_ID = "integration-tool-test-session"
|
||||
_TEST_SESSION = ChatSession(
|
||||
session_id=_TOOL_SESSION_ID,
|
||||
user_id="test-user",
|
||||
messages=[],
|
||||
usage=[],
|
||||
started_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=False)
|
||||
def _close_tool_session():
|
||||
"""Tear down the tool-test browser session after each tool test."""
|
||||
yield
|
||||
_close_session(_TOOL_SESSION_ID)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_navigate_returns_response(_close_tool_session):
|
||||
"""BrowserNavigateTool._execute returns a BrowserNavigateResponse with real content."""
|
||||
tool = BrowserNavigateTool()
|
||||
resp = await tool._execute(
|
||||
user_id=None, session=_TEST_SESSION, url="https://example.com"
|
||||
)
|
||||
assert isinstance(
|
||||
resp, BrowserNavigateResponse
|
||||
), f"Expected BrowserNavigateResponse, got: {resp}"
|
||||
assert urlparse(resp.url).netloc == "example.com"
|
||||
assert resp.title, "Expected non-empty page title"
|
||||
assert resp.snapshot, "Expected non-empty accessibility snapshot"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"ssrf_url",
|
||||
[
|
||||
"http://169.254.169.254/", # AWS/GCP/Azure metadata endpoint
|
||||
"http://127.0.0.1/", # IPv4 loopback
|
||||
"http://10.0.0.1/", # RFC-1918 private range
|
||||
"http://[::1]/", # IPv6 loopback
|
||||
"http://0.0.0.0/", # Wildcard / INADDR_ANY
|
||||
],
|
||||
)
|
||||
async def test_tool_navigate_blocked_url(ssrf_url: str, _close_tool_session):
|
||||
"""BrowserNavigateTool._execute rejects internal/private URLs (SSRF guard)."""
|
||||
tool = BrowserNavigateTool()
|
||||
resp = await tool._execute(user_id=None, session=_TEST_SESSION, url=ssrf_url)
|
||||
assert isinstance(
|
||||
resp, ErrorResponse
|
||||
), f"Expected ErrorResponse for SSRF URL {ssrf_url!r}, got: {resp}"
|
||||
assert resp.error == "blocked_url"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_navigate_missing_url(_close_tool_session):
|
||||
"""BrowserNavigateTool._execute returns an error when url is empty."""
|
||||
tool = BrowserNavigateTool()
|
||||
resp = await tool._execute(user_id=None, session=_TEST_SESSION, url="")
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert resp.error == "missing_url"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_act_scroll(_close_tool_session):
|
||||
"""BrowserActTool._execute can scroll after a navigate."""
|
||||
nav = BrowserNavigateTool()
|
||||
nav_resp = await nav._execute(
|
||||
user_id=None, session=_TEST_SESSION, url="https://example.com"
|
||||
)
|
||||
assert isinstance(nav_resp, BrowserNavigateResponse)
|
||||
|
||||
act = BrowserActTool()
|
||||
resp = await act._execute(
|
||||
user_id=None, session=_TEST_SESSION, action="scroll", direction="down"
|
||||
)
|
||||
assert isinstance(
|
||||
resp, BrowserActResponse
|
||||
), f"Expected BrowserActResponse, got: {resp}"
|
||||
assert resp.action == "scroll"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_act_fill_and_click(_close_tool_session):
|
||||
"""BrowserActTool._execute can fill a form field."""
|
||||
nav = BrowserNavigateTool()
|
||||
nav_resp = await nav._execute(
|
||||
user_id=None, session=_TEST_SESSION, url="https://httpbin.org/forms/post"
|
||||
)
|
||||
assert isinstance(nav_resp, BrowserNavigateResponse)
|
||||
|
||||
act = BrowserActTool()
|
||||
resp = await act._execute(
|
||||
user_id=None,
|
||||
session=_TEST_SESSION,
|
||||
action="fill",
|
||||
target="input[name=custname]",
|
||||
value="ToolIntegrationTest",
|
||||
)
|
||||
assert isinstance(resp, BrowserActResponse), f"fill failed: {resp}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_act_missing_action(_close_tool_session):
|
||||
"""BrowserActTool._execute returns an error when action is missing."""
|
||||
act = BrowserActTool()
|
||||
resp = await act._execute(user_id=None, session=_TEST_SESSION, action="")
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert resp.error == "missing_action"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_act_missing_target(_close_tool_session):
|
||||
"""BrowserActTool._execute returns an error when click target is missing."""
|
||||
act = BrowserActTool()
|
||||
resp = await act._execute(
|
||||
user_id=None, session=_TEST_SESSION, action="click", target=""
|
||||
)
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert resp.error == "missing_target"
|
||||
@@ -7,7 +7,7 @@ from typing import Any
|
||||
from .helpers import (
|
||||
AGENT_EXECUTOR_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
TOOL_ORCHESTRATOR_BLOCK_ID,
|
||||
SMART_DECISION_MAKER_BLOCK_ID,
|
||||
AgentDict,
|
||||
are_types_compatible,
|
||||
generate_uuid,
|
||||
@@ -31,7 +31,7 @@ _GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
||||
_GMAIL_SEND_BLOCK_ID = "6c27abc2-e51d-499e-a85f-5a0041ba94f0"
|
||||
_TEXT_REPLACE_BLOCK_ID = "7e7c87ab-3469-4bcc-9abe-67705091b713"
|
||||
|
||||
# Defaults applied to OrchestratorBlock nodes by the fixer.
|
||||
# Defaults applied to SmartDecisionMakerBlock nodes by the fixer.
|
||||
_SDM_DEFAULTS: dict[str, int | bool] = {
|
||||
"agent_mode_max_iterations": 10,
|
||||
"conversation_compaction": True,
|
||||
@@ -1639,8 +1639,8 @@ class AgentFixer:
|
||||
|
||||
return agent
|
||||
|
||||
def fix_orchestrator_blocks(self, agent: AgentDict) -> AgentDict:
|
||||
"""Fix OrchestratorBlock nodes to ensure agent-mode defaults.
|
||||
def fix_smart_decision_maker_blocks(self, agent: AgentDict) -> AgentDict:
|
||||
"""Fix SmartDecisionMakerBlock nodes to ensure agent-mode defaults.
|
||||
|
||||
Ensures:
|
||||
1. ``agent_mode_max_iterations`` defaults to ``10`` (bounded agent mode)
|
||||
@@ -1657,7 +1657,7 @@ class AgentFixer:
|
||||
nodes = agent.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != TOOL_ORCHESTRATOR_BLOCK_ID:
|
||||
if node.get("block_id") != SMART_DECISION_MAKER_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id", "unknown")
|
||||
@@ -1670,7 +1670,7 @@ class AgentFixer:
|
||||
if field not in input_default or input_default[field] is None:
|
||||
input_default[field] = default_value
|
||||
self.add_fix_log(
|
||||
f"OrchestratorBlock {node_id}: "
|
||||
f"SmartDecisionMakerBlock {node_id}: "
|
||||
f"Set {field}={default_value!r}"
|
||||
)
|
||||
|
||||
@@ -1763,8 +1763,8 @@ class AgentFixer:
|
||||
# Apply fixes for MCPToolBlock nodes
|
||||
agent = self.fix_mcp_tool_blocks(agent)
|
||||
|
||||
# Apply fixes for OrchestratorBlock nodes (agent-mode defaults)
|
||||
agent = self.fix_orchestrator_blocks(agent)
|
||||
# Apply fixes for SmartDecisionMakerBlock nodes (agent-mode defaults)
|
||||
agent = self.fix_smart_decision_maker_blocks(agent)
|
||||
|
||||
# Apply fixes for AgentExecutorBlock nodes (sub-agents)
|
||||
if library_agents:
|
||||
|
||||
@@ -12,7 +12,7 @@ __all__ = [
|
||||
"AGENT_OUTPUT_BLOCK_ID",
|
||||
"AgentDict",
|
||||
"MCP_TOOL_BLOCK_ID",
|
||||
"TOOL_ORCHESTRATOR_BLOCK_ID",
|
||||
"SMART_DECISION_MAKER_BLOCK_ID",
|
||||
"UUID_REGEX",
|
||||
"are_types_compatible",
|
||||
"generate_uuid",
|
||||
@@ -34,7 +34,7 @@ UUID_REGEX = re.compile(r"^" + UUID_RE_STR + r"$")
|
||||
|
||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||
MCP_TOOL_BLOCK_ID = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||
TOOL_ORCHESTRATOR_BLOCK_ID = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
SMART_DECISION_MAKER_BLOCK_ID = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from .helpers import (
|
||||
AGENT_INPUT_BLOCK_ID,
|
||||
AGENT_OUTPUT_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
TOOL_ORCHESTRATOR_BLOCK_ID,
|
||||
SMART_DECISION_MAKER_BLOCK_ID,
|
||||
AgentDict,
|
||||
are_types_compatible,
|
||||
get_defined_property_type,
|
||||
@@ -827,18 +827,18 @@ class AgentValidator:
|
||||
|
||||
return valid
|
||||
|
||||
def validate_orchestrator_blocks(
|
||||
def validate_smart_decision_maker_blocks(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""Validate that OrchestratorBlock nodes have downstream tools.
|
||||
"""Validate that SmartDecisionMakerBlock nodes have downstream tools.
|
||||
|
||||
Checks that each OrchestratorBlock node has at least one link
|
||||
Checks that each SmartDecisionMakerBlock node has at least one link
|
||||
with ``source_name == "tools"`` connecting to a downstream block.
|
||||
Without tools, the block has nothing to call and will error at runtime.
|
||||
|
||||
Returns True if all OrchestratorBlock nodes are valid.
|
||||
Returns True if all SmartDecisionMakerBlock nodes are valid.
|
||||
"""
|
||||
valid = True
|
||||
nodes = agent.get("nodes", [])
|
||||
@@ -848,7 +848,7 @@ class AgentValidator:
|
||||
non_tool_block_ids = {AGENT_INPUT_BLOCK_ID, AGENT_OUTPUT_BLOCK_ID}
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != TOOL_ORCHESTRATOR_BLOCK_ID:
|
||||
if node.get("block_id") != SMART_DECISION_MAKER_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id", "unknown")
|
||||
@@ -863,7 +863,7 @@ class AgentValidator:
|
||||
max_iter = input_default.get("agent_mode_max_iterations")
|
||||
if max_iter is not None and not isinstance(max_iter, int):
|
||||
self.add_error(
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"({node_id}) has non-integer "
|
||||
f"agent_mode_max_iterations={max_iter!r}. "
|
||||
f"This field must be an integer."
|
||||
@@ -871,7 +871,7 @@ class AgentValidator:
|
||||
valid = False
|
||||
elif isinstance(max_iter, int) and max_iter < -1:
|
||||
self.add_error(
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"({node_id}) has invalid "
|
||||
f"agent_mode_max_iterations={max_iter}. "
|
||||
f"Use -1 for infinite or a positive number for "
|
||||
@@ -880,7 +880,7 @@ class AgentValidator:
|
||||
valid = False
|
||||
elif isinstance(max_iter, int) and max_iter > 100:
|
||||
self.add_error(
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"({node_id}) has agent_mode_max_iterations="
|
||||
f"{max_iter} which is unusually high. Values above "
|
||||
f"100 risk excessive cost and long execution times. "
|
||||
@@ -890,7 +890,7 @@ class AgentValidator:
|
||||
valid = False
|
||||
elif max_iter == 0:
|
||||
self.add_error(
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"({node_id}) has agent_mode_max_iterations=0 "
|
||||
f"(traditional mode). The agent generator only supports "
|
||||
f"agent mode (set to -1 for infinite or a positive "
|
||||
@@ -908,7 +908,7 @@ class AgentValidator:
|
||||
|
||||
if not has_tools:
|
||||
self.add_error(
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"({node_id}) has no downstream tool blocks connected. "
|
||||
f"Connect at least one block to its 'tools' output so "
|
||||
f"the AI has tools to call."
|
||||
@@ -1025,8 +1025,8 @@ class AgentValidator:
|
||||
self.validate_mcp_tool_blocks(agent),
|
||||
),
|
||||
(
|
||||
"Orchestrator blocks",
|
||||
self.validate_orchestrator_blocks(agent, node_lookup),
|
||||
"SmartDecisionMaker blocks",
|
||||
self.validate_smart_decision_maker_blocks(agent, node_lookup),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -10,12 +10,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import execution_db, library_db
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionMeta,
|
||||
GraphExecutionWithNodes,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
from .execution_utils import TERMINAL_STATUSES, wait_for_execution
|
||||
@@ -40,7 +35,6 @@ class AgentOutputInput(BaseModel):
|
||||
execution_id: str = ""
|
||||
run_time: str = "latest"
|
||||
wait_if_running: int = Field(default=0, ge=0, le=300)
|
||||
show_execution_details: bool = False
|
||||
|
||||
@field_validator(
|
||||
"agent_name",
|
||||
@@ -114,12 +108,22 @@ class AgentOutputTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Retrieve execution outputs from a library agent. "
|
||||
"Identify by agent_name, library_agent_id, or store_slug. "
|
||||
"Filter by execution_id or run_time. "
|
||||
"Optionally wait for running executions."
|
||||
)
|
||||
return """Retrieve execution outputs from agents in the user's library.
|
||||
|
||||
Identify the agent using one of:
|
||||
- agent_name: Fuzzy search in user's library
|
||||
- library_agent_id: Exact library agent ID
|
||||
- store_slug: Marketplace format 'username/agent-name'
|
||||
|
||||
Select which run to retrieve using:
|
||||
- execution_id: Specific execution ID
|
||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||
|
||||
Wait for completion (optional):
|
||||
- wait_if_running: Max seconds to wait if execution is still running (0-300).
|
||||
If the execution is running/queued, waits up to this many seconds for completion.
|
||||
Returns current status on timeout. If already finished, returns immediately.
|
||||
"""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -128,33 +132,32 @@ class AgentOutputTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "Agent name (fuzzy match).",
|
||||
"description": "Agent name to search for in user's library (fuzzy match)",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Library agent ID.",
|
||||
"description": "Exact library agent ID",
|
||||
},
|
||||
"store_slug": {
|
||||
"type": "string",
|
||||
"description": "Marketplace 'username/agent-name'.",
|
||||
"description": "Marketplace identifier: 'username/agent-slug'",
|
||||
},
|
||||
"execution_id": {
|
||||
"type": "string",
|
||||
"description": "Specific execution ID.",
|
||||
"description": "Specific execution ID to retrieve",
|
||||
},
|
||||
"run_time": {
|
||||
"type": "string",
|
||||
"description": "Time filter: 'latest', 'today', 'yesterday', 'last week', 'last 7 days', 'last month', 'last 30 days', 'YYYY-MM-DD', or ISO datetime.",
|
||||
"description": (
|
||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||
),
|
||||
},
|
||||
"wait_if_running": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds to wait if still running (0-300). Returns current state on timeout.",
|
||||
"minimum": 0,
|
||||
"maximum": 300,
|
||||
},
|
||||
"show_execution_details": {
|
||||
"type": "boolean",
|
||||
"description": "If true, include full node-by-node execution trace (inputs, outputs, status, timing for each node). Useful for debugging agent wiring. Default: false.",
|
||||
"description": (
|
||||
"Max seconds to wait if execution is still running (0-300). "
|
||||
"If running, waits for completion. Returns current state on timeout."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
@@ -236,19 +239,13 @@ class AgentOutputTool(BaseTool):
|
||||
time_start: datetime | None,
|
||||
time_end: datetime | None,
|
||||
include_running: bool = False,
|
||||
include_node_executions: bool = False,
|
||||
) -> tuple[
|
||||
GraphExecution | GraphExecutionWithNodes | None,
|
||||
list[GraphExecutionMeta],
|
||||
str | None,
|
||||
]:
|
||||
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
||||
"""
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
|
||||
Args:
|
||||
include_running: If True, also look for running/queued executions (for waiting)
|
||||
include_node_executions: If True, include node-by-node execution details
|
||||
"""
|
||||
exec_db = execution_db()
|
||||
|
||||
@@ -257,7 +254,7 @@ class AgentOutputTool(BaseTool):
|
||||
execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=include_node_executions,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return None, [], f"Execution '{execution_id}' not found"
|
||||
@@ -295,7 +292,7 @@ class AgentOutputTool(BaseTool):
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=include_node_executions,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, [], None
|
||||
|
||||
@@ -303,14 +300,14 @@ class AgentOutputTool(BaseTool):
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=include_node_executions,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, executions, None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
agent: LibraryAgent,
|
||||
execution: GraphExecution | GraphExecutionWithNodes | None,
|
||||
execution: GraphExecution | None,
|
||||
available_executions: list[GraphExecutionMeta],
|
||||
session_id: str | None,
|
||||
) -> AgentOutputResponse:
|
||||
@@ -328,21 +325,6 @@ class AgentOutputTool(BaseTool):
|
||||
total_executions=0,
|
||||
)
|
||||
|
||||
node_executions_data = None
|
||||
if isinstance(execution, GraphExecutionWithNodes):
|
||||
node_executions_data = [
|
||||
{
|
||||
"node_id": ne.node_id,
|
||||
"block_id": ne.block_id,
|
||||
"status": ne.status.value,
|
||||
"input_data": ne.input_data,
|
||||
"output_data": dict(ne.output_data),
|
||||
"start_time": ne.start_time.isoformat() if ne.start_time else None,
|
||||
"end_time": ne.end_time.isoformat() if ne.end_time else None,
|
||||
}
|
||||
for ne in execution.node_executions
|
||||
]
|
||||
|
||||
execution_info = ExecutionOutputInfo(
|
||||
execution_id=execution.id,
|
||||
status=execution.status.value,
|
||||
@@ -350,7 +332,6 @@ class AgentOutputTool(BaseTool):
|
||||
ended_at=execution.ended_at,
|
||||
outputs=dict(execution.outputs),
|
||||
inputs_summary=execution.inputs if execution.inputs else None,
|
||||
node_executions=node_executions_data,
|
||||
)
|
||||
|
||||
available_list = None
|
||||
@@ -460,7 +441,7 @@ class AgentOutputTool(BaseTool):
|
||||
execution = await execution_db().get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=input_data.show_execution_details,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return ErrorResponse(
|
||||
@@ -516,7 +497,6 @@ class AgentOutputTool(BaseTool):
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
include_running=wait_timeout > 0,
|
||||
include_node_executions=input_data.show_execution_details,
|
||||
)
|
||||
|
||||
if exec_error:
|
||||
|
||||
@@ -42,9 +42,15 @@ class BashExecTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a Bash command or script. Shares filesystem with SDK file tools. "
|
||||
"Useful for scripts, data processing, and package installation. "
|
||||
"Killed after timeout (default 30s, max 120s)."
|
||||
"Execute a Bash command or script. "
|
||||
"Full Bash scripting is supported (loops, conditionals, pipes, "
|
||||
"functions, etc.). "
|
||||
"The working directory is shared with the SDK Read/Write/Edit/Glob/Grep "
|
||||
"tools — files created by either are immediately visible to both. "
|
||||
"Execution is killed after the timeout (default 30s, max 120s). "
|
||||
"Returns stdout and stderr. "
|
||||
"Useful for file manipulation, data processing, running scripts, "
|
||||
"and installing packages."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -54,11 +60,13 @@ class BashExecTool(BaseTool):
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Bash command or script.",
|
||||
"description": "Bash command or script to execute.",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds (default 30, max 120).",
|
||||
"description": (
|
||||
"Max execution time in seconds (default 30, max 120)."
|
||||
),
|
||||
"default": 30,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
"""Local conftest for copilot/tools tests.
|
||||
|
||||
Overrides the session-scoped `server` and `graph_cleanup` autouse fixtures from
|
||||
backend/conftest.py so that integration tests in this directory do not trigger
|
||||
the full SpinTestServer startup (which requires Postgres + RabbitMQ).
|
||||
"""
|
||||
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||
async def server(): # type: ignore[override]
|
||||
"""No-op server stub — tools tests don't need the full backend."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||
async def graph_cleanup(): # type: ignore[override]
|
||||
"""No-op graph cleanup stub."""
|
||||
yield
|
||||
@@ -30,7 +30,12 @@ class ContinueRunBlockTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Resume block execution after a run_block call returned review_required. Pass the review_id."
|
||||
return (
|
||||
"Continue executing a block after human review approval. "
|
||||
"Use this after a run_block call returned review_required. "
|
||||
"Pass the review_id from the review_required response. "
|
||||
"The block will execute with the original pre-approved input data."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -39,7 +44,10 @@ class ContinueRunBlockTool(BaseTool):
|
||||
"properties": {
|
||||
"review_id": {
|
||||
"type": "string",
|
||||
"description": "review_id from the review_required response.",
|
||||
"description": (
|
||||
"The review_id from a previous review_required response. "
|
||||
"This resumes execution with the pre-approved input data."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["review_id"],
|
||||
@@ -119,11 +127,8 @@ class ContinueRunBlockTool(BaseTool):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Continuing block %s (%s) for user %s with review_id=%s",
|
||||
block.name,
|
||||
block_id,
|
||||
user_id,
|
||||
review_id,
|
||||
f"Continuing block {block.name} ({block_id}) for user {user_id} "
|
||||
f"with review_id={review_id}"
|
||||
)
|
||||
|
||||
matched_creds, missing_creds = await resolve_block_credentials(
|
||||
@@ -135,9 +140,6 @@ class ContinueRunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# dry_run=False is safe here: run_block's dry-run fast-path (line ~241)
|
||||
# skips HITL entirely, so continue_run_block is never called during a
|
||||
# dry run — only real executions reach the human review gate.
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
@@ -146,7 +148,6 @@ class ContinueRunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
node_exec_id=review_id,
|
||||
matched_credentials=matched_creds,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
# Delete review record after successful execution (one-time use)
|
||||
|
||||
@@ -23,8 +23,12 @@ class CreateAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent from JSON (nodes + links). Validates, auto-fixes, and saves. "
|
||||
"Before calling, search for existing agents with find_library_agent."
|
||||
"Create a new agent workflow. Pass `agent_json` with the complete "
|
||||
"agent graph JSON you generated using block schemas from find_block. "
|
||||
"The tool validates, auto-fixes, and saves.\n\n"
|
||||
"IMPORTANT: Before calling this tool, search for relevant existing agents "
|
||||
"using find_library_agent that could be used as building blocks. "
|
||||
"Pass their IDs in the library_agent_ids parameter."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -38,21 +42,34 @@ class CreateAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": "Agent graph with 'nodes' and 'links' arrays.",
|
||||
"description": (
|
||||
"The agent JSON to validate and save. "
|
||||
"Must contain 'nodes' and 'links' arrays, and optionally "
|
||||
"'name' and 'description'."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Library agent IDs as building blocks.",
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": "Save the agent (default: true). False for preview.",
|
||||
"description": (
|
||||
"Whether to save the agent. Default is true. "
|
||||
"Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "Folder ID to save into (default: root).",
|
||||
"description": (
|
||||
"Optional folder ID to save the agent into. "
|
||||
"If not provided, the agent is saved at root level. "
|
||||
"Use list_folders to find available folders."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
|
||||
@@ -23,7 +23,9 @@ class CustomizeAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Customize a marketplace/template agent. Validates, auto-fixes, and saves."
|
||||
"Customize a marketplace or template agent. Pass `agent_json` "
|
||||
"with the complete customized agent JSON. The tool validates, "
|
||||
"auto-fixes, and saves."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -37,21 +39,32 @@ class CustomizeAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": "Customized agent JSON with nodes and links.",
|
||||
"description": (
|
||||
"Complete customized agent JSON to validate and save. "
|
||||
"Optionally include 'name' and 'description'."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Library agent IDs as building blocks.",
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": "Save the agent (default: true). False for preview.",
|
||||
"description": (
|
||||
"Whether to save the customized agent. Default is true."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "Folder ID to save into (default: root).",
|
||||
"description": (
|
||||
"Optional folder ID to save the agent into. "
|
||||
"If not provided, the agent is saved at root level. "
|
||||
"Use list_folders to find available folders."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
|
||||
@@ -23,8 +23,12 @@ class EditAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent. Validates, auto-fixes, and saves. "
|
||||
"Before calling, search for existing agents with find_library_agent."
|
||||
"Edit an existing agent. Pass `agent_json` with the complete "
|
||||
"updated agent JSON you generated. The tool validates, auto-fixes, "
|
||||
"and saves.\n\n"
|
||||
"IMPORTANT: Before calling this tool, if the changes involve adding new "
|
||||
"functionality, search for relevant existing agents using find_library_agent "
|
||||
"that could be used as building blocks."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -38,20 +42,33 @@ class EditAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "Graph ID or library agent ID to edit.",
|
||||
"description": (
|
||||
"The ID of the agent to edit. "
|
||||
"Can be a graph ID or library agent ID."
|
||||
),
|
||||
},
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": "Updated agent JSON with nodes and links.",
|
||||
"description": (
|
||||
"Complete updated agent JSON to validate and save. "
|
||||
"Must contain 'nodes' and 'links'. "
|
||||
"Include 'name' and/or 'description' if they need "
|
||||
"to be updated."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Library agent IDs as building blocks.",
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks for the changes."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": "Save changes (default: true). False for preview.",
|
||||
"description": (
|
||||
"Whether to save the changes. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -134,7 +134,11 @@ class SearchFeatureRequestsTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Search existing feature requests. Check before creating a new one."
|
||||
return (
|
||||
"Search existing feature requests to check if a similar request "
|
||||
"already exists before creating a new one. Returns matching feature "
|
||||
"requests with their ID, title, and description."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -230,9 +234,14 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a feature request or add need to existing one. "
|
||||
"Search first to avoid duplicates. Pass existing_issue_id to add to existing. "
|
||||
"Never include PII (names, emails, phone numbers, company names) in title/description."
|
||||
"Create a new feature request or add a customer need to an existing one. "
|
||||
"Always search first with search_feature_requests to avoid duplicates. "
|
||||
"If a matching request exists, pass its ID as existing_issue_id to add "
|
||||
"the user's need to it instead of creating a duplicate. "
|
||||
"IMPORTANT: Never include personally identifiable information (PII) in "
|
||||
"the title or description — no names, emails, phone numbers, company "
|
||||
"names, or other identifying details. Write titles and descriptions in "
|
||||
"generic, feature-focused language."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -242,15 +251,28 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Feature request title. No names, emails, or company info.",
|
||||
"description": (
|
||||
"Title for the feature request. Must be generic and "
|
||||
"feature-focused — do not include any user names, emails, "
|
||||
"company names, or other PII."
|
||||
),
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "What the user wants and why. No names, emails, or company info.",
|
||||
"description": (
|
||||
"Detailed description of what the user wants and why. "
|
||||
"Must not contain any personally identifiable information "
|
||||
"(PII) — describe the feature need generically without "
|
||||
"referencing specific users, companies, or contact details."
|
||||
),
|
||||
},
|
||||
"existing_issue_id": {
|
||||
"type": "string",
|
||||
"description": "Linear issue ID to add need to (from search results).",
|
||||
"description": (
|
||||
"If adding a need to an existing feature request, "
|
||||
"provide its Linear issue ID (from search results). "
|
||||
"Omit to create a new feature request."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["title", "description"],
|
||||
|
||||
@@ -18,7 +18,10 @@ class FindAgentTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Search marketplace agents by capability, or look up by slug ('username/agent-name')."
|
||||
return (
|
||||
"Discover agents from the marketplace based on capabilities and "
|
||||
"user needs, or look up a specific agent by its creator/slug ID."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -27,7 +30,7 @@ class FindAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search keywords, or 'username/agent-name' for direct slug lookup.",
|
||||
"description": "Search query describing what the user wants to accomplish, or a creator/slug ID (e.g. 'username/agent-name') for direct lookup. Use single keywords for best results.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
|
||||
@@ -5,7 +5,6 @@ from prisma.enums import ContentType
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.context import get_current_permissions
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import search
|
||||
|
||||
@@ -39,7 +38,7 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||
|
||||
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||
COPILOT_EXCLUDED_BLOCK_IDS = {
|
||||
# OrchestratorBlock - dynamically discovers downstream blocks via graph topology;
|
||||
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology;
|
||||
# usable in agent graphs (guide hardcodes its ID) but cannot run standalone.
|
||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
}
|
||||
@@ -55,9 +54,13 @@ class FindBlockTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search blocks by name or description. Returns block IDs for run_block. "
|
||||
"Always call this FIRST to get block IDs before using run_block. "
|
||||
"Then call run_block with the block's id and empty input_data to see its detailed schema."
|
||||
"Search for available blocks by name or description, or look up a "
|
||||
"specific block by its ID. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||
"The response includes each block's id, name, and description. "
|
||||
"Call run_block with the block's id **with no inputs** to see detailed inputs/outputs and execute it."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -67,11 +70,19 @@ class FindBlockTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search keywords (e.g. 'email', 'http', 'ai').",
|
||||
"description": (
|
||||
"Search query to find blocks by name or description, "
|
||||
"or a block ID (UUID) for direct lookup. "
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
"include_schemas": {
|
||||
"type": "boolean",
|
||||
"description": "Include full input/output schemas (for agent JSON generation).",
|
||||
"description": (
|
||||
"If true, include full input_schema and output_schema "
|
||||
"for each block. Use when generating agent JSON that "
|
||||
"needs block schemas. Default is false."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
@@ -150,19 +161,6 @@ class FindBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check block-level permissions — hide denied blocks entirely
|
||||
perms = get_current_permissions()
|
||||
if perms is not None and not perms.is_block_allowed(
|
||||
block.id, block.name
|
||||
):
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Search for an alternative block by name",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
summary = BlockInfoSummary(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
@@ -209,7 +207,6 @@ class FindBlockTool(BaseTool):
|
||||
)
|
||||
|
||||
# Enrich results with block information
|
||||
perms = get_current_permissions()
|
||||
blocks: list[BlockInfoSummary] = []
|
||||
for result in results:
|
||||
block_id = result["content_id"]
|
||||
@@ -226,12 +223,6 @@ class FindBlockTool(BaseTool):
|
||||
):
|
||||
continue
|
||||
|
||||
# Skip blocks denied by execution permissions
|
||||
if perms is not None and not perms.is_block_allowed(
|
||||
block.id, block.name
|
||||
):
|
||||
continue
|
||||
|
||||
summary = BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
|
||||
@@ -69,8 +69,8 @@ class TestFindBlockFiltering:
|
||||
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
|
||||
def test_excluded_block_ids_contains_orchestrator(self):
|
||||
"""Verify OrchestratorBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
||||
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -120,18 +120,18 @@ class TestFindBlockFiltering:
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_filtered_from_results(self):
|
||||
"""Verify OrchestratorBlock is filtered from search results."""
|
||||
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
orchestrator_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
search_results = [
|
||||
{"content_id": orchestrator_id, "score": 0.9},
|
||||
{"content_id": smart_decision_id, "score": 0.9},
|
||||
{"content_id": "normal-block-id", "score": 0.8},
|
||||
]
|
||||
|
||||
# OrchestratorBlock has STANDARD type but is excluded by ID
|
||||
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
||||
smart_block = make_mock_block(
|
||||
orchestrator_id, "Orchestrator", BlockType.STANDARD
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
normal_block = make_mock_block(
|
||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||
@@ -139,7 +139,7 @@ class TestFindBlockFiltering:
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
orchestrator_id: smart_block,
|
||||
smart_decision_id: smart_block,
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
@@ -161,7 +161,7 @@ class TestFindBlockFiltering:
|
||||
user_id=_TEST_USER_ID, session=session, query="decision"
|
||||
)
|
||||
|
||||
# Should only return normal block, not OrchestratorBlock
|
||||
# Should only return normal block, not SmartDecisionMakerBlock
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "normal-block-id"
|
||||
@@ -601,8 +601,10 @@ class TestFindBlockDirectLookup:
|
||||
async def test_uuid_lookup_excluded_block_id(self):
|
||||
"""UUID matching an excluded block ID returns NoResultsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
orchestrator_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
block = make_mock_block(orchestrator_id, "Orchestrator", BlockType.STANDARD)
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
@@ -610,7 +612,7 @@ class TestFindBlockDirectLookup:
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query=orchestrator_id
|
||||
user_id=_TEST_USER_ID, session=session, query=smart_decision_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
@@ -19,8 +19,13 @@ class FindLibraryAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search user's library agents. Returns graph_id, schemas for sub-agent composition. "
|
||||
"Omit query to list all."
|
||||
"Search for or list agents in the user's library. Use this to find "
|
||||
"agents the user has already added to their library, including agents "
|
||||
"they created or added from the marketplace. "
|
||||
"When creating agents with sub-agent composition, use this to get "
|
||||
"the agent's graph_id, graph_version, input_schema, and output_schema "
|
||||
"needed for AgentExecutorBlock nodes. "
|
||||
"Omit the query to list all agents."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -30,7 +35,10 @@ class FindLibraryAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search by name/description. Omit to list all.",
|
||||
"description": (
|
||||
"Search query to find agents by name or description. "
|
||||
"Omit to list all agents in the library."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
|
||||
@@ -22,10 +22,20 @@ class FixAgentGraphTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Auto-fix common agent JSON issues: missing/invalid UUIDs, StoreValueBlock prerequisites, "
|
||||
"double curly brace escaping, AddToList/AddToDictionary prerequisites, credentials, "
|
||||
"node spacing, AI model defaults, link static properties, and type mismatches. "
|
||||
"Returns fixed JSON and list of fixes applied."
|
||||
"Auto-fix common issues in an agent JSON graph. Applies fixes for:\n"
|
||||
"- Missing or invalid UUIDs on nodes and links\n"
|
||||
"- StoreValueBlock prerequisites for ConditionBlock\n"
|
||||
"- Double curly brace escaping in prompt templates\n"
|
||||
"- AddToList/AddToDictionary prerequisite blocks\n"
|
||||
"- CodeExecutionBlock output field naming\n"
|
||||
"- Missing credentials configuration\n"
|
||||
"- Node X coordinate spacing (800+ units apart)\n"
|
||||
"- AI model default parameters\n"
|
||||
"- Link static properties based on input schema\n"
|
||||
"- Type mismatches (inserts conversion blocks)\n\n"
|
||||
"Returns the fixed agent JSON plus a list of fixes applied. "
|
||||
"After fixing, the agent is re-validated. If still invalid, "
|
||||
"the remaining errors are included in the response."
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -42,7 +42,12 @@ class GetAgentBuildingGuideTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Get the agent JSON building guide (nodes, links, AgentExecutorBlock, MCPToolBlock usage). Call before generating agent JSON."
|
||||
return (
|
||||
"Returns the complete guide for building agent JSON graphs, including "
|
||||
"block IDs, link structure, AgentInputBlock, AgentOutputBlock, "
|
||||
"AgentExecutorBlock (for sub-agent composition), and MCPToolBlock usage. "
|
||||
"Call this before generating agent JSON to ensure correct structure."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
|
||||
@@ -25,7 +25,8 @@ class GetDocPageTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read full documentation page content by path (from search_docs results)."
|
||||
"Get the full content of a documentation page by its path. "
|
||||
"Use this after search_docs to read the complete content of a relevant page."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -35,7 +36,10 @@ class GetDocPageTool(BaseTool):
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Doc file path (e.g. 'platform/block-sdk-guide.md').",
|
||||
"description": (
|
||||
"The path to the documentation file, as returned by search_docs. "
|
||||
"Example: 'platform/block-sdk-guide.md'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
|
||||
@@ -38,7 +38,11 @@ class GetMCPGuideTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Get MCP server URLs and auth guide. Call before run_mcp_tool if you need a server URL or auth info."
|
||||
return (
|
||||
"Returns the MCP tool guide: known hosted server URLs (Notion, Linear, "
|
||||
"Stripe, Intercom, Cloudflare, Atlassian) and authentication workflow. "
|
||||
"Call before using run_mcp_tool if you need a server URL or auth info."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
|
||||
@@ -1,46 +1,24 @@
|
||||
"""Shared helpers for chat tools."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks import BlockType, get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import (
|
||||
COPILOT_NODE_EXEC_ID_SEPARATOR,
|
||||
COPILOT_NODE_PREFIX,
|
||||
COPILOT_SESSION_PREFIX,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError, expand_file_refs_in_args
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.db_accessors import credit_db, review_db, workspace_db
|
||||
from backend.data.db_accessors import credit_db, workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.executor.simulator import simulate_block
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError, InsufficientBalanceError
|
||||
from backend.util.type import coerce_inputs_to_schema
|
||||
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
ReviewRequiredResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from .utils import (
|
||||
build_missing_credentials_from_field_info,
|
||||
match_credentials_to_requirements,
|
||||
)
|
||||
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
||||
from .utils import match_credentials_to_requirements
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -81,7 +59,6 @@ async def execute_block(
|
||||
node_exec_id: str,
|
||||
matched_credentials: dict[str, CredentialsMetaInput],
|
||||
sensitive_action_safe_mode: bool = False,
|
||||
dry_run: bool = False,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with full context setup, credential injection, and error handling.
|
||||
|
||||
@@ -91,49 +68,6 @@ async def execute_block(
|
||||
Returns:
|
||||
BlockOutputResponse on success, ErrorResponse on failure.
|
||||
"""
|
||||
# Dry-run path: simulate the block with an LLM, no real execution.
|
||||
# HITL review is intentionally skipped — no real execution occurs.
|
||||
if dry_run:
|
||||
try:
|
||||
# Coerce types to match the block's input schema, same as real execution.
|
||||
# This ensures the simulated preview is consistent with real execution
|
||||
# (e.g., "42" → 42, string booleans → bool, enum defaults applied).
|
||||
coerce_inputs_to_schema(input_data, block.input_schema)
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in simulate_block(block, input_data):
|
||||
outputs[output_name].append(output_data)
|
||||
# simulator signals internal failure via ("error", "[SIMULATOR ERROR …]")
|
||||
sim_error = outputs.get("error", [])
|
||||
if (
|
||||
sim_error
|
||||
and isinstance(sim_error[0], str)
|
||||
and sim_error[0].startswith("[SIMULATOR ERROR")
|
||||
):
|
||||
return ErrorResponse(
|
||||
message=sim_error[0],
|
||||
error=sim_error[0],
|
||||
session_id=session_id,
|
||||
)
|
||||
return BlockOutputResponse(
|
||||
message=(
|
||||
f"[DRY RUN] Block '{block.name}' simulated successfully "
|
||||
"— no real execution occurred."
|
||||
),
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
is_dry_run=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Dry-run simulation failed: %s", e, exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Dry-run simulation failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
|
||||
@@ -297,287 +231,6 @@ async def resolve_block_credentials(
|
||||
return await match_credentials_to_requirements(user_id, requirements)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockPreparation:
|
||||
"""Result of successful block validation, ready for execution or task creation.
|
||||
|
||||
Attributes:
|
||||
block: The resolved block instance (schema definition + execute method).
|
||||
block_id: UUID of the block being prepared.
|
||||
input_data: User-supplied input values after file-ref expansion.
|
||||
matched_credentials: Credential field name -> resolved credential metadata.
|
||||
input_schema: JSON Schema for the block's input, with credential
|
||||
discriminators resolved for the user's available providers.
|
||||
credentials_fields: Set of field names in the schema that are credential
|
||||
inputs (e.g. ``{"credentials", "api_key"}``).
|
||||
required_non_credential_keys: Schema-required fields minus credential
|
||||
fields — the fields the user must supply directly.
|
||||
provided_input_keys: Keys the user actually provided in ``input_data``.
|
||||
synthetic_graph_id: Auto-generated graph UUID used for CoPilot
|
||||
single-block executions (no real graph exists in the DB).
|
||||
synthetic_node_id: Auto-generated node UUID paired with
|
||||
``synthetic_graph_id`` to form the execution context for the block.
|
||||
"""
|
||||
|
||||
block: AnyBlockSchema
|
||||
block_id: str
|
||||
input_data: dict[str, Any]
|
||||
matched_credentials: dict[str, CredentialsMetaInput]
|
||||
input_schema: dict[str, Any]
|
||||
credentials_fields: set[str]
|
||||
required_non_credential_keys: set[str]
|
||||
provided_input_keys: set[str]
|
||||
synthetic_graph_id: str
|
||||
synthetic_node_id: str
|
||||
|
||||
|
||||
async def prepare_block_for_execution(
|
||||
block_id: str,
|
||||
input_data: dict[str, Any],
|
||||
user_id: str,
|
||||
session: ChatSession,
|
||||
session_id: str,
|
||||
dry_run: bool = False,
|
||||
) -> "BlockPreparation | ToolResponseBase":
|
||||
"""Validate and prepare a block for execution.
|
||||
|
||||
Performs: block lookup, disabled/excluded-type checks, credential resolution,
|
||||
input schema generation, file-ref expansion, missing-credentials check, and
|
||||
unrecognized-field validation.
|
||||
|
||||
Does NOT check for missing required fields (tools differ: run_block shows a
|
||||
schema preview) and does NOT run the HITL review check (use check_hitl_review
|
||||
separately).
|
||||
|
||||
Args:
|
||||
block_id: Block UUID to prepare.
|
||||
input_data: Input values provided by the caller.
|
||||
user_id: Authenticated user ID.
|
||||
session: Current chat session (needed for file-ref expansion).
|
||||
session_id: Chat session ID (used in error responses).
|
||||
|
||||
Returns:
|
||||
BlockPreparation on success, or a ToolResponseBase error/setup response.
|
||||
"""
|
||||
# Lazy import: find_block imports from .base and .models (siblings), not
|
||||
# from helpers — no actual circular dependency exists today. Kept lazy as a
|
||||
# precaution since find_block is the block-registry module and future changes
|
||||
# could introduce a cycle.
|
||||
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found", session_id=session_id
|
||||
)
|
||||
if block.disabled:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' is disabled", session_id=session_id
|
||||
)
|
||||
|
||||
if (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
if block.block_type == BlockType.MCP_TOOL:
|
||||
hint = (
|
||||
" Use the `run_mcp_tool` tool instead — it handles "
|
||||
"MCP server discovery, authentication, and execution."
|
||||
)
|
||||
elif block.block_type == BlockType.AGENT:
|
||||
hint = " Use the `run_agent` tool instead."
|
||||
else:
|
||||
hint = " This block is designed for use within graphs only."
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block.name}' cannot be run directly.{hint}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
matched_credentials, missing_credentials = await resolve_block_credentials(
|
||||
user_id, block, input_data
|
||||
)
|
||||
|
||||
try:
|
||||
input_schema: dict[str, Any] = block.input_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to generate input schema for block %s: %s", block_id, e)
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block.name}' has an invalid input schema",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Expand @@agptfile: refs using the block's input schema so string/list
|
||||
# fields get the correct deserialization.
|
||||
if input_data:
|
||||
try:
|
||||
input_data = await expand_file_refs_in_args(
|
||||
input_data, user_id, session, input_schema=input_schema
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Failed to resolve file reference: {exc}. "
|
||||
"Ensure the file exists before referencing it."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
|
||||
if missing_credentials and not dry_run:
|
||||
credentials_fields_info = _resolve_discriminated_credentials(block, input_data)
|
||||
missing_creds_dict = build_missing_credentials_from_field_info(
|
||||
credentials_fields_info, set(matched_credentials.keys())
|
||||
)
|
||||
missing_creds_list = list(missing_creds_dict.values())
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires credentials that are not configured. "
|
||||
"Please set up the required credentials before running this block."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=block_id,
|
||||
agent_name=block.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": missing_creds_list,
|
||||
"inputs": get_inputs_from_schema(
|
||||
input_schema, exclude_fields=credentials_fields
|
||||
),
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
),
|
||||
graph_id=None,
|
||||
graph_version=None,
|
||||
)
|
||||
required_keys = set(input_schema.get("required", []))
|
||||
required_non_credential_keys = required_keys - credentials_fields
|
||||
provided_input_keys = set(input_data.keys()) - credentials_fields
|
||||
|
||||
valid_fields = set(input_schema.get("properties", {}).keys()) - credentials_fields
|
||||
unrecognized_fields = provided_input_keys - valid_fields
|
||||
if unrecognized_fields:
|
||||
return InputValidationErrorResponse(
|
||||
message=(
|
||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||
"Block was not executed. Please use the correct field names from the schema."
|
||||
),
|
||||
session_id=session_id,
|
||||
unrecognized_fields=sorted(unrecognized_fields),
|
||||
inputs=input_schema,
|
||||
)
|
||||
|
||||
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
|
||||
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
|
||||
|
||||
return BlockPreparation(
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
input_data=input_data,
|
||||
matched_credentials=matched_credentials,
|
||||
input_schema=input_schema,
|
||||
credentials_fields=credentials_fields,
|
||||
required_non_credential_keys=required_non_credential_keys,
|
||||
provided_input_keys=provided_input_keys,
|
||||
synthetic_graph_id=synthetic_graph_id,
|
||||
synthetic_node_id=synthetic_node_id,
|
||||
)
|
||||
|
||||
|
||||
async def check_hitl_review(
|
||||
prep: BlockPreparation,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
) -> "tuple[str, dict[str, Any]] | ToolResponseBase":
|
||||
"""Check for an existing or new HITL review requirement.
|
||||
|
||||
If a review is needed, stores the review record and returns a
|
||||
ReviewRequiredResponse. Otherwise returns
|
||||
``(synthetic_node_exec_id, input_data)`` ready for execute_block.
|
||||
"""
|
||||
block = prep.block
|
||||
block_id = prep.block_id
|
||||
synthetic_graph_id = prep.synthetic_graph_id
|
||||
synthetic_node_id = prep.synthetic_node_id
|
||||
input_data = prep.input_data
|
||||
|
||||
# Reuse an existing WAITING review for identical input (LLM retry guard)
|
||||
existing_reviews = await review_db().get_pending_reviews_for_execution(
|
||||
synthetic_graph_id, user_id
|
||||
)
|
||||
existing_review = next(
|
||||
(
|
||||
r
|
||||
for r in existing_reviews
|
||||
if r.node_id == synthetic_node_id
|
||||
and r.status.value == "WAITING"
|
||||
and r.payload == input_data
|
||||
),
|
||||
None,
|
||||
)
|
||||
if existing_review:
|
||||
return ReviewRequiredResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires human review. "
|
||||
f"After the user approves, call continue_run_block with "
|
||||
f"review_id='{existing_review.node_exec_id}' to execute."
|
||||
),
|
||||
session_id=session_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
review_id=existing_review.node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
synthetic_node_exec_id = (
|
||||
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}" f"{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
review_context = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
)
|
||||
should_pause, input_data = await block.is_block_exec_need_review(
|
||||
input_data,
|
||||
user_id=user_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
execution_context=review_context,
|
||||
is_graph_execution=False,
|
||||
)
|
||||
if should_pause:
|
||||
return ReviewRequiredResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires human review. "
|
||||
f"After the user approves, call continue_run_block with "
|
||||
f"review_id='{synthetic_node_exec_id}' to execute."
|
||||
),
|
||||
session_id=session_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
review_id=synthetic_node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
return synthetic_node_exec_id, input_data
|
||||
|
||||
|
||||
def _resolve_discriminated_credentials(
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
@@ -606,10 +259,8 @@ def _resolve_discriminated_credentials(
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
"Discriminated provider for %s: %s -> %s",
|
||||
field_name,
|
||||
discriminator_value,
|
||||
effective_field_info.provider,
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for execute_block, prepare_block_for_execution, and check_hitl_review."""
|
||||
"""Tests for execute_block — credit charging and type coercion."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
@@ -7,20 +7,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.copilot.tools.helpers import (
|
||||
BlockPreparation,
|
||||
check_hitl_review,
|
||||
execute_block,
|
||||
prepare_block_for_execution,
|
||||
)
|
||||
from backend.copilot.tools.models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
ReviewRequiredResponse,
|
||||
SetupRequirementsResponse,
|
||||
)
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
|
||||
|
||||
_USER = "test-user-helpers"
|
||||
_SESSION = "test-session-helpers"
|
||||
@@ -522,341 +510,3 @@ async def test_coerce_inner_elements_of_generic():
|
||||
# Inner elements should be coerced from int to str
|
||||
assert block._captured_inputs["values"] == ["1", "2", "3"]
|
||||
assert all(isinstance(v, str) for v in block._captured_inputs["values"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prepare_block_for_execution tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PREP_USER = "prep-user"
|
||||
_PREP_SESSION = "prep-session"
|
||||
|
||||
|
||||
def _make_prep_session(session_id: str = _PREP_SESSION) -> MagicMock:
|
||||
session = MagicMock()
|
||||
session.session_id = session_id
|
||||
return session
|
||||
|
||||
|
||||
def _make_simple_block(
|
||||
block_id: str = "blk-1",
|
||||
name: str = "Simple Block",
|
||||
disabled: bool = False,
|
||||
required: list[str] | None = None,
|
||||
properties: dict[str, Any] | None = None,
|
||||
) -> MagicMock:
|
||||
block = MagicMock()
|
||||
block.id = block_id
|
||||
block.name = name
|
||||
block.disabled = disabled
|
||||
block.description = ""
|
||||
block.block_type = MagicMock()
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": properties or {"text": {"type": "string"}},
|
||||
"required": required or [],
|
||||
}
|
||||
block.input_schema.jsonschema.return_value = schema
|
||||
block.input_schema.get_credentials_fields.return_value = {}
|
||||
block.input_schema.get_credentials_fields_info.return_value = {}
|
||||
return block
|
||||
|
||||
|
||||
def _patch_excluded(block_ids: set | None = None, block_types: set | None = None):
|
||||
return (
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.COPILOT_EXCLUDED_BLOCK_IDS",
|
||||
new=block_ids or set(),
|
||||
create=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.COPILOT_EXCLUDED_BLOCK_TYPES",
|
||||
new=block_types or set(),
|
||||
create=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_not_found() -> None:
|
||||
excl_ids, excl_types = _patch_excluded()
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=None),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="missing",
|
||||
input_data={},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "not found" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_disabled() -> None:
|
||||
block = _make_simple_block(disabled=True)
|
||||
excl_ids, excl_types = _patch_excluded()
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-1",
|
||||
input_data={},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "disabled" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_unrecognized_fields() -> None:
|
||||
block = _make_simple_block(properties={"text": {"type": "string"}})
|
||||
excl_ids, excl_types = _patch_excluded()
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.resolve_block_credentials",
|
||||
AsyncMock(return_value=({}, [])),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.expand_file_refs_in_args",
|
||||
AsyncMock(side_effect=lambda d, *a, **kw: d),
|
||||
),
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-1",
|
||||
input_data={"text": "hi", "unknown_field": "oops"},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, InputValidationErrorResponse)
|
||||
assert "unknown_field" in result.unrecognized_fields
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_missing_credentials() -> None:
|
||||
block = _make_simple_block()
|
||||
mock_cred = MagicMock()
|
||||
excl_ids, excl_types = _patch_excluded()
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.resolve_block_credentials",
|
||||
AsyncMock(return_value=({}, [mock_cred])),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.build_missing_credentials_from_field_info",
|
||||
return_value={"cred_key": mock_cred},
|
||||
),
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-1",
|
||||
input_data={},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_success_returns_preparation() -> None:
|
||||
block = _make_simple_block(
|
||||
required=["text"], properties={"text": {"type": "string"}}
|
||||
)
|
||||
excl_ids, excl_types = _patch_excluded()
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.resolve_block_credentials",
|
||||
AsyncMock(return_value=({}, [])),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.expand_file_refs_in_args",
|
||||
AsyncMock(side_effect=lambda d, *a, **kw: d),
|
||||
),
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-1",
|
||||
input_data={"text": "hello"},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, BlockPreparation)
|
||||
assert result.required_non_credential_keys == {"text"}
|
||||
assert result.provided_input_keys == {"text"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_hitl_review tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_hitl_prep(
|
||||
block_id: str = "blk-hitl",
|
||||
input_data: dict | None = None,
|
||||
session_id: str = "hitl-sess",
|
||||
needs_review: bool = False,
|
||||
) -> BlockPreparation:
|
||||
block = MagicMock()
|
||||
block.id = block_id
|
||||
block.name = "HITL Block"
|
||||
data = input_data if input_data is not None else {"action": "delete"}
|
||||
block.is_block_exec_need_review = AsyncMock(return_value=(needs_review, data))
|
||||
return BlockPreparation(
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
input_data=data,
|
||||
matched_credentials={},
|
||||
input_schema={},
|
||||
credentials_fields=set(),
|
||||
required_non_credential_keys=set(),
|
||||
provided_input_keys=set(),
|
||||
synthetic_graph_id=f"{COPILOT_SESSION_PREFIX}{session_id}",
|
||||
synthetic_node_id=f"{COPILOT_NODE_PREFIX}{block_id}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_hitl_no_review_needed() -> None:
|
||||
prep = _make_hitl_prep(input_data={"action": "read"}, needs_review=False)
|
||||
mock_rdb = MagicMock()
|
||||
mock_rdb.get_pending_reviews_for_execution = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.tools.helpers.review_db", return_value=mock_rdb):
|
||||
result = await check_hitl_review(prep, "user1", "hitl-sess")
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
node_exec_id, returned_data = result
|
||||
assert node_exec_id.startswith(f"{COPILOT_NODE_PREFIX}blk-hitl")
|
||||
assert returned_data == {"action": "read"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_hitl_review_required() -> None:
|
||||
prep = _make_hitl_prep(input_data={"action": "delete"}, needs_review=True)
|
||||
mock_rdb = MagicMock()
|
||||
mock_rdb.get_pending_reviews_for_execution = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.tools.helpers.review_db", return_value=mock_rdb):
|
||||
result = await check_hitl_review(prep, "user1", "hitl-sess")
|
||||
|
||||
assert isinstance(result, ReviewRequiredResponse)
|
||||
assert result.block_id == "blk-hitl"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_hitl_reuses_existing_waiting_review() -> None:
|
||||
prep = _make_hitl_prep(input_data={"action": "delete"}, needs_review=False)
|
||||
|
||||
existing = MagicMock()
|
||||
existing.node_id = prep.synthetic_node_id
|
||||
existing.status.value = "WAITING"
|
||||
existing.payload = {"action": "delete"}
|
||||
existing.node_exec_id = "existing-review-42"
|
||||
|
||||
mock_rdb = MagicMock()
|
||||
mock_rdb.get_pending_reviews_for_execution = AsyncMock(return_value=[existing])
|
||||
|
||||
with patch("backend.copilot.tools.helpers.review_db", return_value=mock_rdb):
|
||||
result = await check_hitl_review(prep, "user1", "hitl-sess")
|
||||
|
||||
assert isinstance(result, ReviewRequiredResponse)
|
||||
assert result.review_id == "existing-review-42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_excluded_by_type() -> None:
|
||||
"""prepare_block_for_execution returns ErrorResponse for excluded block types."""
|
||||
from backend.blocks import BlockType
|
||||
|
||||
block = _make_simple_block()
|
||||
block.block_type = BlockType.AGENT
|
||||
|
||||
excl_ids, excl_types = _patch_excluded(block_types={BlockType.AGENT})
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-agent",
|
||||
input_data={},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "cannot be run directly" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_excluded_by_id() -> None:
|
||||
"""prepare_block_for_execution returns ErrorResponse for excluded block IDs."""
|
||||
block = _make_simple_block(block_id="blk-excluded")
|
||||
|
||||
excl_ids, excl_types = _patch_excluded(block_ids={"blk-excluded"})
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-excluded",
|
||||
input_data={},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "cannot be run directly" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_file_ref_expansion_error() -> None:
|
||||
"""prepare_block_for_execution returns ErrorResponse when file-ref expansion fails."""
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError
|
||||
|
||||
block = _make_simple_block(properties={"text": {"type": "string"}})
|
||||
excl_ids, excl_types = _patch_excluded()
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.resolve_block_credentials",
|
||||
AsyncMock(return_value=({}, [])),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.expand_file_refs_in_args",
|
||||
AsyncMock(
|
||||
side_effect=FileRefExpansionError("@@agptfile:missing.txt not found")
|
||||
),
|
||||
),
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-1",
|
||||
input_data={"text": "@@agptfile:missing.txt"},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "file reference" in result.message.lower()
|
||||
|
||||
@@ -88,7 +88,10 @@ class CreateFolderTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Create a library folder. Use parent_id to nest inside another folder."
|
||||
return (
|
||||
"Create a new folder in the user's library to organize agents. "
|
||||
"Optionally nest it inside an existing folder using parent_id."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -101,19 +104,22 @@ class CreateFolderTool(BaseTool):
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Folder name (max 100 chars).",
|
||||
"description": "Name for the new folder (max 100 chars).",
|
||||
},
|
||||
"parent_id": {
|
||||
"type": "string",
|
||||
"description": "Parent folder ID (omit for root).",
|
||||
"description": (
|
||||
"ID of the parent folder to nest inside. "
|
||||
"Omit to create at root level."
|
||||
),
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"description": "Icon identifier.",
|
||||
"description": "Optional icon identifier for the folder.",
|
||||
},
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "Hex color (#RRGGBB).",
|
||||
"description": "Optional hex color code (#RRGGBB).",
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
@@ -169,9 +175,13 @@ class ListFoldersTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List library folders. Omit parent_id for full tree. "
|
||||
"Set include_agents=true when user asks about agents, wants to see "
|
||||
"what's in their folders, or mentions agents alongside folders."
|
||||
"List the user's library folders. "
|
||||
"Omit parent_id to get the full folder tree. "
|
||||
"Provide parent_id to list only direct children of that folder. "
|
||||
"Set include_agents=true to also return the agents inside each folder "
|
||||
"and root-level agents not in any folder. Always set include_agents=true "
|
||||
"when the user asks about agents, wants to see what's in their folders, "
|
||||
"or mentions agents alongside folders."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -185,11 +195,17 @@ class ListFoldersTool(BaseTool):
|
||||
"properties": {
|
||||
"parent_id": {
|
||||
"type": "string",
|
||||
"description": "List children of this folder (omit for full tree).",
|
||||
"description": (
|
||||
"List children of this folder. "
|
||||
"Omit to get the full folder tree."
|
||||
),
|
||||
},
|
||||
"include_agents": {
|
||||
"type": "boolean",
|
||||
"description": "Include agents in each folder (default: false).",
|
||||
"description": (
|
||||
"Whether to include the list of agents inside each folder. "
|
||||
"Defaults to false."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
@@ -341,7 +357,10 @@ class MoveFolderTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Move a folder. Set target_parent_id to null for root."
|
||||
return (
|
||||
"Move a folder to a different parent folder. "
|
||||
"Set target_parent_id to null to move to root level."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -354,11 +373,14 @@ class MoveFolderTool(BaseTool):
|
||||
"properties": {
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "Folder ID.",
|
||||
"description": "ID of the folder to move.",
|
||||
},
|
||||
"target_parent_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": "New parent folder ID (null for root).",
|
||||
"description": (
|
||||
"ID of the new parent folder. "
|
||||
"Use null to move to root level."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["folder_id"],
|
||||
@@ -411,7 +433,10 @@ class DeleteFolderTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Delete a folder. Agents inside move to root (not deleted)."
|
||||
return (
|
||||
"Delete a folder from the user's library. "
|
||||
"Agents inside the folder are moved to root level (not deleted)."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -474,7 +499,10 @@ class MoveAgentsToFolderTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Move agents to a folder. Set folder_id to null for root."
|
||||
return (
|
||||
"Move one or more agents to a folder. "
|
||||
"Set folder_id to null to move agents to root level."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -488,11 +516,13 @@ class MoveAgentsToFolderTool(BaseTool):
|
||||
"agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Library agent IDs to move.",
|
||||
"description": "List of library agent IDs to move.",
|
||||
},
|
||||
"folder_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Target folder ID (null for root).",
|
||||
"description": (
|
||||
"Target folder ID. Use null to move to root level."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_ids"],
|
||||
|
||||
@@ -272,7 +272,6 @@ class ExecutionOutputInfo(BaseModel):
|
||||
ended_at: datetime | None = None
|
||||
outputs: dict[str, list[Any]]
|
||||
inputs_summary: dict[str, Any] | None = None
|
||||
node_executions: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class AgentOutputResponse(ToolResponseBase):
|
||||
@@ -458,7 +457,6 @@ class BlockOutputResponse(ToolResponseBase):
|
||||
block_name: str
|
||||
outputs: dict[str, list[Any]]
|
||||
success: bool = True
|
||||
is_dry_run: bool = False
|
||||
|
||||
|
||||
class ReviewRequiredResponse(ToolResponseBase):
|
||||
|
||||
@@ -71,7 +71,6 @@ class RunAgentInput(BaseModel):
|
||||
cron: str = ""
|
||||
timezone: str = "UTC"
|
||||
wait_for_result: int = Field(default=0, ge=0, le=300)
|
||||
dry_run: bool = False
|
||||
|
||||
@field_validator(
|
||||
"username_agent_slug",
|
||||
@@ -105,11 +104,19 @@ class RunAgentTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Run or schedule an agent. Automatically checks inputs and credentials. "
|
||||
"Identify by username_agent_slug ('user/agent') or library_agent_id. "
|
||||
"For scheduling, provide schedule_name + cron."
|
||||
)
|
||||
return """Run or schedule an agent from the marketplace or user's library.
|
||||
|
||||
The tool automatically handles the setup flow:
|
||||
- Returns missing inputs if required fields are not provided
|
||||
- Returns missing credentials if user needs to configure them
|
||||
- Executes immediately if all requirements are met
|
||||
- Schedules execution if cron expression is provided
|
||||
|
||||
Identify the agent using either:
|
||||
- username_agent_slug: Marketplace format 'username/agent-name'
|
||||
- library_agent_id: ID of an agent in the user's library
|
||||
|
||||
For scheduled execution, provide: schedule_name, cron, and optionally timezone."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -118,45 +125,39 @@ class RunAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"username_agent_slug": {
|
||||
"type": "string",
|
||||
"description": "Marketplace format 'username/agent-name'.",
|
||||
"description": "Agent identifier in format 'username/agent-name'",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Library agent ID.",
|
||||
"description": "Library agent ID from user's library",
|
||||
},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": "Input values for the agent.",
|
||||
"description": "Input values for the agent",
|
||||
"additionalProperties": True,
|
||||
},
|
||||
"use_defaults": {
|
||||
"type": "boolean",
|
||||
"description": "Run with default values (confirm with user first).",
|
||||
"description": "Set to true to run with default values (user must confirm)",
|
||||
},
|
||||
"schedule_name": {
|
||||
"type": "string",
|
||||
"description": "Name for scheduled execution. Providing this triggers scheduling mode (also requires cron).",
|
||||
"description": "Name for scheduled execution (triggers scheduling mode)",
|
||||
},
|
||||
"cron": {
|
||||
"type": "string",
|
||||
"description": "Cron expression (min hour day month weekday).",
|
||||
"description": "Cron expression (5 fields: min hour day month weekday)",
|
||||
},
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "IANA timezone (default: UTC).",
|
||||
"description": "IANA timezone for schedule (default: UTC)",
|
||||
},
|
||||
"wait_for_result": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds to wait for completion (0-300).",
|
||||
"minimum": 0,
|
||||
"maximum": 300,
|
||||
},
|
||||
"dry_run": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"When true, simulates the entire agent execution using an LLM "
|
||||
"for each block — no real API calls, no credentials needed, "
|
||||
"no credits charged. Useful for testing agent wiring end-to-end."
|
||||
"Max seconds to wait for execution to complete (0-300). "
|
||||
"If >0, blocks until the execution finishes or times out. "
|
||||
"Returns execution outputs when complete."
|
||||
),
|
||||
},
|
||||
},
|
||||
@@ -238,17 +239,103 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 2: Check credentials and inputs
|
||||
graph_credentials, prereq_error = await self._check_prerequisites(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
params=params,
|
||||
session_id=session_id,
|
||||
# Step 2: Check credentials
|
||||
graph_credentials, missing_creds = await match_user_credentials_to_graph(
|
||||
user_id, graph
|
||||
)
|
||||
if prereq_error:
|
||||
return prereq_error
|
||||
|
||||
# Step 3: Execute or Schedule
|
||||
if missing_creds:
|
||||
# Return credentials needed response with input data info
|
||||
# The UI handles credential setup automatically, so the message
|
||||
# focuses on asking about input data
|
||||
requirements_creds_dict = build_missing_credentials_from_graph(
|
||||
graph, None
|
||||
)
|
||||
missing_credentials_dict = build_missing_credentials_from_graph(
|
||||
graph, graph_credentials
|
||||
)
|
||||
requirements_creds_list = list(requirements_creds_dict.values())
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=self._build_inputs_message(graph, MSG_WHAT_VALUES_TO_USE),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=graph.id,
|
||||
agent_name=graph.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_credentials_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": requirements_creds_list,
|
||||
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||
"execution_modes": self._get_execution_modes(graph),
|
||||
},
|
||||
),
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# Step 3: Check inputs
|
||||
# Get all available input fields from schema
|
||||
input_properties = graph.input_schema.get("properties", {})
|
||||
required_fields = set(graph.input_schema.get("required", []))
|
||||
provided_inputs = set(params.inputs.keys())
|
||||
valid_fields = set(input_properties.keys())
|
||||
|
||||
# Check for unknown input fields
|
||||
unrecognized_fields = provided_inputs - valid_fields
|
||||
if unrecognized_fields:
|
||||
return InputValidationErrorResponse(
|
||||
message=(
|
||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||
f"Agent was not executed. Please use the correct field names from the schema."
|
||||
),
|
||||
session_id=session_id,
|
||||
unrecognized_fields=sorted(unrecognized_fields),
|
||||
inputs=graph.input_schema,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# If agent has inputs but none were provided AND use_defaults is not set,
|
||||
# always show what's available first so user can decide
|
||||
if input_properties and not provided_inputs and not params.use_defaults:
|
||||
credentials = extract_credentials_from_schema(
|
||||
graph.credentials_input_schema
|
||||
)
|
||||
return AgentDetailsResponse(
|
||||
message=self._build_inputs_message(graph, MSG_ASK_USER_FOR_VALUES),
|
||||
session_id=session_id,
|
||||
agent=self._build_agent_details(graph, credentials),
|
||||
user_authenticated=True,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# Check if required inputs are missing (and not using defaults)
|
||||
missing_inputs = required_fields - provided_inputs
|
||||
|
||||
if missing_inputs and not params.use_defaults:
|
||||
# Return agent details with missing inputs info
|
||||
credentials = extract_credentials_from_schema(
|
||||
graph.credentials_input_schema
|
||||
)
|
||||
return AgentDetailsResponse(
|
||||
message=(
|
||||
f"Agent '{graph.name}' is missing required inputs: "
|
||||
f"{', '.join(missing_inputs)}. "
|
||||
"Please provide these values to run the agent."
|
||||
),
|
||||
session_id=session_id,
|
||||
agent=self._build_agent_details(graph, credentials),
|
||||
user_authenticated=True,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# Step 4: Execute or Schedule
|
||||
if is_schedule:
|
||||
return await self._schedule_agent(
|
||||
user_id=user_id,
|
||||
@@ -268,7 +355,6 @@ class RunAgentTool(BaseTool):
|
||||
graph_credentials=graph_credentials,
|
||||
inputs=params.inputs,
|
||||
wait_for_result=params.wait_for_result,
|
||||
dry_run=params.dry_run,
|
||||
)
|
||||
|
||||
except NotFoundError as e:
|
||||
@@ -278,14 +364,14 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
except DatabaseError as e:
|
||||
logger.error("Database error: %s", e, exc_info=True)
|
||||
logger.error(f"Database error: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to process request: {e!s}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error processing agent request: %s", e, exc_info=True)
|
||||
logger.error(f"Error processing agent request: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to process request: {e!s}",
|
||||
error=str(e),
|
||||
@@ -345,112 +431,6 @@ class RunAgentTool(BaseTool):
|
||||
trigger_info=trigger_info,
|
||||
)
|
||||
|
||||
async def _check_prerequisites(
|
||||
self,
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
params: "RunAgentInput",
|
||||
session_id: str,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], ToolResponseBase | None]:
|
||||
"""Validate credentials and inputs before execution.
|
||||
|
||||
Dry runs skip all prerequisite gates (credentials, input prompts)
|
||||
since simulate_block doesn't need real credentials or complete inputs.
|
||||
|
||||
Returns:
|
||||
(graph_credentials, error_response) — error_response is None when ready.
|
||||
"""
|
||||
graph_credentials, missing_creds = await match_user_credentials_to_graph(
|
||||
user_id, graph
|
||||
)
|
||||
|
||||
# --- Reject unknown input fields (always, even for dry runs) ---
|
||||
input_properties = graph.input_schema.get("properties", {})
|
||||
provided_inputs = set(params.inputs.keys())
|
||||
valid_fields = set(input_properties.keys())
|
||||
unrecognized_fields = provided_inputs - valid_fields
|
||||
if unrecognized_fields:
|
||||
return graph_credentials, InputValidationErrorResponse(
|
||||
message=(
|
||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||
f"Agent was not executed. Please use the correct field names from the schema."
|
||||
),
|
||||
session_id=session_id,
|
||||
unrecognized_fields=sorted(unrecognized_fields),
|
||||
inputs=graph.input_schema,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# Dry runs bypass remaining prerequisite gates (credentials, missing inputs)
|
||||
if params.dry_run:
|
||||
return graph_credentials, None
|
||||
|
||||
# --- Credential gate ---
|
||||
if missing_creds:
|
||||
requirements_creds_dict = build_missing_credentials_from_graph(graph, None)
|
||||
missing_credentials_dict = build_missing_credentials_from_graph(
|
||||
graph, graph_credentials
|
||||
)
|
||||
return graph_credentials, SetupRequirementsResponse(
|
||||
message=self._build_inputs_message(graph, MSG_WHAT_VALUES_TO_USE),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=graph.id,
|
||||
agent_name=graph.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_credentials_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": list(requirements_creds_dict.values()),
|
||||
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||
"execution_modes": self._get_execution_modes(graph),
|
||||
},
|
||||
),
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# --- Input gates ---
|
||||
required_fields = set(graph.input_schema.get("required", []))
|
||||
|
||||
# Prompt user when inputs exist but none were provided
|
||||
if input_properties and not provided_inputs and not params.use_defaults:
|
||||
credentials = extract_credentials_from_schema(
|
||||
graph.credentials_input_schema
|
||||
)
|
||||
return graph_credentials, AgentDetailsResponse(
|
||||
message=self._build_inputs_message(graph, MSG_ASK_USER_FOR_VALUES),
|
||||
session_id=session_id,
|
||||
agent=self._build_agent_details(graph, credentials),
|
||||
user_authenticated=True,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# Required inputs missing
|
||||
missing_inputs = required_fields - provided_inputs
|
||||
if missing_inputs and not params.use_defaults:
|
||||
credentials = extract_credentials_from_schema(
|
||||
graph.credentials_input_schema
|
||||
)
|
||||
return graph_credentials, AgentDetailsResponse(
|
||||
message=(
|
||||
f"Agent '{graph.name}' is missing required inputs: "
|
||||
f"{', '.join(missing_inputs)}. "
|
||||
"Please provide these values to run the agent."
|
||||
),
|
||||
session_id=session_id,
|
||||
agent=self._build_agent_details(graph, credentials),
|
||||
user_authenticated=True,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
return graph_credentials, None
|
||||
|
||||
async def _run_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
@@ -459,16 +439,12 @@ class RunAgentTool(BaseTool):
|
||||
graph_credentials: dict[str, CredentialsMetaInput],
|
||||
inputs: dict[str, Any],
|
||||
wait_for_result: int = 0,
|
||||
dry_run: bool = False,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute an agent immediately, optionally waiting for completion."""
|
||||
session_id = session.session_id
|
||||
|
||||
# Check rate limits (dry runs don't count against the session limit)
|
||||
if (
|
||||
not dry_run
|
||||
and session.successful_agent_runs.get(graph.id, 0) >= config.max_agent_runs
|
||||
):
|
||||
# Check rate limits
|
||||
if session.successful_agent_runs.get(graph.id, 0) >= config.max_agent_runs:
|
||||
return ErrorResponse(
|
||||
message="Maximum agent runs reached for this session. Please try again later.",
|
||||
session_id=session_id,
|
||||
@@ -483,14 +459,12 @@ class RunAgentTool(BaseTool):
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
graph_credentials_inputs=graph_credentials,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
# Track successful run (dry runs don't count against the session limit)
|
||||
if not dry_run:
|
||||
session.successful_agent_runs[library_agent.graph_id] = (
|
||||
session.successful_agent_runs.get(library_agent.graph_id, 0) + 1
|
||||
)
|
||||
# Track successful run
|
||||
session.successful_agent_runs[library_agent.graph_id] = (
|
||||
session.successful_agent_runs.get(library_agent.graph_id, 0) + 1
|
||||
)
|
||||
|
||||
# Track in PostHog
|
||||
track_agent_run_success(
|
||||
|
||||
@@ -4,18 +4,33 @@ import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.constants import COPILOT_NODE_EXEC_ID_SEPARATOR
|
||||
from backend.copilot.context import get_current_permissions
|
||||
from backend.blocks import BlockType, get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import (
|
||||
COPILOT_NODE_EXEC_ID_SEPARATOR,
|
||||
COPILOT_NODE_PREFIX,
|
||||
COPILOT_SESSION_PREFIX,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError, expand_file_refs_in_args
|
||||
from backend.data.db_accessors import review_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import (
|
||||
BlockPreparation,
|
||||
check_hitl_review,
|
||||
execute_block,
|
||||
prepare_block_for_execution,
|
||||
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
from .helpers import execute_block, get_inputs_from_schema, resolve_block_credentials
|
||||
from .models import (
|
||||
BlockDetails,
|
||||
BlockDetailsResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
ReviewRequiredResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from .models import BlockDetails, BlockDetailsResponse, ErrorResponse, ToolResponseBase
|
||||
from .utils import build_missing_credentials_from_field_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,10 +45,13 @@ class RunBlockTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a block. IMPORTANT: Always get block_id from find_block first "
|
||||
"— do NOT guess or fabricate IDs. "
|
||||
"Call with empty input_data to see schema, then with data to execute. "
|
||||
"If review_required, use continue_run_block."
|
||||
"Execute a specific block with the provided input data. "
|
||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||
"do NOT guess or make up block IDs. "
|
||||
"On first attempt (without input_data), returns detailed schema showing "
|
||||
"required inputs and outputs. Then call again with proper input_data to execute. "
|
||||
"If a block requires human review, use continue_run_block with the "
|
||||
"review_id after the user approves."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -43,22 +61,28 @@ class RunBlockTool(BaseTool):
|
||||
"properties": {
|
||||
"block_id": {
|
||||
"type": "string",
|
||||
"description": "Block ID from find_block results.",
|
||||
"description": (
|
||||
"The block's 'id' field from find_block results. "
|
||||
"NEVER guess this - always get it from find_block first."
|
||||
),
|
||||
},
|
||||
"block_name": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The block's human-readable name from find_block results. "
|
||||
"Used for display purposes in the UI."
|
||||
),
|
||||
},
|
||||
"input_data": {
|
||||
"type": "object",
|
||||
"description": "Input values. Use {} first to see schema.",
|
||||
},
|
||||
"dry_run": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"When true, simulates block execution using an LLM without making any "
|
||||
"real API calls or producing side effects. Useful for testing agent "
|
||||
"wiring and previewing outputs. Default: false."
|
||||
"Input values for the block. "
|
||||
"First call with empty {} to see the block's schema, "
|
||||
"then call again with proper values to execute."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "input_data"],
|
||||
"required": ["block_id", "block_name", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -86,7 +110,6 @@ class RunBlockTool(BaseTool):
|
||||
"""
|
||||
block_id = kwargs.get("block_id", "").strip()
|
||||
input_data = kwargs.get("input_data", {})
|
||||
dry_run = bool(kwargs.get("dry_run", False))
|
||||
session_id = session.session_id
|
||||
|
||||
if not block_id:
|
||||
@@ -107,108 +130,267 @@ class RunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info("Preparing block %s for user %s", block_id, user_id)
|
||||
# Get the block
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
if block.disabled:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' is disabled",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
prep_or_err = await prepare_block_for_execution(
|
||||
block_id=block_id,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
session_id=session_id,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if isinstance(prep_or_err, ToolResponseBase):
|
||||
return prep_or_err
|
||||
prep: BlockPreparation = prep_or_err
|
||||
|
||||
# Check block-level permissions before execution.
|
||||
perms = get_current_permissions()
|
||||
if perms is not None and not perms.is_block_allowed(block_id, prep.block.name):
|
||||
available_hint = (
|
||||
f"Allowed identifiers: {perms.blocks!r}. "
|
||||
if not perms.blocks_exclude and perms.blocks
|
||||
else (
|
||||
f"Blocked identifiers: {perms.blocks!r}. "
|
||||
if perms.blocks_exclude and perms.blocks
|
||||
else ""
|
||||
# Check if block is excluded from CoPilot (graph-only blocks)
|
||||
if (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
# Provide actionable guidance for blocks with dedicated tools
|
||||
if block.block_type == BlockType.MCP_TOOL:
|
||||
hint = (
|
||||
" Use the `run_mcp_tool` tool instead — it handles "
|
||||
"MCP server discovery, authentication, and execution."
|
||||
)
|
||||
elif block.block_type == BlockType.AGENT:
|
||||
hint = " Use the `run_agent` tool instead."
|
||||
else:
|
||||
hint = " This block is designed for use within graphs only."
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block.name}' cannot be run directly.{hint}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
(
|
||||
matched_credentials,
|
||||
missing_credentials,
|
||||
) = await resolve_block_credentials(user_id, block, input_data)
|
||||
|
||||
# Get block schemas for details/validation
|
||||
try:
|
||||
input_schema: dict[str, Any] = block.input_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to generate input schema for block %s: %s",
|
||||
block_id,
|
||||
e,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Block '{prep.block.name}' ({block_id}) is not permitted "
|
||||
f"by the current execution permissions. {available_hint}"
|
||||
"Use find_block to discover blocks that are allowed."
|
||||
),
|
||||
message=f"Block '{block.name}' has an invalid input schema",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
try:
|
||||
output_schema: dict[str, Any] = block.output_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to generate output schema for block %s: %s",
|
||||
block_id,
|
||||
e,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block.name}' has an invalid output schema",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Dry-run fast-path: skip credential/HITL checks — simulation never calls
|
||||
# the real service so credentials and review gates are not needed.
|
||||
# Input field validation (unrecognized fields) is already handled by
|
||||
# prepare_block_for_execution above.
|
||||
if dry_run:
|
||||
synthetic_node_exec_id = (
|
||||
f"{prep.synthetic_node_id}"
|
||||
f"{COPILOT_NODE_EXEC_ID_SEPARATOR}"
|
||||
f"{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
return await execute_block(
|
||||
block=prep.block,
|
||||
block_id=block_id,
|
||||
input_data=prep.input_data,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
matched_credentials=prep.matched_credentials,
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
# Show block details when required inputs are not yet provided.
|
||||
# This is run_block's two-step UX: first call returns the schema,
|
||||
# second call (with inputs) actually executes.
|
||||
if not (prep.required_non_credential_keys <= prep.provided_input_keys):
|
||||
# Expand @@agptfile: refs in input_data with the block's input
|
||||
# schema. The generic _truncating wrapper skips opaque object
|
||||
# properties (input_data has no declared inner properties in the
|
||||
# tool schema), so file ref tokens are still intact here.
|
||||
# Using the block's schema lets us return raw text for string-typed
|
||||
# fields and parsed structures for list/dict-typed fields.
|
||||
if input_data:
|
||||
try:
|
||||
output_schema: dict[str, Any] = prep.block.output_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to generate output schema for block %s: %s", block_id, e
|
||||
input_data = await expand_file_refs_in_args(
|
||||
input_data,
|
||||
user_id,
|
||||
session,
|
||||
input_schema=input_schema,
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{prep.block.name}' has an invalid output schema",
|
||||
error=str(e),
|
||||
message=(
|
||||
f"Failed to resolve file reference: {exc}. "
|
||||
"Ensure the file exists before referencing it."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
credentials_meta = list(prep.matched_credentials.values())
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
missing_creds_dict = build_missing_credentials_from_field_info(
|
||||
credentials_fields_info, set(matched_credentials.keys())
|
||||
)
|
||||
missing_creds_list = list(missing_creds_dict.values())
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires credentials that are not configured. "
|
||||
"Please set up the required credentials before running this block."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=block_id,
|
||||
agent_name=block.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": missing_creds_list,
|
||||
"inputs": self._get_inputs_list(block),
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
),
|
||||
graph_id=None,
|
||||
graph_version=None,
|
||||
)
|
||||
|
||||
# Check if this is a first attempt (required inputs missing)
|
||||
# Return block details so user can see what inputs are needed
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
required_keys = set(input_schema.get("required", []))
|
||||
required_non_credential_keys = required_keys - credentials_fields
|
||||
provided_input_keys = set(input_data.keys()) - credentials_fields
|
||||
|
||||
# Check for unknown input fields
|
||||
valid_fields = (
|
||||
set(input_schema.get("properties", {}).keys()) - credentials_fields
|
||||
)
|
||||
unrecognized_fields = provided_input_keys - valid_fields
|
||||
if unrecognized_fields:
|
||||
return InputValidationErrorResponse(
|
||||
message=(
|
||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||
f"Block was not executed. Please use the correct field names from the schema."
|
||||
),
|
||||
session_id=session_id,
|
||||
unrecognized_fields=sorted(unrecognized_fields),
|
||||
inputs=input_schema,
|
||||
)
|
||||
|
||||
# Show details when not all required non-credential inputs are provided
|
||||
if not (required_non_credential_keys <= provided_input_keys):
|
||||
# Get credentials info for the response
|
||||
credentials_meta = []
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
credentials_meta.append(cred_meta)
|
||||
|
||||
return BlockDetailsResponse(
|
||||
message=(
|
||||
f"Block '{prep.block.name}' details. "
|
||||
f"Block '{block.name}' details. "
|
||||
"Provide input_data matching the inputs schema to execute the block."
|
||||
),
|
||||
session_id=session_id,
|
||||
block=BlockDetails(
|
||||
id=block_id,
|
||||
name=prep.block.name,
|
||||
description=prep.block.description or "",
|
||||
inputs=prep.input_schema,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
inputs=input_schema,
|
||||
outputs=output_schema,
|
||||
credentials=credentials_meta,
|
||||
),
|
||||
user_authenticated=True,
|
||||
)
|
||||
|
||||
hitl_or_err = await check_hitl_review(prep, user_id, session_id)
|
||||
if isinstance(hitl_or_err, ToolResponseBase):
|
||||
return hitl_or_err
|
||||
synthetic_node_exec_id, input_data = hitl_or_err
|
||||
# Generate synthetic IDs for CoPilot context.
|
||||
# Encode node_id in node_exec_id so it can be extracted later
|
||||
# (e.g. for auto-approve, where we need node_id but have no NodeExecution row).
|
||||
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session.session_id}"
|
||||
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
|
||||
|
||||
# Check for an existing WAITING review for this block with the same input.
|
||||
# If the LLM retries run_block with identical input, we reuse the existing
|
||||
# review instead of creating duplicates. Different inputs = new execution.
|
||||
existing_reviews = await review_db().get_pending_reviews_for_execution(
|
||||
synthetic_graph_id, user_id
|
||||
)
|
||||
existing_review = next(
|
||||
(
|
||||
r
|
||||
for r in existing_reviews
|
||||
if r.node_id == synthetic_node_id
|
||||
and r.status.value == "WAITING"
|
||||
and r.payload == input_data
|
||||
),
|
||||
None,
|
||||
)
|
||||
if existing_review:
|
||||
return ReviewRequiredResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires human review. "
|
||||
f"After the user approves, call continue_run_block with "
|
||||
f"review_id='{existing_review.node_exec_id}' to execute."
|
||||
),
|
||||
session_id=session_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
review_id=existing_review.node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
synthetic_node_exec_id = (
|
||||
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}"
|
||||
f"{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
# Check for HITL review before execution.
|
||||
# This creates the review record in the DB for CoPilot flows.
|
||||
review_context = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
)
|
||||
should_pause, input_data = await block.is_block_exec_need_review(
|
||||
input_data,
|
||||
user_id=user_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
execution_context=review_context,
|
||||
is_graph_execution=False,
|
||||
)
|
||||
if should_pause:
|
||||
return ReviewRequiredResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires human review. "
|
||||
f"After the user approves, call continue_run_block with "
|
||||
f"review_id='{synthetic_node_exec_id}' to execute."
|
||||
),
|
||||
session_id=session_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
review_id=synthetic_node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
return await execute_block(
|
||||
block=prep.block,
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
matched_credentials=prep.matched_credentials,
|
||||
dry_run=dry_run,
|
||||
matched_credentials=matched_credentials,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
schema = block.input_schema.jsonschema()
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||
|
||||
@@ -5,8 +5,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.context import _current_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
from ._test_data import make_session
|
||||
from .models import (
|
||||
@@ -94,7 +92,7 @@ class TestRunBlockFiltering:
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=input_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -111,92 +109,29 @@ class TestRunBlockFiltering:
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_returns_error(self):
|
||||
"""Attempting to execute OrchestratorBlock returns error."""
|
||||
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
orchestrator_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
smart_block = make_mock_block(
|
||||
orchestrator_id, "Orchestrator", BlockType.STANDARD
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=smart_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=orchestrator_id,
|
||||
block_id=smart_decision_id,
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_denied_by_permissions_returns_error(self):
|
||||
"""A block denied by CopilotPermissions returns an ErrorResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
standard_block = make_mock_block(block_id, "HTTP Request", BlockType.STANDARD)
|
||||
|
||||
perms = CopilotPermissions(blocks=[block_id], blocks_exclude=True)
|
||||
token = _current_permissions.set(perms)
|
||||
try:
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
return_value=standard_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=block_id,
|
||||
input_data={},
|
||||
)
|
||||
finally:
|
||||
_current_permissions.reset(token)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "not permitted" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_allowed_by_permissions_passes_guard(self):
|
||||
"""A block explicitly allowed by a whitelist CopilotPermissions passes the guard."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
standard_block = make_mock_block(block_id, "HTTP Request", BlockType.STANDARD)
|
||||
|
||||
perms = CopilotPermissions(blocks=[block_id], blocks_exclude=False)
|
||||
token = _current_permissions.set(perms)
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
return_value=standard_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=block_id,
|
||||
input_data={},
|
||||
)
|
||||
finally:
|
||||
_current_permissions.reset(token)
|
||||
|
||||
# Must NOT be blocked by permissions — assert it's not a permission error
|
||||
assert (
|
||||
not isinstance(response, ErrorResponse)
|
||||
or "not permitted" not in response.message
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_excluded_block_passes_guard(self):
|
||||
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
||||
@@ -208,7 +143,7 @@ class TestRunBlockFiltering:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
),
|
||||
patch(
|
||||
@@ -265,7 +200,7 @@ class TestRunBlockInputValidation:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
@@ -308,7 +243,7 @@ class TestRunBlockInputValidation:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
@@ -354,7 +289,7 @@ class TestRunBlockInputValidation:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
@@ -402,7 +337,7 @@ class TestRunBlockInputValidation:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
@@ -446,7 +381,7 @@ class TestRunBlockInputValidation:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
@@ -500,7 +435,7 @@ class TestRunBlockSensitiveAction:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
@@ -556,7 +491,7 @@ class TestRunBlockSensitiveAction:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
@@ -610,7 +545,7 @@ class TestRunBlockSensitiveAction:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
|
||||
@@ -57,9 +57,10 @@ class RunMCPToolTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Discover and execute MCP server tools. "
|
||||
"Call with server_url only to list tools, then with tool_name + tool_arguments to execute. "
|
||||
"Call get_mcp_guide first for server URLs and auth."
|
||||
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
|
||||
"Two-step: (1) call with server_url to list available tools, "
|
||||
"(2) call again with server_url + tool_name + tool_arguments to execute. "
|
||||
"Call get_mcp_guide for known server URLs and auth details."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -69,15 +70,24 @@ class RunMCPToolTool(BaseTool):
|
||||
"properties": {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"description": "MCP server URL (Streamable HTTP endpoint).",
|
||||
"description": (
|
||||
"URL of the MCP server (Streamable HTTP endpoint), "
|
||||
"e.g. https://mcp.example.com/mcp"
|
||||
),
|
||||
},
|
||||
"tool_name": {
|
||||
"type": "string",
|
||||
"description": "Tool to execute. Omit to discover available tools.",
|
||||
"description": (
|
||||
"Name of the MCP tool to execute. "
|
||||
"Omit on first call to discover available tools."
|
||||
),
|
||||
},
|
||||
"tool_arguments": {
|
||||
"type": "object",
|
||||
"description": "Arguments matching the tool's input schema.",
|
||||
"description": (
|
||||
"Arguments to pass to the selected tool. "
|
||||
"Must match the tool's input schema returned during discovery."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["server_url"],
|
||||
|
||||
@@ -38,7 +38,11 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Search platform documentation by keyword. Use get_doc_page to read full results."
|
||||
return (
|
||||
"Search the AutoGPT platform documentation for information about "
|
||||
"how to use the platform, build agents, configure blocks, and more. "
|
||||
"Returns relevant documentation sections. Use get_doc_page to read full content."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -47,7 +51,10 @@ class SearchDocsTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Documentation search query.",
|
||||
"description": (
|
||||
"Search query to find relevant documentation. "
|
||||
"Use natural language to describe what you're looking for."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
|
||||
@@ -1,358 +0,0 @@
|
||||
"""Tests for dry-run execution mode."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.copilot.tools.run_block as run_block_module
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
|
||||
from backend.copilot.tools.run_block import RunBlockTool
|
||||
from backend.executor.simulator import build_simulation_prompt, simulate_block
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_mock_block(
|
||||
name: str = "TestBlock",
|
||||
description: str = "A test block",
|
||||
input_props: dict | None = None,
|
||||
output_props: dict | None = None,
|
||||
):
|
||||
"""Create a minimal mock block with jsonschema() methods."""
|
||||
block = MagicMock()
|
||||
block.name = name
|
||||
block.description = description
|
||||
|
||||
in_props = input_props or {"query": {"type": "string"}}
|
||||
out_props = output_props or {
|
||||
"result": {"type": "string"},
|
||||
"error": {"type": "string"},
|
||||
}
|
||||
|
||||
block.input_schema = MagicMock()
|
||||
block.input_schema.jsonschema.return_value = {
|
||||
"type": "object",
|
||||
"properties": in_props,
|
||||
"required": list(in_props.keys()),
|
||||
}
|
||||
block.input_schema.get_credentials_fields.return_value = {}
|
||||
block.input_schema.get_credentials_fields_info.return_value = {}
|
||||
|
||||
block.output_schema = MagicMock()
|
||||
block.output_schema.jsonschema.return_value = {
|
||||
"type": "object",
|
||||
"properties": out_props,
|
||||
"required": ["result"],
|
||||
}
|
||||
|
||||
return block
|
||||
|
||||
|
||||
def make_openai_response(
|
||||
content: str, prompt_tokens: int = 100, completion_tokens: int = 50
|
||||
):
|
||||
"""Build a mock OpenAI chat completion response."""
|
||||
response = MagicMock()
|
||||
response.choices = [MagicMock()]
|
||||
response.choices[0].message.content = content
|
||||
response.usage = MagicMock()
|
||||
response.usage.prompt_tokens = prompt_tokens
|
||||
response.usage.completion_tokens = completion_tokens
|
||||
return response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# simulate_block tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_block_basic():
|
||||
"""simulate_block returns correct (output_name, output_data) tuples."""
|
||||
mock_block = make_mock_block()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=make_openai_response('{"result": "simulated output", "error": ""}')
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.simulator.get_openai_client", return_value=mock_client
|
||||
):
|
||||
outputs = []
|
||||
async for name, data in simulate_block(mock_block, {"query": "test"}):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert ("result", "simulated output") in outputs
|
||||
assert ("error", "") in outputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_block_json_retry():
|
||||
"""LLM returns invalid JSON twice then valid; verifies 3 total calls."""
|
||||
mock_block = make_mock_block()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
side_effect=[
|
||||
make_openai_response("not json at all"),
|
||||
make_openai_response("still not json"),
|
||||
make_openai_response('{"result": "ok", "error": ""}'),
|
||||
]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.simulator.get_openai_client", return_value=mock_client
|
||||
):
|
||||
outputs = []
|
||||
async for name, data in simulate_block(mock_block, {"query": "test"}):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert mock_client.chat.completions.create.call_count == 3
|
||||
assert ("result", "ok") in outputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_block_all_retries_exhausted():
|
||||
"""LLM always returns invalid JSON; verify error tuple is yielded."""
|
||||
mock_block = make_mock_block()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=make_openai_response("bad json !!!")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.simulator.get_openai_client", return_value=mock_client
|
||||
):
|
||||
outputs = []
|
||||
async for name, data in simulate_block(mock_block, {"query": "test"}):
|
||||
outputs.append((name, data))
|
||||
|
||||
# All retry attempts should have been consumed
|
||||
assert mock_client.chat.completions.create.call_count == 5 # _MAX_JSON_RETRIES
|
||||
assert len(outputs) == 1
|
||||
name, data = outputs[0]
|
||||
assert name == "error"
|
||||
assert "[SIMULATOR ERROR" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_block_missing_output_pins():
|
||||
"""LLM response missing some output pins; verify they're filled with None."""
|
||||
mock_block = make_mock_block(
|
||||
output_props={
|
||||
"result": {"type": "string"},
|
||||
"count": {"type": "integer"},
|
||||
"error": {"type": "string"},
|
||||
}
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
# Only returns "result", missing "count" and "error"
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=make_openai_response('{"result": "hello"}')
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.simulator.get_openai_client", return_value=mock_client
|
||||
):
|
||||
outputs = {}
|
||||
async for name, data in simulate_block(mock_block, {"query": "hi"}):
|
||||
outputs[name] = data
|
||||
|
||||
assert outputs["result"] == "hello"
|
||||
assert outputs["count"] is None # missing pin filled with None
|
||||
assert outputs["error"] == "" # "error" pin filled with ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_block_no_client():
|
||||
"""When no OpenAI client is available, yields SIMULATOR ERROR."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
with patch("backend.executor.simulator.get_openai_client", return_value=None):
|
||||
outputs = []
|
||||
async for name, data in simulate_block(mock_block, {}):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
name, data = outputs[0]
|
||||
assert name == "error"
|
||||
assert "[SIMULATOR ERROR" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_block_truncates_long_inputs():
|
||||
"""Inputs with very long strings should be truncated in the prompt."""
|
||||
mock_block = make_mock_block(input_props={"text": {"type": "string"}})
|
||||
long_text = "x" * 30000 # 30k chars, above the 20k threshold
|
||||
|
||||
system_prompt, user_prompt = build_simulation_prompt(
|
||||
mock_block, {"text": long_text}
|
||||
)
|
||||
|
||||
# The user prompt should contain TRUNCATED marker
|
||||
assert "[TRUNCATED]" in user_prompt
|
||||
# And the total length of the value in the prompt should be well under 30k chars
|
||||
parsed = json.loads(user_prompt.split("## Current Inputs\n", 1)[1])
|
||||
assert len(parsed["text"]) < 25000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# execute_block dry-run tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_dry_run_skips_real_execution():
|
||||
"""execute_block(dry_run=True) calls simulate_block, NOT block.execute."""
|
||||
mock_block = make_mock_block()
|
||||
mock_block.execute = AsyncMock() # should NOT be called
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
yield "result", "simulated"
|
||||
|
||||
# Patching at helpers.simulate_block works because helpers.py imports
|
||||
# simulate_block at the top of the module. If the import were lazy
|
||||
# (inside the function), we'd need to patch the source module instead.
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
|
||||
):
|
||||
response = await execute_block(
|
||||
block=mock_block,
|
||||
block_id="test-block-id",
|
||||
input_data={"query": "hello"},
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
node_exec_id="node-exec-1",
|
||||
matched_credentials={},
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
mock_block.execute.assert_not_called()
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_dry_run_response_format():
|
||||
"""Dry-run response should contain [DRY RUN] in message and success=True."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
yield "result", "simulated"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
|
||||
):
|
||||
response = await execute_block(
|
||||
block=mock_block,
|
||||
block_id="test-block-id",
|
||||
input_data={"query": "hello"},
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
node_exec_id="node-exec-1",
|
||||
matched_credentials={},
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert "[DRY RUN]" in response.message
|
||||
assert response.success is True
|
||||
assert response.outputs == {"result": ["simulated"]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_real_execution_unchanged():
|
||||
"""dry_run=False should still go through the real execution path."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
# We expect it to hit the real path, which will fail on workspace_db() call.
|
||||
# Just verify simulate_block is NOT called.
|
||||
simulate_called = False
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
nonlocal simulate_called
|
||||
simulate_called = True
|
||||
yield "result", "should not happen"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
side_effect=Exception("db not available"),
|
||||
):
|
||||
response = await execute_block(
|
||||
block=mock_block,
|
||||
block_id="test-block-id",
|
||||
input_data={"query": "hello"},
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
node_exec_id="node-exec-1",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert simulate_called is False
|
||||
# The real path raised an exception, so we get an ErrorResponse (which has .error attr)
|
||||
assert hasattr(response, "error")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RunBlockTool parameter tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_block_tool_dry_run_param():
|
||||
"""RunBlockTool parameters should include 'dry_run'."""
|
||||
tool = RunBlockTool()
|
||||
params = tool.parameters
|
||||
assert "dry_run" in params["properties"]
|
||||
assert params["properties"]["dry_run"]["type"] == "boolean"
|
||||
|
||||
|
||||
def test_run_block_tool_dry_run_calls_execute():
|
||||
"""RunBlockTool._execute extracts dry_run from kwargs correctly.
|
||||
|
||||
We verify the extraction logic directly by inspecting the source, then confirm
|
||||
the kwarg is forwarded in the execute_block call site.
|
||||
"""
|
||||
source = inspect.getsource(run_block_module.RunBlockTool._execute)
|
||||
# Verify dry_run is extracted from kwargs
|
||||
assert "dry_run" in source
|
||||
assert 'kwargs.get("dry_run"' in source
|
||||
|
||||
# Scope to _execute method source only — module-wide search is brittle
|
||||
# and can match unrelated text/comments.
|
||||
source_execute = inspect.getsource(run_block_module.RunBlockTool._execute)
|
||||
# Verify dry_run is passed through to execute_block call
|
||||
assert "dry_run=dry_run" in source_execute
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_dry_run_simulator_error_returns_error_response():
|
||||
"""When simulate_block yields a SIMULATOR ERROR tuple, execute_block returns ErrorResponse."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate_error(block, input_data):
|
||||
yield "error", "[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available (missing OpenAI/OpenRouter API key)."
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate_error
|
||||
):
|
||||
response = await execute_block(
|
||||
block=mock_block,
|
||||
block_id="test-block-id",
|
||||
input_data={"query": "hello"},
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
node_exec_id="node-exec-1",
|
||||
matched_credentials={},
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "[SIMULATOR ERROR" in response.message
|
||||
@@ -61,12 +61,12 @@ async def test_run_block_returns_details_when_no_input_provided():
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=http_block,
|
||||
):
|
||||
# Mock credentials check to return no missing credentials
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.resolve_block_credentials",
|
||||
"backend.copilot.tools.run_block.resolve_block_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=({}, []), # (matched_credentials, missing_credentials)
|
||||
):
|
||||
@@ -119,11 +119,11 @@ async def test_run_block_returns_details_when_only_credentials_provided():
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.resolve_block_credentials",
|
||||
"backend.copilot.tools.run_block.resolve_block_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(
|
||||
{
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
"""Schema regression tests for all registered CoPilot tools.
|
||||
|
||||
Validates that every tool in TOOL_REGISTRY produces a well-formed schema:
|
||||
- description is non-empty
|
||||
- all `required` fields exist in `properties`
|
||||
- every property has a `type` and `description`
|
||||
- total schema character budget does not regress past threshold
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# Character budget (~4 chars/token heuristic, targeting ~8000 tokens)
|
||||
_CHAR_BUDGET = 32_000
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def all_tool_schemas() -> list[tuple[str, Any]]:
|
||||
"""Return (tool_name, openai_schema) pairs for every registered tool."""
|
||||
return [(name, tool.as_openai_tool()) for name, tool in TOOL_REGISTRY.items()]
|
||||
|
||||
|
||||
def _get_parametrize_data() -> list[tuple[str, object]]:
|
||||
"""Build parametrize data at collection time."""
|
||||
return [(name, tool.as_openai_tool()) for name, tool in TOOL_REGISTRY.items()]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_name,schema",
|
||||
_get_parametrize_data(),
|
||||
ids=[name for name, _ in _get_parametrize_data()],
|
||||
)
|
||||
class TestToolSchema:
|
||||
"""Validate schema invariants for every registered tool."""
|
||||
|
||||
def test_description_non_empty(self, tool_name: str, schema: dict) -> None:
|
||||
desc = schema["function"].get("description", "")
|
||||
assert desc, f"Tool '{tool_name}' has an empty description"
|
||||
|
||||
def test_required_fields_exist_in_properties(
|
||||
self, tool_name: str, schema: dict
|
||||
) -> None:
|
||||
params = schema["function"].get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
required = params.get("required", [])
|
||||
for field in required:
|
||||
assert field in properties, (
|
||||
f"Tool '{tool_name}': required field '{field}' "
|
||||
f"not found in properties {list(properties.keys())}"
|
||||
)
|
||||
|
||||
def test_every_property_has_type_and_description(
|
||||
self, tool_name: str, schema: dict
|
||||
) -> None:
|
||||
params = schema["function"].get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
for prop_name, prop_def in properties.items():
|
||||
assert (
|
||||
"type" in prop_def
|
||||
), f"Tool '{tool_name}', property '{prop_name}' is missing 'type'"
|
||||
assert (
|
||||
"description" in prop_def
|
||||
), f"Tool '{tool_name}', property '{prop_name}' is missing 'description'"
|
||||
|
||||
|
||||
def test_browser_act_action_enum_complete() -> None:
|
||||
"""Assert browser_act action enum still contains all 14 supported actions.
|
||||
|
||||
This prevents future PRs from accidentally dropping actions during description
|
||||
trimming. The enum is the authoritative list — this locks it at 14 values.
|
||||
"""
|
||||
tool = TOOL_REGISTRY["browser_act"]
|
||||
schema = tool.as_openai_tool()
|
||||
fn_def = schema["function"]
|
||||
params = cast(dict[str, Any], fn_def.get("parameters", {}))
|
||||
actions = params["properties"]["action"]["enum"]
|
||||
expected = {
|
||||
"click",
|
||||
"dblclick",
|
||||
"fill",
|
||||
"type",
|
||||
"scroll",
|
||||
"hover",
|
||||
"press",
|
||||
"check",
|
||||
"uncheck",
|
||||
"select",
|
||||
"wait",
|
||||
"back",
|
||||
"forward",
|
||||
"reload",
|
||||
}
|
||||
assert set(actions) == expected, (
|
||||
f"browser_act action enum changed. Got {set(actions)}, expected {expected}. "
|
||||
"If you added/removed an action, update this test intentionally."
|
||||
)
|
||||
|
||||
|
||||
def test_total_schema_char_budget() -> None:
|
||||
"""Assert total tool schema size stays under the character budget.
|
||||
|
||||
This locks in the 34% token reduction from #12398 and prevents future
|
||||
description bloat from eroding the gains. Uses character count with a
|
||||
~4 chars/token heuristic (budget of 32000 chars ≈ 8000 tokens).
|
||||
Character count is tokenizer-agnostic — no dependency on GPT or Claude
|
||||
tokenizers — while still providing a stable regression gate.
|
||||
"""
|
||||
schemas = [tool.as_openai_tool() for tool in TOOL_REGISTRY.values()]
|
||||
serialized = json.dumps(schemas)
|
||||
total_chars = len(serialized)
|
||||
assert total_chars < _CHAR_BUDGET, (
|
||||
f"Tool schemas use {total_chars} chars (~{total_chars // 4} tokens), "
|
||||
f"exceeding budget of {_CHAR_BUDGET} chars (~{_CHAR_BUDGET // 4} tokens). "
|
||||
f"Description bloat detected — trim descriptions or raise the budget intentionally."
|
||||
)
|
||||
@@ -22,9 +22,17 @@ class ValidateAgentGraphTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Validate agent JSON for correctness: block_ids, links, required fields, "
|
||||
"type compatibility, nested sink notation, prompt brace escaping, "
|
||||
"and AgentExecutorBlock configs. On failure, use fix_agent_graph to auto-fix."
|
||||
"Validate an agent JSON graph for correctness. Checks:\n"
|
||||
"- All block_ids reference real blocks\n"
|
||||
"- All links reference valid source/sink nodes and fields\n"
|
||||
"- Required input fields are wired or have defaults\n"
|
||||
"- Data types are compatible across links\n"
|
||||
"- Nested sink links use correct notation\n"
|
||||
"- Prompt templates use proper curly brace escaping\n"
|
||||
"- AgentExecutorBlock configurations are valid\n\n"
|
||||
"Call this after generating agent JSON to verify correctness. "
|
||||
"If validation fails, either fix issues manually based on the error "
|
||||
"descriptions, or call fix_agent_graph to auto-fix common problems."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -38,7 +46,11 @@ class ValidateAgentGraphTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": "Agent JSON with 'nodes' and 'links' arrays.",
|
||||
"description": (
|
||||
"The agent JSON to validate. Must contain 'nodes' and 'links' arrays. "
|
||||
"Each node needs: id (UUID), block_id, input_default, metadata. "
|
||||
"Each link needs: id (UUID), source_id, source_name, sink_id, sink_name."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
|
||||
@@ -59,7 +59,13 @@ class WebFetchTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Fetch a public web page. Public URLs only — internal addresses blocked. Returns readable text from HTML by default."
|
||||
return (
|
||||
"Fetch the content of a public web page by URL. "
|
||||
"Returns readable text extracted from HTML by default. "
|
||||
"Useful for reading documentation, articles, and API responses. "
|
||||
"Only supports HTTP/HTTPS GET requests to public URLs "
|
||||
"(private/internal network addresses are blocked)."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -68,11 +74,14 @@ class WebFetchTool(BaseTool):
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Public HTTP/HTTPS URL.",
|
||||
"description": "The public HTTP/HTTPS URL to fetch.",
|
||||
},
|
||||
"extract_text": {
|
||||
"type": "boolean",
|
||||
"description": "Extract text from HTML (default: true).",
|
||||
"description": (
|
||||
"If true (default), extract readable text from HTML. "
|
||||
"If false, return raw content."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -27,8 +27,6 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_FILE_SIZE_MB = Config().max_file_size_mb
|
||||
|
||||
# Sentinel file_id used when a tool-result file is read directly from the local
|
||||
# host filesystem (rather than from workspace storage).
|
||||
_LOCAL_TOOL_RESULT_FILE_ID = "local"
|
||||
@@ -417,7 +415,13 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "List persistent workspace files. For ephemeral session files, use SDK Glob/Read instead. Optionally filter by path prefix."
|
||||
return (
|
||||
"List files in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Read/Glob tools instead. "
|
||||
"Returns file names, paths, sizes, and metadata. "
|
||||
"Optionally filter by path prefix."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -426,17 +430,24 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
"properties": {
|
||||
"path_prefix": {
|
||||
"type": "string",
|
||||
"description": "Filter by path prefix (e.g. '/documents/').",
|
||||
"description": (
|
||||
"Optional path prefix to filter files "
|
||||
"(e.g., '/documents/' to list only files in documents folder). "
|
||||
"By default, only files from the current session are listed."
|
||||
),
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max files to return (default 50, max 100).",
|
||||
"description": "Maximum number of files to return (default 50, max 100)",
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
},
|
||||
"include_all_sessions": {
|
||||
"type": "boolean",
|
||||
"description": "Include files from all sessions (default: false).",
|
||||
"description": (
|
||||
"If true, list files from all sessions. "
|
||||
"Default is false (only current session's files)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
@@ -519,11 +530,18 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a file from persistent workspace. Specify file_id or path. "
|
||||
"Small text/image files return inline; large/binary return metadata+URL. "
|
||||
"Use save_to_path to copy to working dir for processing. "
|
||||
"Use offset/length for paginated reads. "
|
||||
"Paths scoped to current session; use /sessions/<id>/... for cross-session access."
|
||||
"Read a file from the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Read tool instead. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Use 'save_to_path' to copy the file to the working directory "
|
||||
"(sandbox or ephemeral) for processing with bash_exec or file tools. "
|
||||
"Use 'offset' and 'length' for paginated reads of large files "
|
||||
"(e.g., persisted tool outputs). "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -533,30 +551,48 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "File ID from list_workspace_files.",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Virtual file path (e.g. '/documents/report.pdf').",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
"save_to_path": {
|
||||
"type": "string",
|
||||
"description": "Copy file to this working directory path for processing.",
|
||||
"description": (
|
||||
"If provided, save the file to this path in the working "
|
||||
"directory (cloud sandbox when E2B is active, or "
|
||||
"ephemeral dir otherwise) so it can be processed with "
|
||||
"bash_exec or file tools. "
|
||||
"The file content is still returned in the response."
|
||||
),
|
||||
},
|
||||
"force_download_url": {
|
||||
"type": "boolean",
|
||||
"description": "Always return metadata+URL instead of inline content.",
|
||||
"description": (
|
||||
"If true, always return metadata+URL instead of inline content. "
|
||||
"Default is false (auto-selects based on file size/type)."
|
||||
),
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Character offset for paginated reads (0-based).",
|
||||
"description": (
|
||||
"Character offset to start reading from (0-based). "
|
||||
"Use with 'length' for paginated reads of large files."
|
||||
),
|
||||
},
|
||||
"length": {
|
||||
"type": "integer",
|
||||
"description": "Max characters to return for paginated reads.",
|
||||
"description": (
|
||||
"Maximum number of characters to return. "
|
||||
"Defaults to full file. Use with 'offset' for paginated reads."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one of file_id or path must be provided
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -719,10 +755,15 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write a file to persistent workspace (survives across sessions). "
|
||||
"Provide exactly one of: content (text), content_base64 (binary), "
|
||||
f"or source_path (copy from working dir). Max {_MAX_FILE_SIZE_MB}MB. "
|
||||
"Paths scoped to current session; use /sessions/<id>/... for cross-session access."
|
||||
"Write or create a file in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Write tool instead. "
|
||||
"Provide content as plain text via 'content', OR base64-encoded via "
|
||||
"'content_base64', OR copy a file from the ephemeral working directory "
|
||||
"via 'source_path'. Exactly one of these three is required. "
|
||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||
"Files are saved to the current session's folder by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -732,31 +773,51 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Filename (e.g. 'report.pdf').",
|
||||
"description": "Name for the file (e.g., 'report.pdf')",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Plain text content. Mutually exclusive with content_base64/source_path.",
|
||||
"description": (
|
||||
"Plain text content to write. Use this for text files "
|
||||
"(code, configs, documents, etc.). "
|
||||
"Mutually exclusive with content_base64 and source_path."
|
||||
),
|
||||
},
|
||||
"content_base64": {
|
||||
"type": "string",
|
||||
"description": "Base64-encoded binary content. Mutually exclusive with content/source_path.",
|
||||
"description": (
|
||||
"Base64-encoded file content. Use this for binary files "
|
||||
"(images, PDFs, etc.). "
|
||||
"Mutually exclusive with content and source_path."
|
||||
),
|
||||
},
|
||||
"source_path": {
|
||||
"type": "string",
|
||||
"description": "Working directory path to copy to workspace. Mutually exclusive with content/content_base64.",
|
||||
"description": (
|
||||
"Path to a file in the ephemeral working directory to "
|
||||
"copy to workspace (e.g., '/tmp/copilot-.../output.csv'). "
|
||||
"Use this to persist files created by bash_exec or SDK Write. "
|
||||
"Mutually exclusive with content and content_base64."
|
||||
),
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Virtual path (e.g. '/documents/report.pdf'). Defaults to '/{filename}'.",
|
||||
"description": (
|
||||
"Optional virtual path where to save the file "
|
||||
"(e.g., '/documents/report.pdf'). "
|
||||
"Defaults to '/{filename}'. Scoped to current session."
|
||||
),
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": "MIME type. Auto-detected from filename if omitted.",
|
||||
"description": (
|
||||
"Optional MIME type of the file. "
|
||||
"Auto-detected from filename if not provided."
|
||||
),
|
||||
},
|
||||
"overwrite": {
|
||||
"type": "boolean",
|
||||
"description": "Overwrite if file exists (default: false).",
|
||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||
},
|
||||
},
|
||||
"required": ["filename"],
|
||||
@@ -798,10 +859,10 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
return resolved
|
||||
content: bytes = resolved
|
||||
|
||||
max_size = _MAX_FILE_SIZE_MB * 1024 * 1024
|
||||
max_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_size:
|
||||
return ErrorResponse(
|
||||
message=f"File too large. Maximum size is {_MAX_FILE_SIZE_MB}MB",
|
||||
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -883,7 +944,12 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Delete a file from persistent workspace. Specify file_id or path. Paths scoped to current session; use /sessions/<id>/... for cross-session access."
|
||||
return (
|
||||
"Delete a file from the user's persistent workspace (cloud storage). "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -892,14 +958,17 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "File ID from list_workspace_files.",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Virtual file path.",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one of file_id or path must be provided
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
|
||||
@@ -32,9 +32,9 @@ from backend.blocks.llm import (
|
||||
AITextSummarizerBlock,
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
@@ -548,6 +548,7 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
)
|
||||
],
|
||||
SmartDecisionMakerBlock: LLM_COST,
|
||||
SearchOrganizationsBlock: [
|
||||
BlockCost(
|
||||
cost_amount=2,
|
||||
@@ -699,7 +700,6 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
),
|
||||
],
|
||||
OrchestratorBlock: LLM_COST,
|
||||
VideoNarrationBlock: [
|
||||
BlockCost(
|
||||
cost_amount=5, # ElevenLabs TTS cost
|
||||
|
||||
@@ -38,10 +38,6 @@ POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
|
||||
if POOL_TIMEOUT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
|
||||
|
||||
STMT_CACHE_SIZE = os.getenv("DB_STATEMENT_CACHE_SIZE")
|
||||
if STMT_CACHE_SIZE:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "statement_cache_size", STMT_CACHE_SIZE)
|
||||
|
||||
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
|
||||
|
||||
prisma = Prisma(
|
||||
|
||||
@@ -89,7 +89,6 @@ class ExecutionContext(BaseModel):
|
||||
# Safety settings
|
||||
human_in_the_loop_safe_mode: bool = True
|
||||
sensitive_action_safe_mode: bool = False
|
||||
dry_run: bool = False # When True, blocks are LLM-simulated, no real execution
|
||||
|
||||
# User settings
|
||||
user_timezone: str = "UTC"
|
||||
@@ -179,7 +178,6 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
)
|
||||
is_shared: bool = False
|
||||
share_token: Optional[str] = None
|
||||
is_dry_run: bool = False
|
||||
|
||||
class Stats(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
@@ -308,7 +306,6 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
),
|
||||
is_shared=_graph_exec.isShared,
|
||||
share_token=_graph_exec.shareToken,
|
||||
is_dry_run=stats.is_dry_run if stats else False,
|
||||
)
|
||||
|
||||
|
||||
@@ -721,12 +718,11 @@ async def create_graph_execution(
|
||||
graph_version: int,
|
||||
starting_nodes_input: list[tuple[str, BlockInput]], # list[(node_id, BlockInput)]
|
||||
inputs: Mapping[str, JsonValue],
|
||||
user_id: str, # Validated by callers (API auth layer / service-level checks)
|
||||
user_id: str,
|
||||
preset_id: Optional[str] = None,
|
||||
credential_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
parent_graph_exec_id: Optional[str] = None,
|
||||
is_dry_run: bool = False,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
@@ -764,7 +760,6 @@ async def create_graph_execution(
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
"parentGraphExecutionId": parent_graph_exec_id,
|
||||
**({"stats": Json({"is_dry_run": True})} if is_dry_run else {}),
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ the function returns plain values instead of lists, it causes:
|
||||
1 validation error for dict[str,list[any]] response
|
||||
Input should be a valid list [type=list_type, input_value='', input_type=str]
|
||||
|
||||
This breaks OrchestratorBlock agent mode tool execution.
|
||||
This breaks SmartDecisionMakerBlock agent mode tool execution.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -737,7 +737,7 @@ class GraphModel(Graph, GraphMeta):
|
||||
# Collect errors per node
|
||||
node_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
||||
|
||||
# Validate tool orchestrator nodes
|
||||
# Validate smart decision maker nodes
|
||||
nodes_block = {
|
||||
node.id: block
|
||||
for node in graph.nodes
|
||||
@@ -1096,9 +1096,6 @@ async def get_graph(
|
||||
Retrieves a graph from the DB.
|
||||
Defaults to the version with `is_active` if `version` is not passed.
|
||||
|
||||
See also: `get_graph_as_admin()` which bypasses ownership and marketplace
|
||||
checks for admin-only routes.
|
||||
|
||||
Returns `None` if the record is not found.
|
||||
"""
|
||||
graph = None
|
||||
@@ -1136,27 +1133,6 @@ async def get_graph(
|
||||
):
|
||||
graph = store_listing.AgentGraph
|
||||
|
||||
# Fall back to library membership: if the user has the agent in their
|
||||
# library (non-deleted, non-archived), grant access even if the agent is
|
||||
# no longer published. "You added it, you keep it."
|
||||
if graph is None and user_id is not None:
|
||||
library_where: dict[str, object] = {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_id,
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
}
|
||||
if version is not None:
|
||||
library_where["agentGraphVersion"] = version
|
||||
|
||||
library_agent = await LibraryAgent.prisma().find_first(
|
||||
where=library_where,
|
||||
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
||||
order={"agentGraphVersion": "desc"},
|
||||
)
|
||||
if library_agent and library_agent.AgentGraph:
|
||||
graph = library_agent.AgentGraph
|
||||
|
||||
if graph is None:
|
||||
return None
|
||||
|
||||
@@ -1231,9 +1207,13 @@ async def get_graph_as_admin(
|
||||
order={"version": "desc"},
|
||||
)
|
||||
|
||||
# Admin access bypasses ownership and marketplace checks — route-level
|
||||
# auth already ensures only admins can call this function.
|
||||
if graph is None:
|
||||
# For access, the graph must be owned by the user or listed in the store
|
||||
if graph is None or (
|
||||
graph.userId != user_id
|
||||
and not await is_graph_published_in_marketplace(
|
||||
graph_id, version or graph.version
|
||||
)
|
||||
):
|
||||
return None
|
||||
|
||||
if for_export:
|
||||
@@ -1392,9 +1372,8 @@ async def validate_graph_execution_permissions(
|
||||
## Logic
|
||||
A user can execute a graph if any of these is true:
|
||||
1. They own the graph and some version of it is still listed in their library
|
||||
2. The graph is in the user's library (non-deleted, non-archived)
|
||||
3. The graph is published in the marketplace and listed in their library
|
||||
4. The graph is published in the marketplace and is being executed as a sub-agent
|
||||
2. The graph is published in the marketplace and listed in their library
|
||||
3. The graph is published in the marketplace and is being executed as a sub-agent
|
||||
|
||||
Args:
|
||||
graph_id: The ID of the graph to check
|
||||
@@ -1416,7 +1395,6 @@ async def validate_graph_execution_permissions(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
}
|
||||
@@ -1426,39 +1404,19 @@ async def validate_graph_execution_permissions(
|
||||
# Step 1: Check if user owns this graph
|
||||
user_owns_graph = graph and graph.userId == user_id
|
||||
|
||||
# Step 2: Check if the exact graph version is in the library.
|
||||
# Step 2: Check if agent is in the library *and not deleted*
|
||||
user_has_in_library = library_agent is not None
|
||||
owner_has_live_library_entry = user_has_in_library
|
||||
if user_owns_graph and not user_has_in_library:
|
||||
# Owners are allowed to execute a new version as long as some live
|
||||
# library entry still exists for the graph. Non-owners stay
|
||||
# version-specific.
|
||||
owner_has_live_library_entry = (
|
||||
await LibraryAgent.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_id,
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
}
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
# Step 3: Apply permission logic
|
||||
# Access is granted if the user owns it, it's in the marketplace, OR
|
||||
# it's in the user's library ("you added it, you keep it").
|
||||
if not (
|
||||
user_owns_graph
|
||||
or user_has_in_library
|
||||
or await is_graph_published_in_marketplace(graph_id, graph_version)
|
||||
):
|
||||
raise GraphNotAccessibleError(
|
||||
f"You do not have access to graph #{graph_id} v{graph_version}: "
|
||||
"it is not owned by you, not in your library, "
|
||||
"and not available in the Marketplace"
|
||||
"it is not owned by you and not available in the Marketplace"
|
||||
)
|
||||
elif not (user_has_in_library or owner_has_live_library_entry or is_sub_graph):
|
||||
elif not (user_has_in_library or is_sub_graph):
|
||||
raise GraphNotInLibraryError(f"Graph #{graph_id} is not in your library")
|
||||
|
||||
# Step 6: Check execution-specific permissions (raises generic NotAuthorizedError)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import fastapi.exceptions
|
||||
@@ -13,17 +13,10 @@ from backend.api.model import CreateGraph
|
||||
from backend.blocks._base import BlockSchema, BlockSchemaInput
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.data.graph import (
|
||||
Graph,
|
||||
Link,
|
||||
Node,
|
||||
get_graph,
|
||||
validate_graph_execution_permissions,
|
||||
)
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.usecases.sample import create_test_user
|
||||
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@@ -602,861 +595,3 @@ def test_mcp_credential_combine_no_discriminator_values():
|
||||
f"Expected 1 credential entry for MCP blocks without discriminator_values, "
|
||||
f"got {len(combined)}: {list(combined.keys())}"
|
||||
)
|
||||
|
||||
|
||||
# --------------- get_graph access-control truth table --------------- #
|
||||
#
|
||||
# Full matrix of access scenarios for get_graph() and get_graph_as_admin().
|
||||
# Access priority: ownership > marketplace APPROVED > library membership.
|
||||
# Library is version-specific. get_graph_as_admin bypasses everything.
|
||||
#
|
||||
# | User | Owns? | Marketplace | Library | Version | Result | Test
|
||||
# |----------|-------|-------------|------------------|---------|---------|-----
|
||||
# | regular | yes | any | any | v1 | ACCESS | test_get_graph_library_not_queried_when_owned
|
||||
# | regular | no | APPROVED | any | v1 | ACCESS | test_get_graph_non_owner_approved_marketplace_agent
|
||||
# | regular | no | not listed | active, same ver | v1 | ACCESS | test_get_graph_library_member_can_access_unpublished
|
||||
# | regular | no | not listed | active, diff ver | v2 | DENIED | test_get_graph_library_wrong_version_denied
|
||||
# | regular | no | not listed | deleted | v1 | DENIED | test_get_graph_deleted_library_agent_denied
|
||||
# | regular | no | not listed | archived | v1 | DENIED | test_get_graph_archived_library_agent_denied
|
||||
# | regular | no | not listed | not present | v1 | DENIED | test_get_graph_non_owner_pending_not_in_library_denied
|
||||
# | regular | no | PENDING | active v1 | v2 | DENIED | test_library_v1_does_not_grant_access_to_pending_v2
|
||||
# | regular | no | not listed | null AgentGraph | v1 | DENIED | test_get_graph_library_with_null_agent_graph_denied
|
||||
# | anon | no | not listed | - | v1 | DENIED | test_get_graph_library_fallback_not_used_for_anonymous
|
||||
# | anon | no | APPROVED | - | v1 | ACCESS | test_get_graph_anonymous_approved_marketplace_access
|
||||
# | admin* | no | PENDING | - | v2 | ACCESS | test_admin_can_access_pending_v2_via_get_graph_as_admin
|
||||
#
|
||||
# Efficiency (no unnecessary queries):
|
||||
# | regular | yes | - | - | v1 | no mkt/lib | test_get_graph_library_not_queried_when_owned
|
||||
# | regular | no | APPROVED | - | v1 | no lib | test_get_graph_library_not_queried_when_marketplace_approved
|
||||
#
|
||||
# * = via get_graph_as_admin (admin-only routes)
|
||||
|
||||
|
||||
def _make_mock_db_graph(user_id: str = "owner-user-id") -> MagicMock:
|
||||
graph = MagicMock()
|
||||
graph.userId = user_id
|
||||
graph.id = "graph-id"
|
||||
graph.version = 1
|
||||
graph.Nodes = []
|
||||
return graph
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_non_owner_approved_marketplace_agent() -> None:
|
||||
"""A non-owner should be able to access a graph that has an APPROVED
|
||||
marketplace listing. This is the normal marketplace download flow."""
|
||||
owner_id = "owner-user-id"
|
||||
requester_id = "different-user-id"
|
||||
graph_id = "graph-id"
|
||||
mock_graph = _make_mock_db_graph(owner_id)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
mock_listing = MagicMock()
|
||||
mock_listing.AgentGraph = mock_graph
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
# First lookup (owned graph) returns None — requester != owner
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
# Marketplace fallback finds an APPROVED listing
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=mock_listing)
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=graph_id,
|
||||
version=1,
|
||||
user_id=requester_id,
|
||||
)
|
||||
|
||||
assert result is not None, "Non-owner should access APPROVED marketplace agent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_non_owner_pending_not_in_library_denied() -> None:
|
||||
"""A non-owner with no library membership and no APPROVED marketplace
|
||||
listing must be denied access."""
|
||||
requester_id = "different-user-id"
|
||||
graph_id = "graph-id"
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=graph_id,
|
||||
version=1,
|
||||
user_id=requester_id,
|
||||
)
|
||||
|
||||
assert (
|
||||
result is None
|
||||
), "User without ownership, marketplace, or library access must be denied"
|
||||
|
||||
|
||||
# --------------- Library membership grants graph access --------------- #
|
||||
# "You added it, you keep it" — product decision from SECRT-2167.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_library_member_can_access_unpublished() -> None:
|
||||
"""A user who has the agent in their library should be able to access it
|
||||
even if it's no longer published in the marketplace."""
|
||||
requester_id = "library-user-id"
|
||||
graph_id = "graph-id"
|
||||
mock_graph = _make_mock_db_graph("original-creator-id")
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
mock_library_agent = MagicMock()
|
||||
mock_library_agent.AgentGraph = mock_graph
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
# Not owned
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
# Not in marketplace (unpublished)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
# But IS in user's library
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_library_agent
|
||||
)
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=graph_id,
|
||||
version=1,
|
||||
user_id=requester_id,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model, "Library member should access unpublished agent"
|
||||
|
||||
# Verify library query filters on non-deleted, non-archived
|
||||
lib_call = mock_lib_prisma.return_value.find_first
|
||||
lib_call.assert_awaited_once()
|
||||
assert lib_call.await_args is not None
|
||||
lib_where = lib_call.await_args.kwargs["where"]
|
||||
assert lib_where["userId"] == requester_id
|
||||
assert lib_where["agentGraphId"] == graph_id
|
||||
assert lib_where["isDeleted"] is False
|
||||
assert lib_where["isArchived"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_deleted_library_agent_denied() -> None:
|
||||
"""If the user soft-deleted the agent from their library, they should
|
||||
NOT get access via the library fallback."""
|
||||
requester_id = "library-user-id"
|
||||
graph_id = "graph-id"
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
# Library query returns None because isDeleted=False filter excludes it
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=graph_id,
|
||||
version=1,
|
||||
user_id=requester_id,
|
||||
)
|
||||
|
||||
assert result is None, "Deleted library agent should not grant graph access"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_anonymous_approved_marketplace_access() -> None:
|
||||
"""Anonymous users (user_id=None) should still access APPROVED marketplace
|
||||
agents — the marketplace fallback doesn't require authentication."""
|
||||
graph_id = "graph-id"
|
||||
mock_graph = _make_mock_db_graph("creator-id")
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
mock_listing = MagicMock()
|
||||
mock_listing.AgentGraph = mock_graph
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=mock_listing)
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=graph_id,
|
||||
version=1,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
assert (
|
||||
result is mock_graph_model
|
||||
), "Anonymous user should access APPROVED marketplace agent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_library_fallback_not_used_for_anonymous() -> None:
|
||||
"""Anonymous requests (user_id=None) must not trigger the library
|
||||
fallback — there's no user to check library membership for."""
|
||||
graph_id = "graph-id"
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=graph_id,
|
||||
version=1,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
# Library should never be queried for anonymous users
|
||||
mock_lib_prisma.return_value.find_first.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_library_not_queried_when_owned() -> None:
|
||||
"""If the user owns the graph, the library fallback should NOT be
|
||||
triggered — ownership is sufficient."""
|
||||
owner_id = "owner-user-id"
|
||||
graph_id = "graph-id"
|
||||
mock_graph = _make_mock_db_graph(owner_id)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
# User owns the graph — first lookup succeeds
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=graph_id,
|
||||
version=1,
|
||||
user_id=owner_id,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
# Neither marketplace nor library should be queried
|
||||
mock_slv_prisma.return_value.find_first.assert_not_called()
|
||||
mock_lib_prisma.return_value.find_first.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_library_not_queried_when_marketplace_approved() -> None:
|
||||
"""If the graph is APPROVED in the marketplace, the library fallback
|
||||
should NOT be triggered — marketplace access is sufficient."""
|
||||
requester_id = "different-user-id"
|
||||
graph_id = "graph-id"
|
||||
mock_graph = _make_mock_db_graph("original-creator-id")
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
mock_listing = MagicMock()
|
||||
mock_listing.AgentGraph = mock_graph
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=mock_listing)
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=graph_id,
|
||||
version=1,
|
||||
user_id=requester_id,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
# Library should not be queried — marketplace was sufficient
|
||||
mock_lib_prisma.return_value.find_first.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_archived_library_agent_denied() -> None:
|
||||
"""If the user archived the agent in their library, they should
|
||||
NOT get access via the library fallback."""
|
||||
requester_id = "library-user-id"
|
||||
graph_id = "graph-id"
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
# Library query returns None because isArchived=False filter excludes it
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=graph_id,
|
||||
version=1,
|
||||
user_id=requester_id,
|
||||
)
|
||||
|
||||
assert result is None, "Archived library agent should not grant graph access"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_library_with_null_agent_graph_denied() -> None:
|
||||
"""If LibraryAgent exists but its AgentGraph relation is None
|
||||
(data integrity issue), access must be denied, not crash."""
|
||||
requester_id = "library-user-id"
|
||||
graph_id = "graph-id"
|
||||
|
||||
mock_library_agent = MagicMock()
|
||||
mock_library_agent.AgentGraph = None # broken relation
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_library_agent
|
||||
)
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=graph_id,
|
||||
version=1,
|
||||
user_id=requester_id,
|
||||
)
|
||||
|
||||
assert (
|
||||
result is None
|
||||
), "Library agent with missing graph relation should not grant access"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_library_wrong_version_denied() -> None:
|
||||
"""Having version 1 in your library must NOT grant access to version 2."""
|
||||
requester_id = "library-user-id"
|
||||
graph_id = "graph-id"
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
# Library has version 1 but we're requesting version 2 —
|
||||
# the where clause includes agentGraphVersion so this returns None
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=graph_id,
|
||||
version=2,
|
||||
user_id=requester_id,
|
||||
)
|
||||
|
||||
assert (
|
||||
result is None
|
||||
), "Library agent for version 1 must not grant access to version 2"
|
||||
# Verify version was included in the library query
|
||||
lib_call = mock_lib_prisma.return_value.find_first
|
||||
lib_call.assert_called_once()
|
||||
lib_where = lib_call.call_args.kwargs["where"]
|
||||
assert lib_where["agentGraphVersion"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_v1_does_not_grant_access_to_pending_v2() -> None:
|
||||
"""A regular user has v1 in their library. v2 is pending (not approved).
|
||||
They must NOT get access to v2 — library membership is version-specific."""
|
||||
requester_id = "regular-user-id"
|
||||
graph_id = "graph-id"
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
):
|
||||
# Not owned
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
# v2 is not APPROVED in marketplace
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
# Library has v1 but not v2 — version filter excludes it
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=graph_id,
|
||||
version=2,
|
||||
user_id=requester_id,
|
||||
)
|
||||
|
||||
assert result is None, "Regular user with v1 in library must not access pending v2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_access_pending_v2_via_get_graph_as_admin() -> None:
|
||||
"""An admin can access v2 (pending) via get_graph_as_admin even though
|
||||
only v1 is approved. get_graph_as_admin bypasses all access checks."""
|
||||
from backend.data.graph import get_graph_as_admin
|
||||
|
||||
admin_id = "admin-user-id"
|
||||
mock_graph = _make_mock_db_graph("creator-user-id")
|
||||
mock_graph.version = 2
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id="graph-id",
|
||||
version=2,
|
||||
user_id=admin_id,
|
||||
for_export=False,
|
||||
)
|
||||
|
||||
assert (
|
||||
result is mock_graph_model
|
||||
), "Admin must access pending v2 via get_graph_as_admin"
|
||||
|
||||
|
||||
# --------------- execution permission truth table --------------- #
|
||||
#
|
||||
# validate_graph_execution_permissions() has two gates:
|
||||
# 1. Accessible graph: owner OR exact-version library entry OR marketplace-published
|
||||
# 2. Runnable graph: exact-version library entry OR owner fallback to any live
|
||||
# library entry for the graph OR sub-graph exception
|
||||
#
|
||||
# Desired owner behavior differs from non-owners:
|
||||
# owners should be allowed to run a new version when some non-archived/non-deleted
|
||||
# version of that graph is still in their library. Non-owners stay
|
||||
# version-specific.
|
||||
#
|
||||
# | User | Owns? | Marketplace | Library state | is_sub_graph | Result | Test
|
||||
# |----------|-------|-------------|------------------------------|--------------|----------|-----
|
||||
# | regular | no | no | exact version present | false | ALLOW | test_validate_graph_execution_permissions_library_member_same_version_allowed
|
||||
# | owner | yes | no | exact version present | false | ALLOW | test_validate_graph_execution_permissions_owner_same_version_in_library_allowed
|
||||
# | owner | yes | no | previous version present | false | ALLOW | test_validate_graph_execution_permissions_owner_previous_library_version_allowed
|
||||
# | owner | yes | no | none present | false | DENY lib | test_validate_graph_execution_permissions_owner_without_library_denied
|
||||
# | owner | yes | no | only archived/deleted older | false | DENY lib | test_validate_graph_execution_permissions_owner_previous_archived_library_version_denied
|
||||
# | regular | no | yes | none present | false | DENY lib | test_validate_graph_execution_permissions_marketplace_graph_not_in_library_denied
|
||||
# | admin | no | no | none present | false | DENY acc | test_validate_graph_execution_permissions_admin_without_library_or_marketplace_denied
|
||||
# | regular | no | yes | none present | true | ALLOW | test_validate_graph_execution_permissions_marketplace_sub_graph_without_library_allowed
|
||||
# | regular | no | no | none present | true | DENY acc | test_validate_graph_execution_permissions_unpublished_sub_graph_without_library_denied
|
||||
# | regular | no | no | wrong version only | false | DENY acc | test_validate_graph_execution_permissions_library_wrong_version_denied
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_execution_permissions_library_member_same_version_allowed() -> (
|
||||
None
|
||||
):
|
||||
requester_id = "library-user-id"
|
||||
graph_id = "graph-id"
|
||||
graph_version = 2
|
||||
mock_graph = MagicMock(userId="creator-user-id")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.is_graph_published_in_marketplace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
) as mock_is_published,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=MagicMock())
|
||||
|
||||
await validate_graph_execution_permissions(
|
||||
user_id=requester_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
mock_is_published.assert_not_awaited()
|
||||
lib_where = mock_lib_prisma.return_value.find_first.call_args.kwargs["where"]
|
||||
assert lib_where["agentGraphVersion"] == graph_version
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_execution_permissions_owner_same_version_in_library_allowed() -> (
|
||||
None
|
||||
):
|
||||
requester_id = "owner-user-id"
|
||||
graph_id = "graph-id"
|
||||
graph_version = 2
|
||||
mock_graph = MagicMock(userId=requester_id)
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.is_graph_published_in_marketplace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
) as mock_is_published,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=MagicMock())
|
||||
|
||||
await validate_graph_execution_permissions(
|
||||
user_id=requester_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
mock_is_published.assert_not_awaited()
|
||||
lib_where = mock_lib_prisma.return_value.find_first.call_args.kwargs["where"]
|
||||
assert lib_where["agentGraphVersion"] == graph_version
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_execution_permissions_owner_previous_library_version_allowed() -> (
|
||||
None
|
||||
):
|
||||
requester_id = "owner-user-id"
|
||||
graph_id = "graph-id"
|
||||
graph_version = 2
|
||||
mock_graph = MagicMock(userId=requester_id)
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.is_graph_published_in_marketplace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
) as mock_is_published,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(
|
||||
side_effect=[None, MagicMock(name="PriorVersionLibraryAgent")]
|
||||
)
|
||||
|
||||
await validate_graph_execution_permissions(
|
||||
user_id=requester_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
mock_is_published.assert_not_awaited()
|
||||
assert mock_lib_prisma.return_value.find_first.await_count == 2
|
||||
first_where = mock_lib_prisma.return_value.find_first.await_args_list[0].kwargs[
|
||||
"where"
|
||||
]
|
||||
second_where = mock_lib_prisma.return_value.find_first.await_args_list[1].kwargs[
|
||||
"where"
|
||||
]
|
||||
assert first_where["agentGraphVersion"] == graph_version
|
||||
assert "agentGraphVersion" not in second_where
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_execution_permissions_owner_without_library_denied() -> (
|
||||
None
|
||||
):
|
||||
requester_id = "owner-user-id"
|
||||
graph_id = "graph-id"
|
||||
graph_version = 2
|
||||
mock_graph = MagicMock(userId=requester_id)
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.is_graph_published_in_marketplace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
) as mock_is_published,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(GraphNotInLibraryError):
|
||||
await validate_graph_execution_permissions(
|
||||
user_id=requester_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
mock_is_published.assert_not_awaited()
|
||||
assert mock_lib_prisma.return_value.find_first.await_count == 2
|
||||
first_where = mock_lib_prisma.return_value.find_first.await_args_list[0].kwargs[
|
||||
"where"
|
||||
]
|
||||
second_where = mock_lib_prisma.return_value.find_first.await_args_list[1].kwargs[
|
||||
"where"
|
||||
]
|
||||
assert first_where["agentGraphVersion"] == graph_version
|
||||
assert second_where == {
|
||||
"userId": requester_id,
|
||||
"agentGraphId": graph_id,
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_execution_permissions_owner_previous_archived_library_version_denied() -> (
|
||||
None
|
||||
):
|
||||
requester_id = "owner-user-id"
|
||||
graph_id = "graph-id"
|
||||
graph_version = 2
|
||||
mock_graph = MagicMock(userId=requester_id)
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.is_graph_published_in_marketplace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
) as mock_is_published,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(side_effect=[None, None])
|
||||
|
||||
with pytest.raises(GraphNotInLibraryError):
|
||||
await validate_graph_execution_permissions(
|
||||
user_id=requester_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
mock_is_published.assert_not_awaited()
|
||||
assert mock_lib_prisma.return_value.find_first.await_count == 2
|
||||
first_where = mock_lib_prisma.return_value.find_first.await_args_list[0].kwargs[
|
||||
"where"
|
||||
]
|
||||
second_where = mock_lib_prisma.return_value.find_first.await_args_list[1].kwargs[
|
||||
"where"
|
||||
]
|
||||
assert first_where["agentGraphVersion"] == graph_version
|
||||
assert second_where == {
|
||||
"userId": requester_id,
|
||||
"agentGraphId": graph_id,
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_execution_permissions_marketplace_graph_not_in_library_denied() -> (
|
||||
None
|
||||
):
|
||||
requester_id = "marketplace-user-id"
|
||||
graph_id = "graph-id"
|
||||
graph_version = 2
|
||||
mock_graph = MagicMock(userId="creator-user-id")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.is_graph_published_in_marketplace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
) as mock_is_published,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(GraphNotInLibraryError):
|
||||
await validate_graph_execution_permissions(
|
||||
user_id=requester_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
mock_is_published.assert_awaited_once_with(graph_id, graph_version)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_execution_permissions_admin_without_library_or_marketplace_denied() -> (
|
||||
None
|
||||
):
|
||||
requester_id = "admin-user-id"
|
||||
graph_id = "graph-id"
|
||||
graph_version = 2
|
||||
mock_graph = MagicMock(userId="creator-user-id")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.is_graph_published_in_marketplace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
) as mock_is_published,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(GraphNotAccessibleError):
|
||||
await validate_graph_execution_permissions(
|
||||
user_id=requester_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
mock_is_published.assert_awaited_once_with(graph_id, graph_version)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_execution_permissions_unpublished_sub_graph_without_library_denied() -> (
|
||||
None
|
||||
):
|
||||
requester_id = "marketplace-user-id"
|
||||
graph_id = "graph-id"
|
||||
graph_version = 2
|
||||
mock_graph = MagicMock(userId="creator-user-id")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.is_graph_published_in_marketplace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
) as mock_is_published,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(GraphNotAccessibleError):
|
||||
await validate_graph_execution_permissions(
|
||||
user_id=requester_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
is_sub_graph=True,
|
||||
)
|
||||
|
||||
mock_is_published.assert_awaited_once_with(graph_id, graph_version)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_execution_permissions_marketplace_sub_graph_without_library_allowed() -> (
|
||||
None
|
||||
):
|
||||
requester_id = "marketplace-user-id"
|
||||
graph_id = "graph-id"
|
||||
graph_version = 2
|
||||
mock_graph = MagicMock(userId="creator-user-id")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.is_graph_published_in_marketplace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
) as mock_is_published,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
|
||||
await validate_graph_execution_permissions(
|
||||
user_id=requester_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
is_sub_graph=True,
|
||||
)
|
||||
|
||||
mock_is_published.assert_awaited_once_with(graph_id, graph_version)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_execution_permissions_library_wrong_version_denied() -> (
|
||||
None
|
||||
):
|
||||
requester_id = "library-user-id"
|
||||
graph_id = "graph-id"
|
||||
graph_version = 2
|
||||
mock_graph = MagicMock(userId="creator-user-id")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.is_graph_published_in_marketplace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
) as mock_is_published,
|
||||
):
|
||||
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(GraphNotAccessibleError):
|
||||
await validate_graph_execution_permissions(
|
||||
user_id=requester_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
mock_is_published.assert_awaited_once_with(graph_id, graph_version)
|
||||
lib_where = mock_lib_prisma.return_value.find_first.call_args.kwargs["where"]
|
||||
assert lib_where["agentGraphVersion"] == graph_version
|
||||
|
||||
@@ -312,15 +312,6 @@ def SchemaField(
|
||||
) # type: ignore
|
||||
|
||||
|
||||
# SDK default credentials use IDs like "{provider}-default" (set in sdk/builder.py).
|
||||
# They must never be exposed to users via the API.
|
||||
SDK_DEFAULT_SUFFIX = "-default"
|
||||
|
||||
|
||||
def is_sdk_default(cred_id: str) -> bool:
|
||||
return cred_id.endswith(SDK_DEFAULT_SUFFIX)
|
||||
|
||||
|
||||
class _BaseCredentials(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
provider: str
|
||||
@@ -898,10 +889,6 @@ class GraphExecutionStats(BaseModel):
|
||||
default=None,
|
||||
description="AI-generated score (0.0-1.0) indicating how well the execution achieved its intended purpose",
|
||||
)
|
||||
is_dry_run: bool = Field(
|
||||
default=False,
|
||||
description="Whether this execution was a dry-run simulation",
|
||||
)
|
||||
|
||||
|
||||
class UserExecutionSummaryStats(BaseModel):
|
||||
|
||||
@@ -23,29 +23,11 @@ def _cache_key(user_id: str) -> str:
|
||||
|
||||
|
||||
def _json_to_list(value: Any) -> list[str]:
|
||||
"""Convert Json field to list[str], handling None.
|
||||
|
||||
Also handles legacy dict-format rows (e.g. ``{"Learn": [...], "Create": [...]}``
|
||||
from the reverted themed-prompts feature) by flattening all values into a single
|
||||
list so existing personalised data isn't silently lost.
|
||||
"""
|
||||
"""Convert Json field to list[str], handling None."""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return cast(list[str], value)
|
||||
if isinstance(value, dict):
|
||||
# Legacy themed-prompt format: flatten all string values from all categories.
|
||||
logger.debug(
|
||||
"_json_to_list: flattening legacy dict-format value (keys=%s)",
|
||||
list(value.keys()),
|
||||
)
|
||||
return [
|
||||
item
|
||||
for vals in value.values()
|
||||
if isinstance(vals, list)
|
||||
for item in vals
|
||||
if isinstance(item, str)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import pydantic
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import UserWorkspace, UserWorkspaceFile
|
||||
from prisma.types import UserWorkspaceFileWhereInput
|
||||
|
||||
@@ -76,23 +75,22 @@ async def get_or_create_workspace(user_id: str) -> Workspace:
|
||||
"""
|
||||
Get user's workspace, creating one if it doesn't exist.
|
||||
|
||||
Uses upsert to handle race conditions when multiple concurrent requests
|
||||
attempt to create a workspace for the same user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
Workspace instance
|
||||
"""
|
||||
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||
if workspace:
|
||||
return Workspace.from_db(workspace)
|
||||
|
||||
try:
|
||||
workspace = await UserWorkspace.prisma().create(data={"userId": user_id})
|
||||
except UniqueViolationError:
|
||||
# Concurrent request already created it
|
||||
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||
if workspace is None:
|
||||
raise
|
||||
workspace = await UserWorkspace.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id},
|
||||
"update": {}, # No updates needed if exists
|
||||
},
|
||||
)
|
||||
|
||||
return Workspace.from_db(workspace)
|
||||
|
||||
|
||||
@@ -81,7 +81,6 @@ from backend.util.settings import Settings
|
||||
from .activity_status_generator import generate_activity_status_for_execution
|
||||
from .automod.manager import automod_manager
|
||||
from .cluster_lock import ClusterLock
|
||||
from .simulator import simulate_block
|
||||
from .utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
@@ -223,11 +222,9 @@ async def execute_node(
|
||||
raise ValueError(f"Block {node_block.id} is disabled and cannot be executed")
|
||||
|
||||
# Sanity check: validate the execution input.
|
||||
input_data, error = validate_exec(
|
||||
node, data.inputs, resolve_input=False, dry_run=execution_context.dry_run
|
||||
)
|
||||
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
|
||||
if input_data is None:
|
||||
log_metadata.warning(f"Skip execution, input validation error: {error}")
|
||||
log_metadata.error(f"Skip execution, input validation error: {error}")
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
@@ -375,12 +372,9 @@ async def execute_node(
|
||||
scope.set_tag(f"execution_context.{k}", v)
|
||||
|
||||
try:
|
||||
if execution_context.dry_run:
|
||||
block_iter = simulate_block(node_block, input_data)
|
||||
else:
|
||||
block_iter = node_block.execute(input_data, **extra_exec_kwargs)
|
||||
|
||||
async for output_name, output_data in block_iter:
|
||||
async for output_name, output_data in node_block.execute(
|
||||
input_data, **extra_exec_kwargs
|
||||
):
|
||||
output_data = json.to_dict(output_data)
|
||||
output_size += len(json.dumps(output_data))
|
||||
log_metadata.debug("Node produced output", **{output_name: output_data})
|
||||
@@ -512,9 +506,7 @@ async def _enqueue_next_nodes(
|
||||
next_node_input.update(node_input_mask)
|
||||
|
||||
# Validate the input data for the next node.
|
||||
next_node_input, validation_msg = validate_exec(
|
||||
next_node, next_node_input, dry_run=execution_context.dry_run
|
||||
)
|
||||
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
|
||||
suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
|
||||
|
||||
# Incomplete input data, skip queueing the execution.
|
||||
@@ -559,9 +551,7 @@ async def _enqueue_next_nodes(
|
||||
if node_input_mask:
|
||||
idata.update(node_input_mask)
|
||||
|
||||
idata, msg = validate_exec(
|
||||
next_node, idata, dry_run=execution_context.dry_run
|
||||
)
|
||||
idata, msg = validate_exec(next_node, idata)
|
||||
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
|
||||
if not idata:
|
||||
log_metadata.info(f"Enqueueing static-link skipped: {suffix}")
|
||||
@@ -839,12 +829,9 @@ class ExecutionProcessor:
|
||||
return
|
||||
|
||||
if exec_meta.stats is None:
|
||||
exec_stats = GraphExecutionStats(
|
||||
is_dry_run=graph_exec.execution_context.dry_run,
|
||||
)
|
||||
exec_stats = GraphExecutionStats()
|
||||
else:
|
||||
exec_stats = exec_meta.stats.to_db()
|
||||
exec_stats.is_dry_run = graph_exec.execution_context.dry_run
|
||||
|
||||
timing_info, status = self._on_graph_execution(
|
||||
graph_exec=graph_exec,
|
||||
@@ -984,10 +971,7 @@ class ExecutionProcessor:
|
||||
running_node_evaluation = self.running_node_evaluation
|
||||
|
||||
try:
|
||||
if (
|
||||
not graph_exec.execution_context.dry_run
|
||||
and db_client.get_credits(graph_exec.user_id) <= 0
|
||||
):
|
||||
if db_client.get_credits(graph_exec.user_id) <= 0:
|
||||
raise InsufficientBalanceError(
|
||||
user_id=graph_exec.user_id,
|
||||
message="You have no credits left to run an agent.",
|
||||
@@ -1058,24 +1042,21 @@ class ExecutionProcessor:
|
||||
f"for node {queued_node_exec.node_id}",
|
||||
)
|
||||
|
||||
# Charge usage (may raise) — skipped for dry runs
|
||||
# Charge usage (may raise) ------------------------------
|
||||
try:
|
||||
if not graph_exec.execution_context.dry_run:
|
||||
cost, remaining_balance = self._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(
|
||||
graph_exec.user_id
|
||||
),
|
||||
)
|
||||
with execution_stats_lock:
|
||||
execution_stats.cost += cost
|
||||
# Check if we crossed the low balance threshold
|
||||
self._handle_low_balance(
|
||||
db_client=db_client,
|
||||
user_id=graph_exec.user_id,
|
||||
current_balance=remaining_balance,
|
||||
transaction_cost=cost,
|
||||
)
|
||||
cost, remaining_balance = self._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
)
|
||||
with execution_stats_lock:
|
||||
execution_stats.cost += cost
|
||||
# Check if we crossed the low balance threshold
|
||||
self._handle_low_balance(
|
||||
db_client=db_client,
|
||||
user_id=graph_exec.user_id,
|
||||
current_balance=remaining_balance,
|
||||
transaction_cost=cost,
|
||||
)
|
||||
except InsufficientBalanceError as balance_error:
|
||||
error = balance_error # Set error to trigger FAILED status
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
"""
|
||||
LLM-powered block simulator for dry-run execution.
|
||||
|
||||
When dry_run=True, instead of calling the real block, this module
|
||||
role-plays the block's execution using an LLM. No real API calls,
|
||||
no side effects. The LLM is grounded by:
|
||||
- Block name and description
|
||||
- Input/output schemas (from block.input_schema.jsonschema() / output_schema.jsonschema())
|
||||
- The actual input values
|
||||
|
||||
Inspired by https://github.com/Significant-Gravitas/agent-simulator
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from backend.util.clients import get_openai_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Use the same fast/cheap model the copilot uses for non-primary tasks.
|
||||
# Overridable via ChatConfig.title_model if ChatConfig is available.
|
||||
def _simulator_model() -> str:
|
||||
try:
|
||||
from backend.copilot.config import ChatConfig # noqa: PLC0415
|
||||
|
||||
model = ChatConfig().title_model
|
||||
except Exception:
|
||||
model = "openai/gpt-4o-mini"
|
||||
|
||||
# get_openai_client() may return a direct OpenAI client (not OpenRouter).
|
||||
# Direct OpenAI expects bare model names ("gpt-4o-mini"), not the
|
||||
# OpenRouter-prefixed form ("openai/gpt-4o-mini"). Strip the prefix when
|
||||
# the internal OpenAI key is configured (i.e. not going through OpenRouter).
|
||||
try:
|
||||
from backend.util.settings import Settings # noqa: PLC0415
|
||||
|
||||
secrets = Settings().secrets
|
||||
# get_openai_client() uses the direct OpenAI client whenever
|
||||
# openai_internal_api_key is set, regardless of open_router_api_key.
|
||||
# Strip the provider prefix (e.g. "openai/gpt-4o-mini" → "gpt-4o-mini")
|
||||
# so the model name is valid for the direct OpenAI API.
|
||||
if secrets.openai_internal_api_key and "/" in model:
|
||||
model = model.split("/", 1)[1]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return model
|
||||
|
||||
|
||||
_TEMPERATURE = 0.2
|
||||
_MAX_JSON_RETRIES = 5
|
||||
_MAX_INPUT_VALUE_CHARS = 20000
|
||||
|
||||
|
||||
def _truncate_value(value: Any) -> Any:
|
||||
"""Recursively truncate long strings anywhere in a value."""
|
||||
if isinstance(value, str):
|
||||
return (
|
||||
value[:_MAX_INPUT_VALUE_CHARS] + "... [TRUNCATED]"
|
||||
if len(value) > _MAX_INPUT_VALUE_CHARS
|
||||
else value
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
return {k: _truncate_value(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_truncate_value(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _truncate_input_values(input_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively truncate long string values so the prompt doesn't blow up."""
|
||||
return {k: _truncate_value(v) for k, v in input_data.items()}
|
||||
|
||||
|
||||
def _describe_schema_pins(schema: dict[str, Any]) -> str:
|
||||
"""Format output pins as a bullet list for the prompt."""
|
||||
properties = schema.get("properties", {})
|
||||
required = set(schema.get("required", []))
|
||||
lines = []
|
||||
for pin_name, pin_schema in properties.items():
|
||||
pin_type = pin_schema.get("type", "any")
|
||||
req = "required" if pin_name in required else "optional"
|
||||
lines.append(f"- {pin_name}: {pin_type} ({req})")
|
||||
return "\n".join(lines) if lines else "(no output pins defined)"
|
||||
|
||||
|
||||
def build_simulation_prompt(block: Any, input_data: dict[str, Any]) -> tuple[str, str]:
|
||||
"""Build (system_prompt, user_prompt) for block simulation."""
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
|
||||
input_pins = _describe_schema_pins(input_schema)
|
||||
output_pins = _describe_schema_pins(output_schema)
|
||||
output_properties = list(output_schema.get("properties", {}).keys())
|
||||
|
||||
block_name = getattr(block, "name", type(block).__name__)
|
||||
block_description = getattr(block, "description", "No description available.")
|
||||
|
||||
system_prompt = f"""You are simulating the execution of a software block called "{block_name}".
|
||||
|
||||
## Block Description
|
||||
{block_description}
|
||||
|
||||
## Input Schema
|
||||
{input_pins}
|
||||
|
||||
## Output Schema (what you must return)
|
||||
{output_pins}
|
||||
|
||||
Your task: given the current inputs, produce realistic simulated outputs for this block.
|
||||
|
||||
Rules:
|
||||
- Respond with a single JSON object whose keys are EXACTLY the output pin names listed above.
|
||||
- Assume all credentials and authentication are present and valid. Never simulate authentication failures.
|
||||
- Make the simulated outputs realistic and consistent with the inputs.
|
||||
- If there is an "error" pin, set it to "" (empty string) unless you are simulating a logical error.
|
||||
- Do not include any extra keys beyond the output pins.
|
||||
|
||||
Output pin names you MUST include: {json.dumps(output_properties)}
|
||||
"""
|
||||
|
||||
safe_inputs = _truncate_input_values(input_data)
|
||||
user_prompt = f"## Current Inputs\n{json.dumps(safe_inputs, indent=2)}"
|
||||
|
||||
return system_prompt, user_prompt
|
||||
|
||||
|
||||
async def simulate_block(
|
||||
block: Any,
|
||||
input_data: dict[str, Any],
|
||||
) -> AsyncIterator[tuple[str, Any]]:
|
||||
"""Simulate block execution using an LLM.
|
||||
|
||||
Yields (output_name, output_data) tuples matching the Block.execute() interface.
|
||||
On unrecoverable failure, yields a single ("error", "[SIMULATOR ERROR ...") tuple.
|
||||
"""
|
||||
client = get_openai_client()
|
||||
if client is None:
|
||||
yield (
|
||||
"error",
|
||||
"[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available "
|
||||
"(missing OpenAI/OpenRouter API key).",
|
||||
)
|
||||
return
|
||||
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
output_properties: dict[str, Any] = output_schema.get("properties", {})
|
||||
|
||||
system_prompt, user_prompt = build_simulation_prompt(block, input_data)
|
||||
|
||||
model = _simulator_model()
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(_MAX_JSON_RETRIES):
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
temperature=_TEMPERATURE,
|
||||
response_format={"type": "json_object"},
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
)
|
||||
if not response.choices:
|
||||
raise ValueError("LLM returned empty choices array")
|
||||
raw = response.choices[0].message.content or ""
|
||||
parsed = json.loads(raw)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError(f"LLM returned non-object JSON: {raw[:200]}")
|
||||
|
||||
# Fill missing output pins with defaults
|
||||
result: dict[str, Any] = {}
|
||||
for pin_name in output_properties:
|
||||
if pin_name in parsed:
|
||||
result[pin_name] = parsed[pin_name]
|
||||
else:
|
||||
result[pin_name] = "" if pin_name == "error" else None
|
||||
|
||||
logger.debug(
|
||||
"simulate_block: block=%s attempt=%d tokens=%s/%s",
|
||||
getattr(block, "name", "?"),
|
||||
attempt + 1,
|
||||
getattr(getattr(response, "usage", None), "prompt_tokens", "?"),
|
||||
getattr(getattr(response, "usage", None), "completion_tokens", "?"),
|
||||
)
|
||||
|
||||
for pin_name, pin_value in result.items():
|
||||
yield pin_name, pin_value
|
||||
return
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
"simulate_block: JSON parse error on attempt %d/%d: %s",
|
||||
attempt + 1,
|
||||
_MAX_JSON_RETRIES,
|
||||
e,
|
||||
)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.error("simulate_block: LLM call failed: %s", e, exc_info=True)
|
||||
break
|
||||
|
||||
logger.error(
|
||||
"simulate_block: all %d retries exhausted for block=%s; last_error=%s",
|
||||
_MAX_JSON_RETRIES,
|
||||
getattr(block, "name", "?"),
|
||||
last_error,
|
||||
)
|
||||
yield (
|
||||
"error",
|
||||
f"[SIMULATOR ERROR — NOT A BLOCK FAILURE] Failed after {_MAX_JSON_RETRIES} "
|
||||
f"attempts: {last_error}",
|
||||
)
|
||||
@@ -181,7 +181,6 @@ def validate_exec(
|
||||
node: Node,
|
||||
data: BlockInput,
|
||||
resolve_input: bool = True,
|
||||
dry_run: bool = False,
|
||||
) -> tuple[BlockInput | None, str]:
|
||||
"""
|
||||
Validate the input data for a node execution.
|
||||
@@ -190,9 +189,6 @@ def validate_exec(
|
||||
node: The node to execute.
|
||||
data: The input data for the node execution.
|
||||
resolve_input: Whether to resolve dynamic pins into dict/list/object.
|
||||
dry_run: When True, credential fields are allowed to be missing — they
|
||||
will be substituted with a sentinel so the node can be queued and
|
||||
later executed via simulate_block.
|
||||
|
||||
Returns:
|
||||
A tuple of the validated data and the block name.
|
||||
@@ -211,14 +207,6 @@ def validate_exec(
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
return None, f"{error_prefix} unpopulated links {missing_links}"
|
||||
|
||||
# For dry runs, supply sentinel values for any missing credential fields so
|
||||
# the node can be queued — simulate_block never calls the real API anyway.
|
||||
if dry_run:
|
||||
cred_field_names = set(schema.get_credentials_fields().keys())
|
||||
for field_name in cred_field_names:
|
||||
if field_name not in data:
|
||||
data = {**data, field_name: None}
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
input_default = schema.get_input_defaults(node.input_default)
|
||||
data = {**input_default, **data}
|
||||
@@ -230,21 +218,13 @@ def validate_exec(
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
if dry_run:
|
||||
# In dry-run mode all missing inputs are tolerated — simulate_block()
|
||||
# generates synthetic outputs without needing real input values.
|
||||
pass
|
||||
else:
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
# Skip for dry runs — simulate_block doesn't use real inputs, and sentinel
|
||||
# credential values (None) would fail JSON-schema type/required checks.
|
||||
if not dry_run:
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.warning(error_message)
|
||||
return None, error_message
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.warning(error_message)
|
||||
return None, error_message
|
||||
|
||||
return data, node_block.name
|
||||
|
||||
@@ -447,7 +427,6 @@ async def _construct_starting_node_execution_input(
|
||||
user_id: str,
|
||||
graph_inputs: GraphInput,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
dry_run: bool = False,
|
||||
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
@@ -460,7 +439,6 @@ async def _construct_starting_node_execution_input(
|
||||
user_id (str): The ID of the user executing the graph.
|
||||
data (GraphInput): The input data for the graph execution.
|
||||
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
||||
dry_run: When True, skip credential validation errors (simulation needs no real creds).
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
@@ -473,32 +451,6 @@ async def _construct_starting_node_execution_input(
|
||||
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
# Dry runs simulate every block — missing credentials are irrelevant.
|
||||
# Strip credential-only errors so the graph can proceed.
|
||||
if dry_run and validation_errors:
|
||||
|
||||
def _is_credential_error(msg: str) -> bool:
|
||||
"""Match errors produced by _validate_node_input_credentials."""
|
||||
m = msg.lower()
|
||||
return (
|
||||
m == "these credentials are required"
|
||||
or m.startswith("invalid credentials:")
|
||||
or m.startswith("credentials not available:")
|
||||
or m.startswith("unknown credentials #")
|
||||
)
|
||||
|
||||
validation_errors = {
|
||||
node_id: {
|
||||
field: msg
|
||||
for field, msg in errors.items()
|
||||
if not _is_credential_error(msg)
|
||||
}
|
||||
for node_id, errors in validation_errors.items()
|
||||
}
|
||||
# Remove nodes that have no remaining errors
|
||||
validation_errors = {
|
||||
node_id: errors for node_id, errors in validation_errors.items() if errors
|
||||
}
|
||||
n_error_nodes = len(validation_errors)
|
||||
n_errors = sum(len(errors) for errors in validation_errors.values())
|
||||
if validation_errors:
|
||||
@@ -542,7 +494,7 @@ async def _construct_starting_node_execution_input(
|
||||
"Please use the appropriate trigger to run this agent."
|
||||
)
|
||||
|
||||
input_data, error = validate_exec(node, input_data, dry_run=dry_run)
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise ValueError(error)
|
||||
else:
|
||||
@@ -564,7 +516,6 @@ async def validate_and_construct_node_execution_input(
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
is_sub_graph: bool = False,
|
||||
dry_run: bool = False,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
@@ -630,7 +581,6 @@ async def validate_and_construct_node_execution_input(
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -868,7 +818,6 @@ async def add_graph_execution(
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
execution_context: Optional[ExecutionContext] = None,
|
||||
graph_exec_id: Optional[str] = None,
|
||||
dry_run: bool = False,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
@@ -933,7 +882,6 @@ async def add_graph_execution(
|
||||
graph_credentials_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
is_sub_graph=parent_exec_id is not None,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -947,7 +895,6 @@ async def add_graph_execution(
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
parent_graph_exec_id=parent_exec_id,
|
||||
is_dry_run=dry_run,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -970,7 +917,6 @@ async def add_graph_execution(
|
||||
# Safety settings
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
dry_run=dry_run,
|
||||
# User settings
|
||||
user_timezone=(
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user