mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
66 Commits
feat/build
...
spare/16
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80bfd64ffa | ||
|
|
0076ad2a1a | ||
|
|
edb3d322f0 | ||
|
|
9381057079 | ||
|
|
f21a36ca37 | ||
|
|
ee5382a064 | ||
|
|
b80e5ea987 | ||
|
|
3d4fcfacb6 | ||
|
|
32eac6d52e | ||
|
|
9762f4cde7 | ||
|
|
76901ba22f | ||
|
|
23b65939f3 | ||
|
|
1c27eaac53 | ||
|
|
923b164794 | ||
|
|
e86ac21c43 | ||
|
|
94224be841 | ||
|
|
da4bdc7ab9 | ||
|
|
7176cecf25 | ||
|
|
f35210761c | ||
|
|
1ebcf85669 | ||
|
|
ab7c38bda7 | ||
|
|
0f67e45d05 | ||
|
|
b9ce37600e | ||
|
|
3921deaef1 | ||
|
|
f01f668674 | ||
|
|
f7a3491f91 | ||
|
|
cbff3b53d3 | ||
|
|
5b9a4c52c9 | ||
|
|
0ce1c90b55 | ||
|
|
d4c6eb9adc | ||
|
|
1bb91b53b7 | ||
|
|
a5f9c43a41 | ||
|
|
1240f38f75 | ||
|
|
f617f50f0b | ||
|
|
943a1df815 | ||
|
|
593001e0c8 | ||
|
|
e1db8234a3 | ||
|
|
282173be9d | ||
|
|
5d9a169e04 | ||
|
|
6fd1050457 | ||
|
|
02708bcd00 | ||
|
|
156d61fe5c | ||
|
|
5a29de0e0e | ||
|
|
e657472162 | ||
|
|
4d00e0f179 | ||
|
|
1d7282b5f3 | ||
|
|
e3591fcaa3 | ||
|
|
876dc32e17 | ||
|
|
616e29f5e4 | ||
|
|
280a98ad38 | ||
|
|
c7f2a7dd03 | ||
|
|
6d0e2063ec | ||
|
|
8b577ae194 | ||
|
|
d8f5f783ae | ||
|
|
82d22f3680 | ||
|
|
50622333d1 | ||
|
|
27af5782a9 | ||
|
|
522f932e67 | ||
|
|
a6124b06d5 | ||
|
|
ae660ea04f | ||
|
|
2479f3a1c4 | ||
|
|
8153306384 | ||
|
|
9c3d100a22 | ||
|
|
fc3bf6c154 | ||
|
|
e32d258a7e | ||
|
|
3e86544bfe |
@@ -2,7 +2,7 @@
|
||||
name: pr-address
|
||||
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
|
||||
user-invocable: true
|
||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
@@ -17,18 +17,70 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Fetch comments (all sources)
|
||||
## Read the PR description
|
||||
|
||||
Understand the **Why / What / How** before addressing comments — you need context to make good fixes:
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
**Bots to watch for:**
|
||||
- `autogpt-reviewer` — posts "Blockers", "Should Fix", "Nice to Have". Address ALL of them.
|
||||
- `sentry[bot]` — bug predictions. Fix real bugs, explain false positives.
|
||||
- `coderabbitai[bot]` — automated review. Address actionable items.
|
||||
## Fetch comments (all sources)
|
||||
|
||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||
|
||||
Use GraphQL to fetch inline threads. It natively exposes `isResolved`, returns threads already grouped with all replies, and paginates via cursor — no manual thread reconstruction needed.
|
||||
|
||||
```bash
|
||||
gh api graphql -f query='
|
||||
{
|
||||
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads(first: 100) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
path
|
||||
comments(last: 1) {
|
||||
nodes { databaseId body author { login } createdAt }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
If `pageInfo.hasNextPage` is true, fetch subsequent pages by adding `after: "<endCursor>"` to `reviewThreads(first: 100, after: "...")` and repeat until `hasNextPage` is false.
|
||||
|
||||
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
|
||||
|
||||
### 2. Top-level reviews — REST (MUST paginate)
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
||||
```
|
||||
|
||||
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80–170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
|
||||
|
||||
Two things to extract:
|
||||
- **Overall state**: look for `CHANGES_REQUESTED` or `APPROVED` reviews.
|
||||
- **Actionable feedback**: non-empty bodies only. Empty-body reviews are thread-resolution events — they indicate progress but have no feedback to act on.
|
||||
|
||||
**Where each reviewer posts:**
|
||||
- `autogpt-reviewer` — posts detailed structured reviews ("Blockers", "Should Fix", "Nice to Have") as **top-level reviews**. Not present on every PR. Address ALL items.
|
||||
- `sentry[bot]` — posts bug predictions as **inline threads**. Fix real bugs, explain false positives.
|
||||
- `coderabbitai[bot]` — posts summaries as **top-level reviews** AND actionable items as **inline threads**. Address actionable items.
|
||||
- Human reviewers — can post in any source. Address ALL non-empty feedback.
|
||||
|
||||
### 3. PR conversation comments — REST
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
||||
```
|
||||
|
||||
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
|
||||
|
||||
## For each unaddressed comment
|
||||
|
||||
@@ -40,8 +92,8 @@ Address comments **one at a time**: fix → commit → push → inline reply →
|
||||
|
||||
| Comment type | How to reply |
|
||||
|---|---|
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="Fixed in <commit-sha>: <description>"` |
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
|
||||
## Format and commit
|
||||
|
||||
@@ -69,11 +121,88 @@ For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
```text
|
||||
address comments → format → commit → push
|
||||
→ re-check comments → fix new ones → push
|
||||
→ wait for CI → re-check comments after CI settles
|
||||
→ wait for CI (while addressing new comments) → fix failures → push
|
||||
→ re-check comments after CI settles
|
||||
→ repeat until: all comments addressed AND CI green AND no new comments arriving
|
||||
```
|
||||
|
||||
While CI runs, stay productive: run local tests, address remaining comments.
|
||||
### Polling for CI + new comments
|
||||
|
||||
**The loop ends when:** CI fully green + all comments addressed + no new comments since CI settled.
|
||||
After pushing, poll for **both** CI status and new comments in a single loop. Do not use `gh pr checks --watch` — it blocks the tool and prevents reacting to new comments while CI is running.
|
||||
|
||||
> **Note:** `gh pr checks --watch --fail-fast` is tempting but it blocks the entire Bash tool call, meaning the agent cannot check for or address new comments until CI fully completes. Always poll manually instead.
|
||||
|
||||
**Polling loop — repeat every 30 seconds:**
|
||||
|
||||
1. Check CI status:
|
||||
```bash
|
||||
gh pr checks {N} --repo Significant-Gravitas/AutoGPT --json bucket,name,link
|
||||
```
|
||||
Parse the results: if every check has `bucket` of `"pass"` or `"skipping"`, CI is green. If any has `"fail"`, CI has failed. Otherwise CI is still pending.
|
||||
|
||||
2. Check for merge conflicts:
|
||||
```bash
|
||||
gh pr view {N} --repo Significant-Gravitas/AutoGPT --json mergeable --jq '.mergeable'
|
||||
```
|
||||
If the result is `"CONFLICTING"`, the PR has a merge conflict — see "Resolving merge conflicts" below. If `"UNKNOWN"`, GitHub is still computing mergeability — wait and re-check next poll.
|
||||
|
||||
3. Check for new/changed comments (all three sources):
|
||||
|
||||
**Inline threads** — re-run the GraphQL query from "Fetch comments". For each unresolved thread, record `{thread_id, last_comment_databaseId}` as your baseline. On each poll, action is needed if:
|
||||
- A new thread `id` appears that wasn't in the baseline (new thread), OR
|
||||
- An existing thread's `last_comment_databaseId` has changed (new reply on existing thread)
|
||||
|
||||
**Conversation comments:**
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
||||
```
|
||||
Compare total count and newest `id` against baseline. Filter to non-empty, non-bot, non-author-update messages.
|
||||
|
||||
**Top-level reviews:**
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
||||
```
|
||||
Watch for new non-empty reviews (`CHANGES_REQUESTED` or `COMMENTED` with body). Compare total count and newest `id` against baseline.
|
||||
|
||||
4. **React in this precedence order (first match wins):**
|
||||
|
||||
| What happened | Action |
|
||||
|---|---|
|
||||
| Merge conflict detected | See "Resolving merge conflicts" below. |
|
||||
| Mergeability is `UNKNOWN` | GitHub is still computing mergeability. Sleep 30 seconds, then restart polling from the top. |
|
||||
| New comments detected | Address them (fix → commit → push → reply). After pushing, re-fetch all comments to update your baseline, then restart this polling loop from the top (new commits invalidate CI status). |
|
||||
| CI failed (bucket == "fail") | Get failed check links: `gh pr checks {N} --repo Significant-Gravitas/AutoGPT --json bucket,link --jq '.[] \| select(.bucket == "fail") \| .link'`. Extract run ID from link (format: `.../actions/runs/<run-id>/job/...`), read logs with `gh run view <run-id> --repo Significant-Gravitas/AutoGPT --log-failed`. Fix → commit → push → restart polling. |
|
||||
| CI green + no new comments | **Do not exit immediately.** Bots (coderabbitai, sentry) often post reviews shortly after CI settles. Continue polling for **2 more cycles (60s)** after CI goes green. Only exit after 2 consecutive green+quiet polls. |
|
||||
| CI pending + no new comments | Sleep 30 seconds, then poll again. |
|
||||
|
||||
**The loop ends when:** CI fully green + all comments addressed + **2 consecutive polls with no new comments after CI settled.**
|
||||
|
||||
### Resolving merge conflicts
|
||||
|
||||
1. Identify the PR's target branch and remote:
|
||||
```bash
|
||||
gh pr view {N} --repo Significant-Gravitas/AutoGPT --json baseRefName --jq '.baseRefName'
|
||||
git remote -v # find the remote pointing to Significant-Gravitas/AutoGPT (typically 'upstream' in forks, 'origin' for direct contributors)
|
||||
```
|
||||
|
||||
2. Pull the latest base branch with a 3-way merge:
|
||||
```bash
|
||||
git pull {base-remote} {base-branch} --no-rebase
|
||||
```
|
||||
|
||||
3. Resolve conflicting files, then verify no conflict markers remain:
|
||||
```bash
|
||||
if grep -R -n -E '^(<<<<<<<|=======|>>>>>>>)' <conflicted-files>; then
|
||||
echo "Unresolved conflict markers found — resolve before proceeding."
|
||||
exit 1
|
||||
fi
|
||||
```
|
||||
|
||||
4. Stage and push:
|
||||
```bash
|
||||
git add <conflicted-files>
|
||||
git commit -m "Resolve merge conflicts with {base-branch}"
|
||||
git push
|
||||
```
|
||||
|
||||
5. Restart the polling loop from the top — new commits reset CI status.
|
||||
|
||||
@@ -17,6 +17,16 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Read the PR description
|
||||
|
||||
Before reading code, understand the **why**, **what**, and **how** from the PR description:
|
||||
|
||||
```bash
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
Every PR should have a Why / What / How structure. If any of these are missing, note it as feedback.
|
||||
|
||||
## Read the diff
|
||||
|
||||
```bash
|
||||
@@ -28,12 +38,14 @@ gh pr diff {N}
|
||||
Before posting anything, fetch existing inline comments to avoid duplicates:
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
||||
```
|
||||
|
||||
## What to check
|
||||
|
||||
**Description quality:** Does the PR description cover Why (motivation/problem), What (summary of changes), and How (approach/implementation details)? If any are missing, request them — you can't judge the approach without understanding the problem and intent.
|
||||
|
||||
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
||||
|
||||
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
||||
|
||||
754
.claude/skills/pr-test/SKILL.md
Normal file
754
.claude/skills/pr-test/SKILL.md
Normal file
@@ -0,0 +1,754 @@
|
||||
---
|
||||
name: pr-test
|
||||
description: "E2E manual testing of PRs/branches using docker compose, agent-browser, and API calls. TRIGGER when user asks to manually test a PR, test a feature end-to-end, or run integration tests against a running system."
|
||||
user-invocable: true
|
||||
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "2.0.0"
|
||||
---
|
||||
|
||||
# Manual E2E Test
|
||||
|
||||
Test a PR/branch end-to-end by building the full platform, interacting via browser and API, capturing screenshots, and reporting results.
|
||||
|
||||
## Critical Requirements
|
||||
|
||||
These are NON-NEGOTIABLE. Every test run MUST satisfy ALL the following:
|
||||
|
||||
### 1. Screenshots at Every Step
|
||||
- Take a screenshot at EVERY significant test step — not just at the end
|
||||
- Every test scenario MUST have at least one BEFORE and one AFTER screenshot
|
||||
- Name screenshots sequentially: `{NN}-{action}-{state}.png` (e.g., `01-credits-before.png`, `02-credits-after.png`)
|
||||
- If a screenshot is missing for a scenario, the test is INCOMPLETE — go back and take it
|
||||
|
||||
### 2. Screenshots MUST Be Posted to PR
|
||||
- Push ALL screenshots to a temp branch `test-screenshots/pr-{N}`
|
||||
- Post a PR comment with ALL screenshots embedded inline using GitHub raw URLs
|
||||
- This is NOT optional — every test run MUST end with a PR comment containing screenshots
|
||||
- If screenshot upload fails, retry. If it still fails, list failed files and require manual drag-and-drop/paste attachment in the PR comment
|
||||
|
||||
### 3. State Verification with Before/After Evidence
|
||||
- For EVERY state-changing operation (API call, user action), capture the state BEFORE and AFTER
|
||||
- Log the actual API response values (e.g., `credits_before=100, credits_after=95`)
|
||||
- Screenshot MUST show the relevant UI state change
|
||||
- Compare expected vs actual values explicitly — do not just eyeball it
|
||||
|
||||
### 4. Negative Test Cases Are Mandatory
|
||||
- Test at least ONE negative case per feature (e.g., insufficient credits, invalid input, unauthorized access)
|
||||
- Verify error messages are user-friendly and accurate
|
||||
- Verify the system state did NOT change after a rejected operation
|
||||
|
||||
### 5. Test Report Must Include Full Evidence
|
||||
Each test scenario in the report MUST have:
|
||||
- **Steps**: What was done (exact commands or UI actions)
|
||||
- **Expected**: What should happen
|
||||
- **Actual**: What actually happened
|
||||
- **API Evidence**: Before/after API response values for state-changing operations
|
||||
- **Screenshot Evidence**: Before/after screenshots with explanations
|
||||
|
||||
## State Manipulation for Realistic Testing
|
||||
|
||||
When testing features that depend on specific states (rate limits, credits, quotas):
|
||||
|
||||
1. **Use Redis CLI to set counters directly:**
|
||||
```bash
|
||||
# Find the Redis container
|
||||
REDIS_CONTAINER=$(docker ps --format '{{.Names}}' | grep redis | head -1)
|
||||
# Set a key with expiry
|
||||
docker exec $REDIS_CONTAINER redis-cli SET key value EX ttl
|
||||
# Example: Set rate limit counter to near-limit
|
||||
docker exec $REDIS_CONTAINER redis-cli SET "rate_limit:user:test@test.com" 99 EX 3600
|
||||
# Example: Check current value
|
||||
docker exec $REDIS_CONTAINER redis-cli GET "rate_limit:user:test@test.com"
|
||||
```
|
||||
|
||||
2. **Use API calls to check before/after state:**
|
||||
```bash
|
||||
# BEFORE: Record current state
|
||||
BEFORE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
echo "Credits BEFORE: $BEFORE"
|
||||
|
||||
# Perform the action...
|
||||
|
||||
# AFTER: Record new state and compare
|
||||
AFTER=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
echo "Credits AFTER: $AFTER"
|
||||
echo "Delta: $(( BEFORE - AFTER ))"
|
||||
```
|
||||
|
||||
3. **Take screenshots BEFORE and AFTER state changes** — the UI must reflect the backend state change
|
||||
|
||||
4. **Never rely on mocked/injected browser state** — always use real backend state. Do NOT use `agent-browser eval` to fake UI state. The backend must be the source of truth.
|
||||
|
||||
5. **Use direct DB queries when needed:**
|
||||
```bash
|
||||
# Query via Supabase's PostgREST or docker exec into the DB
|
||||
docker exec supabase-db psql -U supabase_admin -d postgres -c "SELECT credits FROM user_credits WHERE user_id = '...';"
|
||||
```
|
||||
|
||||
6. **After every API test, verify the state change actually persisted:**
|
||||
```bash
|
||||
# Example: After a credits purchase, verify DB matches API
|
||||
API_CREDITS=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
DB_CREDITS=$(docker exec supabase-db psql -U supabase_admin -d postgres -t -c "SELECT credits FROM user_credits WHERE user_id = '...';" | tr -d ' ')
|
||||
[ "$API_CREDITS" = "$DB_CREDITS" ] && echo "CONSISTENT" || echo "MISMATCH: API=$API_CREDITS DB=$DB_CREDITS"
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
- `$ARGUMENTS` — worktree path (e.g. `$REPO_ROOT`) or PR number
|
||||
- If `--fix` flag is present, auto-fix bugs found and push fixes (like pr-address loop)
|
||||
|
||||
## Step 0: Resolve the target
|
||||
|
||||
```bash
|
||||
# If argument is a PR number, find its worktree
|
||||
gh pr view {N} --json headRefName --jq '.headRefName'
|
||||
# If argument is a path, use it directly
|
||||
```
|
||||
|
||||
Determine:
|
||||
- `REPO_ROOT` — the root repo directory: `git -C "$WORKTREE_PATH" worktree list | head -1 | awk '{print $1}'` (or `git rev-parse --show-toplevel` if not a worktree)
|
||||
- `WORKTREE_PATH` — the worktree directory
|
||||
- `PLATFORM_DIR` — `$WORKTREE_PATH/autogpt_platform`
|
||||
- `BACKEND_DIR` — `$PLATFORM_DIR/backend`
|
||||
- `FRONTEND_DIR` — `$PLATFORM_DIR/frontend`
|
||||
- `PR_NUMBER` — the PR number (from `gh pr list --head $(git branch --show-current)`)
|
||||
- `PR_TITLE` — the PR title, slugified (e.g. "Add copilot permissions" → "add-copilot-permissions")
|
||||
- `RESULTS_DIR` — `$REPO_ROOT/test-results/PR-{PR_NUMBER}-{slugified-title}`
|
||||
|
||||
Create the results directory:
|
||||
```bash
|
||||
PR_NUMBER=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')
|
||||
PR_TITLE=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json title --jq '.[0].title' | tr '[:upper:]' '[:lower:]' | sed 's/[^a-z0-9]/-/g' | sed 's/--*/-/g' | sed 's/^-//;s/-$//' | head -c 50)
|
||||
RESULTS_DIR="$REPO_ROOT/test-results/PR-${PR_NUMBER}-${PR_TITLE}"
|
||||
mkdir -p $RESULTS_DIR
|
||||
```
|
||||
|
||||
**Test user credentials** (for logging into the UI or verifying results manually):
|
||||
- Email: `test@test.com`
|
||||
- Password: `testtest123`
|
||||
|
||||
## Step 1: Understand the PR
|
||||
|
||||
Before testing, understand what changed:
|
||||
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
|
||||
# Read PR description to understand the WHY
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
|
||||
git log --oneline dev..HEAD | head -20
|
||||
git diff dev --stat
|
||||
```
|
||||
|
||||
Read the PR description (Why / What / How) and changed files to understand:
|
||||
0. **Why** does this PR exist? What problem does it solve?
|
||||
1. **What** feature/fix does this PR implement?
|
||||
2. **How** does it work? What's the approach?
|
||||
3. What components are affected? (backend, frontend, copilot, executor, etc.)
|
||||
4. What are the key user-facing behaviors to test?
|
||||
|
||||
## Step 2: Write test scenarios
|
||||
|
||||
Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
|
||||
|
||||
```markdown
|
||||
# Test Plan: PR #{N} — {title}
|
||||
|
||||
## Scenarios
|
||||
1. [Scenario name] — [what to verify]
|
||||
2. ...
|
||||
|
||||
## API Tests (if applicable)
|
||||
1. [Endpoint] — [expected behavior]
|
||||
- Before state: [what to check before]
|
||||
- After state: [what to verify changed]
|
||||
|
||||
## UI Tests (if applicable)
|
||||
1. [Page/component] — [interaction to test]
|
||||
- Screenshot before: [what to capture]
|
||||
- Screenshot after: [what to capture]
|
||||
|
||||
## Negative Tests (REQUIRED — at least one per feature)
|
||||
1. [What should NOT happen] — [how to trigger it]
|
||||
- Expected error: [what error message/code]
|
||||
- State unchanged: [what to verify did NOT change]
|
||||
```
|
||||
|
||||
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
|
||||
|
||||
## Step 3: Environment setup
|
||||
|
||||
### 3a. Copy .env files from the root worktree
|
||||
|
||||
The root worktree (`$REPO_ROOT`) has the canonical `.env` files with all API keys. Copy them to the target worktree:
|
||||
|
||||
```bash
|
||||
# CRITICAL: .env files are NOT checked into git. They must be copied manually.
|
||||
cp $REPO_ROOT/autogpt_platform/.env $PLATFORM_DIR/.env
|
||||
cp $REPO_ROOT/autogpt_platform/backend/.env $BACKEND_DIR/.env
|
||||
cp $REPO_ROOT/autogpt_platform/frontend/.env $FRONTEND_DIR/.env
|
||||
```
|
||||
|
||||
### 3b. Configure copilot authentication
|
||||
|
||||
The copilot needs an LLM API to function. Two approaches (try subscription first):
|
||||
|
||||
#### Option 1: Subscription mode (preferred — uses your Claude Max/Pro subscription)
|
||||
|
||||
The `claude_agent_sdk` Python package **bundles its own Claude CLI binary** — no need to install `@anthropic-ai/claude-code` via npm. The backend auto-provisions credentials from environment variables on startup.
|
||||
|
||||
Run the helper script to extract tokens from your host and auto-update `backend/.env` (works on macOS, Linux, and Windows/WSL):
|
||||
|
||||
```bash
|
||||
# Extracts OAuth tokens and writes CLAUDE_CODE_OAUTH_TOKEN + CLAUDE_CODE_REFRESH_TOKEN into .env
|
||||
bash $BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env
|
||||
```
|
||||
|
||||
**How it works:** The script reads the OAuth token from:
|
||||
- **macOS**: system keychain (`"Claude Code-credentials"`)
|
||||
- **Linux/WSL**: `~/.claude/.credentials.json`
|
||||
- **Windows**: `%APPDATA%/claude/.credentials.json`
|
||||
|
||||
It sets `CLAUDE_CODE_OAUTH_TOKEN`, `CLAUDE_CODE_REFRESH_TOKEN`, and `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` in the `.env` file. On container startup, the backend auto-provisions `~/.claude/.credentials.json` inside the container from these env vars. The SDK's bundled CLI then authenticates using that file. No `claude login`, no npm install needed.
|
||||
|
||||
**Note:** The OAuth token expires (~24h). If copilot returns auth errors, re-run the script and restart: `$BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env && docker compose up -d copilot_executor`
|
||||
|
||||
#### Option 2: OpenRouter API key mode (fallback)
|
||||
|
||||
If subscription mode doesn't work, switch to API key mode using OpenRouter:
|
||||
|
||||
```bash
|
||||
# In $BACKEND_DIR/.env, ensure these are set:
|
||||
CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false
|
||||
CHAT_API_KEY=<value of OPEN_ROUTER_API_KEY from the same .env>
|
||||
CHAT_BASE_URL=https://openrouter.ai/api/v1
|
||||
CHAT_USE_CLAUDE_AGENT_SDK=true
|
||||
```
|
||||
|
||||
Use `sed` to update these values:
|
||||
```bash
|
||||
ORKEY=$(grep "^OPEN_ROUTER_API_KEY=" $BACKEND_DIR/.env | cut -d= -f2)
|
||||
[ -n "$ORKEY" ] || { echo "ERROR: OPEN_ROUTER_API_KEY is missing in $BACKEND_DIR/.env"; exit 1; }
|
||||
perl -i -pe 's/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false/' $BACKEND_DIR/.env
|
||||
# Add or update CHAT_API_KEY and CHAT_BASE_URL
|
||||
grep -q "^CHAT_API_KEY=" $BACKEND_DIR/.env && perl -i -pe "s|^CHAT_API_KEY=.*|CHAT_API_KEY=$ORKEY|" $BACKEND_DIR/.env || echo "CHAT_API_KEY=$ORKEY" >> $BACKEND_DIR/.env
|
||||
grep -q "^CHAT_BASE_URL=" $BACKEND_DIR/.env && perl -i -pe 's|^CHAT_BASE_URL=.*|CHAT_BASE_URL=https://openrouter.ai/api/v1|' $BACKEND_DIR/.env || echo "CHAT_BASE_URL=https://openrouter.ai/api/v1" >> $BACKEND_DIR/.env
|
||||
```
|
||||
|
||||
### 3c. Stop conflicting containers
|
||||
|
||||
```bash
|
||||
# Stop any running app containers (keep infra: supabase, redis, rabbitmq, clamav)
|
||||
docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocket|database_manager|scheduler|notification|frontend|migrate" | while read name; do
|
||||
docker stop "$name" 2>/dev/null
|
||||
done
|
||||
```
|
||||
|
||||
### 3e. Build and start
|
||||
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
|
||||
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker build failed"; exit 1; fi
|
||||
|
||||
cd $PLATFORM_DIR && docker compose up -d 2>&1 | tail -20
|
||||
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker compose up failed"; exit 1; fi
|
||||
```
|
||||
|
||||
**Note:** If the container appears to be running old code (e.g. missing PR changes), use `docker compose build --no-cache` to force a full rebuild. Docker BuildKit may sometimes reuse cached `COPY` layers from a previous build on a different branch.
|
||||
|
||||
**Expected time: 3-8 minutes** for build, 5-10 minutes with `--no-cache`.
|
||||
|
||||
### 3f. Wait for services to be ready
|
||||
|
||||
```bash
|
||||
# Poll until backend and frontend respond
|
||||
for i in $(seq 1 60); do
|
||||
BACKEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8006/docs 2>/dev/null)
|
||||
FRONTEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:3000 2>/dev/null)
|
||||
if [ "$BACKEND" = "200" ] && [ "$FRONTEND" = "200" ]; then
|
||||
echo "Services ready"
|
||||
break
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
```
|
||||
|
||||
|
||||
### 3h. Create test user and get auth token
|
||||
|
||||
```bash
|
||||
ANON_KEY=$(grep "NEXT_PUBLIC_SUPABASE_ANON_KEY=" $FRONTEND_DIR/.env | sed 's/.*NEXT_PUBLIC_SUPABASE_ANON_KEY=//' | tr -d '[:space:]')
|
||||
|
||||
# Signup (idempotent — returns "User already registered" if exists)
|
||||
RESULT=$(curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}')
|
||||
|
||||
# If "Database error finding user", restart supabase-auth and retry
|
||||
if echo "$RESULT" | grep -q "Database error"; then
|
||||
docker restart supabase-auth && sleep 5
|
||||
curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}'
|
||||
fi
|
||||
|
||||
# Get auth token
|
||||
TOKEN=$(curl -s -X POST 'http://localhost:8000/auth/v1/token?grant_type=password' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}' | jq -r '.access_token // ""')
|
||||
```
|
||||
|
||||
**Use this token for ALL API calls:**
|
||||
```bash
|
||||
curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/...
|
||||
```
|
||||
|
||||
## Step 4: Run tests
|
||||
|
||||
### Service ports reference
|
||||
|
||||
| Service | Port | URL |
|
||||
|---------|------|-----|
|
||||
| Frontend | 3000 | http://localhost:3000 |
|
||||
| Backend REST | 8006 | http://localhost:8006 |
|
||||
| Supabase Auth (via Kong) | 8000 | http://localhost:8000 |
|
||||
| Executor | 8002 | http://localhost:8002 |
|
||||
| Copilot Executor | 8008 | http://localhost:8008 |
|
||||
| WebSocket | 8001 | http://localhost:8001 |
|
||||
| Database Manager | 8005 | http://localhost:8005 |
|
||||
| Redis | 6379 | localhost:6379 |
|
||||
| RabbitMQ | 5672 | localhost:5672 |
|
||||
|
||||
### API testing
|
||||
|
||||
Use `curl` with the auth token for backend API tests. **For EVERY API call that changes state, record before/after values:**
|
||||
|
||||
```bash
|
||||
# Example: List agents
|
||||
curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/graphs | jq . | head -20
|
||||
|
||||
# Example: Create an agent
|
||||
curl -s -X POST http://localhost:8006/api/graphs \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{...}' | jq .
|
||||
|
||||
# Example: Run an agent
|
||||
curl -s -X POST "http://localhost:8006/api/graphs/{graph_id}/execute" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"data": {...}}'
|
||||
|
||||
# Example: Get execution results
|
||||
curl -s -H "Authorization: Bearer $TOKEN" \
|
||||
"http://localhost:8006/api/graphs/{graph_id}/executions/{exec_id}" | jq .
|
||||
```
|
||||
|
||||
**State verification pattern (use for EVERY state-changing API call):**
|
||||
```bash
|
||||
# 1. Record BEFORE state
|
||||
BEFORE_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
|
||||
echo "BEFORE: $BEFORE_STATE"
|
||||
|
||||
# 2. Perform the action
|
||||
ACTION_RESULT=$(curl -s -X POST ... | jq .)
|
||||
echo "ACTION RESULT: $ACTION_RESULT"
|
||||
|
||||
# 3. Record AFTER state
|
||||
AFTER_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
|
||||
echo "AFTER: $AFTER_STATE"
|
||||
|
||||
# 4. Log the comparison
|
||||
echo "=== STATE CHANGE VERIFICATION ==="
|
||||
echo "Before: $BEFORE_STATE"
|
||||
echo "After: $AFTER_STATE"
|
||||
echo "Expected change: {describe what should have changed}"
|
||||
```
|
||||
|
||||
### Browser testing with agent-browser
|
||||
|
||||
```bash
|
||||
# Close any existing session
|
||||
agent-browser close 2>/dev/null || true
|
||||
|
||||
# Use --session-name to persist cookies across navigations
|
||||
# This means login only needs to happen once per test session
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/login' --timeout 15000
|
||||
|
||||
# Get interactive elements
|
||||
agent-browser --session-name pr-test snapshot | grep "textbox\|button"
|
||||
|
||||
# Login
|
||||
agent-browser --session-name pr-test fill {email_ref} "test@test.com"
|
||||
agent-browser --session-name pr-test fill {password_ref} "testtest123"
|
||||
agent-browser --session-name pr-test click {login_button_ref}
|
||||
sleep 5
|
||||
|
||||
# Dismiss cookie banner if present
|
||||
agent-browser --session-name pr-test click 'text=Accept All' 2>/dev/null || true
|
||||
|
||||
# Navigate — cookies are preserved so login persists
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
|
||||
|
||||
# Take screenshot
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/01-page.png
|
||||
|
||||
# Interact with elements
|
||||
agent-browser --session-name pr-test fill {ref} "text"
|
||||
agent-browser --session-name pr-test press "Enter"
|
||||
agent-browser --session-name pr-test click {ref}
|
||||
agent-browser --session-name pr-test click 'text=Button Text'
|
||||
|
||||
# Read page content
|
||||
agent-browser --session-name pr-test snapshot | grep "text:"
|
||||
```
|
||||
|
||||
**Key pages:**
|
||||
- `/copilot` — CoPilot chat (for testing copilot features)
|
||||
- `/build` — Agent builder (for testing block/node features)
|
||||
- `/build?flowID={id}` — Specific agent in builder
|
||||
- `/library` — Agent library (for testing listing/import features)
|
||||
- `/library/agents/{id}` — Agent detail with run history
|
||||
- `/marketplace` — Marketplace
|
||||
|
||||
### Checking logs
|
||||
|
||||
```bash
|
||||
# Backend REST server
|
||||
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
|
||||
|
||||
# Executor (runs agent graphs)
|
||||
docker logs autogpt_platform-executor-1 2>&1 | tail -30
|
||||
|
||||
# Copilot executor (runs copilot chat sessions)
|
||||
docker logs autogpt_platform-copilot_executor-1 2>&1 | tail -30
|
||||
|
||||
# Frontend
|
||||
docker logs autogpt_platform-frontend-1 2>&1 | tail -30
|
||||
|
||||
# Filter for errors
|
||||
docker logs autogpt_platform-executor-1 2>&1 | grep -i "error\|exception\|traceback" | tail -20
|
||||
```
|
||||
|
||||
### Copilot chat testing
|
||||
|
||||
The copilot uses SSE streaming. To test via API:
|
||||
|
||||
```bash
|
||||
# Create a session
|
||||
SESSION_ID=$(curl -s -X POST 'http://localhost:8006/api/chat/sessions' \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{}' | jq -r '.id // .session_id // ""')
|
||||
|
||||
# Stream a message (SSE - will stream chunks)
|
||||
curl -N -X POST "http://localhost:8006/api/chat/sessions/$SESSION_ID/stream" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"message": "Hello, what can you help me with?"}' \
|
||||
--max-time 60 2>/dev/null | head -50
|
||||
```
|
||||
|
||||
Or test via browser (preferred for UI verification):
|
||||
```bash
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
|
||||
# ... fill chat input and press Enter, wait 20-30s for response
|
||||
```
|
||||
|
||||
## Step 5: Record results and take screenshots
|
||||
|
||||
**Take a screenshot at EVERY significant test step** — before and after interactions, on success, and on failure. This is NON-NEGOTIABLE.
|
||||
|
||||
**Required screenshot pattern for each test scenario:**
|
||||
```bash
|
||||
# BEFORE the action
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-before.png
|
||||
|
||||
# Perform the action...
|
||||
|
||||
# AFTER the action
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-after.png
|
||||
```
|
||||
|
||||
**Naming convention:**
|
||||
```bash
|
||||
# Examples:
|
||||
# $RESULTS_DIR/01-login-page-before.png
|
||||
# $RESULTS_DIR/02-login-page-after.png
|
||||
# $RESULTS_DIR/03-credits-page-before.png
|
||||
# $RESULTS_DIR/04-credits-purchase-after.png
|
||||
# $RESULTS_DIR/05-negative-insufficient-credits.png
|
||||
# $RESULTS_DIR/06-error-state.png
|
||||
```
|
||||
|
||||
**Minimum requirements:**
|
||||
- At least TWO screenshots per test scenario (before + after)
|
||||
- At least ONE screenshot for each negative test case showing the error state
|
||||
- If a test fails, screenshot the failure state AND any error logs visible in the UI
|
||||
|
||||
## Step 6: Show results to user with screenshots
|
||||
|
||||
**CRITICAL: After all tests complete, you MUST show every screenshot to the user using the Read tool, with an explanation of what each screenshot shows.** This is the most important part of the test report — the user needs to visually verify the results.
|
||||
|
||||
For each screenshot:
|
||||
1. Use the `Read` tool to display the PNG file (Claude can read images)
|
||||
2. Write a 1-2 sentence explanation below it describing:
|
||||
- What page/state is being shown
|
||||
- What the screenshot proves (which test scenario it validates)
|
||||
- Any notable details visible in the UI
|
||||
|
||||
Format the output like this:
|
||||
|
||||
```markdown
|
||||
### Screenshot 1: {descriptive title}
|
||||
[Read the PNG file here]
|
||||
|
||||
**What it shows:** {1-2 sentence explanation of what this screenshot proves}
|
||||
|
||||
---
|
||||
```
|
||||
|
||||
After showing all screenshots, output a **detailed** summary table:
|
||||
|
||||
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|
||||
|---|----------|--------|-------------|-------------------|
|
||||
| 1 | {name} | PASS/FAIL | Before: X, After: Y | 01-before.png, 02-after.png |
|
||||
| 2 | ... | ... | ... | ... |
|
||||
|
||||
**IMPORTANT:** As you show each screenshot and record test results, persist them in shell variables for Step 7:
|
||||
|
||||
```bash
|
||||
# Build these variables during Step 6 — they are required by Step 7's script
|
||||
# NOTE: declare -A requires Bash 4.0+. This is standard on modern systems (macOS ships zsh
|
||||
# but Homebrew bash is 5.x; Linux typically has bash 5.x). If running on Bash <4, use a
|
||||
# plain variable with a lookup function instead.
|
||||
declare -A SCREENSHOT_EXPLANATIONS=(
|
||||
["01-login-page.png"]="Shows the login page loaded successfully with SSO options visible."
|
||||
["02-builder-with-block.png"]="The builder canvas displays the newly added block connected to the trigger."
|
||||
# ... one entry per screenshot, using the same explanations you showed the user above
|
||||
)
|
||||
|
||||
TEST_RESULTS_TABLE="| 1 | Login flow | PASS | N/A | 01-login-before.png, 02-login-after.png |
|
||||
| 2 | Credits purchase | PASS | Before: 100, After: 95 | 03-credits-before.png, 04-credits-after.png |
|
||||
| 3 | Insufficient credits (negative) | PASS | Credits: 0, rejected | 05-insufficient-credits-error.png |"
|
||||
# ... one row per test scenario with actual results
|
||||
```
|
||||
|
||||
## Step 7: Post test report as PR comment with screenshots
|
||||
|
||||
Upload screenshots to the PR using the GitHub Git API (no local git operations — safe for worktrees), then post a comment with inline images and per-screenshot explanations.
|
||||
|
||||
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
REPO="Significant-Gravitas/AutoGPT"
|
||||
SCREENSHOTS_BRANCH="test-screenshots/pr-${PR_NUMBER}"
|
||||
SCREENSHOTS_DIR="test-screenshots/PR-${PR_NUMBER}"
|
||||
|
||||
# Step 1: Create blobs for each screenshot and build tree JSON
|
||||
# Retry each blob upload up to 3 times. If still failing, list them at end of report.
|
||||
shopt -s nullglob
|
||||
SCREENSHOT_FILES=("$RESULTS_DIR"/*.png)
|
||||
if [ ${#SCREENSHOT_FILES[@]} -eq 0 ]; then
|
||||
echo "ERROR: No screenshots found in $RESULTS_DIR. Test run is incomplete."
|
||||
exit 1
|
||||
fi
|
||||
TREE_JSON='['
|
||||
FIRST=true
|
||||
FAILED_UPLOADS=()
|
||||
for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
BASENAME=$(basename "$img")
|
||||
B64=$(base64 < "$img")
|
||||
BLOB_SHA=""
|
||||
for attempt in 1 2 3; do
|
||||
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha' 2>/dev/null || true)
|
||||
[ -n "$BLOB_SHA" ] && break
|
||||
sleep 1
|
||||
done
|
||||
if [ -z "$BLOB_SHA" ]; then
|
||||
FAILED_UPLOADS+=("$img")
|
||||
continue
|
||||
fi
|
||||
if [ "$FIRST" = true ]; then FIRST=false; else TREE_JSON+=','; fi
|
||||
TREE_JSON+="{\"path\":\"${SCREENSHOTS_DIR}/${BASENAME}\",\"mode\":\"100644\",\"type\":\"blob\",\"sha\":\"${BLOB_SHA}\"}"
|
||||
done
|
||||
TREE_JSON+=']'
|
||||
|
||||
# Step 2: Create tree, commit, and branch ref
|
||||
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
gh api "repos/${REPO}/git/refs" \
|
||||
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-f sha="$COMMIT_SHA" 2>/dev/null \
|
||||
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-X PATCH -f sha="$COMMIT_SHA" -f force=true
|
||||
```
|
||||
|
||||
Then post the comment with **inline images AND explanations for each screenshot**:
|
||||
|
||||
```bash
|
||||
REPO_URL="https://raw.githubusercontent.com/${REPO}/${SCREENSHOTS_BRANCH}"
|
||||
|
||||
# Build image markdown using uploaded image URLs; skip FAILED_UPLOADS (listed separately)
|
||||
|
||||
IMAGE_MARKDOWN=""
|
||||
for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
BASENAME=$(basename "$img")
|
||||
TITLE=$(echo "${BASENAME%.png}" | sed 's/^[0-9]*-//' | sed 's/-/ /g' | awk '{for(i=1;i<=NF;i++) $i=toupper(substr($i,1,1)) tolower(substr($i,2))}1')
|
||||
# Skip images that failed to upload — they will be listed at the end
|
||||
IS_FAILED=false
|
||||
for failed in "${FAILED_UPLOADS[@]}"; do
|
||||
[ "$(basename "$failed")" = "$BASENAME" ] && IS_FAILED=true && break
|
||||
done
|
||||
if [ "$IS_FAILED" = true ]; then
|
||||
continue
|
||||
fi
|
||||
EXPLANATION="${SCREENSHOT_EXPLANATIONS[$BASENAME]}"
|
||||
if [ -z "$EXPLANATION" ]; then
|
||||
echo "ERROR: Missing screenshot explanation for $BASENAME. Add it to SCREENSHOT_EXPLANATIONS in Step 6."
|
||||
exit 1
|
||||
fi
|
||||
IMAGE_MARKDOWN="${IMAGE_MARKDOWN}
|
||||
### ${TITLE}
|
||||

|
||||
${EXPLANATION}
|
||||
"
|
||||
done
|
||||
|
||||
# Write comment body to file to avoid shell interpretation issues with special characters
|
||||
COMMENT_FILE=$(mktemp)
|
||||
# If any uploads failed, append a section listing them with instructions
|
||||
FAILED_SECTION=""
|
||||
if [ ${#FAILED_UPLOADS[@]} -gt 0 ]; then
|
||||
FAILED_SECTION="
|
||||
## ⚠️ Failed Screenshot Uploads
|
||||
The following screenshots could not be uploaded via the GitHub API after 3 retries.
|
||||
**To add them:** drag-and-drop or paste these files into a PR comment manually:
|
||||
"
|
||||
for failed in "${FAILED_UPLOADS[@]}"; do
|
||||
FAILED_SECTION="${FAILED_SECTION}
|
||||
- \`$(basename "$failed")\` (local path: \`$failed\`)"
|
||||
done
|
||||
FAILED_SECTION="${FAILED_SECTION}
|
||||
|
||||
**Run status:** INCOMPLETE until the files above are manually attached and visible inline in the PR."
|
||||
fi
|
||||
|
||||
cat > "$COMMENT_FILE" <<INNEREOF
|
||||
## E2E Test Report
|
||||
|
||||
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|
||||
|---|----------|--------|-------------|-------------------|
|
||||
${TEST_RESULTS_TABLE}
|
||||
|
||||
${IMAGE_MARKDOWN}
|
||||
${FAILED_SECTION}
|
||||
INNEREOF
|
||||
|
||||
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
|
||||
rm -f "$COMMENT_FILE"
|
||||
```
|
||||
|
||||
**The PR comment MUST include:**
|
||||
1. A summary table of all scenarios with PASS/FAIL and before/after API evidence
|
||||
2. Every successfully uploaded screenshot rendered inline; any failed uploads listed with manual attachment instructions
|
||||
3. A 1-2 sentence explanation below each screenshot describing what it proves
|
||||
|
||||
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
|
||||
|
||||
## Fix mode (--fix flag)
|
||||
|
||||
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.
|
||||
|
||||
### Fix protocol for EVERY issue found (including UX issues):
|
||||
|
||||
1. **Identify** the root cause in the code — read the relevant source files
|
||||
2. **Write a failing test first** (TDD): For backend bugs, write a test marked with `pytest.mark.xfail(reason="...")`. For frontend/Playwright bugs, write a test with `.fixme` annotation. Run it to confirm it fails as expected.
|
||||
3. **Screenshot** the broken state: `agent-browser screenshot $RESULTS_DIR/{NN}-broken-{description}.png`
|
||||
4. **Fix** the code in the worktree
|
||||
5. **Rebuild** ONLY the affected service (not the whole stack):
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose up --build -d {service_name}
|
||||
# e.g., docker compose up --build -d rest_server
|
||||
# e.g., docker compose up --build -d frontend
|
||||
```
|
||||
6. **Wait** for the service to be ready (poll health endpoint)
|
||||
7. **Re-test** the same scenario
|
||||
8. **Screenshot** the fixed state: `agent-browser screenshot $RESULTS_DIR/{NN}-fixed-{description}.png`
|
||||
9. **Remove the xfail/fixme marker** from the test written in step 2, and verify it passes
|
||||
10. **Verify** the fix did not break other scenarios (run a quick smoke test)
|
||||
11. **Commit and push** immediately:
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
git add -A
|
||||
git commit -m "fix: {description of fix}"
|
||||
git push
|
||||
```
|
||||
12. **Continue** to the next test scenario
|
||||
|
||||
### Fix loop (like pr-address)
|
||||
|
||||
```text
|
||||
test scenario → find issue (bug OR UX problem) → screenshot broken state
|
||||
→ fix code → rebuild affected service only → re-test → screenshot fixed state
|
||||
→ verify no regressions → commit + push
|
||||
→ repeat for next scenario
|
||||
→ after ALL scenarios pass, run full re-test to verify everything together
|
||||
```
|
||||
|
||||
**Key differences from non-fix mode:**
|
||||
- UX issues count as bugs — fix them (bad alignment, confusing labels, missing loading states)
|
||||
- Every fix MUST have a before/after screenshot pair proving it works
|
||||
- Commit after EACH fix, not in a batch at the end
|
||||
- The final re-test must produce a clean set of all-passing screenshots
|
||||
|
||||
## Known issues and workarounds
|
||||
|
||||
### Problem: "Database error finding user" on signup
|
||||
**Cause:** Supabase auth service schema cache is stale after migration.
|
||||
**Fix:** `docker restart supabase-auth && sleep 5` then retry signup.
|
||||
|
||||
### Problem: Copilot returns auth errors in subscription mode
|
||||
**Cause:** `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` but `CLAUDE_CODE_OAUTH_TOKEN` is not set or expired.
|
||||
**Fix:** Re-extract the OAuth token from macOS keychain (see step 3b, Option 1) and recreate the container (`docker compose up -d copilot_executor`). The backend auto-provisions `~/.claude/.credentials.json` from the env var on startup. No `npm install` or `claude login` needed — the SDK bundles its own CLI binary.
|
||||
|
||||
### Problem: agent-browser can't find chromium
|
||||
**Cause:** The Dockerfile auto-provisions system chromium on all architectures (including ARM64). If your branch is behind `dev`, this may not be present yet.
|
||||
**Fix:** Check if chromium exists: `which chromium || which chromium-browser`. If missing, install it: `apt-get install -y chromium` and set `AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium` in the container environment.
|
||||
|
||||
### Problem: agent-browser selector matches multiple elements
|
||||
**Cause:** `text=X` matches all elements containing that text.
|
||||
**Fix:** Use `agent-browser snapshot` to get specific `ref=eNN` references, then use those: `agent-browser click eNN`.
|
||||
|
||||
### Problem: Frontend shows cookie banner blocking interaction
|
||||
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
|
||||
|
||||
### Problem: Container loses npm packages after rebuild
|
||||
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
|
||||
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
|
||||
|
||||
### Problem: Services not starting after `docker compose up`
|
||||
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.
|
||||
|
||||
### Problem: Docker uses cached layers with old code (PR changes not visible)
|
||||
**Cause:** `docker compose up --build` reuses cached `COPY` layers from previous builds. If the PR branch changes Python files but the previous build already cached that layer from `dev`, the container runs `dev` code.
|
||||
**Fix:** Always use `docker compose build --no-cache` for the first build of a PR branch. Subsequent rebuilds within the same branch can use `--build`.
|
||||
|
||||
### Problem: `agent-browser open` loses login session
|
||||
**Cause:** Without session persistence, `agent-browser open` starts fresh.
|
||||
**Fix:** Use `--session-name pr-test` on ALL agent-browser commands. This auto-saves/restores cookies and localStorage across navigations. Alternatively, use `agent-browser eval "window.location.href = '...'"` to navigate within the same context.
|
||||
|
||||
### Problem: Supabase auth returns "Database error querying schema"
|
||||
**Cause:** The database schema changed (migration ran) but supabase-auth has a stale schema cache.
|
||||
**Fix:** `docker restart supabase-db && sleep 10 && docker restart supabase-auth && sleep 8`. If user data was lost, re-signup.
|
||||
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,8 +1,12 @@
|
||||
<!-- Clearly explain the need for these changes: -->
|
||||
### Why / What / How
|
||||
|
||||
<!-- Why: Why does this PR exist? What problem does it solve, or what's broken/missing without it? -->
|
||||
<!-- What: What does this PR change? Summarize the changes at a high level. -->
|
||||
<!-- How: How does it work? Describe the approach, key implementation details, or architecture decisions. -->
|
||||
|
||||
### Changes 🏗️
|
||||
|
||||
<!-- Concisely describe all of the changes made in this pull request: -->
|
||||
<!-- List the key changes. Keep it higher level than the diff but specific enough to highlight what's new/modified. -->
|
||||
|
||||
### Checklist 📋
|
||||
|
||||
|
||||
114
.github/workflows/platform-backend-ci.yml
vendored
114
.github/workflows/platform-backend-ci.yml
vendored
@@ -27,10 +27,91 @@ defaults:
|
||||
working-directory: autogpt_platform/backend
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 10
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-py3.12-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run Linters
|
||||
run: poetry run lint --skip-pyright
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
type-check:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- name: Run Pyright
|
||||
run: poetry run pyright --pythonversion ${{ matrix.python-version }}
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 15
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -98,9 +179,9 @@ jobs:
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
key: poetry-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
@@ -158,22 +239,22 @@ jobs:
|
||||
echo "Waiting for ClamAV daemon to start..."
|
||||
max_attempts=60
|
||||
attempt=0
|
||||
|
||||
|
||||
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
|
||||
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
|
||||
sleep 5
|
||||
attempt=$((attempt+1))
|
||||
done
|
||||
|
||||
|
||||
if [ $attempt -eq $max_attempts ]; then
|
||||
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
|
||||
echo "Checking ClamAV service logs..."
|
||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "ClamAV is ready!"
|
||||
|
||||
|
||||
# Verify ClamAV is responsive
|
||||
echo "Testing ClamAV connection..."
|
||||
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
|
||||
@@ -188,18 +269,13 @@ jobs:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
run: poetry run lint
|
||||
|
||||
- name: Run pytest with coverage
|
||||
- name: Run pytest
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
else
|
||||
poetry run pytest -s -vv
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
@@ -211,6 +287,12 @@ jobs:
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
@@ -224,9 +306,3 @@ jobs:
|
||||
# the backend service, docker composes, and examples
|
||||
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
|
||||
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
|
||||
4
.github/workflows/platform-fullstack-ci.yml
vendored
4
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -294,7 +294,7 @@ jobs:
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
path: autogpt_platform/frontend/playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
@@ -303,7 +303,7 @@ jobs:
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
path: autogpt_platform/frontend/test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
|
||||
@@ -55,16 +55,37 @@ AutoGPT Platform is a monorepo containing:
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` — inline review comments
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
54
autogpt_platform/autogpt_libs/poetry.lock
generated
54
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
@@ -67,7 +67,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
|
||||
optional = false
|
||||
python-versions = "<3.11,>=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||
@@ -541,7 +541,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
|
||||
{file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"},
|
||||
@@ -2181,14 +2181,14 @@ testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "7.0.0"
|
||||
version = "7.1.0"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"},
|
||||
{file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"},
|
||||
{file = "pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678"},
|
||||
{file = "pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2342,30 +2342,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.15.0"
|
||||
version = "0.15.7"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.15.0-py3-none-linux_armv6l.whl", hash = "sha256:aac4ebaa612a82b23d45964586f24ae9bc23ca101919f5590bdb368d74ad5455"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:dcd4be7cc75cfbbca24a98d04d0b9b36a270d0833241f776b788d59f4142b14d"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d747e3319b2bce179c7c1eaad3d884dc0a199b5f4d5187620530adf9105268ce"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:650bd9c56ae03102c51a5e4b554d74d825ff3abe4db22b90fd32d816c2e90621"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a6664b7eac559e3048223a2da77769c2f92b43a6dfd4720cef42654299a599c9"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f811f97b0f092b35320d1556f3353bf238763420ade5d9e62ebd2b73f2ff179"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:761ec0a66680fab6454236635a39abaf14198818c8cdf691e036f4bc0f406b2d"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:940f11c2604d317e797b289f4f9f3fa5555ffe4fb574b55ed006c3d9b6f0eb78"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbca3d40558789126da91d7ef9a7c87772ee107033db7191edefa34e2c7f1b4"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9a121a96db1d75fa3eb39c4539e607f628920dd72ff1f7c5ee4f1b768ac62d6e"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5298d518e493061f2eabd4abd067c7e4fb89e2f63291c94332e35631c07c3662"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:afb6e603d6375ff0d6b0cee563fa21ab570fd15e65c852cb24922cef25050cf1"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:77e515f6b15f828b94dc17d2b4ace334c9ddb7d9468c54b2f9ed2b9c1593ef16"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6f6e80850a01eb13b3e42ee0ebdf6e4497151b48c35051aab51c101266d187a3"},
|
||||
{file = "ruff-0.15.0-py3-none-win32.whl", hash = "sha256:238a717ef803e501b6d51e0bdd0d2c6e8513fe9eec14002445134d3907cd46c3"},
|
||||
{file = "ruff-0.15.0-py3-none-win_amd64.whl", hash = "sha256:dd5e4d3301dc01de614da3cdffc33d4b1b96fb89e45721f1598e5532ccf78b18"},
|
||||
{file = "ruff-0.15.0-py3-none-win_arm64.whl", hash = "sha256:c480d632cc0ca3f0727acac8b7d053542d9e114a462a145d0b00e7cd658c515a"},
|
||||
{file = "ruff-0.15.0.tar.gz", hash = "sha256:6bdea47cdbea30d40f8f8d7d69c0854ba7c15420ec75a26f463290949d7f7e9a"},
|
||||
{file = "ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580"},
|
||||
{file = "ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de"},
|
||||
{file = "ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1"},
|
||||
{file = "ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2"},
|
||||
{file = "ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2564,7 +2564,7 @@ description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
||||
@@ -2912,4 +2912,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb"
|
||||
content-hash = "e0936a065565550afed18f6298b7e04e814b44100def7049f1a0d68662624a39"
|
||||
|
||||
@@ -26,8 +26,8 @@ pyright = "^1.1.408"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-cov = "^7.0.0"
|
||||
ruff = "^0.15.0"
|
||||
pytest-cov = "^7.1.0"
|
||||
ruff = "^0.15.7"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
## ===== SIGNUP / INVITE GATE ===== ##
|
||||
# Set to true to require an invite before users can sign up
|
||||
ENABLE_INVITE_GATE=false
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
|
||||
@@ -66,7 +66,7 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **Lazy `%s` logging** — `logger.info("Processing %s items", count)` not `logger.info(f"Processing {count} items")`
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
@@ -75,6 +75,7 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
@@ -84,6 +85,30 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
@@ -50,7 +50,7 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/scripts/gen_prisma_types_stub.py ./scripts/
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# =============================== DB MIGRATOR =============================== #
|
||||
@@ -82,7 +82,7 @@ RUN pip3 install prisma>=0.15.0 --break-system-packages
|
||||
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/scripts/gen_prisma_types_stub.py ./scripts/
|
||||
COPY autogpt_platform/backend/migrations ./migrations
|
||||
|
||||
# ============================== BACKEND SERVER ============================== #
|
||||
@@ -121,19 +121,21 @@ RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1 \
|
||||
# Install agent-browser (Copilot browser tool) using the system chromium package.
|
||||
# Chrome for Testing (the binary agent-browser downloads via `agent-browser install`)
|
||||
# has no ARM64 builds, so we use the distro-packaged chromium instead — verified to
|
||||
# work with agent-browser via Docker tests on arm64; amd64 is validated in CI.
|
||||
# Note: system chromium tracks the Debian package schedule rather than a pinned
|
||||
# Chrome for Testing release. If agent-browser requires a specific Chrome version,
|
||||
# verify compatibility against the chromium package version in the base image.
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends chromium fonts-liberation \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& agent-browser install \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
ENV AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
|
||||
@@ -1,17 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import prisma.enums
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
@@ -23,70 +14,3 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class CreateInvitedUserRequest(BaseModel):
|
||||
email: EmailStr
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class InvitedUserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
|
||||
return cls.model_validate(record.model_dump())
|
||||
|
||||
|
||||
class InvitedUsersResponse(BaseModel):
|
||||
invited_users: list[InvitedUserResponse]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class BulkInvitedUserRowResponse(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserResponse] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResponse(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResponse]
|
||||
|
||||
@classmethod
|
||||
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
||||
return cls(
|
||||
created_count=result.created_count,
|
||||
skipped_count=result.skipped_count,
|
||||
error_count=result.error_count,
|
||||
results=[
|
||||
BulkInvitedUserRowResponse(
|
||||
row_number=row.row_number,
|
||||
email=row.email,
|
||||
name=row.name,
|
||||
status=row.status,
|
||||
message=row.message,
|
||||
invited_user=(
|
||||
InvitedUserResponse.from_record(row.invited_user)
|
||||
if row.invited_user is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for row in result.results
|
||||
],
|
||||
)
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.graph import get_graph_as_admin
|
||||
|
||||
# Shared constants
|
||||
ADMIN_USER_ID = "admin-user-id"
|
||||
CREATOR_USER_ID = "other-creator-id"
|
||||
GRAPH_ID = "test-graph-id"
|
||||
GRAPH_VERSION = 3
|
||||
|
||||
|
||||
def _make_mock_graph(user_id: str = CREATOR_USER_ID) -> MagicMock:
|
||||
graph = MagicMock()
|
||||
graph.userId = user_id
|
||||
graph.id = GRAPH_ID
|
||||
graph.version = GRAPH_VERSION
|
||||
graph.Nodes = []
|
||||
return graph
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_access_pending_agent_not_owned() -> None:
|
||||
"""Admin must be able to access a graph they don't own even if it's not
|
||||
APPROVED in the marketplace. This is the core use case: reviewing a
|
||||
submitted-but-pending agent from the admin dashboard."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.graph.AgentGraph.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=False,
|
||||
)
|
||||
|
||||
assert (
|
||||
result is not None
|
||||
), "Admin should be able to access a pending agent they don't own"
|
||||
assert result is mock_graph_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_download_pending_agent_with_subagents() -> None:
|
||||
"""Admin export (for_export=True) of a pending agent must include
|
||||
sub-graphs. This exercises the full export code path that the Download
|
||||
button uses."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_sub_graph = MagicMock(name="SubGraph")
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.graph.AgentGraph.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.get_sub_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_sub_graph],
|
||||
) as mock_get_sub,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
assert result is not None, "Admin export of pending agent must succeed"
|
||||
mock_get_sub.assert_awaited_once_with(mock_graph)
|
||||
mock_from_db.assert_called_once_with(
|
||||
graph=mock_graph,
|
||||
sub_graphs=[mock_sub_graph],
|
||||
for_export=True,
|
||||
)
|
||||
@@ -1,137 +0,0 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, File, Query, Security, UploadFile
|
||||
|
||||
from backend.data.invited_user import (
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
list_invited_users,
|
||||
retry_invited_user_tally,
|
||||
revoke_invited_user,
|
||||
)
|
||||
from backend.data.tally import mask_email
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
BulkInvitedUsersResponse,
|
||||
CreateInvitedUserRequest,
|
||||
InvitedUserResponse,
|
||||
InvitedUsersResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["users", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/invited-users",
|
||||
response_model=InvitedUsersResponse,
|
||||
summary="List Invited Users",
|
||||
)
|
||||
async def get_invited_users(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
) -> InvitedUsersResponse:
|
||||
logger.info("Admin user %s requested invited users", admin_user_id)
|
||||
invited_users, total = await list_invited_users(page=page, page_size=page_size)
|
||||
return InvitedUsersResponse(
|
||||
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=max(1, math.ceil(total / page_size)),
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Create Invited User",
|
||||
)
|
||||
async def create_invited_user_route(
|
||||
request: CreateInvitedUserRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s creating invited user for %s",
|
||||
admin_user_id,
|
||||
mask_email(request.email),
|
||||
)
|
||||
invited_user = await create_invited_user(request.email, request.name)
|
||||
logger.info(
|
||||
"Admin user %s created invited user %s",
|
||||
admin_user_id,
|
||||
invited_user.id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/bulk",
|
||||
response_model=BulkInvitedUsersResponse,
|
||||
summary="Bulk Create Invited Users",
|
||||
operation_id="postV2BulkCreateInvitedUsers",
|
||||
)
|
||||
async def bulk_create_invited_users_route(
|
||||
file: UploadFile = File(...),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> BulkInvitedUsersResponse:
|
||||
logger.info(
|
||||
"Admin user %s bulk invited users from %s",
|
||||
admin_user_id,
|
||||
file.filename or "<unnamed>",
|
||||
)
|
||||
content = await file.read()
|
||||
result = await bulk_create_invited_users_from_file(file.filename, content)
|
||||
return BulkInvitedUsersResponse.from_result(result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/revoke",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Revoke Invited User",
|
||||
)
|
||||
async def revoke_invited_user_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
|
||||
)
|
||||
invited_user = await revoke_invited_user(invited_user_id)
|
||||
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/retry-tally",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Retry Invited User Tally",
|
||||
)
|
||||
async def retry_invited_user_tally_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s retrying Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
invited_user = await retry_invited_user_tally(invited_user_id)
|
||||
logger.info(
|
||||
"Admin user %s retried Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
@@ -1,168 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUserRowResult,
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
)
|
||||
|
||||
from .user_admin_routes import router as user_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(user_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _sample_invited_user() -> InvitedUserRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return InvitedUserRecord(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||
auth_user_id=None,
|
||||
name="Invited User",
|
||||
tally_understanding=None,
|
||||
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tally_computed_at=None,
|
||||
tally_error=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=1,
|
||||
skipped_count=1,
|
||||
error_count=0,
|
||||
results=[
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=1,
|
||||
email="invited@example.com",
|
||||
name=None,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=_sample_invited_user(),
|
||||
),
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=2,
|
||||
email="duplicate@example.com",
|
||||
name=None,
|
||||
status="SKIPPED",
|
||||
message="An invited user with this email already exists",
|
||||
invited_user=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_get_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
||||
AsyncMock(return_value=([_sample_invited_user()], 1)),
|
||||
)
|
||||
|
||||
response = client.get("/admin/invited-users")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["invited_users"]) == 1
|
||||
assert data["invited_users"][0]["email"] == "invited@example.com"
|
||||
assert data["invited_users"][0]["status"] == "INVITED"
|
||||
assert data["pagination"]["total_items"] == 1
|
||||
assert data["pagination"]["current_page"] == 1
|
||||
assert data["pagination"]["page_size"] == 50
|
||||
|
||||
|
||||
def test_create_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
||||
AsyncMock(return_value=_sample_invited_user()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users",
|
||||
json={"email": "invited@example.com", "name": "Invited User"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "invited@example.com"
|
||||
assert data["name"] == "Invited User"
|
||||
|
||||
|
||||
def test_bulk_create_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
||||
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users/bulk",
|
||||
files={
|
||||
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 1
|
||||
assert data["skipped_count"] == 1
|
||||
assert data["results"][0]["status"] == "CREATED"
|
||||
assert data["results"][1]["status"] == "SKIPPED"
|
||||
|
||||
|
||||
def test_revoke_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
revoked = _sample_invited_user().model_copy(
|
||||
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
||||
AsyncMock(return_value=revoked),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/revoke")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "REVOKED"
|
||||
|
||||
|
||||
def test_retry_invited_user_tally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
retried = _sample_invited_user().model_copy(
|
||||
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
||||
AsyncMock(return_value=retried),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tally_status"] == "RUNNING"
|
||||
@@ -4,14 +4,12 @@ from difflib import SequenceMatcher
|
||||
from typing import Any, Sequence, get_args, get_origin
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from prisma.models import mv_suggested_blocks
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks._base import (
|
||||
AnyBlockSchema,
|
||||
@@ -24,6 +22,7 @@ from backend.blocks.llm import LlmModel
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.text import split_camelcase
|
||||
|
||||
from .model import (
|
||||
BlockCategoryResponse,
|
||||
@@ -271,7 +270,7 @@ async def _build_cached_search_results(
|
||||
|
||||
# Use hybrid search when query is present, otherwise list all blocks
|
||||
if (include_blocks or include_integrations) and normalized_query:
|
||||
block_results, block_total, integration_total = await _hybrid_search_blocks(
|
||||
block_results, block_total, integration_total = await _text_search_blocks(
|
||||
query=search_query,
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
@@ -383,117 +382,75 @@ def _collect_block_results(
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
async def _hybrid_search_blocks(
|
||||
async def _text_search_blocks(
|
||||
*,
|
||||
query: str,
|
||||
include_blocks: bool,
|
||||
include_integrations: bool,
|
||||
) -> tuple[list[_ScoredItem], int, int]:
|
||||
"""
|
||||
Search blocks using hybrid search with builder-specific filtering.
|
||||
Search blocks using in-memory text matching over the block registry.
|
||||
|
||||
Uses unified_hybrid_search for semantic + lexical search, then applies
|
||||
post-filtering for block/integration types and scoring adjustments.
|
||||
All blocks are already loaded in memory, so this is fast and reliable
|
||||
regardless of whether OpenAI embeddings are available.
|
||||
|
||||
Scoring:
|
||||
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
|
||||
- Base: text relevance via _score_primary_fields, plus BLOCK_SCORE_BOOST
|
||||
to prioritize blocks over marketplace agents in combined results
|
||||
- +30 for exact name match, +15 for prefix name match
|
||||
- +20 if the block has an LlmModel field and the query matches an LLM model name
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
include_blocks: Whether to include regular blocks
|
||||
include_integrations: Whether to include integration blocks
|
||||
|
||||
Returns:
|
||||
Tuple of (scored_items, block_count, integration_count)
|
||||
"""
|
||||
results: list[_ScoredItem] = []
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
|
||||
if not include_blocks and not include_integrations:
|
||||
return results, block_count, integration_count
|
||||
return results, 0, 0
|
||||
|
||||
normalized_query = query.strip().lower()
|
||||
|
||||
# Fetch more results to account for post-filtering
|
||||
search_results, _ = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=150,
|
||||
min_score=0.10,
|
||||
all_results, _, _ = _collect_block_results(
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
|
||||
# Load all blocks for getting BlockInfo
|
||||
all_blocks = load_all_blocks()
|
||||
|
||||
for result in search_results:
|
||||
block_id = result["content_id"]
|
||||
for item in all_results:
|
||||
block_info = item.item
|
||||
assert isinstance(block_info, BlockInfo)
|
||||
name = split_camelcase(block_info.name).lower()
|
||||
|
||||
# Skip excluded blocks
|
||||
if block_id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
# Build rich description including input field descriptions,
|
||||
# matching the searchable text that the embedding pipeline uses
|
||||
desc_parts = [block_info.description or ""]
|
||||
block_cls = all_blocks.get(block_info.id)
|
||||
if block_cls is not None:
|
||||
block: AnyBlockSchema = block_cls()
|
||||
desc_parts += [
|
||||
f"{f}: {info.description}"
|
||||
for f, info in block.input_schema.model_fields.items()
|
||||
if info.description
|
||||
]
|
||||
description = " ".join(desc_parts).lower()
|
||||
|
||||
metadata = result.get("metadata", {})
|
||||
hybrid_score = result.get("relevance", 0.0)
|
||||
|
||||
# Get the actual block class
|
||||
if block_id not in all_blocks:
|
||||
continue
|
||||
|
||||
block_cls = all_blocks[block_id]
|
||||
block: AnyBlockSchema = block_cls()
|
||||
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
# Check block/integration filter using metadata
|
||||
is_integration = metadata.get("is_integration", False)
|
||||
|
||||
if is_integration and not include_integrations:
|
||||
continue
|
||||
if not is_integration and not include_blocks:
|
||||
continue
|
||||
|
||||
# Get block info
|
||||
block_info = block.get_info()
|
||||
|
||||
# Calculate final score: scale hybrid score and add builder-specific bonuses
|
||||
# Hybrid scores are 0-1, builder scores were 0-200+
|
||||
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
|
||||
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
|
||||
score = _score_primary_fields(name, description, normalized_query)
|
||||
|
||||
# Add LLM model match bonus
|
||||
has_llm_field = metadata.get("has_llm_model_field", False)
|
||||
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
|
||||
final_score += 20
|
||||
if block_cls is not None and _matches_llm_model(
|
||||
block_cls().input_schema, normalized_query
|
||||
):
|
||||
score += 20
|
||||
|
||||
# Add exact/prefix match bonus for deterministic tie-breaking
|
||||
name = block_info.name.lower()
|
||||
if name == normalized_query:
|
||||
final_score += 30
|
||||
elif name.startswith(normalized_query):
|
||||
final_score += 15
|
||||
|
||||
# Track counts
|
||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||
if is_integration:
|
||||
integration_count += 1
|
||||
else:
|
||||
block_count += 1
|
||||
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=filter_type,
|
||||
score=final_score,
|
||||
sort_key=name,
|
||||
if score >= MIN_SCORE_FOR_FILTERED_RESULTS:
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=item.filter_type,
|
||||
score=score + BLOCK_SCORE_BOOST,
|
||||
sort_key=name,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
block_count = sum(1 for r in results if r.filter_type == "blocks")
|
||||
integration_count = sum(1 for r in results if r.filter_type == "integrations")
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
|
||||
@@ -60,7 +60,6 @@ from backend.copilot.tools.models import (
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -895,36 +894,6 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Suggested Prompts ==========
|
||||
|
||||
|
||||
class SuggestedPromptsResponse(BaseModel):
|
||||
"""Response model for user-specific suggested prompts."""
|
||||
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/suggested-prompts",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def get_suggested_prompts(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SuggestedPromptsResponse:
|
||||
"""
|
||||
Get LLM-generated suggested prompts for the authenticated user.
|
||||
|
||||
Returns personalized quick-action prompts based on the user's
|
||||
business understanding. Returns an empty list if no custom prompts
|
||||
are available.
|
||||
"""
|
||||
understanding = await get_business_understanding(user_id)
|
||||
if understanding is None:
|
||||
return SuggestedPromptsResponse(prompts=[])
|
||||
|
||||
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, rate limiting, and suggested prompts."""
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -400,62 +400,3 @@ def test_usage_rejects_unauthenticated_request() -> None:
|
||||
response = unauthenticated_client.get("/usage")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_get_business_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
return_value=None,
|
||||
):
|
||||
"""Mock get_business_understanding."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_suggested_prompts_returns_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding and prompts gets them back."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
|
||||
|
||||
|
||||
def test_suggested_prompts_no_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with no understanding gets empty list."""
|
||||
_mock_get_business_understanding(mocker, return_value=None)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
|
||||
def test_suggested_prompts_empty_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding but no prompts gets empty list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = []
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
@@ -9,7 +9,7 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data.db import query_raw_with_schema, transaction
|
||||
from backend.data.graph import (
|
||||
GraphModel,
|
||||
GraphModelWithoutNodes,
|
||||
@@ -104,7 +104,8 @@ async def get_store_agents(
|
||||
# search_used_hybrid remains False, will use fallback path below
|
||||
|
||||
# Convert hybrid search results (dict format) if hybrid succeeded
|
||||
if search_used_hybrid:
|
||||
# Fall through to direct DB search if hybrid returned nothing
|
||||
if search_used_hybrid and agents:
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
@@ -130,52 +131,20 @@ async def get_store_agents(
|
||||
)
|
||||
continue
|
||||
|
||||
if not search_used_hybrid:
|
||||
# Fallback path - use basic search or no search
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
# Add basic text search if search_query provided but hybrid failed
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == StoreAgentsSortOptions.RATING:
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.RUNS:
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.NAME:
|
||||
order_by.append({"agent_name": "asc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.UPDATED_AT:
|
||||
order_by.append({"updated_at": "desc"})
|
||||
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
if not search_used_hybrid or not agents:
|
||||
# Fallback path: direct DB query with optional tsvector search.
|
||||
# This mirrors the original pre-hybrid-search implementation.
|
||||
store_agents, total = await _fallback_store_agent_search(
|
||||
search_query=search_query,
|
||||
featured=featured,
|
||||
creators=creators,
|
||||
category=category,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in db_agents:
|
||||
try:
|
||||
store_agents.append(store_model.StoreAgent.from_db(agent))
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing StoreAgent from db: {e}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(store_agents)} agents")
|
||||
return store_model.StoreAgentsResponse(
|
||||
agents=store_agents,
|
||||
@@ -195,6 +164,126 @@ async def get_store_agents(
|
||||
# await log_search_term(search_query=search_term)
|
||||
|
||||
|
||||
async def _fallback_store_agent_search(
|
||||
*,
|
||||
search_query: str | None,
|
||||
featured: bool,
|
||||
creators: list[str] | None,
|
||||
category: str | None,
|
||||
sorted_by: StoreAgentsSortOptions | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> tuple[list[store_model.StoreAgent], int]:
|
||||
"""Direct DB search fallback when hybrid search is unavailable or empty.
|
||||
|
||||
Uses ad-hoc to_tsvector/plainto_tsquery with ts_rank_cd for text search,
|
||||
matching the quality of the original pre-hybrid-search implementation.
|
||||
Falls back to simple listing when no search query is provided.
|
||||
"""
|
||||
if not search_query:
|
||||
# No search query — use Prisma for simple filtered listing
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
order_by = []
|
||||
if sorted_by == StoreAgentsSortOptions.RATING:
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.RUNS:
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.NAME:
|
||||
order_by.append({"agent_name": "asc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.UPDATED_AT:
|
||||
order_by.append({"updated_at": "desc"})
|
||||
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
return [store_model.StoreAgent.from_db(a) for a in db_agents], total
|
||||
|
||||
# Text search using ad-hoc tsvector on StoreAgent view fields
|
||||
params: list[Any] = [search_query]
|
||||
filters = ["sa.is_available = true"]
|
||||
param_idx = 2
|
||||
|
||||
if featured:
|
||||
filters.append("sa.featured = true")
|
||||
if creators:
|
||||
params.append(creators)
|
||||
filters.append(f"sa.creator_username = ANY(${param_idx})")
|
||||
param_idx += 1
|
||||
if category:
|
||||
params.append(category)
|
||||
filters.append(f"${param_idx} = ANY(sa.categories)")
|
||||
param_idx += 1
|
||||
|
||||
where_sql = " AND ".join(filters)
|
||||
|
||||
params.extend([page_size, (page - 1) * page_size])
|
||||
limit_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
offset_param = f"${param_idx}"
|
||||
|
||||
sql = f"""
|
||||
WITH ranked AS (
|
||||
SELECT sa.*,
|
||||
ts_rank_cd(
|
||||
to_tsvector('english',
|
||||
COALESCE(sa.agent_name, '') || ' ' ||
|
||||
COALESCE(sa.sub_heading, '') || ' ' ||
|
||||
COALESCE(sa.description, '')
|
||||
),
|
||||
plainto_tsquery('english', $1)
|
||||
) AS rank,
|
||||
COUNT(*) OVER () AS total_count
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
WHERE {where_sql}
|
||||
AND to_tsvector('english',
|
||||
COALESCE(sa.agent_name, '') || ' ' ||
|
||||
COALESCE(sa.sub_heading, '') || ' ' ||
|
||||
COALESCE(sa.description, '')
|
||||
) @@ plainto_tsquery('english', $1)
|
||||
)
|
||||
SELECT * FROM ranked
|
||||
ORDER BY rank DESC
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(sql, *params)
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
store_agents = []
|
||||
for row in results:
|
||||
try:
|
||||
store_agents.append(
|
||||
store_model.StoreAgent(
|
||||
slug=row["slug"],
|
||||
agent_name=row["agent_name"],
|
||||
agent_image=row["agent_image"][0] if row["agent_image"] else "",
|
||||
creator=row["creator_username"] or "Needs Profile",
|
||||
creator_avatar=row["creator_avatar"] or "",
|
||||
sub_heading=row["sub_heading"],
|
||||
description=row["description"],
|
||||
runs=row["runs"],
|
||||
rating=row["rating"],
|
||||
agent_graph_id=row.get("graph_id", ""),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing StoreAgent from fallback search: {e}")
|
||||
continue
|
||||
|
||||
return store_agents, total
|
||||
|
||||
|
||||
async def log_search_term(search_query: str):
|
||||
"""Log a search term to the database"""
|
||||
|
||||
@@ -1139,16 +1228,21 @@ async def review_store_submission(
|
||||
},
|
||||
)
|
||||
|
||||
# Generate embedding for approved listing (blocking - admin operation)
|
||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=submission.name,
|
||||
description=submission.description,
|
||||
sub_heading=submission.subHeading,
|
||||
categories=submission.categories,
|
||||
tx=tx,
|
||||
)
|
||||
# Generate embedding for approved listing (best-effort)
|
||||
try:
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=submission.name,
|
||||
description=submission.description,
|
||||
sub_heading=submission.subHeading,
|
||||
categories=submission.categories,
|
||||
tx=tx,
|
||||
)
|
||||
except Exception as emb_err:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for listing "
|
||||
f"{store_listing_version_id}: {emb_err}"
|
||||
)
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": submission.storeListingId},
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import urllib.parse
|
||||
|
||||
import autogpt_libs.auth
|
||||
@@ -259,21 +258,18 @@ async def get_graph_meta_by_store_listing_version_id(
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str,
|
||||
) -> fastapi.responses.FileResponse:
|
||||
) -> fastapi.responses.Response:
|
||||
"""Download agent graph file for a specific marketplace listing version"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
return fastapi.responses.Response(
|
||||
content=backend.util.json.dumps(graph_data),
|
||||
media_type="application/json",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{file_name}"',
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -55,7 +55,6 @@ from backend.data.credit import (
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.invited_user import get_or_activate_user
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -71,6 +70,7 @@ from backend.data.onboarding import (
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
@@ -136,10 +136,12 @@ _tally_background_tasks: set[asyncio.Task] = set()
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_activate_user(user_data)
|
||||
user = await get_or_create_user(user_data)
|
||||
|
||||
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
|
||||
# not produce a stored result before first activation.
|
||||
# Fire-and-forget: populate business understanding from Tally form.
|
||||
# We use created_at proximity instead of an is_new flag because
|
||||
# get_or_create_user is cached — a separate is_new return value would be
|
||||
# unreliable on repeated calls within the cache TTL.
|
||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||
if age_seconds < 30:
|
||||
try:
|
||||
@@ -163,8 +165,7 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_email_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
email: str = Body(...),
|
||||
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
||||
) -> dict[str, str]:
|
||||
await update_user_email(user_id, email)
|
||||
|
||||
@@ -178,16 +179,10 @@ async def update_user_email_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
try:
|
||||
user = await get_user_by_id(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail="User not found. Please complete activation via /auth/user first.",
|
||||
)
|
||||
user = await get_or_create_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@@ -198,8 +193,7 @@ async def get_user_timezone_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
request: UpdateTimezoneRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
@@ -598,6 +592,11 @@ async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
|
||||
async def configure_user_auto_top_up(
|
||||
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> str:
|
||||
"""Configure auto top-up settings and perform an immediate top-up if needed.
|
||||
|
||||
Raises HTTPException(422) if the request parameters are invalid or if
|
||||
the credit top-up fails.
|
||||
"""
|
||||
if request.threshold < 0:
|
||||
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
|
||||
if request.amount < 500 and request.amount != 0:
|
||||
@@ -612,10 +611,20 @@ async def configure_user_auto_top_up(
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
try:
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
except ValueError as e:
|
||||
known_messages = (
|
||||
"must not be negative",
|
||||
"already exists for user",
|
||||
"No payment method found",
|
||||
)
|
||||
if any(msg in str(e) for msg in known_messages):
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
raise
|
||||
|
||||
await set_auto_top_up(
|
||||
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_or_activate_user",
|
||||
"backend.api.features.v1.get_or_create_user",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -188,6 +188,7 @@ async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file: UploadFile,
|
||||
session_id: str | None = Query(default=None),
|
||||
overwrite: bool = Query(default=False),
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to the user's workspace.
|
||||
@@ -248,7 +249,9 @@ async def upload_file(
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
workspace_file = await manager.write_file(
|
||||
content, filename, overwrite=overwrite
|
||||
)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from prisma.errors import PrismaError
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.admin.user_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
import backend.api.features.chat.routes as chat_routes
|
||||
@@ -211,13 +210,22 @@ instrument_fastapi(
|
||||
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: fastapi.Request, exc: Exception):
|
||||
if log_error:
|
||||
logger.exception(
|
||||
"%s %s failed. Investigate and resolve the underlying issue: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
if status_code >= 500:
|
||||
logger.exception(
|
||||
"%s %s failed. Investigate and resolve the underlying issue: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s %s failed with %d: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
status_code,
|
||||
exc,
|
||||
)
|
||||
|
||||
hint = (
|
||||
"Adjust the request and retry."
|
||||
@@ -267,12 +275,10 @@ async def validation_error_handler(
|
||||
|
||||
|
||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||
app.add_exception_handler(
|
||||
FolderAlreadyExistsError, handle_internal_http_error(409, False)
|
||||
)
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||
app.add_exception_handler(FolderAlreadyExistsError, handle_internal_http_error(409))
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403))
|
||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||
@@ -312,11 +318,6 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.user_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/users",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Shared configuration for all AgentMail blocks.
|
||||
"""
|
||||
|
||||
from agentmail import AsyncAgentMail
|
||||
|
||||
from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr
|
||||
|
||||
agent_mail = (
|
||||
ProviderBuilder("agent_mail")
|
||||
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
|
||||
.build()
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="agent_mail",
|
||||
title="Mock AgentMail API Key",
|
||||
api_key=SecretStr("mock-agentmail-api-key"),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def _client(credentials: APIKeyCredentials) -> AsyncAgentMail:
|
||||
"""Create an AsyncAgentMail client from credentials."""
|
||||
return AsyncAgentMail(api_key=credentials.api_key.get_secret_value())
|
||||
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
AgentMail Attachment blocks — download file attachments from messages and threads.
|
||||
|
||||
Attachments are files associated with messages (PDFs, CSVs, images, etc.).
|
||||
To send attachments, include them in the attachments parameter when using
|
||||
AgentMailSendMessageBlock or AgentMailReplyToMessageBlock.
|
||||
|
||||
To download, first get the attachment_id from a message's attachments array,
|
||||
then use these blocks to retrieve the file content as base64.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailGetMessageAttachmentBlock(Block):
|
||||
"""
|
||||
Download a file attachment from a specific email message.
|
||||
|
||||
Retrieves the raw file content and returns it as base64-encoded data.
|
||||
First get the attachment_id from a message object's attachments array,
|
||||
then use this block to download the file.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID containing the attachment"
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="Attachment ID to download (from the message's attachments array)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
content_base64: str = SchemaField(
|
||||
description="File content encoded as a base64 string. Decode with base64.b64decode() to get raw bytes."
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="The attachment ID that was downloaded"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a283ffc4-8087-4c3d-9135-8f26b86742ec",
|
||||
description="Download a file attachment from an email message. Returns base64-encoded file content.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"attachment_id": "test-attach",
|
||||
},
|
||||
test_output=[
|
||||
("content_base64", "dGVzdA=="),
|
||||
("attachment_id", "test-attach"),
|
||||
],
|
||||
test_mock={
|
||||
"get_attachment": lambda *a, **kw: b"test",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_attachment(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
message_id: str,
|
||||
attachment_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.get_attachment(
|
||||
inbox_id=inbox_id,
|
||||
message_id=message_id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_attachment(
|
||||
credentials=credentials,
|
||||
inbox_id=input_data.inbox_id,
|
||||
message_id=input_data.message_id,
|
||||
attachment_id=input_data.attachment_id,
|
||||
)
|
||||
if isinstance(data, bytes):
|
||||
encoded = base64.b64encode(data).decode()
|
||||
elif isinstance(data, str):
|
||||
encoded = base64.b64encode(data.encode("utf-8")).decode()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unexpected attachment data type: {type(data).__name__}"
|
||||
)
|
||||
|
||||
yield "content_base64", encoded
|
||||
yield "attachment_id", input_data.attachment_id
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetThreadAttachmentBlock(Block):
|
||||
"""
|
||||
Download a file attachment from a conversation thread.
|
||||
|
||||
Same as GetMessageAttachment but looks up by thread ID instead of
|
||||
message ID. Useful when you know the thread but not the specific
|
||||
message containing the attachment.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID containing the attachment")
|
||||
attachment_id: str = SchemaField(
|
||||
description="Attachment ID to download (from a message's attachments array within the thread)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
content_base64: str = SchemaField(
|
||||
description="File content encoded as a base64 string. Decode with base64.b64decode() to get raw bytes."
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="The attachment ID that was downloaded"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="06b6a4c4-9d71-4992-9e9c-cf3b352763b5",
|
||||
description="Download a file attachment from a conversation thread. Returns base64-encoded file content.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
"attachment_id": "test-attach",
|
||||
},
|
||||
test_output=[
|
||||
("content_base64", "dGVzdA=="),
|
||||
("attachment_id", "test-attach"),
|
||||
],
|
||||
test_mock={
|
||||
"get_attachment": lambda *a, **kw: b"test",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_attachment(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
thread_id: str,
|
||||
attachment_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.get_attachment(
|
||||
inbox_id=inbox_id,
|
||||
thread_id=thread_id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_attachment(
|
||||
credentials=credentials,
|
||||
inbox_id=input_data.inbox_id,
|
||||
thread_id=input_data.thread_id,
|
||||
attachment_id=input_data.attachment_id,
|
||||
)
|
||||
if isinstance(data, bytes):
|
||||
encoded = base64.b64encode(data).decode()
|
||||
elif isinstance(data, str):
|
||||
encoded = base64.b64encode(data.encode("utf-8")).decode()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unexpected attachment data type: {type(data).__name__}"
|
||||
)
|
||||
|
||||
yield "content_base64", encoded
|
||||
yield "attachment_id", input_data.attachment_id
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
678
autogpt_platform/backend/backend/blocks/agent_mail/drafts.py
Normal file
678
autogpt_platform/backend/backend/blocks/agent_mail/drafts.py
Normal file
@@ -0,0 +1,678 @@
|
||||
"""
|
||||
AgentMail Draft blocks — create, get, list, update, send, and delete drafts.
|
||||
|
||||
A Draft is an unsent message that can be reviewed, edited, and sent later.
|
||||
Drafts enable human-in-the-loop review, scheduled sending (via send_at),
|
||||
and complex multi-step email composition workflows.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreateDraftBlock(Block):
|
||||
"""
|
||||
Create a draft email in an AgentMail inbox for review or scheduled sending.
|
||||
|
||||
Drafts let agents prepare emails without sending immediately. Use send_at
|
||||
to schedule automatic sending at a future time (ISO 8601 format).
|
||||
Scheduled drafts are auto-labeled 'scheduled' and can be cancelled by
|
||||
deleting the draft.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to create the draft in"
|
||||
)
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses (e.g. ['user@example.com'])"
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject line", default="")
|
||||
text: str = SchemaField(description="Plain text body of the draft", default="")
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the draft", default="", advanced=True
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
in_reply_to: str = SchemaField(
|
||||
description="Message ID this draft replies to, for threading follow-up drafts",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
send_at: str = SchemaField(
|
||||
description="Schedule automatic sending at this ISO 8601 datetime (e.g. '2025-01-15T09:00:00Z'). Leave empty for manual send.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(
|
||||
description="Unique identifier of the created draft"
|
||||
)
|
||||
send_status: str = SchemaField(
|
||||
description="'scheduled' if send_at was set, empty otherwise. Values: scheduled, sending, failed.",
|
||||
default="",
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete draft object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="25ac9086-69fd-48b8-b910-9dbe04b8f3bd",
|
||||
description="Create a draft email for review or scheduled sending. Use send_at for automatic future delivery.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"to": ["user@example.com"],
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "mock-draft-id"),
|
||||
("send_status", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "mock-draft-id",
|
||||
"send_status": "",
|
||||
"model_dump": lambda self: {"draft_id": "mock-draft-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_draft(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.create(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"to": input_data.to}
|
||||
if input_data.subject:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.in_reply_to:
|
||||
params["in_reply_to"] = input_data.in_reply_to
|
||||
if input_data.send_at:
|
||||
params["send_at"] = input_data.send_at
|
||||
|
||||
draft = await self.create_draft(credentials, input_data.inbox_id, **params)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetDraftBlock(Block):
|
||||
"""
|
||||
Retrieve a specific draft from an AgentMail inbox.
|
||||
|
||||
Returns the draft contents including recipients, subject, body, and
|
||||
scheduled send status. Use this to review a draft before approving it.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(description="Unique identifier of the draft")
|
||||
subject: str = SchemaField(description="Draft subject line", default="")
|
||||
send_status: str = SchemaField(
|
||||
description="Scheduled send status: 'scheduled', 'sending', 'failed', or empty",
|
||||
default="",
|
||||
)
|
||||
send_at: str = SchemaField(
|
||||
description="Scheduled send time (ISO 8601) if set", default=""
|
||||
)
|
||||
result: dict = SchemaField(description="Complete draft object with all fields")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8e57780d-dc25-43d4-a0f4-1f02877b09fb",
|
||||
description="Retrieve a draft email to review its contents, recipients, and scheduled send status.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "test-draft"),
|
||||
("subject", ""),
|
||||
("send_status", ""),
|
||||
("send_at", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "test-draft",
|
||||
"subject": "",
|
||||
"send_status": "",
|
||||
"send_at": "",
|
||||
"model_dump": lambda self: {"draft_id": "test-draft"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_draft(credentials: APIKeyCredentials, inbox_id: str, draft_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.get(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
draft = await self.get_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "subject", draft.subject or ""
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "send_at", draft.send_at or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts in an AgentMail inbox with optional label filtering.
|
||||
|
||||
Use labels=['scheduled'] to find all drafts queued for future sending.
|
||||
Useful for building approval dashboards or monitoring pending outreach.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list drafts from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Filter drafts by labels (e.g. ['scheduled'] for pending sends)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects with subject, recipients, send_status, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e84883b7-7c39-4c5c-88e8-0a72b078ea63",
|
||||
description="List drafts in an AgentMail inbox. Filter by labels=['scheduled'] to find pending sends.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_drafts(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.list(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_drafts(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateDraftBlock(Block):
|
||||
"""
|
||||
Update an existing draft's content, recipients, or scheduled send time.
|
||||
|
||||
Use this to reschedule a draft (change send_at), modify recipients,
|
||||
or edit the subject/body before sending. To cancel a scheduled send,
|
||||
delete the draft instead.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to update")
|
||||
to: Optional[list[str]] = SchemaField(
|
||||
description="Updated recipient email addresses (replaces existing list). Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
subject: Optional[str] = SchemaField(
|
||||
description="Updated subject line. Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
text: Optional[str] = SchemaField(
|
||||
description="Updated plain text body. Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
html: Optional[str] = SchemaField(
|
||||
description="Updated HTML body. Omit to keep current value.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
send_at: Optional[str] = SchemaField(
|
||||
description="Reschedule: new ISO 8601 send time (e.g. '2025-01-20T14:00:00Z'). Omit to keep current value.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(description="The updated draft ID")
|
||||
send_status: str = SchemaField(description="Updated send status", default="")
|
||||
result: dict = SchemaField(description="Complete updated draft object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="351f6e51-695a-421a-9032-46a587b10336",
|
||||
description="Update a draft's content, recipients, or scheduled send time. Use to reschedule or edit before sending.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "test-draft"),
|
||||
("send_status", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "test-draft",
|
||||
"send_status": "",
|
||||
"model_dump": lambda self: {"draft_id": "test-draft"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_draft(
|
||||
credentials: APIKeyCredentials, inbox_id: str, draft_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.update(
|
||||
inbox_id=inbox_id, draft_id=draft_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.to is not None:
|
||||
params["to"] = input_data.to
|
||||
if input_data.subject is not None:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text is not None:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html is not None:
|
||||
params["html"] = input_data.html
|
||||
if input_data.send_at is not None:
|
||||
params["send_at"] = input_data.send_at
|
||||
|
||||
draft = await self.update_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id, **params
|
||||
)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailSendDraftBlock(Block):
|
||||
"""
|
||||
Send a draft immediately, converting it into a delivered message.
|
||||
|
||||
The draft is deleted after successful sending and becomes a regular
|
||||
message with a message_id. Use this for human-in-the-loop approval
|
||||
workflows: agent creates draft, human reviews, then this block sends it.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to send now")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID of the now-sent email (draft is deleted)"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID the sent message belongs to"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete sent message object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="37c39e83-475d-4b3d-843a-d923d001b85a",
|
||||
description="Send a draft immediately, converting it into a delivered message. The draft is deleted after sending.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-msg-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"send_draft": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-msg-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_draft(credentials: APIKeyCredentials, inbox_id: str, draft_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.send(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
msg = await self.send_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteDraftBlock(Block):
|
||||
"""
|
||||
Delete a draft from an AgentMail inbox. Also cancels any scheduled send.
|
||||
|
||||
If the draft was scheduled with send_at, deleting it cancels the
|
||||
scheduled delivery. This is the way to cancel a scheduled email.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(
|
||||
description="Draft ID to delete (also cancels scheduled sends)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the draft was successfully deleted/cancelled"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9023eb99-3e2f-4def-808b-d9c584b3d9e7",
|
||||
description="Delete a draft or cancel a scheduled email. Removes the draft permanently.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_draft": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_draft(
|
||||
credentials: APIKeyCredentials, inbox_id: str, draft_id: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.drafts.delete(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListOrgDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts across every inbox in your organization.
|
||||
|
||||
Returns drafts from all inboxes in one query. Perfect for building
|
||||
a central approval dashboard where a human supervisor can review
|
||||
and approve any draft created by any agent.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects from all inboxes in the organization"
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ed7558ae-3a07-45f5-af55-a25fe88c9971",
|
||||
description="List all drafts across every inbox in your organization. Use for central approval dashboards.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_org_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_org_drafts(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.drafts.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_org_drafts(credentials, **params)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
414
autogpt_platform/backend/backend/blocks/agent_mail/inbox.py
Normal file
414
autogpt_platform/backend/backend/blocks/agent_mail/inbox.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
AgentMail Inbox blocks — create, get, list, update, and delete inboxes.
|
||||
|
||||
An Inbox is a fully programmable email account for AI agents. Each inbox gets
|
||||
a unique email address and can send, receive, and manage emails via the
|
||||
AgentMail API. You can create thousands of inboxes on demand.
|
||||
"""
|
||||
|
||||
from agentmail.inboxes.types import CreateInboxRequest
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreateInboxBlock(Block):
|
||||
"""
|
||||
Create a new email inbox for an AI agent via AgentMail.
|
||||
|
||||
Each inbox gets a unique email address (e.g. username@agentmail.to).
|
||||
If username and domain are not provided, AgentMail auto-generates them.
|
||||
Use custom domains by specifying the domain field.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
username: str = SchemaField(
|
||||
description="Local part of the email address (e.g. 'support' for support@domain.com). Leave empty to auto-generate.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Email domain (e.g. 'mydomain.com'). Defaults to agentmail.to if empty.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field of sent emails (e.g. 'Support Agent')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(
|
||||
description="Unique identifier for the created inbox (also the email address)"
|
||||
)
|
||||
email_address: str = SchemaField(
|
||||
description="Full email address of the inbox (e.g. support@agentmail.to)"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7a8ac219-c6ec-4eec-a828-81af283ce04c",
|
||||
description="Create a new email inbox for an AI agent via AgentMail. Each inbox gets a unique address and can send/receive emails.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("inbox_id", "mock-inbox-id"),
|
||||
("email_address", "mock-inbox-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "mock-inbox-id",
|
||||
"model_dump": lambda self: {"inbox_id": "mock-inbox-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_inbox(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.create(request=CreateInboxRequest(**params))
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.username:
|
||||
params["username"] = input_data.username
|
||||
if input_data.domain:
|
||||
params["domain"] = input_data.domain
|
||||
if input_data.display_name:
|
||||
params["display_name"] = input_data.display_name
|
||||
|
||||
inbox = await self.create_inbox(credentials, **params)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetInboxBlock(Block):
|
||||
"""
|
||||
Retrieve details of an existing AgentMail inbox by its ID or email address.
|
||||
|
||||
Returns the inbox metadata including email address, display name, and
|
||||
configuration. Use this to check if an inbox exists or get its properties.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to look up (e.g. 'support@agentmail.to')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(description="Unique identifier of the inbox")
|
||||
email_address: str = SchemaField(description="Full email address of the inbox")
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field", default=""
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b858f62b-6c12-4736-aaf2-dbc5a9281320",
|
||||
description="Retrieve details of an existing AgentMail inbox including its email address, display name, and configuration.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("inbox_id", "test-inbox"),
|
||||
("email_address", "test-inbox"),
|
||||
("display_name", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "test-inbox",
|
||||
"display_name": "",
|
||||
"model_dump": lambda self: {"inbox_id": "test-inbox"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_inbox(credentials: APIKeyCredentials, inbox_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.get(inbox_id=inbox_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
inbox = await self.get_inbox(credentials, input_data.inbox_id)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "display_name", inbox.display_name or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListInboxesBlock(Block):
|
||||
"""
|
||||
List all email inboxes in your AgentMail organization.
|
||||
|
||||
Returns a paginated list of all inboxes with their metadata.
|
||||
Use page_token for pagination when you have many inboxes.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of inboxes to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page of results",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inboxes: list[dict] = SchemaField(
|
||||
description="List of inbox objects, each containing inbox_id, email_address, display_name, etc."
|
||||
)
|
||||
count: int = SchemaField(
|
||||
description="Total number of inboxes in your organization"
|
||||
)
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token to pass as page_token to get the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cfd84a06-2121-4cef-8d14-8badf52d22f0",
|
||||
description="List all email inboxes in your AgentMail organization with pagination support.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("inboxes", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_inboxes": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"inboxes": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_inboxes(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_inboxes(credentials, **params)
|
||||
inboxes = [i.model_dump() for i in response.inboxes]
|
||||
|
||||
yield "inboxes", inboxes
|
||||
yield "count", (c if (c := response.count) is not None else len(inboxes))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateInboxBlock(Block):
|
||||
"""
|
||||
Update the display name of an existing AgentMail inbox.
|
||||
|
||||
Changes the friendly name shown in the 'From' field when emails are sent
|
||||
from this inbox. The email address itself cannot be changed.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to update (e.g. 'support@agentmail.to')"
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="New display name for the inbox (e.g. 'Customer Support Bot')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(description="The updated inbox ID")
|
||||
result: dict = SchemaField(
|
||||
description="Complete updated inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="59b49f59-a6d1-4203-94c0-3908adac50b6",
|
||||
description="Update the display name of an AgentMail inbox. Changes the 'From' name shown when emails are sent.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"display_name": "Updated",
|
||||
},
|
||||
test_output=[
|
||||
("inbox_id", "test-inbox"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "test-inbox",
|
||||
"model_dump": lambda self: {"inbox_id": "test-inbox"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_inbox(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.update(inbox_id=inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
inbox = await self.update_inbox(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
display_name=input_data.display_name,
|
||||
)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteInboxBlock(Block):
|
||||
"""
|
||||
Permanently delete an AgentMail inbox and all its data.
|
||||
|
||||
This removes the inbox, all its messages, threads, and drafts.
|
||||
This action cannot be undone. The email address will no longer
|
||||
receive or send emails.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to permanently delete"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the inbox was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ade970ae-8428-4a7b-9278-b52054dbf535",
|
||||
description="Permanently delete an AgentMail inbox and all its messages, threads, and drafts. This action cannot be undone.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_inbox": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_inbox(credentials: APIKeyCredentials, inbox_id: str):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.delete(inbox_id=inbox_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_inbox(credentials, input_data.inbox_id)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
384
autogpt_platform/backend/backend/blocks/agent_mail/lists.py
Normal file
384
autogpt_platform/backend/backend/blocks/agent_mail/lists.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
AgentMail List blocks — manage allow/block lists for email filtering.
|
||||
|
||||
Lists let you control which email addresses and domains your agents can
|
||||
send to or receive from. There are four list types based on two dimensions:
|
||||
direction (send/receive) and type (allow/block).
|
||||
|
||||
- receive + allow: Only accept emails from these addresses/domains
|
||||
- receive + block: Reject emails from these addresses/domains
|
||||
- send + allow: Only send emails to these addresses/domains
|
||||
- send + block: Prevent sending emails to these addresses/domains
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class ListDirection(str, Enum):
|
||||
SEND = "send"
|
||||
RECEIVE = "receive"
|
||||
|
||||
|
||||
class ListType(str, Enum):
|
||||
ALLOW = "allow"
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
class AgentMailListEntriesBlock(Block):
|
||||
"""
|
||||
List all entries in an AgentMail allow/block list.
|
||||
|
||||
Retrieves email addresses and domains that are currently allowed
|
||||
or blocked for sending or receiving. Use direction and list_type
|
||||
to select which of the four lists to query.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' to filter outgoing emails, 'receive' to filter incoming emails"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist (only permit these), 'block' for blacklist (reject these)"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of entries to return per page",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entries: list[dict] = SchemaField(
|
||||
description="List of entries, each with an email address or domain"
|
||||
)
|
||||
count: int = SchemaField(description="Number of entries returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="01489100-35da-45aa-8a01-9540ba0e9a21",
|
||||
description="List all entries in an AgentMail allow/block list. Choose send/receive direction and allow/block type.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
},
|
||||
test_output=[
|
||||
("entries", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_entries": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"entries": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_entries(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.list(direction, list_type, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_entries(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
**params,
|
||||
)
|
||||
entries = [e.model_dump() for e in response.entries]
|
||||
|
||||
yield "entries", entries
|
||||
yield "count", (c if (c := response.count) is not None else len(entries))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailCreateListEntryBlock(Block):
|
||||
"""
|
||||
Add an email address or domain to an AgentMail allow/block list.
|
||||
|
||||
Entries can be full email addresses (e.g. 'partner@example.com') or
|
||||
entire domains (e.g. 'example.com'). For block lists, you can optionally
|
||||
provide a reason (e.g. 'spam', 'competitor').
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing email rules, 'receive' for incoming email rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' to whitelist, 'block' to blacklist"
|
||||
)
|
||||
entry: str = SchemaField(
|
||||
description="Email address (user@example.com) or domain (example.com) to add"
|
||||
)
|
||||
reason: str = SchemaField(
|
||||
description="Reason for blocking (only used with block lists, e.g. 'spam', 'competitor')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entry: str = SchemaField(
|
||||
description="The email address or domain that was added"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete entry object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b6650a0a-b113-40cf-8243-ff20f684f9b8",
|
||||
description="Add an email address or domain to an allow/block list. Block spam senders or whitelist trusted domains.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[
|
||||
("entry", "spam@example.com"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_entry": lambda *a, **kw: type(
|
||||
"Entry",
|
||||
(),
|
||||
{
|
||||
"model_dump": lambda self: {"entry": "spam@example.com"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.create(direction, list_type, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"entry": input_data.entry}
|
||||
if input_data.reason and input_data.list_type == ListType.BLOCK:
|
||||
params["reason"] = input_data.reason
|
||||
|
||||
result = await self.create_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
**params,
|
||||
)
|
||||
result_dict = result.model_dump()
|
||||
|
||||
yield "entry", input_data.entry
|
||||
yield "result", result_dict
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetListEntryBlock(Block):
|
||||
"""
|
||||
Check if an email address or domain exists in an AgentMail allow/block list.
|
||||
|
||||
Returns the entry details if found. Use this to verify whether a specific
|
||||
address or domain is currently allowed or blocked.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing rules, 'receive' for incoming rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist, 'block' for blacklist"
|
||||
)
|
||||
entry: str = SchemaField(description="Email address or domain to look up")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entry: str = SchemaField(
|
||||
description="The email address or domain that was found"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete entry object with metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fb117058-ab27-40d1-9231-eb1dd526fc7a",
|
||||
description="Check if an email address or domain is in an allow/block list. Verify filtering rules.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[
|
||||
("entry", "spam@example.com"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_entry": lambda *a, **kw: type(
|
||||
"Entry",
|
||||
(),
|
||||
{
|
||||
"model_dump": lambda self: {"entry": "spam@example.com"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, entry: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.get(direction, list_type, entry=entry)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.get_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
input_data.entry,
|
||||
)
|
||||
result_dict = result.model_dump()
|
||||
|
||||
yield "entry", input_data.entry
|
||||
yield "result", result_dict
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteListEntryBlock(Block):
|
||||
"""
|
||||
Remove an email address or domain from an AgentMail allow/block list.
|
||||
|
||||
After removal, the address/domain will no longer be filtered by this list.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing rules, 'receive' for incoming rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist, 'block' for blacklist"
|
||||
)
|
||||
entry: str = SchemaField(
|
||||
description="Email address or domain to remove from the list"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the entry was successfully removed"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2b8d57f1-1c9e-470f-a70b-5991c80fad5f",
|
||||
description="Remove an email address or domain from an allow/block list to stop filtering it.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_entry": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, entry: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.lists.delete(direction, list_type, entry=entry)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
input_data.entry,
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
695
autogpt_platform/backend/backend/blocks/agent_mail/messages.py
Normal file
695
autogpt_platform/backend/backend/blocks/agent_mail/messages.py
Normal file
@@ -0,0 +1,695 @@
|
||||
"""
|
||||
AgentMail Message blocks — send, list, get, reply, forward, and update messages.
|
||||
|
||||
A Message is an individual email within a Thread. Agents can send new messages
|
||||
(which create threads), reply to existing messages, forward them, and manage
|
||||
labels for state tracking (e.g. read/unread, campaign tags).
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailSendMessageBlock(Block):
|
||||
"""
|
||||
Send a new email from an AgentMail inbox, automatically creating a new thread.
|
||||
|
||||
Supports plain text and HTML bodies, CC/BCC recipients, and labels for
|
||||
organizing messages (e.g. campaign tracking, state management).
|
||||
Max 50 combined recipients across to, cc, and bcc.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to send from (e.g. 'agent@agentmail.to')"
|
||||
)
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses (e.g. ['user@example.com'])"
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject line")
|
||||
text: str = SchemaField(
|
||||
description="Plain text body of the email. Always provide this as a fallback for email clients that don't render HTML."
|
||||
)
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the email. Embed CSS in a <style> tag for best compatibility across email clients.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses for human-in-the-loop oversight",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses (hidden from other recipients)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Labels to tag the message for filtering and state management (e.g. ['outreach', 'q4-campaign'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the sent message"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID grouping this message and any future replies"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete sent message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b67469b2-7748-4d81-a223-4ebd332cca89",
|
||||
description="Send a new email from an AgentMail inbox. Creates a new conversation thread. Supports HTML, CC/BCC, and labels.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"to": ["user@example.com"],
|
||||
"subject": "Test",
|
||||
"text": "Hello",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-msg-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"send_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_message(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.send(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total = len(input_data.to) + len(input_data.cc) + len(input_data.bcc)
|
||||
if total > 50:
|
||||
raise ValueError(
|
||||
f"Max 50 combined recipients across to, cc, and bcc (got {total})"
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"to": input_data.to,
|
||||
"subject": input_data.subject,
|
||||
"text": input_data.text,
|
||||
}
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
msg = await self.send_message(credentials, input_data.inbox_id, **params)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListMessagesBlock(Block):
|
||||
"""
|
||||
List all messages in an AgentMail inbox with optional label filtering.
|
||||
|
||||
Returns a paginated list of messages. Use labels to filter (e.g.
|
||||
labels=['unread'] to only get unprocessed messages). Useful for
|
||||
polling workflows or building inbox views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list messages from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of messages to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return messages with ALL of these labels (e.g. ['unread'] or ['q4-campaign', 'follow-up'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
messages: list[dict] = SchemaField(
|
||||
description="List of message objects with subject, sender, text, html, labels, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of messages returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="721234df-c7a2-4927-b205-744badbd5844",
|
||||
description="List messages in an AgentMail inbox. Filter by labels to find unread, campaign-tagged, or categorized messages.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("messages", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_messages": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"messages": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_messages(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.list(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_messages(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
messages = [m.model_dump() for m in response.messages]
|
||||
|
||||
yield "messages", messages
|
||||
yield "count", (c if (c := response.count) is not None else len(messages))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetMessageBlock(Block):
|
||||
"""
|
||||
Retrieve a specific email message by ID from an AgentMail inbox.
|
||||
|
||||
Returns the full message including subject, body (text and HTML),
|
||||
sender, recipients, and attachments. Use extracted_text to get
|
||||
only the new reply content without quoted history.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID to retrieve (e.g. '<abc123@agentmail.to>')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(description="Unique identifier of the message")
|
||||
thread_id: str = SchemaField(description="Thread this message belongs to")
|
||||
subject: str = SchemaField(description="Email subject line")
|
||||
text: str = SchemaField(
|
||||
description="Full plain text body (may include quoted reply history)"
|
||||
)
|
||||
extracted_text: str = SchemaField(
|
||||
description="Just the new reply content with quoted history stripped. Best for AI processing.",
|
||||
default="",
|
||||
)
|
||||
html: str = SchemaField(description="HTML body of the email", default="")
|
||||
result: dict = SchemaField(
|
||||
description="Complete message object with all fields including sender, recipients, attachments, labels"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2788bdfa-1527-4603-a5e4-a455c05c032f",
|
||||
description="Retrieve a specific email message by ID. Includes extracted_text for clean reply content without quoted history.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "test-msg"),
|
||||
("thread_id", "t1"),
|
||||
("subject", "Hi"),
|
||||
("text", "Hello"),
|
||||
("extracted_text", "Hello"),
|
||||
("html", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "test-msg",
|
||||
"thread_id": "t1",
|
||||
"subject": "Hi",
|
||||
"text": "Hello",
|
||||
"extracted_text": "Hello",
|
||||
"html": "",
|
||||
"model_dump": lambda self: {"message_id": "test-msg"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_message(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
message_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.get(
|
||||
inbox_id=inbox_id, message_id=message_id
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
msg = await self.get_message(
|
||||
credentials, input_data.inbox_id, input_data.message_id
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "subject", msg.subject or ""
|
||||
yield "text", msg.text or ""
|
||||
yield "extracted_text", msg.extracted_text or ""
|
||||
yield "html", msg.html or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailReplyToMessageBlock(Block):
|
||||
"""
|
||||
Reply to an existing email message, keeping the reply in the same thread.
|
||||
|
||||
The reply is automatically added to the same conversation thread as the
|
||||
original message. Use this for multi-turn agent conversations.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to send the reply from"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID to reply to (e.g. '<abc123@agentmail.to>')"
|
||||
)
|
||||
text: str = SchemaField(description="Plain text body of the reply")
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the reply",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the reply message"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID the reply was added to")
|
||||
result: dict = SchemaField(
|
||||
description="Complete reply message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b9fe53fa-5026-4547-9570-b54ccb487229",
|
||||
description="Reply to an existing email in the same conversation thread. Use for multi-turn agent conversations.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"text": "Reply",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-reply-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"reply_to_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-reply-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-reply-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def reply_to_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.reply(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"text": input_data.text}
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
|
||||
reply = await self.reply_to_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = reply.model_dump()
|
||||
|
||||
yield "message_id", reply.message_id
|
||||
yield "thread_id", reply.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailForwardMessageBlock(Block):
|
||||
"""
|
||||
Forward an existing email message to one or more recipients.
|
||||
|
||||
Sends the original message content to different email addresses.
|
||||
Optionally prepend additional text or override the subject line.
|
||||
Max 50 combined recipients across to, cc, and bcc.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to forward from"
|
||||
)
|
||||
message_id: str = SchemaField(description="Message ID to forward")
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses to forward the message to (e.g. ['user@example.com'])"
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses (hidden from other recipients)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
subject: str = SchemaField(
|
||||
description="Override the subject line (defaults to 'Fwd: <original subject>')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text: str = SchemaField(
|
||||
description="Additional plain text to prepend before the forwarded content",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
html: str = SchemaField(
|
||||
description="Additional HTML to prepend before the forwarded content",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the forwarded message"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID of the forward")
|
||||
result: dict = SchemaField(
|
||||
description="Complete forwarded message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b70c7e33-5d66-4f8e-897f-ac73a7bfce82",
|
||||
description="Forward an email message to one or more recipients. Supports CC/BCC and optional extra text or subject override.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"to": ["user@example.com"],
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-fwd-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"forward_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-fwd-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-fwd-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def forward_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.forward(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total = len(input_data.to) + len(input_data.cc) + len(input_data.bcc)
|
||||
if total > 50:
|
||||
raise ValueError(
|
||||
f"Max 50 combined recipients across to, cc, and bcc (got {total})"
|
||||
)
|
||||
|
||||
params: dict = {"to": input_data.to}
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.subject:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
|
||||
fwd = await self.forward_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = fwd.model_dump()
|
||||
|
||||
yield "message_id", fwd.message_id
|
||||
yield "thread_id", fwd.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateMessageBlock(Block):
|
||||
"""
|
||||
Add or remove labels on an email message for state management.
|
||||
|
||||
Labels are string tags used to track message state (read/unread),
|
||||
categorize messages (billing, support), or tag campaigns (q4-outreach).
|
||||
Common pattern: add 'read' and remove 'unread' after processing a message.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(description="Message ID to update labels on")
|
||||
add_labels: list[str] = SchemaField(
|
||||
description="Labels to add (e.g. ['read', 'processed', 'high-priority'])",
|
||||
default_factory=list,
|
||||
)
|
||||
remove_labels: list[str] = SchemaField(
|
||||
description="Labels to remove (e.g. ['unread', 'pending'])",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(description="The updated message ID")
|
||||
result: dict = SchemaField(
|
||||
description="Complete updated message object with current labels"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="694ff816-4c89-4a5e-a552-8c31be187735",
|
||||
description="Add or remove labels on an email message. Use for read/unread tracking, campaign tagging, or state management.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"add_labels": ["read"],
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "test-msg"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "test-msg",
|
||||
"model_dump": lambda self: {"message_id": "test-msg"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.update(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
if not input_data.add_labels and not input_data.remove_labels:
|
||||
raise ValueError(
|
||||
"Must specify at least one label operation: add_labels or remove_labels"
|
||||
)
|
||||
|
||||
params: dict = {}
|
||||
if input_data.add_labels:
|
||||
params["add_labels"] = input_data.add_labels
|
||||
if input_data.remove_labels:
|
||||
params["remove_labels"] = input_data.remove_labels
|
||||
|
||||
msg = await self.update_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
651
autogpt_platform/backend/backend/blocks/agent_mail/pods.py
Normal file
651
autogpt_platform/backend/backend/blocks/agent_mail/pods.py
Normal file
@@ -0,0 +1,651 @@
|
||||
"""
|
||||
AgentMail Pod blocks — create, get, list, delete pods and list pod-scoped resources.
|
||||
|
||||
Pods provide multi-tenant isolation between your customers. Each pod acts as
|
||||
an isolated workspace containing its own inboxes, domains, threads, and drafts.
|
||||
Use pods when building SaaS platforms, agency tools, or AI agent fleets that
|
||||
serve multiple customers.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreatePodBlock(Block):
|
||||
"""
|
||||
Create a new pod for multi-tenant customer isolation.
|
||||
|
||||
Each pod acts as an isolated workspace for one customer or tenant.
|
||||
Use client_id to map pods to your internal tenant IDs for idempotent
|
||||
creation (safe to retry without creating duplicates).
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
client_id: str = SchemaField(
|
||||
description="Your internal tenant/customer ID for idempotent mapping. Lets you access the pod by your own ID instead of AgentMail's pod_id.",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pod_id: str = SchemaField(description="Unique identifier of the created pod")
|
||||
result: dict = SchemaField(description="Complete pod object with all metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a2db9784-2d17-4f8f-9d6b-0214e6f22101",
|
||||
description="Create a new pod for multi-tenant customer isolation. Use client_id to map to your internal tenant IDs.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("pod_id", "mock-pod-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_pod": lambda *a, **kw: type(
|
||||
"Pod",
|
||||
(),
|
||||
{
|
||||
"pod_id": "mock-pod-id",
|
||||
"model_dump": lambda self: {"pod_id": "mock-pod-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_pod(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.create(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.client_id:
|
||||
params["client_id"] = input_data.client_id
|
||||
|
||||
pod = await self.create_pod(credentials, **params)
|
||||
result = pod.model_dump()
|
||||
|
||||
yield "pod_id", pod.pod_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetPodBlock(Block):
|
||||
"""
|
||||
Retrieve details of an existing pod by its ID.
|
||||
|
||||
Returns the pod metadata including its client_id mapping and
|
||||
creation timestamp.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pod_id: str = SchemaField(description="Unique identifier of the pod")
|
||||
result: dict = SchemaField(description="Complete pod object with all metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="553361bc-bb1b-4322-9ad4-0c226200217e",
|
||||
description="Retrieve details of an existing pod including its client_id mapping and metadata.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("pod_id", "test-pod"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_pod": lambda *a, **kw: type(
|
||||
"Pod",
|
||||
(),
|
||||
{
|
||||
"pod_id": "test-pod",
|
||||
"model_dump": lambda self: {"pod_id": "test-pod"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_pod(credentials: APIKeyCredentials, pod_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.pods.get(pod_id=pod_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
pod = await self.get_pod(credentials, pod_id=input_data.pod_id)
|
||||
result = pod.model_dump()
|
||||
|
||||
yield "pod_id", pod.pod_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodsBlock(Block):
|
||||
"""
|
||||
List all pods in your AgentMail organization.
|
||||
|
||||
Returns a paginated list of all tenant pods with their metadata.
|
||||
Use this to see all customer workspaces at a glance.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of pods to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pods: list[dict] = SchemaField(
|
||||
description="List of pod objects with pod_id, client_id, creation time, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of pods returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9d3725ee-2968-431a-a816-857ab41e1420",
|
||||
description="List all tenant pods in your organization. See all customer workspaces at a glance.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("pods", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pods": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"pods": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pods(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pods(credentials, **params)
|
||||
pods = [p.model_dump() for p in response.pods]
|
||||
|
||||
yield "pods", pods
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeletePodBlock(Block):
|
||||
"""
|
||||
Permanently delete a pod. All inboxes and domains must be removed first.
|
||||
|
||||
You cannot delete a pod that still contains inboxes or domains.
|
||||
Delete all child resources first, then delete the pod.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(
|
||||
description="Pod ID to permanently delete (must have no inboxes or domains)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the pod was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f371f8cd-682d-4f5f-905c-529c74a8fb35",
|
||||
description="Permanently delete a pod. All inboxes and domains must be removed first.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_pod": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_pod(credentials: APIKeyCredentials, pod_id: str):
|
||||
client = _client(credentials)
|
||||
await client.pods.delete(pod_id=pod_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_pod(credentials, pod_id=input_data.pod_id)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodInboxesBlock(Block):
|
||||
"""
|
||||
List all inboxes within a specific pod (customer workspace).
|
||||
|
||||
Returns only the inboxes belonging to this pod, providing
|
||||
tenant-scoped visibility.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list inboxes from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of inboxes to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inboxes: list[dict] = SchemaField(
|
||||
description="List of inbox objects within this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of inboxes returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a8c17ce0-b7c1-4bc3-ae39-680e1952e5d0",
|
||||
description="List all inboxes within a pod. View email accounts scoped to a specific customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("inboxes", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_inboxes": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"inboxes": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_inboxes(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.inboxes.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pod_inboxes(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
inboxes = [i.model_dump() for i in response.inboxes]
|
||||
|
||||
yield "inboxes", inboxes
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodThreadsBlock(Block):
|
||||
"""
|
||||
List all conversation threads across all inboxes within a pod.
|
||||
|
||||
Returns threads from every inbox in the pod. Use for building
|
||||
per-customer dashboards showing all email activity, or for
|
||||
supervisor agents monitoring a customer's conversations.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list threads from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects from all inboxes in this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="80214f08-8b85-4533-a6b8-f8123bfcb410",
|
||||
description="List all conversation threads across all inboxes within a pod. View all email activity for a customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_threads(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.threads.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_pod_threads(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts across all inboxes within a pod.
|
||||
|
||||
Returns pending drafts from every inbox in the pod. Use for
|
||||
per-customer approval dashboards or monitoring scheduled sends.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list drafts from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects from all inboxes in this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="12fd7a3e-51ad-4b20-97c1-0391f207f517",
|
||||
description="List all drafts across all inboxes within a pod. View pending emails for a customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_drafts(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.drafts.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pod_drafts(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailCreatePodInboxBlock(Block):
|
||||
"""
|
||||
Create a new email inbox within a specific pod (customer workspace).
|
||||
|
||||
The inbox is automatically scoped to the pod and inherits its
|
||||
isolation guarantees. If username/domain are not provided,
|
||||
AgentMail auto-generates a unique address.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to create the inbox in")
|
||||
username: str = SchemaField(
|
||||
description="Local part of the email address (e.g. 'support'). Leave empty to auto-generate.",
|
||||
default="",
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Email domain (e.g. 'mydomain.com'). Defaults to agentmail.to if empty.",
|
||||
default="",
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field (e.g. 'Customer Support')",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(
|
||||
description="Unique identifier of the created inbox"
|
||||
)
|
||||
email_address: str = SchemaField(description="Full email address of the inbox")
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c6862373-1ac6-402e-89e6-7db1fea882af",
|
||||
description="Create a new email inbox within a pod. The inbox is scoped to the customer workspace.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("inbox_id", "mock-inbox-id"),
|
||||
("email_address", "mock-inbox-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_pod_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "mock-inbox-id",
|
||||
"model_dump": lambda self: {"inbox_id": "mock-inbox-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_pod_inbox(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.inboxes.create(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.username:
|
||||
params["username"] = input_data.username
|
||||
if input_data.domain:
|
||||
params["domain"] = input_data.domain
|
||||
if input_data.display_name:
|
||||
params["display_name"] = input_data.display_name
|
||||
|
||||
inbox = await self.create_pod_inbox(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
438
autogpt_platform/backend/backend/blocks/agent_mail/threads.py
Normal file
438
autogpt_platform/backend/backend/blocks/agent_mail/threads.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
AgentMail Thread blocks — list, get, and delete conversation threads.
|
||||
|
||||
A Thread groups related messages into a single conversation. Threads are
|
||||
created automatically when a new message is sent and grow as replies are added.
|
||||
Threads can be queried per-inbox or across the entire organization.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailListInboxThreadsBlock(Block):
|
||||
"""
|
||||
List all conversation threads within a specific AgentMail inbox.
|
||||
|
||||
Returns a paginated list of threads with optional label filtering.
|
||||
Use labels to find threads by campaign, status, or custom tags.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list threads from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels (e.g. ['q4-campaign', 'follow-up'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects with thread_id, subject, message count, labels, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="63dd9e2d-ef81-405c-b034-c031f0437334",
|
||||
description="List all conversation threads in an AgentMail inbox. Filter by labels for campaign tracking or status management.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_threads(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.list(inbox_id=inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_threads(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", (c if (c := response.count) is not None else len(threads))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetInboxThreadBlock(Block):
|
||||
"""
|
||||
Retrieve a single conversation thread from an AgentMail inbox.
|
||||
|
||||
Returns the thread with all its messages in chronological order.
|
||||
Use this to get the full conversation history for context when
|
||||
composing replies.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
thread_id: str = SchemaField(description="Unique identifier of the thread")
|
||||
messages: list[dict] = SchemaField(
|
||||
description="All messages in the thread, in chronological order"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete thread object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="42866290-1479-4153-83e7-550b703e9da2",
|
||||
description="Retrieve a conversation thread with all its messages. Use for getting full conversation context before replying.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[
|
||||
("thread_id", "test-thread"),
|
||||
("messages", []),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_thread": lambda *a, **kw: type(
|
||||
"Thread",
|
||||
(),
|
||||
{
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
"model_dump": lambda self: {
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_thread(credentials: APIKeyCredentials, inbox_id: str, thread_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.get(inbox_id=inbox_id, thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
thread = await self.get_thread(
|
||||
credentials, input_data.inbox_id, input_data.thread_id
|
||||
)
|
||||
messages = [m.model_dump() for m in thread.messages]
|
||||
result = thread.model_dump()
|
||||
result["messages"] = messages
|
||||
|
||||
yield "thread_id", thread.thread_id
|
||||
yield "messages", messages
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteInboxThreadBlock(Block):
|
||||
"""
|
||||
Permanently delete a conversation thread and all its messages from an inbox.
|
||||
|
||||
This removes the thread and every message within it. This action
|
||||
cannot be undone.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID to permanently delete")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the thread was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="18cd5f6f-4ff6-45da-8300-25a50ea7fb75",
|
||||
description="Permanently delete a conversation thread and all its messages. This action cannot be undone.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_thread": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_thread(
|
||||
credentials: APIKeyCredentials, inbox_id: str, thread_id: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.threads.delete(inbox_id=inbox_id, thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_thread(
|
||||
credentials, input_data.inbox_id, input_data.thread_id
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListOrgThreadsBlock(Block):
|
||||
"""
|
||||
List conversation threads across ALL inboxes in your organization.
|
||||
|
||||
Unlike per-inbox listing, this returns threads from every inbox.
|
||||
Ideal for building supervisor agents that monitor all conversations,
|
||||
analytics dashboards, or cross-agent routing workflows.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects from all inboxes in the organization"
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d7a0657b-58ab-48b2-898b-7bd94f44a708",
|
||||
description="List threads across ALL inboxes in your organization. Use for supervisor agents, dashboards, or cross-agent monitoring.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_org_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_org_threads(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.threads.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_org_threads(credentials, **params)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", (c if (c := response.count) is not None else len(threads))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetOrgThreadBlock(Block):
|
||||
"""
|
||||
Retrieve a single conversation thread by ID from anywhere in the organization.
|
||||
|
||||
Works without needing to know which inbox the thread belongs to.
|
||||
Returns the thread with all its messages in chronological order.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID to retrieve (works across all inboxes)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
thread_id: str = SchemaField(description="Unique identifier of the thread")
|
||||
messages: list[dict] = SchemaField(
|
||||
description="All messages in the thread, in chronological order"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete thread object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="39aaae31-3eb1-44c6-9e37-5a44a4529649",
|
||||
description="Retrieve a conversation thread by ID from anywhere in the organization, without needing the inbox ID.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[
|
||||
("thread_id", "test-thread"),
|
||||
("messages", []),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_org_thread": lambda *a, **kw: type(
|
||||
"Thread",
|
||||
(),
|
||||
{
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
"model_dump": lambda self: {
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_org_thread(credentials: APIKeyCredentials, thread_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.threads.get(thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
thread = await self.get_org_thread(credentials, input_data.thread_id)
|
||||
messages = [m.model_dump() for m in thread.messages]
|
||||
result = thread.model_dump()
|
||||
result["messages"] = messages
|
||||
|
||||
yield "thread_id", thread.thread_id
|
||||
yield "messages", messages
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -27,6 +27,7 @@ from backend.util.file import MediaFileType, store_media_file
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
NANO_BANANA_PRO = "google/nano-banana-pro"
|
||||
NANO_BANANA_2 = "google/nano-banana-2"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@@ -77,7 +78,7 @@ class AIImageCustomizerBlock(Block):
|
||||
)
|
||||
model: GeminiImageModel = SchemaField(
|
||||
description="The AI model to use for image generation and editing",
|
||||
default=GeminiImageModel.NANO_BANANA,
|
||||
default=GeminiImageModel.NANO_BANANA_2,
|
||||
title="Model",
|
||||
)
|
||||
images: list[MediaFileType] = SchemaField(
|
||||
@@ -103,7 +104,7 @@ class AIImageCustomizerBlock(Block):
|
||||
super().__init__(
|
||||
id="d76bbe4c-930e-4894-8469-b66775511f71",
|
||||
description=(
|
||||
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
|
||||
"Generate and edit custom images using Google's Nano-Banana models from Gemini. "
|
||||
"Provide a prompt and optional reference images to create or modify images."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
@@ -111,7 +112,7 @@ class AIImageCustomizerBlock(Block):
|
||||
output_schema=AIImageCustomizerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Make the scene more vibrant and colorful",
|
||||
"model": GeminiImageModel.NANO_BANANA,
|
||||
"model": GeminiImageModel.NANO_BANANA_2,
|
||||
"images": [],
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"output_format": OutputFormat.JPG,
|
||||
|
||||
@@ -115,6 +115,7 @@ class ImageGenModel(str, Enum):
|
||||
RECRAFT = "Recraft v3"
|
||||
SD3_5 = "Stable Diffusion 3.5 Medium"
|
||||
NANO_BANANA_PRO = "Nano Banana Pro"
|
||||
NANO_BANANA_2 = "Nano Banana 2"
|
||||
|
||||
|
||||
class AIImageGeneratorBlock(Block):
|
||||
@@ -131,7 +132,7 @@ class AIImageGeneratorBlock(Block):
|
||||
)
|
||||
model: ImageGenModel = SchemaField(
|
||||
description="The AI model to use for image generation",
|
||||
default=ImageGenModel.SD3_5,
|
||||
default=ImageGenModel.NANO_BANANA_2,
|
||||
title="Model",
|
||||
)
|
||||
size: ImageSize = SchemaField(
|
||||
@@ -165,7 +166,7 @@ class AIImageGeneratorBlock(Block):
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"prompt": "An octopus using a laptop in a snowy forest with 'AutoGPT' clearly visible on the screen",
|
||||
"model": ImageGenModel.RECRAFT,
|
||||
"model": ImageGenModel.NANO_BANANA_2,
|
||||
"size": ImageSize.SQUARE,
|
||||
"style": ImageStyle.REALISTIC,
|
||||
},
|
||||
@@ -179,7 +180,9 @@ class AIImageGeneratorBlock(Block):
|
||||
],
|
||||
test_mock={
|
||||
# Return a data URI directly so store_media_file doesn't need to download
|
||||
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||
"_run_client": lambda *args, **kwargs: (
|
||||
"data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -280,17 +283,24 @@ class AIImageGeneratorBlock(Block):
|
||||
)
|
||||
return output
|
||||
|
||||
elif input_data.model == ImageGenModel.NANO_BANANA_PRO:
|
||||
# Use Nano Banana Pro (Google Gemini 3 Pro Image)
|
||||
elif input_data.model in (
|
||||
ImageGenModel.NANO_BANANA_PRO,
|
||||
ImageGenModel.NANO_BANANA_2,
|
||||
):
|
||||
# Use Nano Banana models (Google Gemini image variants)
|
||||
model_map = {
|
||||
ImageGenModel.NANO_BANANA_PRO: "google/nano-banana-pro",
|
||||
ImageGenModel.NANO_BANANA_2: "google/nano-banana-2",
|
||||
}
|
||||
input_params = {
|
||||
"prompt": modified_prompt,
|
||||
"aspect_ratio": SIZE_TO_NANO_BANANA_RATIO[input_data.size],
|
||||
"resolution": "2K", # Default to 2K for good quality/cost balance
|
||||
"resolution": "2K",
|
||||
"output_format": "jpg",
|
||||
"safety_filter_level": "block_only_high", # Most permissive
|
||||
"safety_filter_level": "block_only_high",
|
||||
}
|
||||
output = await self._run_client(
|
||||
credentials, "google/nano-banana-pro", input_params
|
||||
credentials, model_map[input_data.model], input_params
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
520
autogpt_platform/backend/backend/blocks/autopilot.py
Normal file
520
autogpt_platform/backend/backend/blocks/autopilot.py
Normal file
@@ -0,0 +1,520 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.copilot.permissions import (
|
||||
CopilotPermissions,
|
||||
ToolName,
|
||||
all_known_tool_names,
|
||||
validate_block_identifiers,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Block ID shared between autopilot.py and copilot prompting.py.
|
||||
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
|
||||
|
||||
class ToolCallEntry(TypedDict):
|
||||
"""A single tool invocation record from an autopilot execution."""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
input: Any
|
||||
output: Any | None
|
||||
success: bool | None
|
||||
|
||||
|
||||
class TokenUsage(TypedDict):
|
||||
"""Aggregated token counts from the autopilot stream."""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class AutoPilotBlock(Block):
|
||||
"""Execute tasks using AutoGPT AutoPilot with full access to platform tools.
|
||||
|
||||
The autopilot can manage agents, access workspace files, fetch web content,
|
||||
run blocks, and more. This block enables sub-agent patterns (autopilot calling
|
||||
autopilot) and scheduled autopilot execution via the agent executor.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
"""Input schema for the AutoPilot block."""
|
||||
|
||||
prompt: str = SchemaField(
|
||||
description=(
|
||||
"The task or instruction for the autopilot to execute. "
|
||||
"The autopilot has access to platform tools like agent management, "
|
||||
"workspace files, web fetch, block execution, and more."
|
||||
),
|
||||
placeholder="Find my agents and list them",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
system_context: str = SchemaField(
|
||||
description=(
|
||||
"Optional additional context prepended to the prompt. "
|
||||
"Use this to constrain autopilot behavior, provide domain "
|
||||
"context, or set output format requirements."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID to continue an existing autopilot conversation. "
|
||||
"Leave empty to start a new session. "
|
||||
"Use the session_id output from a previous run to continue."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
max_recursion_depth: int = SchemaField(
|
||||
description=(
|
||||
"Maximum nesting depth when the autopilot calls this block "
|
||||
"recursively (sub-agent pattern). Prevents infinite loops."
|
||||
),
|
||||
default=3,
|
||||
ge=1,
|
||||
le=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools: list[ToolName] = SchemaField(
|
||||
description=(
|
||||
"Tool names to filter. Works with tools_exclude to form an "
|
||||
"allow-list or deny-list. "
|
||||
"Leave empty to apply no tool filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'tools' list is interpreted. "
|
||||
"True (default): 'tools' is a deny-list — listed tools are blocked, "
|
||||
"all others are allowed. An empty 'tools' list means allow everything. "
|
||||
"False: 'tools' is an allow-list — only listed tools are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks: list[str] = SchemaField(
|
||||
description=(
|
||||
"Block identifiers to filter when the copilot uses run_block. "
|
||||
"Each entry can be: a block name (e.g. 'HTTP Request'), "
|
||||
"a full block UUID, or the first 8 hex characters of the UUID "
|
||||
"(e.g. 'c069dc6b'). Works with blocks_exclude. "
|
||||
"Leave empty to apply no block filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'blocks' list is interpreted. "
|
||||
"True (default): 'blocks' is a deny-list — listed blocks are blocked, "
|
||||
"all others are allowed. An empty 'blocks' list means allow everything. "
|
||||
"False: 'blocks' is an allow-list — only listed blocks are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for the AutoPilot block."""
|
||||
|
||||
response: str = SchemaField(
|
||||
description="The final text response from the autopilot."
|
||||
)
|
||||
tool_calls: list[ToolCallEntry] = SchemaField(
|
||||
description=(
|
||||
"List of tools called during execution. Each entry has "
|
||||
"tool_call_id, tool_name, input, output, and success fields."
|
||||
),
|
||||
)
|
||||
conversation_history: str = SchemaField(
|
||||
description=(
|
||||
"Current turn messages (user prompt + assistant reply) as JSON. "
|
||||
"It can be used for logging or analysis."
|
||||
),
|
||||
)
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID for this conversation. "
|
||||
"Pass this back to continue the conversation in a future run."
|
||||
),
|
||||
)
|
||||
token_usage: TokenUsage = SchemaField(
|
||||
description=(
|
||||
"Token usage statistics: prompt_tokens, "
|
||||
"completion_tokens, total_tokens."
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id=AUTOPILOT_BLOCK_ID,
|
||||
description=(
|
||||
"Execute tasks using AutoGPT AutoPilot with full access to "
|
||||
"platform tools (agent management, workspace files, web fetch, "
|
||||
"block execution, and more). Enables sub-agent patterns and "
|
||||
"scheduled autopilot execution."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.AGENT},
|
||||
input_schema=AutoPilotBlock.Input,
|
||||
output_schema=AutoPilotBlock.Output,
|
||||
test_input={
|
||||
"prompt": "List my agents",
|
||||
"system_context": "",
|
||||
"session_id": "",
|
||||
"max_recursion_depth": 3,
|
||||
},
|
||||
test_output=[
|
||||
("response", "You have 2 agents: Agent A and Agent B."),
|
||||
("tool_calls", []),
|
||||
(
|
||||
"conversation_history",
|
||||
'[{"role": "user", "content": "List my agents"}]',
|
||||
),
|
||||
("session_id", "test-session-id"),
|
||||
(
|
||||
"token_usage",
|
||||
{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"create_session": lambda *args, **kwargs: "test-session-id",
|
||||
"execute_copilot": lambda *args, **kwargs: (
|
||||
"You have 2 agents: Agent A and Agent B.",
|
||||
[],
|
||||
'[{"role": "user", "content": "List my agents"}]',
|
||||
"test-session-id",
|
||||
{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def create_session(self, user_id: str) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session # avoid circular import
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
return session.session_id
|
||||
|
||||
async def execute_copilot(
|
||||
self,
|
||||
prompt: str,
|
||||
system_context: str,
|
||||
session_id: str,
|
||||
max_recursion_depth: int,
|
||||
user_id: str,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
||||
"""Invoke the copilot and collect all stream results.
|
||||
|
||||
Delegates to :func:`collect_copilot_response` — the shared helper that
|
||||
consumes ``stream_chat_completion_sdk`` without wrapping it in an
|
||||
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
|
||||
|
||||
Args:
|
||||
prompt: The user task/instruction.
|
||||
system_context: Optional context prepended to the prompt.
|
||||
session_id: Chat session to use.
|
||||
max_recursion_depth: Maximum allowed recursion nesting.
|
||||
user_id: Authenticated user ID.
|
||||
permissions: Optional capability filter restricting tools/blocks.
|
||||
|
||||
Returns:
|
||||
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
||||
"""
|
||||
from backend.copilot.sdk.collect import (
|
||||
collect_copilot_response, # avoid circular import
|
||||
)
|
||||
|
||||
tokens = _check_recursion(max_recursion_depth)
|
||||
perm_token = None
|
||||
try:
|
||||
effective_permissions, perm_token = _merge_inherited_permissions(
|
||||
permissions
|
||||
)
|
||||
effective_prompt = prompt
|
||||
if system_context:
|
||||
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
||||
|
||||
result = await collect_copilot_response(
|
||||
session_id=session_id,
|
||||
message=effective_prompt,
|
||||
user_id=user_id,
|
||||
permissions=effective_permissions,
|
||||
)
|
||||
|
||||
# Build a lightweight conversation summary from streamed data.
|
||||
turn_messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": effective_prompt},
|
||||
]
|
||||
if result.tool_calls:
|
||||
turn_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": result.response_text,
|
||||
"tool_calls": result.tool_calls,
|
||||
}
|
||||
)
|
||||
else:
|
||||
turn_messages.append(
|
||||
{"role": "assistant", "content": result.response_text}
|
||||
)
|
||||
history_json = json.dumps(turn_messages, default=str)
|
||||
|
||||
tool_calls: list[ToolCallEntry] = [
|
||||
{
|
||||
"tool_call_id": tc["tool_call_id"],
|
||||
"tool_name": tc["tool_name"],
|
||||
"input": tc["input"],
|
||||
"output": tc["output"],
|
||||
"success": tc["success"],
|
||||
}
|
||||
for tc in result.tool_calls
|
||||
]
|
||||
|
||||
usage: TokenUsage = {
|
||||
"prompt_tokens": result.prompt_tokens,
|
||||
"completion_tokens": result.completion_tokens,
|
||||
"total_tokens": result.total_tokens,
|
||||
}
|
||||
|
||||
return (
|
||||
result.response_text,
|
||||
tool_calls,
|
||||
history_json,
|
||||
session_id,
|
||||
usage,
|
||||
)
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
if perm_token is not None:
|
||||
_inherited_permissions.reset(perm_token)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Validate inputs, invoke the autopilot, and yield structured outputs.
|
||||
|
||||
Yields session_id even on failure so callers can inspect/resume the session.
|
||||
"""
|
||||
if not input_data.prompt.strip():
|
||||
yield "error", "Prompt cannot be empty."
|
||||
return
|
||||
|
||||
if not execution_context.user_id:
|
||||
yield "error", "Cannot run autopilot without an authenticated user."
|
||||
return
|
||||
|
||||
if input_data.max_recursion_depth < 1:
|
||||
yield "error", "max_recursion_depth must be at least 1."
|
||||
return
|
||||
|
||||
# Validate and build permissions eagerly — fail before creating a session.
|
||||
permissions = await _build_and_validate_permissions(input_data)
|
||||
if isinstance(permissions, str):
|
||||
# Validation error returned as a string message.
|
||||
yield "error", permissions
|
||||
return
|
||||
|
||||
# Create session eagerly so the user always gets the session_id,
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
if not sid:
|
||||
sid = await self.create_session(execution_context.user_id)
|
||||
|
||||
# NOTE: No asyncio.timeout() here — the SDK manages its own
|
||||
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
|
||||
# would cancel the task mid-flight, corrupting the SDK's internal
|
||||
# anyio memory stream (see service.py CRITICAL comment).
|
||||
try:
|
||||
response, tool_calls, history, _, usage = await self.execute_copilot(
|
||||
prompt=input_data.prompt,
|
||||
system_context=input_data.system_context,
|
||||
session_id=sid,
|
||||
max_recursion_depth=input_data.max_recursion_depth,
|
||||
user_id=execution_context.user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
yield "tool_calls", tool_calls
|
||||
yield "conversation_history", history
|
||||
yield "session_id", sid
|
||||
yield "token_usage", usage
|
||||
except asyncio.CancelledError:
|
||||
yield "session_id", sid
|
||||
yield "error", "AutoPilot execution was cancelled."
|
||||
raise
|
||||
except Exception as exc:
|
||||
yield "session_id", sid
|
||||
yield "error", str(exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers – placed after the block class for top-down readability.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Task-scoped recursion depth counter & chain-wide limit.
|
||||
# contextvars are scoped to the current asyncio task, so concurrent
|
||||
# graph executions each get independent counters.
|
||||
_autopilot_recursion_depth: contextvars.ContextVar[int] = contextvars.ContextVar(
|
||||
"_autopilot_recursion_depth", default=0
|
||||
)
|
||||
_autopilot_recursion_limit: contextvars.ContextVar[int | None] = contextvars.ContextVar(
|
||||
"_autopilot_recursion_limit", default=None
|
||||
)
|
||||
|
||||
|
||||
def _check_recursion(
|
||||
max_depth: int,
|
||||
) -> tuple[contextvars.Token[int], contextvars.Token[int | None]]:
|
||||
"""Check and increment recursion depth.
|
||||
|
||||
Returns ContextVar tokens that must be passed to ``_reset_recursion``
|
||||
when the caller exits to restore the previous depth.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the current depth already meets or exceeds the limit.
|
||||
"""
|
||||
current = _autopilot_recursion_depth.get()
|
||||
inherited = _autopilot_recursion_limit.get()
|
||||
limit = max_depth if inherited is None else min(inherited, max_depth)
|
||||
if current >= limit:
|
||||
raise RuntimeError(
|
||||
f"AutoPilot recursion depth limit reached ({limit}). "
|
||||
"The autopilot has called itself too many times."
|
||||
)
|
||||
return (
|
||||
_autopilot_recursion_depth.set(current + 1),
|
||||
_autopilot_recursion_limit.set(limit),
|
||||
)
|
||||
|
||||
|
||||
def _reset_recursion(
|
||||
tokens: tuple[contextvars.Token[int], contextvars.Token[int | None]],
|
||||
) -> None:
|
||||
"""Restore recursion depth and limit to their previous values."""
|
||||
_autopilot_recursion_depth.reset(tokens[0])
|
||||
_autopilot_recursion_limit.reset(tokens[1])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Permission helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Inherited permissions from a parent AutoPilotBlock execution.
|
||||
# This acts as a ceiling: child executions can only be more restrictive.
|
||||
_inherited_permissions: contextvars.ContextVar["CopilotPermissions | None"] = (
|
||||
contextvars.ContextVar("_inherited_permissions", default=None)
|
||||
)
|
||||
|
||||
|
||||
async def _build_and_validate_permissions(
|
||||
input_data: "AutoPilotBlock.Input",
|
||||
) -> "CopilotPermissions | str":
|
||||
"""Build a :class:`CopilotPermissions` from block input and validate it.
|
||||
|
||||
Returns a :class:`CopilotPermissions` on success or a human-readable
|
||||
error string if validation fails.
|
||||
"""
|
||||
# Tool names are validated by Pydantic via the ToolName Literal type
|
||||
# at model construction time — no runtime check needed here.
|
||||
# Validate block identifiers against live block registry.
|
||||
if input_data.blocks:
|
||||
invalid_blocks = await validate_block_identifiers(input_data.blocks)
|
||||
if invalid_blocks:
|
||||
return (
|
||||
f"Unknown block identifier(s) in 'blocks': {invalid_blocks}. "
|
||||
"Use find_block to discover valid block names and IDs. "
|
||||
"You may also use the first 8 characters of a block UUID."
|
||||
)
|
||||
|
||||
return CopilotPermissions(
|
||||
tools=list(input_data.tools),
|
||||
tools_exclude=input_data.tools_exclude,
|
||||
blocks=input_data.blocks,
|
||||
blocks_exclude=input_data.blocks_exclude,
|
||||
)
|
||||
|
||||
|
||||
def _merge_inherited_permissions(
|
||||
permissions: "CopilotPermissions | None",
|
||||
) -> "tuple[CopilotPermissions | None, contextvars.Token[CopilotPermissions | None] | None]":
|
||||
"""Merge *permissions* with any inherited parent permissions.
|
||||
|
||||
The merged result is stored back into the contextvar so that any nested
|
||||
AutoPilotBlock invocation (sub-agent) inherits the merged ceiling.
|
||||
|
||||
Returns a tuple of (merged_permissions, reset_token). The caller MUST
|
||||
reset the contextvar via ``_inherited_permissions.reset(token)`` in a
|
||||
``finally`` block when ``reset_token`` is not None — this prevents
|
||||
permission leakage between sequential independent executions in the same
|
||||
asyncio task.
|
||||
"""
|
||||
parent = _inherited_permissions.get()
|
||||
|
||||
if permissions is None and parent is None:
|
||||
return None, None
|
||||
|
||||
all_tools = all_known_tool_names()
|
||||
|
||||
if permissions is None:
|
||||
permissions = CopilotPermissions() # allow-all; will be narrowed by parent
|
||||
|
||||
merged = (
|
||||
permissions.merged_with_parent(parent, all_tools)
|
||||
if parent is not None
|
||||
else permissions
|
||||
)
|
||||
|
||||
# Store merged permissions as the new inherited ceiling for nested calls.
|
||||
# Return the token so the caller can restore the previous value in finally.
|
||||
token = _inherited_permissions.set(merged)
|
||||
return merged, token
|
||||
@@ -0,0 +1,265 @@
|
||||
"""Tests for AutoPilotBlock permission fields and validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AutoPilotBlock,
|
||||
_build_and_validate_permissions,
|
||||
_inherited_permissions,
|
||||
_merge_inherited_permissions,
|
||||
)
|
||||
from backend.copilot.permissions import CopilotPermissions, all_known_tool_names
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input(**kwargs) -> AutoPilotBlock.Input:
|
||||
defaults = {
|
||||
"prompt": "Do something",
|
||||
"system_context": "",
|
||||
"session_id": "",
|
||||
"max_recursion_depth": 3,
|
||||
"tools": [],
|
||||
"tools_exclude": True,
|
||||
"blocks": [],
|
||||
"blocks_exclude": True,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return AutoPilotBlock.Input(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_and_validate_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestBuildAndValidatePermissions:
|
||||
async def test_empty_inputs_returns_empty_permissions(self):
|
||||
inp = _make_input()
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.is_empty()
|
||||
|
||||
async def test_valid_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["run_block", "web_fetch"], tools_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.tools == ["run_block", "web_fetch"]
|
||||
assert result.tools_exclude is True
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are now caught at Pydantic validation time
|
||||
(Literal type), before ``_build_and_validate_permissions`` is called."""
|
||||
with pytest.raises(ValidationError, match="not_a_real_tool"):
|
||||
_make_input(tools=["not_a_real_tool"])
|
||||
|
||||
async def test_valid_block_name_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.blocks == ["HTTP Request"]
|
||||
|
||||
async def test_valid_partial_uuid_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["c069dc6b"], blocks_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
|
||||
async def test_invalid_block_identifier_returns_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["totally_fake_block"])
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, str)
|
||||
assert "totally_fake_block" in result
|
||||
assert "Unknown block identifier" in result
|
||||
|
||||
async def test_sdk_builtin_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["Read", "Task", "WebSearch"], tools_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert not result.tools_exclude
|
||||
|
||||
async def test_empty_blocks_skips_validation(self):
|
||||
# Should not call validate_block_identifiers at all when blocks=[].
|
||||
with patch(
|
||||
"backend.copilot.permissions.validate_block_identifiers"
|
||||
) as mock_validate:
|
||||
inp = _make_input(blocks=[])
|
||||
await _build_and_validate_permissions(inp)
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _merge_inherited_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeInheritedPermissions:
|
||||
def test_no_permissions_no_parent_returns_none(self):
|
||||
merged, token = _merge_inherited_permissions(None)
|
||||
assert merged is None
|
||||
assert token is None
|
||||
|
||||
def test_permissions_no_parent_returned_unchanged(self):
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
merged, token = _merge_inherited_permissions(perms)
|
||||
try:
|
||||
assert merged is perms
|
||||
assert token is not None
|
||||
finally:
|
||||
if token is not None:
|
||||
_inherited_permissions.reset(token)
|
||||
|
||||
def test_child_narrows_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
# Set parent as inherited
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
child = CopilotPermissions(tools=["web_fetch"], tools_exclude=True)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
all_t = all_known_tool_names()
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert "bash_exec" not in effective
|
||||
assert "web_fetch" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_none_permissions_with_parent_uses_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
merged, inner_token = _merge_inherited_permissions(None)
|
||||
try:
|
||||
assert merged is not None
|
||||
# Merged should have parent's restrictions
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_child_cannot_expand_parent_whitelist(self):
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
# Child tries to allow more tools
|
||||
child = CopilotPermissions(
|
||||
tools=["run_block", "bash_exec"], tools_exclude=False
|
||||
)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
assert "run_block" in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AutoPilotBlock.run — validation integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAutoPilotBlockRunPermissions:
|
||||
async def _collect_outputs(self, block, input_data, user_id="test-user"):
|
||||
"""Helper to collect all yields from block.run()."""
|
||||
ctx = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id="g1",
|
||||
graph_exec_id="ge1",
|
||||
node_exec_id="ne1",
|
||||
node_id="n1",
|
||||
)
|
||||
outputs = {}
|
||||
async for key, val in block.run(input_data, execution_context=ctx):
|
||||
outputs[key] = val
|
||||
return outputs
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are caught at Pydantic validation (Literal type)."""
|
||||
with pytest.raises(ValidationError, match="not_a_tool"):
|
||||
_make_input(tools=["not_a_tool"])
|
||||
|
||||
async def test_invalid_block_yields_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(blocks=["nonexistent_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "nonexistent_block" in outputs["error"]
|
||||
|
||||
async def test_empty_prompt_yields_error_before_permission_check(self):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(prompt=" ", tools=["run_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "Prompt cannot be empty" in outputs["error"]
|
||||
|
||||
async def test_valid_permissions_passed_to_execute(self):
|
||||
"""Permissions are forwarded to execute_copilot when valid."""
|
||||
block = AutoPilotBlock()
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_execute_copilot(self_inner, **kwargs):
|
||||
captured["permissions"] = kwargs.get("permissions")
|
||||
return (
|
||||
"ok",
|
||||
[],
|
||||
'[{"role":"user","content":"hi"}]',
|
||||
"test-sid",
|
||||
{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
AutoPilotBlock, "create_session", new=AsyncMock(return_value="test-sid")
|
||||
), patch.object(AutoPilotBlock, "execute_copilot", new=fake_execute_copilot):
|
||||
inp = _make_input(tools=["run_block"], tools_exclude=False)
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
|
||||
assert "error" not in outputs
|
||||
perms = captured.get("permissions")
|
||||
assert isinstance(perms, CopilotPermissions)
|
||||
assert perms.tools == ["run_block"]
|
||||
assert perms.tools_exclude is False
|
||||
@@ -472,7 +472,7 @@ class AddToListBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
entries_added = input_data.entries.copy()
|
||||
if input_data.entry:
|
||||
if input_data.entry is not None:
|
||||
entries_added.append(input_data.entry)
|
||||
|
||||
updated_list = input_data.list.copy()
|
||||
|
||||
@@ -21,6 +21,7 @@ from backend.data.model import (
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import resolve_and_check_blocked
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -99,6 +100,8 @@ class SendEmailBlock(Block):
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
ALLOWED_SMTP_PORTS = {25, 465, 587, 2525}
|
||||
|
||||
@staticmethod
|
||||
def send_email(
|
||||
config: SMTPConfig,
|
||||
@@ -129,6 +132,17 @@ class SendEmailBlock(Block):
|
||||
self, input_data: Input, *, credentials: SMTPCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# --- SSRF Protection ---
|
||||
smtp_port = input_data.config.smtp_port
|
||||
if smtp_port not in self.ALLOWED_SMTP_PORTS:
|
||||
yield "error", (
|
||||
f"SMTP port {smtp_port} is not allowed. "
|
||||
f"Allowed ports: {sorted(self.ALLOWED_SMTP_PORTS)}"
|
||||
)
|
||||
return
|
||||
|
||||
await resolve_and_check_blocked(input_data.config.smtp_server)
|
||||
|
||||
status = self.send_email(
|
||||
config=input_data.config,
|
||||
to_email=input_data.to_email,
|
||||
@@ -180,7 +194,19 @@ class SendEmailBlock(Block):
|
||||
"was rejected by the server. "
|
||||
"Please verify your account is authorized to send emails."
|
||||
)
|
||||
except smtplib.SMTPConnectError:
|
||||
yield "error", (
|
||||
f"Cannot connect to SMTP server '{input_data.config.smtp_server}' "
|
||||
f"on port {input_data.config.smtp_port}."
|
||||
)
|
||||
except smtplib.SMTPServerDisconnected:
|
||||
yield "error", (
|
||||
f"SMTP server '{input_data.config.smtp_server}' "
|
||||
"disconnected unexpectedly."
|
||||
)
|
||||
except smtplib.SMTPDataError as e:
|
||||
yield "error", f"Email data rejected by server: {str(e)}"
|
||||
except ValueError as e:
|
||||
yield "error", str(e)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -34,17 +34,29 @@ TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class FluxKontextModelName(str, Enum):
|
||||
PRO = "Flux Kontext Pro"
|
||||
MAX = "Flux Kontext Max"
|
||||
class ImageEditorModel(str, Enum):
|
||||
FLUX_KONTEXT_PRO = "Flux Kontext Pro"
|
||||
FLUX_KONTEXT_MAX = "Flux Kontext Max"
|
||||
NANO_BANANA_PRO = "Nano Banana Pro"
|
||||
NANO_BANANA_2 = "Nano Banana 2"
|
||||
|
||||
@property
|
||||
def api_name(self) -> str:
|
||||
return f"black-forest-labs/flux-kontext-{self.name.lower()}"
|
||||
_map = {
|
||||
"FLUX_KONTEXT_PRO": "black-forest-labs/flux-kontext-pro",
|
||||
"FLUX_KONTEXT_MAX": "black-forest-labs/flux-kontext-max",
|
||||
"NANO_BANANA_PRO": "google/nano-banana-pro",
|
||||
"NANO_BANANA_2": "google/nano-banana-2",
|
||||
}
|
||||
return _map[self.name]
|
||||
|
||||
|
||||
# Keep old name as alias for backwards compatibility
|
||||
FluxKontextModelName = ImageEditorModel
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@@ -69,7 +81,7 @@ class AIImageEditorBlock(Block):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Replicate API key with permissions for Flux Kontext models",
|
||||
description="Replicate API key with permissions for Flux Kontext and Nano Banana models",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text instruction describing the desired edit",
|
||||
@@ -87,14 +99,14 @@ class AIImageEditorBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
seed: Optional[int] = SchemaField(
|
||||
description="Random seed. Set for reproducible generation",
|
||||
description="Random seed. Set for reproducible generation (Flux Kontext only; ignored by Nano Banana models)",
|
||||
default=None,
|
||||
title="Seed",
|
||||
advanced=True,
|
||||
)
|
||||
model: FluxKontextModelName = SchemaField(
|
||||
model: ImageEditorModel = SchemaField(
|
||||
description="Model variant to use",
|
||||
default=FluxKontextModelName.PRO,
|
||||
default=ImageEditorModel.NANO_BANANA_2,
|
||||
title="Model",
|
||||
)
|
||||
|
||||
@@ -107,7 +119,7 @@ class AIImageEditorBlock(Block):
|
||||
super().__init__(
|
||||
id="3fd9c73d-4370-4925-a1ff-1b86b99fabfa",
|
||||
description=(
|
||||
"Edit images using BlackForest Labs' Flux Kontext models. Provide a prompt "
|
||||
"Edit images using Flux Kontext or Google Nano Banana models. Provide a prompt "
|
||||
"and optional reference image to generate a modified image."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
@@ -118,7 +130,7 @@ class AIImageEditorBlock(Block):
|
||||
"input_image": "data:image/png;base64,MQ==",
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"seed": None,
|
||||
"model": FluxKontextModelName.PRO,
|
||||
"model": ImageEditorModel.NANO_BANANA_2,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
@@ -127,7 +139,9 @@ class AIImageEditorBlock(Block):
|
||||
],
|
||||
test_mock={
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
|
||||
"run_model": lambda *args, **kwargs: (
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
@@ -142,7 +156,7 @@ class AIImageEditorBlock(Block):
|
||||
) -> BlockOutput:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.api_name,
|
||||
model=input_data.model,
|
||||
prompt=input_data.prompt,
|
||||
input_image_b64=(
|
||||
await store_media_file(
|
||||
@@ -169,7 +183,7 @@ class AIImageEditorBlock(Block):
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
model: ImageEditorModel,
|
||||
prompt: str,
|
||||
input_image_b64: Optional[str],
|
||||
aspect_ratio: str,
|
||||
@@ -178,12 +192,29 @@ class AIImageEditorBlock(Block):
|
||||
graph_exec_id: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image_b64,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
model_name = model.api_name
|
||||
|
||||
is_nano_banana = model in (
|
||||
ImageEditorModel.NANO_BANANA_PRO,
|
||||
ImageEditorModel.NANO_BANANA_2,
|
||||
)
|
||||
if is_nano_banana:
|
||||
input_params: dict = {
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"output_format": "jpg",
|
||||
"safety_filter_level": "block_only_high",
|
||||
}
|
||||
# NB API expects "image_input" as a list, unlike Flux's single "input_image"
|
||||
if input_image_b64:
|
||||
input_params["image_input"] = [input_image_b64]
|
||||
else:
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image_b64,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
try:
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
|
||||
@@ -211,7 +211,7 @@ class AgentOutputBlock(Block):
|
||||
if input_data.format:
|
||||
try:
|
||||
formatter = TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(
|
||||
yield "output", await formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -33,6 +33,13 @@ from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.openai_responses import (
|
||||
convert_tools_to_responses_format,
|
||||
extract_responses_content,
|
||||
extract_responses_reasoning,
|
||||
extract_responses_tool_calls,
|
||||
extract_responses_usage,
|
||||
)
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.request import validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
@@ -42,6 +49,9 @@ settings = Settings()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
# HTTP status codes for user-caused errors that should not be reported to Sentry.
|
||||
USER_ERROR_STATUS_CODES = (401, 403, 429)
|
||||
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.AIML_API,
|
||||
ProviderName.ANTHROPIC,
|
||||
@@ -111,7 +121,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
@@ -277,9 +286,6 @@ MODEL_METADATA = {
|
||||
LlmModel.GPT4_TURBO: ModelMetadata(
|
||||
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata(
|
||||
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
|
||||
), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
|
||||
@@ -793,6 +799,19 @@ async def llm_call(
|
||||
)
|
||||
prompt = result.messages
|
||||
|
||||
# Sanitize unpaired surrogates in message content to prevent
|
||||
# UnicodeEncodeError when httpx encodes the JSON request body.
|
||||
for msg in prompt:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content.encode("utf-8")
|
||||
except UnicodeEncodeError:
|
||||
logger.warning("Sanitized unpaired surrogates in LLM prompt content")
|
||||
msg["content"] = content.encode("utf-8", errors="surrogatepass").decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
@@ -801,36 +820,53 @@ async def llm_call(
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||
|
||||
if provider == "openai":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = None
|
||||
|
||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
tools_param = convert_tools_to_responses_format(tools) if tools else openai.omit
|
||||
|
||||
text_config = openai.omit
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
text_config = {"format": {"type": "json_object"}} # type: ignore
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
response = await oai_client.responses.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
input=prompt, # type: ignore[arg-type]
|
||||
tools=tools_param, # type: ignore[arg-type]
|
||||
max_output_tokens=max_tokens,
|
||||
parallel_tool_calls=get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
),
|
||||
text=text_config, # type: ignore[arg-type]
|
||||
store=False,
|
||||
)
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
raw_tool_calls = extract_responses_tool_calls(response)
|
||||
tool_calls = (
|
||||
[
|
||||
ToolContentBlock(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
function=ToolCall(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
)
|
||||
for tc in raw_tool_calls
|
||||
]
|
||||
if raw_tool_calls
|
||||
else None
|
||||
)
|
||||
reasoning = extract_responses_reasoning(response)
|
||||
content = extract_responses_content(response)
|
||||
prompt_tokens, completion_tokens = extract_responses_usage(response)
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
raw_response=response,
|
||||
prompt=prompt,
|
||||
response=response.choices[0].message.content or "",
|
||||
response=content,
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
@@ -858,65 +894,60 @@ async def llm_call(
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=credentials.api_key.get_secret_value()
|
||||
)
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "groq":
|
||||
if tools:
|
||||
raise ValueError("Groq does not support tools.")
|
||||
@@ -1276,8 +1307,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = fmt.format_string(input_data.sys_prompt, values)
|
||||
input_data.prompt = await fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = await fmt.format_string(
|
||||
input_data.sys_prompt, values
|
||||
)
|
||||
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
@@ -1427,7 +1460,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
is_user_error = (
|
||||
isinstance(e, (anthropic.APIStatusError, openai.APIStatusError))
|
||||
and e.status_code in USER_ERROR_STATUS_CODES
|
||||
)
|
||||
if is_user_error:
|
||||
logger.warning(f"Error calling LLM: {e}")
|
||||
error_feedback_message = f"Error calling LLM: {e}"
|
||||
break
|
||||
else:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
"maximum context length" in str(e).lower()
|
||||
or "token limit" in str(e).lower()
|
||||
|
||||
@@ -61,20 +61,27 @@ class ExecutionParams(BaseModel):
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool request.
|
||||
Supports both OpenAI and Anthropics formats.
|
||||
Supports OpenAI Chat Completions, Responses API, and Anthropic formats.
|
||||
"""
|
||||
tool_call_ids = []
|
||||
|
||||
# OpenAI Responses API: function_call items have type="function_call"
|
||||
if entry.get("type") == "function_call":
|
||||
if call_id := entry.get("call_id"):
|
||||
tool_call_ids.append(call_id)
|
||||
return tool_call_ids
|
||||
|
||||
if entry.get("role") != "assistant":
|
||||
return tool_call_ids
|
||||
|
||||
# OpenAI: check for tool_calls in the entry.
|
||||
# OpenAI Chat Completions: check for tool_calls in the entry.
|
||||
calls = entry.get("tool_calls")
|
||||
if isinstance(calls, list):
|
||||
for call in calls:
|
||||
if tool_id := call.get("id"):
|
||||
tool_call_ids.append(tool_id)
|
||||
|
||||
# Anthropics: check content items for tool_use type.
|
||||
# Anthropic: check content items for tool_use type.
|
||||
content = entry.get("content")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
@@ -89,16 +96,22 @@ def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool response.
|
||||
Supports both OpenAI and Anthropics formats.
|
||||
Supports OpenAI Chat Completions, Responses API, and Anthropic formats.
|
||||
"""
|
||||
tool_call_ids: list[str] = []
|
||||
|
||||
# OpenAI: a tool response message with role "tool" and key "tool_call_id".
|
||||
# OpenAI Responses API: function_call_output items
|
||||
if entry.get("type") == "function_call_output":
|
||||
if call_id := entry.get("call_id"):
|
||||
tool_call_ids.append(str(call_id))
|
||||
return tool_call_ids
|
||||
|
||||
# OpenAI Chat Completions: a tool response message with role "tool".
|
||||
if entry.get("role") == "tool":
|
||||
if tool_call_id := entry.get("tool_call_id"):
|
||||
tool_call_ids.append(str(tool_call_id))
|
||||
|
||||
# Anthropics: check content items for tool_result type.
|
||||
# Anthropic: check content items for tool_result type.
|
||||
if entry.get("role") == "user":
|
||||
content = entry.get("content")
|
||||
if isinstance(content, list):
|
||||
@@ -111,14 +124,16 @@ def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
|
||||
return tool_call_ids
|
||||
|
||||
|
||||
def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
def _create_tool_response(
|
||||
call_id: str, output: Any, *, responses_api: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a tool response message for either OpenAI or Anthropics,
|
||||
based on the tool_id format.
|
||||
Create a tool response message for OpenAI, Anthropic, or OpenAI Responses API,
|
||||
based on the tool_id format and the responses_api flag.
|
||||
"""
|
||||
content = output if isinstance(output, str) else json.dumps(output)
|
||||
|
||||
# Anthropics format: tool IDs typically start with "toolu_"
|
||||
# Anthropic format: tool IDs typically start with "toolu_"
|
||||
if call_id.startswith("toolu_"):
|
||||
return {
|
||||
"role": "user",
|
||||
@@ -128,8 +143,11 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
],
|
||||
}
|
||||
|
||||
# OpenAI format: tool IDs typically start with "call_".
|
||||
# Or default fallback (if the tool_id doesn't match any known prefix)
|
||||
# OpenAI Responses API format
|
||||
if responses_api:
|
||||
return {"type": "function_call_output", "call_id": call_id, "output": content}
|
||||
|
||||
# OpenAI Chat Completions format (default fallback)
|
||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
|
||||
|
||||
@@ -177,10 +195,19 @@ def _combine_tool_responses(tool_outputs: list[dict[str, Any]]) -> list[dict[str
|
||||
return tool_outputs
|
||||
|
||||
|
||||
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
||||
def _convert_raw_response_to_dict(
|
||||
raw_response: Any,
|
||||
) -> dict[str, Any] | list[dict[str, Any]]:
|
||||
"""
|
||||
Safely convert raw_response to dictionary format for conversation history.
|
||||
Handles different response types from different LLM providers.
|
||||
|
||||
For the OpenAI Responses API, the raw_response is the entire Response
|
||||
object. Its ``output`` items (messages, function_calls) are extracted
|
||||
individually so they can be used as valid input items on the next call.
|
||||
Returns a **list** of dicts in that case.
|
||||
|
||||
For Chat Completions / Anthropic / Ollama, returns a single dict.
|
||||
"""
|
||||
if isinstance(raw_response, str):
|
||||
# Ollama returns a string, convert to dict format
|
||||
@@ -188,11 +215,28 @@ def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
||||
elif isinstance(raw_response, dict):
|
||||
# Already a dict (from tests or some providers)
|
||||
return raw_response
|
||||
elif _is_responses_api_object(raw_response):
|
||||
# OpenAI Responses API: extract individual output items
|
||||
items = [json.to_dict(item) for item in raw_response.output]
|
||||
return items if items else [{"role": "assistant", "content": ""}]
|
||||
else:
|
||||
# OpenAI/Anthropic return objects, convert with json.to_dict
|
||||
# Chat Completions / Anthropic return message objects
|
||||
return json.to_dict(raw_response)
|
||||
|
||||
|
||||
def _is_responses_api_object(obj: Any) -> bool:
|
||||
"""Detect an OpenAI Responses API Response object.
|
||||
|
||||
These have ``object == "response"`` and an ``output`` list, but no
|
||||
``role`` attribute (unlike ChatCompletionMessage).
|
||||
"""
|
||||
return (
|
||||
getattr(obj, "object", None) == "response"
|
||||
and hasattr(obj, "output")
|
||||
and not hasattr(obj, "role")
|
||||
)
|
||||
|
||||
|
||||
def get_pending_tool_calls(conversation_history: list[Any] | None) -> dict[str, int]:
|
||||
"""
|
||||
All the tool calls entry in the conversation history requires a response.
|
||||
@@ -214,9 +258,10 @@ def get_pending_tool_calls(conversation_history: list[Any] | None) -> dict[str,
|
||||
return {call_id: count for call_id, count in pending_calls.items() if count > 0}
|
||||
|
||||
|
||||
class SmartDecisionMakerBlock(Block):
|
||||
class OrchestratorBlock(Block):
|
||||
"""
|
||||
A block that uses a language model to make smart decisions based on a given prompt.
|
||||
A block that uses a language model to orchestrate tool calls, supporting both
|
||||
single-shot and iterative agent mode execution.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -357,8 +402,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
description="Uses AI to intelligently decide what tool to use.",
|
||||
categories={BlockCategory.AI},
|
||||
block_type=BlockType.AI,
|
||||
input_schema=SmartDecisionMakerBlock.Input,
|
||||
output_schema=SmartDecisionMakerBlock.Output,
|
||||
input_schema=OrchestratorBlock.Input,
|
||||
output_schema=OrchestratorBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
@@ -396,7 +441,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_name = custom_name if custom_name else block.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"name": OrchestratorBlock.cleanup(tool_name),
|
||||
"description": block.description,
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
@@ -407,7 +452,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Clean property key to ensure Anthropic API compatibility for ALL fields
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
clean_field_name = OrchestratorBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
|
||||
if is_dynamic:
|
||||
@@ -441,7 +486,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Always use cleaned field name for property key (Anthropic API compliance)
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
clean_field_name = OrchestratorBlock.cleanup(field_name)
|
||||
|
||||
if is_dynamic:
|
||||
base_name = extract_base_field_name(field_name)
|
||||
@@ -498,7 +543,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_name = custom_name if custom_name else sink_graph_meta.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"name": OrchestratorBlock.cleanup(tool_name),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
@@ -508,7 +553,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
for link in links:
|
||||
field_name = link.sink_name
|
||||
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
clean_field_name = OrchestratorBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
|
||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||
@@ -574,17 +619,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
raise ValueError(f"Sink node not found: {links[0].sink_id}")
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
tool_func = (
|
||||
await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
tool_func = await OrchestratorBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
else:
|
||||
tool_func = (
|
||||
await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
tool_func = await OrchestratorBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
|
||||
@@ -754,19 +795,34 @@ class SmartDecisionMakerBlock(Block):
|
||||
self, prompt: list[dict], response, tool_outputs: list | None = None
|
||||
):
|
||||
"""Update conversation history with response and tool outputs."""
|
||||
# Don't add separate reasoning message with tool calls (breaks Anthropic's tool_use->tool_result pairing)
|
||||
assistant_message = _convert_raw_response_to_dict(response.raw_response)
|
||||
has_tool_calls = isinstance(assistant_message.get("content"), list) and any(
|
||||
item.get("type") == "tool_use"
|
||||
for item in assistant_message.get("content", [])
|
||||
)
|
||||
converted = _convert_raw_response_to_dict(response.raw_response)
|
||||
|
||||
if response.reasoning and not has_tool_calls:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
if isinstance(converted, list):
|
||||
# Responses API: output items are already individual dicts
|
||||
has_tool_calls = any(
|
||||
item.get("type") == "function_call" for item in converted
|
||||
)
|
||||
|
||||
prompt.append(assistant_message)
|
||||
if response.reasoning and not has_tool_calls:
|
||||
prompt.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"[Reasoning]: {response.reasoning}",
|
||||
}
|
||||
)
|
||||
prompt.extend(converted)
|
||||
else:
|
||||
# Chat Completions / Anthropic: single assistant message dict
|
||||
has_tool_calls = isinstance(converted.get("content"), list) and any(
|
||||
item.get("type") == "tool_use" for item in converted.get("content", [])
|
||||
)
|
||||
if response.reasoning and not has_tool_calls:
|
||||
prompt.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"[Reasoning]: {response.reasoning}",
|
||||
}
|
||||
)
|
||||
prompt.append(converted)
|
||||
|
||||
if tool_outputs:
|
||||
prompt.extend(tool_outputs)
|
||||
@@ -776,6 +832,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_info: ToolInfo,
|
||||
execution_params: ExecutionParams,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
*,
|
||||
responses_api: bool = False,
|
||||
) -> dict:
|
||||
"""Execute a single tool using the execution manager for proper integration."""
|
||||
# Lazy imports to avoid circular dependencies
|
||||
@@ -847,7 +905,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
task=node_exec_future,
|
||||
)
|
||||
|
||||
# Execute the node directly since we're in the SmartDecisionMaker context
|
||||
# Execute the node directly since we're in the Orchestrator context
|
||||
node_exec_future.set_result(
|
||||
await execution_processor.on_node_execution(
|
||||
node_exec=node_exec_entry,
|
||||
@@ -868,13 +926,17 @@ class SmartDecisionMakerBlock(Block):
|
||||
if node_outputs
|
||||
else "Tool executed successfully"
|
||||
)
|
||||
return _create_tool_response(tool_call.id, tool_response_content)
|
||||
return _create_tool_response(
|
||||
tool_call.id, tool_response_content, responses_api=responses_api
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution with manager failed: {e}")
|
||||
logger.warning(f"Tool execution with manager failed: {e}")
|
||||
# Return error response
|
||||
return _create_tool_response(
|
||||
tool_call.id, f"Tool execution failed: {str(e)}"
|
||||
tool_call.id,
|
||||
f"Tool execution failed: {str(e)}",
|
||||
responses_api=responses_api,
|
||||
)
|
||||
|
||||
async def _execute_tools_agent_mode(
|
||||
@@ -895,6 +957,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
"""Execute tools in agent mode with a loop until finished."""
|
||||
max_iterations = input_data.agent_mode_max_iterations
|
||||
iteration = 0
|
||||
use_responses_api = input_data.model.metadata.provider == "openai"
|
||||
|
||||
# Execution parameters for tool execution
|
||||
execution_params = ExecutionParams(
|
||||
@@ -951,14 +1014,19 @@ class SmartDecisionMakerBlock(Block):
|
||||
for tool_info in processed_tools:
|
||||
try:
|
||||
tool_response = await self._execute_single_tool_with_manager(
|
||||
tool_info, execution_params, execution_processor
|
||||
tool_info,
|
||||
execution_params,
|
||||
execution_processor,
|
||||
responses_api=use_responses_api,
|
||||
)
|
||||
tool_outputs.append(tool_response)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution failed: {e}")
|
||||
# Create error response for the tool
|
||||
error_response = _create_tool_response(
|
||||
tool_info.tool_call.id, f"Error: {str(e)}"
|
||||
tool_info.tool_call.id,
|
||||
f"Error: {str(e)}",
|
||||
responses_api=use_responses_api,
|
||||
)
|
||||
tool_outputs.append(error_response)
|
||||
|
||||
@@ -1020,11 +1088,17 @@ class SmartDecisionMakerBlock(Block):
|
||||
if pending_tool_calls and input_data.last_tool_output is None:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
use_responses_api = input_data.model.metadata.provider == "openai"
|
||||
|
||||
tool_output = []
|
||||
if pending_tool_calls and input_data.last_tool_output is not None:
|
||||
first_call_id = next(iter(pending_tool_calls.keys()))
|
||||
tool_output.append(
|
||||
_create_tool_response(first_call_id, input_data.last_tool_output)
|
||||
_create_tool_response(
|
||||
first_call_id,
|
||||
input_data.last_tool_output,
|
||||
responses_api=use_responses_api,
|
||||
)
|
||||
)
|
||||
|
||||
prompt.extend(tool_output)
|
||||
@@ -1035,7 +1109,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
return
|
||||
elif input_data.last_tool_output:
|
||||
logger.error(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
f"[OrchestratorBlock-node_exec_id={node_exec_id}] "
|
||||
f"No pending tool calls found. This may indicate an issue with the "
|
||||
f"conversation history, or the tool giving response more than once."
|
||||
f"This should not happen! Please check the conversation history for any inconsistencies."
|
||||
@@ -1050,11 +1124,15 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
|
||||
input_data.prompt = await llm.fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = await llm.fmt.format_string(
|
||||
input_data.sys_prompt, values
|
||||
)
|
||||
|
||||
if input_data.sys_prompt and not any(
|
||||
p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
p.get("role") == "system"
|
||||
and isinstance(p.get("content"), str)
|
||||
and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
for p in prompt
|
||||
):
|
||||
prompt.append(
|
||||
@@ -1065,7 +1143,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
|
||||
if input_data.prompt and not any(
|
||||
p["role"] == "user" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
p.get("role") == "user"
|
||||
and isinstance(p.get("content"), str)
|
||||
and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
for p in prompt
|
||||
):
|
||||
prompt.append(
|
||||
@@ -1166,18 +1246,33 @@ class SmartDecisionMakerBlock(Block):
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{original_field_name}"
|
||||
|
||||
logger.debug(
|
||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||
"[OrchestratorBlock|geid:%s|neid:%s] emit %s",
|
||||
graph_exec_id,
|
||||
node_exec_id,
|
||||
emit_key,
|
||||
)
|
||||
yield emit_key, arg_value
|
||||
|
||||
if response.reasoning:
|
||||
converted = _convert_raw_response_to_dict(response.raw_response)
|
||||
|
||||
# Check for tool calls to avoid inserting reasoning between tool pairs
|
||||
if isinstance(converted, list):
|
||||
has_tool_calls = any(
|
||||
item.get("type") == "function_call" for item in converted
|
||||
)
|
||||
else:
|
||||
has_tool_calls = isinstance(converted.get("content"), list) and any(
|
||||
item.get("type") == "tool_use" for item in converted.get("content", [])
|
||||
)
|
||||
|
||||
if response.reasoning and not has_tool_calls:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(_convert_raw_response_to_dict(response.raw_response))
|
||||
if isinstance(converted, list):
|
||||
prompt.extend(converted)
|
||||
else:
|
||||
prompt.append(converted)
|
||||
|
||||
yield "conversations", prompt
|
||||
@@ -1,13 +1,8 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
# Monkey patch Stagehands to prevent signal handling in worker threads
|
||||
import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
from stagehand import AsyncStagehand
|
||||
from stagehand.types.session_act_params import Options as ActOptions
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
@@ -28,46 +23,6 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
|
||||
warnings.filterwarnings("ignore", module="litellm.llms.custom_httpx")
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
def safe_register_signal_handlers(self):
|
||||
"""Only register signal handlers in the main thread"""
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
original_register_signal_handlers(self)
|
||||
else:
|
||||
# Skip signal handling in worker threads
|
||||
pass
|
||||
|
||||
|
||||
# Replace the method
|
||||
stagehand.main.Stagehand._register_signal_handlers = safe_register_signal_handlers
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_signal_handling():
|
||||
"""Context manager to temporarily disable signal handling"""
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# In worker threads, temporarily replace signal.signal with a no-op
|
||||
original_signal = signal.signal
|
||||
|
||||
def noop_signal(*args, **kwargs):
|
||||
pass
|
||||
|
||||
signal.signal = noop_signal
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal = original_signal
|
||||
else:
|
||||
# In main thread, don't modify anything
|
||||
yield
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -148,13 +103,10 @@ class StagehandObserveBlock(Block):
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
@@ -185,32 +137,28 @@ class StagehandObserveBlock(Block):
|
||||
|
||||
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
|
||||
observe_results = await page.observe(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
for result in observe_results:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
observe_response = await session.observe(
|
||||
instruction=input_data.instruction,
|
||||
)
|
||||
for result in observe_response.data.result:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
|
||||
class StagehandActBlock(Block):
|
||||
@@ -242,24 +190,22 @@ class StagehandActBlock(Block):
|
||||
description="Variables to use in the action. Variables contains data you want the action to use.",
|
||||
default_factory=dict,
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
timeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM ready. Extended timeout for slow-loading forms",
|
||||
default=60000,
|
||||
timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms for each action.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the action was completed successfully"
|
||||
)
|
||||
message: str = SchemaField(description="Details about the action’s execution.")
|
||||
message: str = SchemaField(description="Details about the action's execution.")
|
||||
action: str = SchemaField(description="Action performed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -282,32 +228,33 @@ class StagehandActBlock(Block):
|
||||
|
||||
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
for action in input_data.action:
|
||||
action_results = await page.act(
|
||||
action,
|
||||
variables=input_data.variables,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
timeoutMs=input_data.timeoutMs,
|
||||
)
|
||||
yield "success", action_results.success
|
||||
yield "message", action_results.message
|
||||
yield "action", action_results.action
|
||||
for action in input_data.action:
|
||||
act_options = ActOptions(
|
||||
variables={k: v for k, v in input_data.variables.items()},
|
||||
timeout=input_data.timeout_ms,
|
||||
)
|
||||
act_response = await session.act(
|
||||
input=action,
|
||||
options=act_options,
|
||||
)
|
||||
result = act_response.data.result
|
||||
yield "success", result.success
|
||||
yield "message", result.message
|
||||
yield "action", result.action_description
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
|
||||
class StagehandExtractBlock(Block):
|
||||
@@ -335,13 +282,10 @@ class StagehandExtractBlock(Block):
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
@@ -367,24 +311,21 @@ class StagehandExtractBlock(Block):
|
||||
|
||||
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
extraction = await page.extract(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
yield "extraction", str(extraction.model_dump()["extraction"])
|
||||
extract_response = await session.extract(
|
||||
instruction=input_data.instruction,
|
||||
)
|
||||
yield "extraction", str(extract_response.data.result)
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
223
autogpt_platform/backend/backend/blocks/test/test_autopilot.py
Normal file
223
autogpt_platform/backend/backend/blocks/test/test_autopilot.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Tests for AutoPilotBlock: recursion guard, streaming, validation, and error paths."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AUTOPILOT_BLOCK_ID,
|
||||
AutoPilotBlock,
|
||||
_autopilot_recursion_depth,
|
||||
_autopilot_recursion_limit,
|
||||
_check_recursion,
|
||||
_reset_recursion,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
|
||||
def _make_context(user_id: str = "test-user-123") -> ExecutionContext:
|
||||
"""Helper to build an ExecutionContext for tests."""
|
||||
return ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id="graph-1",
|
||||
graph_exec_id="gexec-1",
|
||||
graph_version=1,
|
||||
node_id="node-1",
|
||||
node_exec_id="nexec-1",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recursion guard unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRecursion:
|
||||
"""Unit tests for _check_recursion / _reset_recursion."""
|
||||
|
||||
def test_first_call_increments_depth(self):
|
||||
tokens = _check_recursion(3)
|
||||
try:
|
||||
assert _autopilot_recursion_depth.get() == 1
|
||||
assert _autopilot_recursion_limit.get() == 3
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
|
||||
def test_reset_restores_previous_values(self):
|
||||
assert _autopilot_recursion_depth.get() == 0
|
||||
assert _autopilot_recursion_limit.get() is None
|
||||
tokens = _check_recursion(5)
|
||||
_reset_recursion(tokens)
|
||||
assert _autopilot_recursion_depth.get() == 0
|
||||
assert _autopilot_recursion_limit.get() is None
|
||||
|
||||
def test_exceeding_limit_raises(self):
|
||||
t1 = _check_recursion(2)
|
||||
try:
|
||||
t2 = _check_recursion(2)
|
||||
try:
|
||||
with pytest.raises(RuntimeError, match="recursion depth limit"):
|
||||
_check_recursion(2)
|
||||
finally:
|
||||
_reset_recursion(t2)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
def test_nested_calls_respect_inherited_limit(self):
|
||||
"""Inner call with higher max_depth still respects outer limit."""
|
||||
t1 = _check_recursion(2) # sets limit=2
|
||||
try:
|
||||
t2 = _check_recursion(10) # inner wants 10, but inherited is 2
|
||||
try:
|
||||
# depth is now 2, limit is min(10, 2) = 2 → should raise
|
||||
with pytest.raises(RuntimeError, match="recursion depth limit"):
|
||||
_check_recursion(10)
|
||||
finally:
|
||||
_reset_recursion(t2)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
def test_limit_of_one_blocks_immediately_on_second_call(self):
|
||||
t1 = _check_recursion(1)
|
||||
try:
|
||||
with pytest.raises(RuntimeError):
|
||||
_check_recursion(1)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AutoPilotBlock.run() validation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunValidation:
|
||||
"""Tests for input validation in AutoPilotBlock.run()."""
|
||||
|
||||
@pytest.fixture
|
||||
def block(self):
|
||||
return AutoPilotBlock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_prompt_yields_error(self, block):
|
||||
block.Input # ensure schema is accessible
|
||||
input_data = block.Input(prompt=" ", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
assert outputs.get("error") == "Prompt cannot be empty."
|
||||
assert "response" not in outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_user_id_yields_error(self, block):
|
||||
input_data = block.Input(prompt="hello", max_recursion_depth=3)
|
||||
ctx = _make_context(user_id="")
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
assert "authenticated user" in outputs.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_run_yields_all_outputs(self, block):
|
||||
"""With execute_copilot mocked, run() should yield all 5 success outputs."""
|
||||
mock_result = (
|
||||
"Hello world",
|
||||
[],
|
||||
'[{"role":"user","content":"hi"}]',
|
||||
"sess-abc",
|
||||
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
block.execute_copilot = AsyncMock(return_value=mock_result)
|
||||
block.create_session = AsyncMock(return_value="sess-abc")
|
||||
|
||||
input_data = block.Input(prompt="hi", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["response"] == "Hello world"
|
||||
assert outputs["tool_calls"] == []
|
||||
assert outputs["session_id"] == "sess-abc"
|
||||
assert outputs["token_usage"]["total_tokens"] == 15
|
||||
assert "error" not in outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_yields_error(self, block):
|
||||
"""On unexpected failure, run() should yield an error output."""
|
||||
block.execute_copilot = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
block.create_session = AsyncMock(return_value="sess-fail")
|
||||
|
||||
input_data = block.Input(prompt="do something", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["session_id"] == "sess-fail"
|
||||
assert "boom" in outputs.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_error_yields_error_and_reraises(self, block):
|
||||
"""CancelledError should yield error, then re-raise."""
|
||||
block.execute_copilot = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
block.create_session = AsyncMock(return_value="sess-cancel")
|
||||
|
||||
input_data = block.Input(prompt="do something", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["session_id"] == "sess-cancel"
|
||||
assert "cancelled" in outputs.get("error", "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_existing_session_id_skips_create(self, block):
|
||||
"""When session_id is provided, create_session should not be called."""
|
||||
mock_result = (
|
||||
"ok",
|
||||
[],
|
||||
"[]",
|
||||
"existing-sid",
|
||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
block.execute_copilot = AsyncMock(return_value=mock_result)
|
||||
block.create_session = AsyncMock()
|
||||
|
||||
input_data = block.Input(
|
||||
prompt="test", session_id="existing-sid", max_recursion_depth=3
|
||||
)
|
||||
ctx = _make_context()
|
||||
async for _ in block.run(input_data, execution_context=ctx):
|
||||
pass
|
||||
|
||||
block.create_session.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Block registration / ID tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBlockRegistration:
|
||||
def test_block_id_matches_constant(self):
|
||||
block = AutoPilotBlock()
|
||||
assert block.id == AUTOPILOT_BLOCK_ID
|
||||
|
||||
def test_max_recursion_depth_has_upper_bound(self):
|
||||
"""Schema should enforce le=10."""
|
||||
schema = AutoPilotBlock.Input.model_json_schema()
|
||||
max_rec = schema["properties"]["max_recursion_depth"]
|
||||
assert (
|
||||
max_rec.get("maximum") == 10 or max_rec.get("exclusiveMaximum", 999) <= 11
|
||||
)
|
||||
|
||||
def test_output_schema_has_no_duplicate_error_field(self):
|
||||
"""Output should inherit error from BlockSchemaOutput, not redefine it."""
|
||||
# The field should exist (inherited) but there should be no explicit
|
||||
# redefinition. We verify by checking the class __annotations__ directly.
|
||||
assert "error" not in AutoPilotBlock.Output.__annotations__
|
||||
@@ -1,9 +1,18 @@
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import anthropic
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
# TEST_CREDENTIALS_INPUT is a plain dict that satisfies AICredentials at runtime
|
||||
# but not at the type level. Cast once here to avoid per-test suppressors.
|
||||
_TEST_AI_CREDENTIALS = cast(llm.AICredentials, llm.TEST_CREDENTIALS_INPUT)
|
||||
|
||||
|
||||
class TestLLMStatsTracking:
|
||||
"""Test that LLM blocks correctly track token usage statistics."""
|
||||
@@ -13,18 +22,17 @@ class TestLLMStatsTracking:
|
||||
"""Test that llm_call returns proper token counts in LLMResponse."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
# Mock the OpenAI client
|
||||
# Mock the OpenAI Responses API response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [
|
||||
MagicMock(message=MagicMock(content="Test response", tool_calls=None))
|
||||
]
|
||||
mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20)
|
||||
mock_response.output_text = "Test response"
|
||||
mock_response.output = []
|
||||
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||
|
||||
# Test with mocked OpenAI response
|
||||
with patch("openai.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
mock_client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
@@ -271,30 +279,17 @@ class TestLLMStatsTracking:
|
||||
mock_response = MagicMock()
|
||||
# Return different responses for chunk summary vs final summary
|
||||
if call_count == 1:
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
]
|
||||
mock_response.output_text = '<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>'
|
||||
else:
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
]
|
||||
mock_response.usage = MagicMock(prompt_tokens=50, completion_tokens=30)
|
||||
mock_response.output_text = '<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>'
|
||||
mock_response.output = []
|
||||
mock_response.usage = MagicMock(input_tokens=50, output_tokens=30)
|
||||
return mock_response
|
||||
|
||||
with patch("openai.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_client.chat.completions.create = mock_create
|
||||
mock_client.responses.create = mock_create
|
||||
|
||||
# Test with very short text (should only need 1 chunk + 1 final summary)
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
@@ -669,3 +664,148 @@ class TestAITextSummarizerValidation:
|
||||
error_message = str(exc_info.value)
|
||||
assert "Expected a string summary" in error_message
|
||||
assert "received dict" in error_message
|
||||
|
||||
|
||||
def _make_anthropic_status_error(status_code: int) -> anthropic.APIStatusError:
|
||||
"""Create an anthropic.APIStatusError with the given status code."""
|
||||
request = httpx.Request("POST", "https://api.anthropic.com/v1/messages")
|
||||
response = httpx.Response(status_code, request=request)
|
||||
return anthropic.APIStatusError(
|
||||
f"Error code: {status_code}", response=response, body=None
|
||||
)
|
||||
|
||||
|
||||
def _make_openai_status_error(status_code: int) -> openai.APIStatusError:
|
||||
"""Create an openai.APIStatusError with the given status code."""
|
||||
response = httpx.Response(
|
||||
status_code, request=httpx.Request("POST", "https://api.openai.com/v1/chat")
|
||||
)
|
||||
return openai.APIStatusError(
|
||||
f"Error code: {status_code}", response=response, body=None
|
||||
)
|
||||
|
||||
|
||||
class TestUserErrorStatusCodeHandling:
|
||||
"""Test that user-caused LLM API errors (401/403/429) break the retry loop
|
||||
and are logged as warnings, while server errors (500) trigger retries."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("status_code", [401, 403, 429])
|
||||
async def test_anthropic_user_error_breaks_retry_loop(self, status_code: int):
|
||||
"""401/403/429 Anthropic errors should break immediately, not retry."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_anthropic_status_error(status_code)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count == 1
|
||||
), f"Expected exactly 1 call for status {status_code}, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("status_code", [401, 403, 429])
|
||||
async def test_openai_user_error_breaks_retry_loop(self, status_code: int):
|
||||
"""401/403/429 OpenAI errors should break immediately, not retry."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_openai_status_error(status_code)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count == 1
|
||||
), f"Expected exactly 1 call for status {status_code}, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_error_retries(self):
|
||||
"""500 errors should be retried (not break immediately)."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_anthropic_status_error(500)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count > 1
|
||||
), f"Expected multiple retry attempts for 500, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_error_logs_warning_not_exception(self):
|
||||
"""User-caused errors should log with logger.warning, not logger.exception."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
raise _make_anthropic_status_error(401)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(llm.logger, "warning") as mock_warning,
|
||||
patch.object(llm.logger, "exception") as mock_exception,
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
mock_warning.assert_called_once()
|
||||
mock_exception.assert_not_called()
|
||||
|
||||
@@ -57,7 +57,7 @@ async def execute_graph(
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data import graph
|
||||
|
||||
test_user = await create_test_user()
|
||||
@@ -66,7 +66,7 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
block_id=OrchestratorBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": creds,
|
||||
@@ -108,10 +108,10 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
async def test_orchestrator_function_signature(server: SpinTestServer):
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data import graph
|
||||
|
||||
test_user = await create_test_user()
|
||||
@@ -120,7 +120,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
block_id=OrchestratorBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": creds,
|
||||
@@ -169,7 +169,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
tool_functions = await SmartDecisionMakerBlock._create_tool_node_signatures(
|
||||
tool_functions = await OrchestratorBlock._create_tool_node_signatures(
|
||||
test_graph.nodes[0].id
|
||||
)
|
||||
assert tool_functions is not None, "Tool functions should not be None"
|
||||
@@ -198,12 +198,12 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_tracks_llm_stats():
|
||||
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
|
||||
async def test_orchestrator_tracks_llm_stats():
|
||||
"""Test that OrchestratorBlock correctly tracks LLM usage stats."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock the llm.llm_call function to return controlled data
|
||||
mock_response = MagicMock()
|
||||
@@ -224,14 +224,14 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
|
||||
# Create test input
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Should I continue with this task?",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -274,12 +274,12 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_parameter_validation():
|
||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||
async def test_orchestrator_parameter_validation():
|
||||
"""Test that OrchestratorBlock correctly validates tool call parameters."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool functions with specific parameter schema
|
||||
mock_tool_functions = [
|
||||
@@ -327,13 +327,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_with_typo,
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -394,13 +394,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_missing_required,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -454,13 +454,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_valid,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -518,13 +518,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_all_params,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -562,12 +562,12 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_raw_response_conversion():
|
||||
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
||||
async def test_orchestrator_raw_response_conversion():
|
||||
"""Test that Orchestrator correctly handles different raw_response types with retry mechanism."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool functions
|
||||
mock_tool_functions = [
|
||||
@@ -637,7 +637,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
@@ -646,7 +646,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
# Second call returns successful response
|
||||
mock_llm_call.side_effect = [mock_response_retry, mock_response_success]
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -715,12 +715,12 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_ollama,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[], # No tools for this test
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -771,12 +771,12 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_dict,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -811,12 +811,12 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_mode():
|
||||
async def test_orchestrator_agent_mode():
|
||||
"""Test that agent mode executes tools directly and loops until finished."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool call that requires multiple iterations
|
||||
mock_tool_call_1 = MagicMock()
|
||||
@@ -893,7 +893,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||
), patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
), patch(
|
||||
"backend.executor.manager.async_update_node_execution_status",
|
||||
@@ -929,7 +929,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
}
|
||||
|
||||
# Test agent mode with max_iterations = 3
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Complete this task using tools",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -969,12 +969,12 @@ async def test_smart_decision_maker_agent_mode():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_traditional_mode_default():
|
||||
async def test_orchestrator_traditional_mode_default():
|
||||
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool call
|
||||
mock_tool_call = MagicMock()
|
||||
@@ -1018,7 +1018,7 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
):
|
||||
|
||||
# Test default behavior (traditional mode)
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -1060,12 +1060,12 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
|
||||
async def test_orchestrator_uses_customized_name_for_blocks():
|
||||
"""Test that OrchestratorBlock uses customized_name from node metadata for tool names."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
@@ -1080,7 +1080,7 @@ async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
result = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1091,12 +1091,12 @@ async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
|
||||
async def test_orchestrator_falls_back_to_block_name():
|
||||
"""Test that OrchestratorBlock falls back to block.name when no customized_name."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
@@ -1111,7 +1111,7 @@ async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
result = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1122,11 +1122,11 @@ async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
|
||||
async def test_orchestrator_uses_customized_name_for_agents():
|
||||
"""Test that OrchestratorBlock uses customized_name from metadata for agent nodes."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
@@ -1152,10 +1152,10 @@ async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
result = await OrchestratorBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1166,11 +1166,11 @@ async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
async def test_orchestrator_agent_falls_back_to_graph_name():
|
||||
"""Test that agent node falls back to graph name when no customized_name."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
@@ -1196,10 +1196,10 @@ async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
result = await OrchestratorBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -3,12 +3,12 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
"""Test Smart Decision Maker can handle dynamic dictionary fields (_#_) for any block"""
|
||||
async def test_orchestrator_handles_dynamic_dict_fields():
|
||||
"""Test Orchestrator can handle dynamic dictionary fields (_#_) for any block"""
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
@@ -23,24 +23,24 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
source_name="tools_^_create_dict_~_name",
|
||||
sink_name="values_#_name", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_age",
|
||||
sink_name="values_#_age", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_city",
|
||||
sink_name="values_#_city", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
signature = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, mock_links # type: ignore
|
||||
)
|
||||
|
||||
@@ -70,8 +70,8 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
"""Test Smart Decision Maker can handle dynamic list fields (_$_) for any block"""
|
||||
async def test_orchestrator_handles_dynamic_list_fields():
|
||||
"""Test Orchestrator can handle dynamic list fields (_$_) for any block"""
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
@@ -86,18 +86,18 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
source_name="tools_^_add_to_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_to_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
signature = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, mock_links # type: ignore
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
||||
"""Comprehensive tests for OrchestratorBlock dynamic field handling."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
@@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.text import MatchTextPatternBlock
|
||||
from backend.data.dynamic_fields import get_dynamic_field_description
|
||||
|
||||
@@ -37,7 +37,7 @@ async def test_dynamic_field_description_generation():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_dict_fields():
|
||||
"""Test that function signatures are created correctly for dictionary dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
@@ -52,19 +52,19 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
source_name="tools_^_create_dict_~_values___name", # Sanitized source
|
||||
sink_name="values_#_name", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___age", # Sanitized source
|
||||
sink_name="values_#_age", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___email", # Sanitized source
|
||||
sink_name="values_#_email", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -100,7 +100,7 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_list_fields():
|
||||
"""Test that function signatures are created correctly for list dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
@@ -115,19 +115,19 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
source_name="tools_^_add_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_2",
|
||||
sink_name="entries_$_2", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -154,7 +154,7 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_object_fields():
|
||||
"""Test that function signatures are created correctly for object dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node for MatchTextPatternBlock (simulating object fields)
|
||||
mock_node = Mock()
|
||||
@@ -169,13 +169,13 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
source_name="tools_^_extract_~_user_name",
|
||||
sink_name="data_@_user_name", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_email",
|
||||
sink_name="data_@_user_email", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -197,11 +197,11 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_tool_node_signatures():
|
||||
"""Test that the mapping between sanitized and original field names is built correctly."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock the database client and connected nodes
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
) as mock_db:
|
||||
mock_client = AsyncMock()
|
||||
mock_db.return_value = mock_client
|
||||
@@ -281,7 +281,7 @@ async def test_create_tool_node_signatures():
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_yielding_with_dynamic_fields():
|
||||
"""Test that outputs are yielded correctly with dynamic field names mapped back."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# No more sanitized mapping needed since we removed sanitization
|
||||
|
||||
@@ -309,13 +309,13 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
"backend.blocks.orchestrator.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
# Mock the database manager to avoid HTTP calls during tool execution
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
) as mock_db_manager, patch.object(
|
||||
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
@@ -420,7 +420,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_regular_and_dynamic_fields():
|
||||
"""Test handling of blocks with both regular and dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node
|
||||
mock_node = Mock()
|
||||
@@ -450,19 +450,19 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
source_name="tools_^_test_~_regular",
|
||||
sink_name="regular_field", # Regular field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key",
|
||||
sink_name="values_#_key1", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key2",
|
||||
sink_name="values_#_key2", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -488,7 +488,7 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_errors_dont_pollute_conversation():
|
||||
"""Test that validation errors are only used during retries and don't pollute the conversation."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Track conversation history changes
|
||||
conversation_snapshots = []
|
||||
@@ -535,7 +535,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
"backend.blocks.orchestrator.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.side_effect = mock_llm_call
|
||||
|
||||
@@ -565,7 +565,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
|
||||
# Mock the database manager to avoid HTTP calls during tool execution
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
) as mock_db_manager:
|
||||
# Set up the mock database manager for agent mode
|
||||
mock_db_client = AsyncMock()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -290,7 +290,9 @@ class FillTextTemplateBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
formatter = text.TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(input_data.format, input_data.values)
|
||||
yield "output", await formatter.format_string(
|
||||
input_data.format, input_data.values
|
||||
)
|
||||
|
||||
|
||||
class CombineTextsBlock(Block):
|
||||
|
||||
@@ -115,10 +115,22 @@ class ChatConfig(BaseSettings):
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
use_openrouter: bool = Field(
|
||||
default=True,
|
||||
description="Enable routing API calls through the OpenRouter proxy. "
|
||||
"The actual decision also requires ``api_key`` and ``base_url`` — "
|
||||
"use the ``openrouter_active`` property for the final answer.",
|
||||
)
|
||||
use_claude_code_subscription: bool = Field(
|
||||
default=False,
|
||||
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
|
||||
)
|
||||
test_mode: bool = Field(
|
||||
default=False,
|
||||
description="Use dummy service instead of real LLM calls. "
|
||||
"Send __test_transient_error__, __test_fatal_error__, or "
|
||||
"__test_slow_response__ to trigger specific scenarios.",
|
||||
)
|
||||
|
||||
# E2B Sandbox Configuration
|
||||
use_e2b_sandbox: bool = Field(
|
||||
@@ -136,7 +148,7 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
|
||||
default=420, # 7 min safety net — allows headroom for compaction retries
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
@@ -146,6 +158,21 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
|
||||
)
|
||||
|
||||
@property
|
||||
def openrouter_active(self) -> bool:
|
||||
"""True when OpenRouter is enabled AND credentials are usable.
|
||||
|
||||
Single source of truth for "will the SDK route through OpenRouter?".
|
||||
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
|
||||
present — mirrors the fallback logic in ``_build_sdk_env``.
|
||||
"""
|
||||
if not self.use_openrouter:
|
||||
return False
|
||||
base = (self.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
return bool(self.api_key and base and base.startswith("http"))
|
||||
|
||||
@property
|
||||
def e2b_active(self) -> bool:
|
||||
"""True when E2B is enabled and the API key is present.
|
||||
@@ -168,15 +195,6 @@ class ChatConfig(BaseSettings):
|
||||
"""
|
||||
return self.e2b_api_key if self.e2b_active else None
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
def get_use_e2b_sandbox(cls, v):
|
||||
"""Get use_e2b_sandbox from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("e2b_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_e2b_api_key(cls, v):
|
||||
@@ -219,26 +237,6 @@ class ChatConfig(BaseSettings):
|
||||
v = OPENROUTER_BASE_URL
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_agent_sdk(cls, v):
|
||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||
# Check environment variable - default to True if not set
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("use_claude_code_subscription", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_code_subscription(cls, v):
|
||||
"""Get use_claude_code_subscription from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return False if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
@@ -248,6 +246,7 @@ class ChatConfig(BaseSettings):
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
env_prefix = "CHAT_"
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore" # Ignore extra environment variables
|
||||
|
||||
@@ -6,19 +6,70 @@ from .config import ChatConfig
|
||||
|
||||
# Env vars that the ChatConfig validators read — must be cleared so they don't
|
||||
# override the explicit constructor values we pass in each test.
|
||||
_E2B_ENV_VARS = (
|
||||
_ENV_VARS_TO_CLEAR = (
|
||||
"CHAT_USE_E2B_SANDBOX",
|
||||
"CHAT_E2B_API_KEY",
|
||||
"E2B_API_KEY",
|
||||
"CHAT_USE_OPENROUTER",
|
||||
"CHAT_API_KEY",
|
||||
"OPEN_ROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"CHAT_BASE_URL",
|
||||
"OPENROUTER_BASE_URL",
|
||||
"OPENAI_BASE_URL",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _E2B_ENV_VARS:
|
||||
def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _ENV_VARS_TO_CLEAR:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestOpenrouterActive:
|
||||
"""Tests for the openrouter_active property."""
|
||||
|
||||
def test_enabled_with_credentials_returns_true(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is True
|
||||
|
||||
def test_enabled_but_missing_api_key_returns_false(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key=None,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
def test_disabled_returns_false_despite_credentials(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
def test_strips_v1_suffix_and_still_valid(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is True
|
||||
|
||||
def test_invalid_base_url_returns_false(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="not-a-url",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
|
||||
class TestE2BActive:
|
||||
"""Tests for the e2b_active property — single source of truth for E2B usage."""
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
# The hex suffix makes accidental LLM generation of these strings virtually
|
||||
# impossible, avoiding false-positive marker detection in normal conversation.
|
||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX = (
|
||||
"[__COPILOT_RETRYABLE_ERROR_a9c2__]" # ErrorCard + retry
|
||||
)
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Prefix for all synthetic IDs generated by CoPilot block execution.
|
||||
@@ -35,3 +38,24 @@ def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||
Format: "{node_id}:{random_hex}" → returns "{node_id}".
|
||||
"""
|
||||
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transient Anthropic API error detection
|
||||
# ---------------------------------------------------------------------------
|
||||
# Patterns in error text that indicate a transient Anthropic API error
|
||||
# (ECONNRESET / dropped TCP connection) which is retryable.
|
||||
_TRANSIENT_ERROR_PATTERNS = (
|
||||
"socket connection was closed unexpectedly",
|
||||
"ECONNRESET",
|
||||
"connection was forcibly closed",
|
||||
"network socket disconnected",
|
||||
)
|
||||
|
||||
FRIENDLY_TRANSIENT_MSG = "Anthropic connection interrupted — please retry"
|
||||
|
||||
|
||||
def is_transient_api_error(error_text: str) -> bool:
|
||||
"""Return True if *error_text* matches a known transient Anthropic API error."""
|
||||
lower = error_text.lower()
|
||||
return any(pat.lower() in lower for pat in _TRANSIENT_ERROR_PATTERNS)
|
||||
|
||||
@@ -17,8 +17,20 @@ from backend.util.workspace import WorkspaceManager
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
# Allowed base directory for the Read tool.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
|
||||
# Allowed base directory for the Read tool. Public so service.py can use it
|
||||
# for sweep operations without depending on a private implementation detail.
|
||||
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
|
||||
# _projects_base() function.
|
||||
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))
|
||||
|
||||
# Compiled UUID pattern for validating conversation directory names.
|
||||
# Kept as a module-level constant so the security-relevant pattern is easy
|
||||
# to audit in one place and avoids recompilation on every call.
|
||||
_UUID_RE = re.compile(r"^[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}$", re.IGNORECASE)
|
||||
|
||||
# Encoded project-directory name for the current session (e.g.
|
||||
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
|
||||
@@ -34,17 +46,33 @@ _current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
)
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
# Current execution's capability filter. None means "no restrictions".
|
||||
# Set by set_execution_context(); read by run_block and service.py.
|
||||
_current_permissions: "ContextVar[CopilotPermissions | None]" = ContextVar(
|
||||
"_current_permissions", default=None
|
||||
)
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does."""
|
||||
|
||||
def encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The Claude CLI encodes the absolute cwd as a directory name by replacing
|
||||
every non-alphanumeric character with ``-``. For example
|
||||
``/tmp/copilot-abc`` becomes ``-tmp-copilot-abc``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
# Keep the private alias for internal callers (backwards compat).
|
||||
_encode_cwd_for_cli = encode_cwd_for_cli
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> None:
|
||||
"""Set per-turn context variables used by file-resolution tool handlers."""
|
||||
_current_user_id.set(user_id)
|
||||
@@ -52,6 +80,7 @@ def set_execution_context(
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
_current_permissions.set(permissions)
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
@@ -59,6 +88,11 @@ def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
return _current_user_id.get(), _current_session.get()
|
||||
|
||||
|
||||
def get_current_permissions() -> "CopilotPermissions | None":
|
||||
"""Return the capability filter for the current execution, or None if unrestricted."""
|
||||
return _current_permissions.get()
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current session, or None if not active."""
|
||||
return _current_sandbox.get()
|
||||
@@ -100,7 +134,9 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
|
||||
The SDK nests tool-results under a conversation UUID directory;
|
||||
the UUID segment is validated with ``_UUID_RE``.
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
@@ -119,10 +155,22 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
if resolved == tool_results_dir or resolved.startswith(
|
||||
tool_results_dir + os.sep
|
||||
):
|
||||
return True
|
||||
project_dir = os.path.realpath(os.path.join(SDK_PROJECTS_DIR, encoded))
|
||||
# Defence-in-depth: ensure project_dir didn't escape the base.
|
||||
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
|
||||
return False
|
||||
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
|
||||
# The SDK always creates a conversation UUID directory between
|
||||
# the project dir and tool-results/.
|
||||
if resolved.startswith(project_dir + os.sep):
|
||||
relative = resolved[len(project_dir) + 1 :]
|
||||
parts = relative.split(os.sep)
|
||||
# Require exactly: [<uuid>, "tool-results", <file>, ...]
|
||||
if (
|
||||
len(parts) >= 3
|
||||
and _UUID_RE.match(parts[0])
|
||||
and parts[1] == "tool-results"
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -9,8 +9,9 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import (
|
||||
_SDK_PROJECTS_DIR,
|
||||
SDK_PROJECTS_DIR,
|
||||
_current_project_dir,
|
||||
get_current_permissions,
|
||||
get_current_sandbox,
|
||||
get_execution_context,
|
||||
get_sdk_cwd,
|
||||
@@ -18,6 +19,7 @@ from backend.copilot.context import (
|
||||
resolve_sandbox_path,
|
||||
set_execution_context,
|
||||
)
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
|
||||
def _make_session() -> MagicMock:
|
||||
@@ -61,6 +63,19 @@ def test_get_current_sandbox_returns_set_value():
|
||||
assert get_current_sandbox() is mock_sandbox
|
||||
|
||||
|
||||
def test_set_and_get_current_permissions():
|
||||
"""set_execution_context stores permissions; get_current_permissions returns it."""
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
set_execution_context("u1", _make_session(), permissions=perms)
|
||||
assert get_current_permissions() is perms
|
||||
|
||||
|
||||
def test_get_current_permissions_defaults_to_none():
|
||||
"""get_current_permissions returns None when no permissions have been set."""
|
||||
set_execution_context("u1", _make_session())
|
||||
assert get_current_permissions() is None
|
||||
|
||||
|
||||
def test_get_sdk_cwd_empty_when_not_set():
|
||||
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
|
||||
set_execution_context("u1", _make_session(), sdk_cwd=None)
|
||||
@@ -104,11 +119,13 @@ def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
|
||||
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_dir():
|
||||
"""Files under the tool-results directory for the current project are allowed."""
|
||||
def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
"""Files under <encoded-cwd>/<uuid>/tool-results/ are allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
path = os.path.join(tool_results_dir, "output.txt")
|
||||
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results", "output.txt"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
@@ -117,10 +134,22 @@ def test_is_allowed_local_path_tool_results_dir():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
|
||||
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
path = os.path.join(SDK_PROJECTS_DIR, encoded, "tool-results", "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
"""A path adjacent to tool-results/ but not inside it is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
sibling_path = os.path.join(SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
@@ -129,6 +158,21 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
|
||||
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
uuid_str = "12345678-1234-5678-9abc-def012345678"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, uuid_str, "not-tool-results", "output.txt"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -16,6 +16,7 @@ from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
@@ -246,17 +247,25 @@ class CoPilotProcessor:
|
||||
# Choose service based on LaunchDarkly flag.
|
||||
# Claude Code subscription forces SDK mode (CLI subprocess auth).
|
||||
config = ChatConfig()
|
||||
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
if config.test_mode:
|
||||
stream_fn = stream_chat_completion_dummy
|
||||
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
|
||||
else:
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
async for chunk in stream_fn(
|
||||
|
||||
173
autogpt_platform/backend/backend/copilot/integration_creds.py
Normal file
173
autogpt_platform/backend/backend/copilot/integration_creds.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Integration credential lookup with per-process TTL cache.
|
||||
|
||||
Provides token retrieval for connected integrations so that copilot tools
|
||||
(e.g. bash_exec) can inject auth tokens into the execution environment without
|
||||
hitting the database on every command.
|
||||
|
||||
Cache semantics (handled automatically by TTLCache):
|
||||
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
|
||||
for users who have credentials and are running many bash commands.
|
||||
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
|
||||
on every E2B command for users who haven't connected an account yet, while
|
||||
still picking up a newly-connected account within one minute.
|
||||
|
||||
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
|
||||
least-recently-used entry when the limit is reached.
|
||||
|
||||
Multi-worker note: both caches are in-process only. Each worker/replica
|
||||
maintains its own independent cache, so a credential fetch may be duplicated
|
||||
across processes. This is acceptable for the current goal (reduce DB hits per
|
||||
session per-process), but if cache efficiency across replicas becomes important
|
||||
a shared cache (e.g. Redis) should be used instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from backend.copilot.providers import SUPPORTED_PROVIDERS
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
register_creds_changed_hook,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Derived from the single SUPPORTED_PROVIDERS registry for backward compat.
|
||||
PROVIDER_ENV_VARS: dict[str, list[str]] = {
|
||||
slug: entry["env_vars"] for slug, entry in SUPPORTED_PROVIDERS.items()
|
||||
}
|
||||
|
||||
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
|
||||
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
|
||||
_CACHE_MAX_SIZE = 10_000
|
||||
|
||||
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
|
||||
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
|
||||
# because all callers (get_provider_token, invalidate_user_provider_cache) run
|
||||
# exclusively on the asyncio event loop. There are no await points between a
|
||||
# cache read and its corresponding write within any function, so no concurrent
|
||||
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
|
||||
# this path, a threading.RLock should be wrapped around these caches.
|
||||
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
|
||||
)
|
||||
# Separate cache for "no credentials" results with a shorter TTL.
|
||||
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
# entries. This avoids a lazy import inside creds_manager and eliminates the
|
||||
# circular-import risk.
|
||||
try:
|
||||
register_creds_changed_hook(invalidate_user_provider_cache)
|
||||
except RuntimeError:
|
||||
# Hook already registered (e.g. module re-import in tests).
|
||||
pass
|
||||
|
||||
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
|
||||
# on every cache-miss call to get_provider_token().
|
||||
_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
"""Return the user's access token for *provider*, or ``None`` if not connected.
|
||||
|
||||
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
|
||||
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
|
||||
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
|
||||
command for users who haven't connected yet, while still picking up a
|
||||
newly-connected account within one minute.
|
||||
"""
|
||||
cache_key = (user_id, provider)
|
||||
|
||||
if cache_key in _null_cache:
|
||||
return None
|
||||
if cached := _token_cache.get(cache_key):
|
||||
return cached
|
||||
|
||||
manager = _manager
|
||||
try:
|
||||
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch %s credentials for user %s",
|
||||
provider,
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
|
||||
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
|
||||
# full git access, while a public-data-only token lacks push/pull permission.
|
||||
# lock=False — background injection; not worth a distributed lock acquisition.
|
||||
oauth2_creds = sorted(
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
fresh = await manager.refresh_if_needed(
|
||||
user_id, cast(OAuth2Credentials, creds), lock=False
|
||||
)
|
||||
token = fresh.access_token.get_secret_value()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh %s OAuth token for user %s; "
|
||||
"discarding stale token to force re-auth",
|
||||
provider,
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
# Do NOT fall back to the stale token — it is likely expired
|
||||
# or revoked. Returning None forces the caller to re-auth,
|
||||
# preventing the LLM from receiving a non-functional token.
|
||||
continue
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# Pass 2: fall back to API key (no expiry, no refresh needed).
|
||||
for creds in creds_list:
|
||||
if creds.type == "api_key":
|
||||
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
"""Return env vars for all providers the user has connected.
|
||||
|
||||
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
|
||||
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
|
||||
Only providers with a stored credential contribute entries.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
for provider, var_names in PROVIDER_ENV_VARS.items():
|
||||
token = await get_provider_token(user_id, provider)
|
||||
if token:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
@@ -0,0 +1,195 @@
|
||||
"""Tests for integration_creds — TTL cache and token lookup paths."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
get_provider_token,
|
||||
invalidate_user_provider_cache,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
_USER = "user-integration-creds-test"
|
||||
_PROVIDER = "github"
|
||||
|
||||
|
||||
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="creds-api-key",
|
||||
provider=_PROVIDER,
|
||||
api_key=SecretStr(key),
|
||||
title="Test API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id="creds-oauth2",
|
||||
provider=_PROVIDER,
|
||||
title="Test OAuth",
|
||||
access_token=SecretStr(token),
|
||||
refresh_token=SecretStr("test-refresh"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
def test_removes_token_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_token_cache[key] = "tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _token_cache
|
||||
|
||||
def test_removes_null_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_null_cache[key] = True
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _null_cache
|
||||
|
||||
def test_noop_when_key_not_cached(self):
|
||||
# Should not raise even when there is no cache entry.
|
||||
invalidate_user_provider_cache("no-such-user", _PROVIDER)
|
||||
|
||||
def test_only_removes_targeted_key(self):
|
||||
other_key = ("other-user", _PROVIDER)
|
||||
_token_cache[other_key] = "other-tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_cached_token_without_db_hit(self):
|
||||
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "cached-tok"
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_none_for_null_cached_provider(self):
|
||||
_null_cache[(_USER, _PROVIDER)] = True
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_api_key_creds_returned_and_cached(self):
|
||||
api_creds = _make_api_key_creds("my-api-key")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "my-api-key"
|
||||
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_preferred_over_api_key(self):
|
||||
oauth_creds = _make_oauth2_creds("oauth-tok")
|
||||
api_creds = _make_api_key_creds("api-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[api_creds, oauth_creds]
|
||||
)
|
||||
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_returns_none(self):
|
||||
"""On refresh failure, return None instead of caching a stale token."""
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
# Stale tokens must NOT be returned — forces re-auth.
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
assert _null_cache.get((_USER, _PROVIDER)) is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_db_exception_returns_none_without_caching(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
side_effect=RuntimeError("db down")
|
||||
)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
# DB errors are not cached — next call will retry
|
||||
assert (_USER, _PROVIDER) not in _token_cache
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
|
||||
"""Verify the TTL constants are set correctly for each cache."""
|
||||
assert _null_cache.ttl == _NULL_CACHE_TTL
|
||||
assert _token_cache.ttl == _TOKEN_CACHE_TTL
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
_token_cache[(_USER, "github")] = "gh-tok"
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
for var in PROVIDER_ENV_VARS["github"]:
|
||||
assert result[var] == "gh-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_dict_when_no_credentials(self):
|
||||
_null_cache[(_USER, "github")] = True
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
assert result == {}
|
||||
430
autogpt_platform/backend/backend/copilot/permissions.py
Normal file
430
autogpt_platform/backend/backend/copilot/permissions.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""Copilot execution permissions — tool and block allow/deny filtering.
|
||||
|
||||
:class:`CopilotPermissions` is the single model used everywhere:
|
||||
|
||||
- ``AutoPilotBlock`` reads four block-input fields and builds one instance.
|
||||
- ``stream_chat_completion_sdk`` applies it when constructing
|
||||
``ClaudeAgentOptions.allowed_tools`` / ``disallowed_tools``.
|
||||
- ``run_block`` reads it from the contextvar to gate block execution.
|
||||
- Recursive (sub-agent) invocations merge parent and child so children
|
||||
can only be *more* restrictive, never more permissive.
|
||||
|
||||
Tool names
|
||||
----------
|
||||
Users specify the **short name** as it appears in ``TOOL_REGISTRY`` (e.g.
|
||||
``run_block``, ``web_fetch``) or as an SDK built-in (e.g. ``Read``,
|
||||
``Task``, ``WebSearch``). Internally these are mapped to the full SDK
|
||||
format (``mcp__copilot__run_block``, ``Read``, …) by
|
||||
:func:`apply_tool_permissions`.
|
||||
|
||||
Block identifiers
|
||||
-----------------
|
||||
Each entry in ``blocks`` may be one of:
|
||||
|
||||
- A **full UUID** (``c069dc6b-c3ed-4c12-b6e5-d47361e64ce6``)
|
||||
- A **partial UUID** — the first 8-character hex segment (``c069dc6b``)
|
||||
- A **block name** (case-insensitive, e.g. ``"HTTP Request"``)
|
||||
|
||||
:func:`validate_block_identifiers` resolves all entries against the live
|
||||
block registry and returns any that could not be matched.
|
||||
|
||||
Semantics
|
||||
---------
|
||||
``tools_exclude=True`` (default) — ``tools`` is a **blacklist**; listed
|
||||
tools are denied and everything else is allowed. An empty list means
|
||||
"allow all" (no filtering).
|
||||
|
||||
``tools_exclude=False`` — ``tools`` is a **whitelist**; only listed tools
|
||||
are allowed.
|
||||
|
||||
``blocks_exclude`` follows the same pattern for ``blocks``.
|
||||
|
||||
Recursion inheritance
|
||||
---------------------
|
||||
:meth:`CopilotPermissions.merged_with_parent` produces a new instance that
|
||||
is at most as permissive as the parent:
|
||||
|
||||
- Tools: effective-allowed sets are intersected then stored as a whitelist.
|
||||
- Blocks: the parent is stored in ``_parent`` and consulted during every
|
||||
:meth:`is_block_allowed` call so both constraints must pass.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Literal, get_args
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants — single source of truth for all accepted tool names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Literal type combining all valid tool names — used by AutoPilotBlock.Input
|
||||
# so the frontend renders a multi-select dropdown.
|
||||
# This is the SINGLE SOURCE OF TRUTH. All other name sets are derived from it.
|
||||
ToolName = Literal[
|
||||
# Platform tools (must match keys in TOOL_REGISTRY)
|
||||
"add_understanding",
|
||||
"bash_exec",
|
||||
"browser_act",
|
||||
"browser_navigate",
|
||||
"browser_screenshot",
|
||||
"connect_integration",
|
||||
"continue_run_block",
|
||||
"create_agent",
|
||||
"create_feature_request",
|
||||
"create_folder",
|
||||
"customize_agent",
|
||||
"delete_folder",
|
||||
"delete_workspace_file",
|
||||
"edit_agent",
|
||||
"find_agent",
|
||||
"find_block",
|
||||
"find_library_agent",
|
||||
"fix_agent_graph",
|
||||
"get_agent_building_guide",
|
||||
"get_doc_page",
|
||||
"get_mcp_guide",
|
||||
"list_folders",
|
||||
"list_workspace_files",
|
||||
"move_agents_to_folder",
|
||||
"move_folder",
|
||||
"read_workspace_file",
|
||||
"run_agent",
|
||||
"run_block",
|
||||
"run_mcp_tool",
|
||||
"search_docs",
|
||||
"search_feature_requests",
|
||||
"update_folder",
|
||||
"validate_agent_graph",
|
||||
"view_agent_output",
|
||||
"web_fetch",
|
||||
"write_workspace_file",
|
||||
# SDK built-ins
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Read",
|
||||
"Task",
|
||||
"TodoWrite",
|
||||
"WebSearch",
|
||||
"Write",
|
||||
]
|
||||
|
||||
# Frozen set of all valid tool names — derived from the Literal.
|
||||
ALL_TOOL_NAMES: frozenset[str] = frozenset(get_args(ToolName))
|
||||
|
||||
# SDK built-in tool names — uppercase-initial names are SDK built-ins.
|
||||
SDK_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
n for n in ALL_TOOL_NAMES if n[0].isupper()
|
||||
)
|
||||
|
||||
# Platform tool names — everything that isn't an SDK built-in.
|
||||
PLATFORM_TOOL_NAMES: frozenset[str] = ALL_TOOL_NAMES - SDK_BUILTIN_TOOL_NAMES
|
||||
|
||||
# Compiled regex patterns for block identifier classification.
|
||||
_FULL_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_PARTIAL_UUID_RE = re.compile(r"^[0-9a-f]{8}$", re.IGNORECASE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper — block identifier matching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _block_matches(identifier: str, block_id: str, block_name: str) -> bool:
|
||||
"""Return True if *identifier* resolves to the given block.
|
||||
|
||||
Resolution order:
|
||||
1. Full UUID — exact case-insensitive match against *block_id*.
|
||||
2. Partial UUID (8 hex chars, first segment) — prefix match.
|
||||
3. Name — case-insensitive equality against *block_name*.
|
||||
"""
|
||||
ident = identifier.strip()
|
||||
if _FULL_UUID_RE.match(ident):
|
||||
return ident.lower() == block_id.lower()
|
||||
if _PARTIAL_UUID_RE.match(ident):
|
||||
return block_id.lower().startswith(ident.lower())
|
||||
return ident.lower() == block_name.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CopilotPermissions(BaseModel):
|
||||
"""Capability filter for a single copilot execution.
|
||||
|
||||
Attributes:
|
||||
tools: Tool names to filter (short names, e.g. ``run_block``).
|
||||
tools_exclude: When True (default) ``tools`` is a blacklist;
|
||||
when False it is a whitelist. Ignored when *tools* is empty.
|
||||
blocks: Block identifiers (name, full UUID, or 8-char partial UUID).
|
||||
blocks_exclude: Same semantics as *tools_exclude* but for blocks.
|
||||
"""
|
||||
|
||||
tools: list[str] = []
|
||||
tools_exclude: bool = True
|
||||
blocks: list[str] = []
|
||||
blocks_exclude: bool = True
|
||||
|
||||
# Private: parent permissions for recursion inheritance.
|
||||
# Set only by merged_with_parent(); never exposed in block input schema.
|
||||
_parent: CopilotPermissions | None = PrivateAttr(default=None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def effective_allowed_tools(self, all_tools: frozenset[str]) -> frozenset[str]:
|
||||
"""Compute the set of short tool names that are permitted.
|
||||
|
||||
Args:
|
||||
all_tools: Universe of valid short tool names.
|
||||
|
||||
Returns:
|
||||
Subset of *all_tools* that pass the filter.
|
||||
"""
|
||||
if not self.tools:
|
||||
return frozenset(all_tools)
|
||||
tool_set = frozenset(self.tools)
|
||||
if self.tools_exclude:
|
||||
return all_tools - tool_set
|
||||
return all_tools & tool_set
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Block helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_block_allowed(self, block_id: str, block_name: str) -> bool:
|
||||
"""Return True if the block may be executed under these permissions.
|
||||
|
||||
Checks this instance first, then consults the parent (if any) so
|
||||
the entire inheritance chain is respected.
|
||||
"""
|
||||
if not self._check_block_locally(block_id, block_name):
|
||||
return False
|
||||
if self._parent is not None:
|
||||
return self._parent.is_block_allowed(block_id, block_name)
|
||||
return True
|
||||
|
||||
def _check_block_locally(self, block_id: str, block_name: str) -> bool:
|
||||
"""Check *only* this instance's block filter (ignores parent)."""
|
||||
if not self.blocks:
|
||||
return True # No filter → allow all
|
||||
matched = any(
|
||||
_block_matches(identifier, block_id, block_name)
|
||||
for identifier in self.blocks
|
||||
)
|
||||
return not matched if self.blocks_exclude else matched
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recursion / merging
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def merged_with_parent(
|
||||
self,
|
||||
parent: CopilotPermissions,
|
||||
all_tools: frozenset[str],
|
||||
) -> CopilotPermissions:
|
||||
"""Return a new instance that is at most as permissive as *parent*.
|
||||
|
||||
- Tools: intersection of effective-allowed sets, stored as a whitelist.
|
||||
- Blocks: parent is stored internally; both constraints are applied
|
||||
during :meth:`is_block_allowed`.
|
||||
"""
|
||||
merged_tools = self.effective_allowed_tools(
|
||||
all_tools
|
||||
) & parent.effective_allowed_tools(all_tools)
|
||||
result = CopilotPermissions(
|
||||
tools=sorted(merged_tools),
|
||||
tools_exclude=False,
|
||||
blocks=self.blocks,
|
||||
blocks_exclude=self.blocks_exclude,
|
||||
)
|
||||
result._parent = parent
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Convenience
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True when no filtering is configured (allow-all passthrough)."""
|
||||
return not self.tools and not self.blocks and self._parent is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def all_known_tool_names() -> frozenset[str]:
|
||||
"""Return all short tool names accepted in *tools*.
|
||||
|
||||
Returns the pre-computed ``ALL_TOOL_NAMES`` set (derived from the
|
||||
``ToolName`` Literal). On first call, also verifies consistency with
|
||||
the live ``TOOL_REGISTRY``.
|
||||
"""
|
||||
_assert_tool_names_consistent()
|
||||
return ALL_TOOL_NAMES
|
||||
|
||||
|
||||
def validate_tool_names(tools: list[str]) -> list[str]:
|
||||
"""Return entries in *tools* that are not valid tool names.
|
||||
|
||||
Args:
|
||||
tools: List of short tool name strings to validate.
|
||||
|
||||
Returns:
|
||||
List of invalid names (empty if all are valid).
|
||||
"""
|
||||
return [t for t in tools if t not in ALL_TOOL_NAMES]
|
||||
|
||||
|
||||
_tool_names_checked = False
|
||||
|
||||
|
||||
def _assert_tool_names_consistent() -> None:
|
||||
"""Verify that ``PLATFORM_TOOL_NAMES`` matches ``TOOL_REGISTRY`` keys.
|
||||
|
||||
Called once lazily (TOOL_REGISTRY has heavy imports). Raises
|
||||
``AssertionError`` with a helpful diff if they diverge.
|
||||
"""
|
||||
global _tool_names_checked
|
||||
if _tool_names_checked:
|
||||
return
|
||||
_tool_names_checked = True
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
registry_keys: frozenset[str] = frozenset(TOOL_REGISTRY.keys())
|
||||
declared: frozenset[str] = PLATFORM_TOOL_NAMES
|
||||
if registry_keys != declared:
|
||||
missing = registry_keys - declared
|
||||
extra = declared - registry_keys
|
||||
parts: list[str] = [
|
||||
"PLATFORM_TOOL_NAMES in permissions.py is out of sync with TOOL_REGISTRY."
|
||||
]
|
||||
if missing:
|
||||
parts.append(f" Missing from PLATFORM_TOOL_NAMES: {sorted(missing)}")
|
||||
if extra:
|
||||
parts.append(f" Extra in PLATFORM_TOOL_NAMES: {sorted(extra)}")
|
||||
parts.append(" Update the ToolName Literal to match.")
|
||||
raise AssertionError("\n".join(parts))
|
||||
|
||||
|
||||
async def validate_block_identifiers(
|
||||
identifiers: list[str],
|
||||
) -> list[str]:
|
||||
"""Resolve each block identifier and return those that could not be matched.
|
||||
|
||||
Args:
|
||||
identifiers: List of block identifiers (name, full UUID, or partial UUID).
|
||||
|
||||
Returns:
|
||||
List of identifiers that matched no known block.
|
||||
"""
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
# get_blocks() returns dict[block_id_str, BlockClass]; instantiate once to get names.
|
||||
block_registry = get_blocks()
|
||||
block_info = {bid: cls().name for bid, cls in block_registry.items()}
|
||||
invalid: list[str] = []
|
||||
for ident in identifiers:
|
||||
matched = any(
|
||||
_block_matches(ident, bid, bname) for bid, bname in block_info.items()
|
||||
)
|
||||
if not matched:
|
||||
invalid.append(ident)
|
||||
return invalid
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDK tool-list application
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_tool_permissions(
|
||||
permissions: CopilotPermissions,
|
||||
*,
|
||||
use_e2b: bool = False,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Compute (allowed_tools, extra_disallowed) for :class:`ClaudeAgentOptions`.
|
||||
|
||||
Takes the base allowed/disallowed lists from
|
||||
:func:`~backend.copilot.sdk.tool_adapter.get_copilot_tool_names` /
|
||||
:func:`~backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools` and
|
||||
applies *permissions* on top.
|
||||
|
||||
Returns:
|
||||
``(allowed_tools, extra_disallowed)`` where *allowed_tools* is the
|
||||
possibly-narrowed list to pass to ``ClaudeAgentOptions.allowed_tools``
|
||||
and *extra_disallowed* is the list to pass to
|
||||
``ClaudeAgentOptions.disallowed_tools``.
|
||||
"""
|
||||
from backend.copilot.sdk.tool_adapter import (
|
||||
_READ_TOOL_NAME,
|
||||
MCP_TOOL_PREFIX,
|
||||
get_copilot_tool_names,
|
||||
get_sdk_disallowed_tools,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
base_allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
base_disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
if permissions.is_empty():
|
||||
return base_allowed, base_disallowed
|
||||
|
||||
all_tools = all_known_tool_names()
|
||||
effective = permissions.effective_allowed_tools(all_tools)
|
||||
|
||||
# In E2B mode, SDK built-in file tools (Read, Write, Edit, Glob, Grep)
|
||||
# are replaced by MCP equivalents (read_file, write_file, ...).
|
||||
# Map each SDK built-in name to its E2B MCP name so users can use the
|
||||
# familiar names in their permissions and the E2B tools are included.
|
||||
_SDK_TO_E2B: dict[str, str] = {}
|
||||
if use_e2b:
|
||||
from backend.copilot.sdk.e2b_file_tools import E2B_FILE_TOOL_NAMES
|
||||
|
||||
_SDK_TO_E2B = dict(
|
||||
zip(
|
||||
["Read", "Write", "Edit", "Glob", "Grep"],
|
||||
E2B_FILE_TOOL_NAMES,
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
|
||||
# Build an updated allowed list by mapping short names → SDK names and
|
||||
# keeping only those present in the original base_allowed list.
|
||||
def to_sdk_names(short: str) -> list[str]:
|
||||
names: list[str] = []
|
||||
if short in TOOL_REGISTRY:
|
||||
names.append(f"{MCP_TOOL_PREFIX}{short}")
|
||||
elif short in _SDK_TO_E2B:
|
||||
# E2B mode: map SDK built-in file tool to its MCP equivalent.
|
||||
names.append(f"{MCP_TOOL_PREFIX}{_SDK_TO_E2B[short]}")
|
||||
else:
|
||||
names.append(short) # SDK built-in — used as-is
|
||||
return names
|
||||
|
||||
# short names permitted by permissions
|
||||
permitted_sdk: set[str] = set()
|
||||
for s in effective:
|
||||
permitted_sdk.update(to_sdk_names(s))
|
||||
# Always include the internal Read tool (used by SDK for large/truncated outputs)
|
||||
permitted_sdk.add(f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}")
|
||||
|
||||
filtered_allowed = [t for t in base_allowed if t in permitted_sdk]
|
||||
|
||||
# Extra disallowed = tools that were in base_allowed but are now removed
|
||||
removed = set(base_allowed) - set(filtered_allowed)
|
||||
extra_disallowed = list(set(base_disallowed) | removed)
|
||||
|
||||
return filtered_allowed, extra_disallowed
|
||||
579
autogpt_platform/backend/backend/copilot/permissions_test.py
Normal file
579
autogpt_platform/backend/backend/copilot/permissions_test.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""Tests for CopilotPermissions — tool/block capability filtering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.permissions import (
|
||||
ALL_TOOL_NAMES,
|
||||
PLATFORM_TOOL_NAMES,
|
||||
SDK_BUILTIN_TOOL_NAMES,
|
||||
CopilotPermissions,
|
||||
_block_matches,
|
||||
all_known_tool_names,
|
||||
apply_tool_permissions,
|
||||
validate_block_identifiers,
|
||||
validate_tool_names,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _block_matches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBlockMatches:
|
||||
BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
BLOCK_NAME = "HTTP Request"
|
||||
|
||||
def test_full_uuid_match(self):
|
||||
assert _block_matches(self.BLOCK_ID, self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_full_uuid_case_insensitive(self):
|
||||
assert _block_matches(self.BLOCK_ID.upper(), self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_full_uuid_no_match(self):
|
||||
other = "aaaaaaaa-0000-0000-0000-000000000000"
|
||||
assert not _block_matches(other, self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_match(self):
|
||||
assert _block_matches("c069dc6b", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_case_insensitive(self):
|
||||
assert _block_matches("C069DC6B", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_no_match(self):
|
||||
assert not _block_matches("deadbeef", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_name_match(self):
|
||||
assert _block_matches("HTTP Request", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_name_case_insensitive(self):
|
||||
assert _block_matches("http request", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
assert _block_matches("HTTP REQUEST", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_name_no_match(self):
|
||||
assert not _block_matches("Unknown Block", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_not_matching_as_name(self):
|
||||
# "c069dc6b" is 8 hex chars → treated as partial UUID, NOT name match
|
||||
assert not _block_matches(
|
||||
"c069dc6b", "ffffffff-0000-0000-0000-000000000000", "c069dc6b"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.effective_allowed_tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
ALL_TOOLS = frozenset(
|
||||
["run_block", "web_fetch", "bash_exec", "find_agent", "Task", "Read"]
|
||||
)
|
||||
|
||||
|
||||
class TestEffectiveAllowedTools:
|
||||
def test_empty_list_allows_all(self):
|
||||
perms = CopilotPermissions(tools=[], tools_exclude=True)
|
||||
assert perms.effective_allowed_tools(ALL_TOOLS) == ALL_TOOLS
|
||||
|
||||
def test_empty_whitelist_allows_all(self):
|
||||
# edge: tools_exclude=False but empty list → allow all
|
||||
perms = CopilotPermissions(tools=[], tools_exclude=False)
|
||||
assert perms.effective_allowed_tools(ALL_TOOLS) == ALL_TOOLS
|
||||
|
||||
def test_blacklist_removes_listed(self):
|
||||
perms = CopilotPermissions(tools=["bash_exec", "web_fetch"], tools_exclude=True)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert "bash_exec" not in result
|
||||
assert "web_fetch" not in result
|
||||
assert "run_block" in result
|
||||
assert "Task" in result
|
||||
|
||||
def test_whitelist_keeps_only_listed(self):
|
||||
perms = CopilotPermissions(tools=["run_block", "Task"], tools_exclude=False)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert result == frozenset(["run_block", "Task"])
|
||||
|
||||
def test_whitelist_unknown_tool_yields_empty(self):
|
||||
perms = CopilotPermissions(tools=["nonexistent"], tools_exclude=False)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert result == frozenset()
|
||||
|
||||
def test_blacklist_unknown_tool_ignored(self):
|
||||
perms = CopilotPermissions(tools=["nonexistent"], tools_exclude=True)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert result == ALL_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.is_block_allowed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
BLOCK_NAME = "HTTP Request"
|
||||
|
||||
|
||||
class TestIsBlockAllowed:
|
||||
def test_empty_allows_everything(self):
|
||||
perms = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_blacklist_blocks_listed(self):
|
||||
perms = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
assert not perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_blacklist_allows_unlisted(self):
|
||||
perms = CopilotPermissions(blocks=["Other Block"], blocks_exclude=True)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_whitelist_allows_listed(self):
|
||||
perms = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=False)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_whitelist_blocks_unlisted(self):
|
||||
perms = CopilotPermissions(blocks=["Other Block"], blocks_exclude=False)
|
||||
assert not perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_blacklist(self):
|
||||
perms = CopilotPermissions(blocks=["c069dc6b"], blocks_exclude=True)
|
||||
assert not perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_full_uuid_whitelist(self):
|
||||
perms = CopilotPermissions(blocks=[BLOCK_ID], blocks_exclude=False)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_parent_blocks_when_child_allows(self):
|
||||
parent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
child = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
child._parent = parent
|
||||
assert not child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_parent_allows_when_child_blocks(self):
|
||||
parent = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
child = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
child._parent = parent
|
||||
assert not child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_both_must_allow(self):
|
||||
parent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=False)
|
||||
child = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=False)
|
||||
child._parent = parent
|
||||
assert child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_grandparent_blocks_propagate(self):
|
||||
grandparent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
parent = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
parent._parent = grandparent
|
||||
child = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
child._parent = parent
|
||||
assert not child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.merged_with_parent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergedWithParent:
|
||||
def test_tool_intersection(self):
|
||||
all_t = frozenset(["run_block", "web_fetch", "bash_exec"])
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
child = CopilotPermissions(tools=["web_fetch"], tools_exclude=True)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert "bash_exec" not in effective
|
||||
assert "web_fetch" not in effective
|
||||
assert "run_block" in effective
|
||||
|
||||
def test_parent_whitelist_narrows_child(self):
|
||||
all_t = frozenset(["run_block", "web_fetch", "bash_exec"])
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
child = CopilotPermissions(tools=[], tools_exclude=True) # allow all
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert effective == frozenset(["run_block"])
|
||||
|
||||
def test_child_cannot_expand_parent_whitelist(self):
|
||||
all_t = frozenset(["run_block", "web_fetch", "bash_exec"])
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
child = CopilotPermissions(
|
||||
tools=["run_block", "bash_exec"], tools_exclude=False
|
||||
)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
# bash_exec was not in parent's whitelist → must not appear
|
||||
assert "bash_exec" not in effective
|
||||
assert "run_block" in effective
|
||||
|
||||
def test_merged_stored_as_whitelist(self):
|
||||
all_t = frozenset(["run_block", "web_fetch"])
|
||||
parent = CopilotPermissions(tools=[], tools_exclude=True)
|
||||
child = CopilotPermissions(tools=[], tools_exclude=True)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
assert not merged.tools_exclude # stored as whitelist
|
||||
assert set(merged.tools) == {"run_block", "web_fetch"}
|
||||
|
||||
def test_block_parent_stored(self):
|
||||
all_t = frozenset(["run_block"])
|
||||
parent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
child = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
# Parent restriction is preserved via _parent
|
||||
assert not merged.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.is_empty
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsEmpty:
|
||||
def test_default_is_empty(self):
|
||||
assert CopilotPermissions().is_empty()
|
||||
|
||||
def test_with_tools_not_empty(self):
|
||||
assert not CopilotPermissions(tools=["bash_exec"]).is_empty()
|
||||
|
||||
def test_with_blocks_not_empty(self):
|
||||
assert not CopilotPermissions(blocks=["HTTP Request"]).is_empty()
|
||||
|
||||
def test_with_parent_not_empty(self):
|
||||
perms = CopilotPermissions()
|
||||
perms._parent = CopilotPermissions(tools=["bash_exec"])
|
||||
assert not perms.is_empty()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_tool_names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateToolNames:
|
||||
def test_valid_registry_tool(self):
|
||||
assert validate_tool_names(["run_block", "web_fetch"]) == []
|
||||
|
||||
def test_valid_sdk_builtin(self):
|
||||
assert validate_tool_names(["Read", "Task", "WebSearch"]) == []
|
||||
|
||||
def test_invalid_tool(self):
|
||||
result = validate_tool_names(["nonexistent_tool"])
|
||||
assert "nonexistent_tool" in result
|
||||
|
||||
def test_mixed(self):
|
||||
result = validate_tool_names(["run_block", "fake_tool"])
|
||||
assert "fake_tool" in result
|
||||
assert "run_block" not in result
|
||||
|
||||
def test_empty_list(self):
|
||||
assert validate_tool_names([]) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_block_identifiers (async)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestValidateBlockIdentifiers:
|
||||
async def test_empty_list(self):
|
||||
result = await validate_block_identifiers([])
|
||||
assert result == []
|
||||
|
||||
async def test_valid_full_uuid(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(
|
||||
["c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"]
|
||||
)
|
||||
assert result == []
|
||||
|
||||
async def test_invalid_identifier(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(["totally_unknown"])
|
||||
assert "totally_unknown" in result
|
||||
|
||||
async def test_partial_uuid_match(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(["c069dc6b"])
|
||||
assert result == []
|
||||
|
||||
async def test_name_match(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(["http request"])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# apply_tool_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApplyToolPermissions:
|
||||
def test_empty_permissions_returns_base_unchanged(self, mocker):
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=["mcp__copilot__run_block", "mcp__copilot__web_fetch", "Task"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object(), "web_fetch": object()},
|
||||
)
|
||||
perms = CopilotPermissions()
|
||||
allowed, disallowed = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
assert "mcp__copilot__web_fetch" in allowed
|
||||
|
||||
def test_blacklist_removes_tool(self, mocker):
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__web_fetch",
|
||||
"mcp__copilot__bash_exec",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{
|
||||
"run_block": object(),
|
||||
"web_fetch": object(),
|
||||
"bash_exec": object(),
|
||||
},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "web_fetch", "bash_exec", "Task"]),
|
||||
)
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__bash_exec" not in allowed
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
|
||||
def test_whitelist_keeps_only_listed(self, mocker):
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__web_fetch",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object(), "web_fetch": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "web_fetch", "Task", "WebSearch"]),
|
||||
)
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
assert "mcp__copilot__web_fetch" not in allowed
|
||||
assert "Task" not in allowed
|
||||
|
||||
def test_read_tool_always_included_even_when_blacklisted(self, mocker):
|
||||
"""mcp__copilot__Read must stay in allowed even if Read is explicitly blacklisted."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=[],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Task"]),
|
||||
)
|
||||
# Explicitly blacklist Read
|
||||
perms = CopilotPermissions(tools=["Read"], tools_exclude=True)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__Read" in allowed # always preserved for SDK internals
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
assert "Task" in allowed
|
||||
|
||||
def test_read_tool_always_included_with_narrow_whitelist(self, mocker):
|
||||
"""mcp__copilot__Read must stay in allowed even when not in a whitelist."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=[],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Task"]),
|
||||
)
|
||||
# Whitelist only run_block — Read not listed
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__Read" in allowed # always preserved for SDK internals
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
|
||||
def test_e2b_file_tools_included_when_sdk_builtin_whitelisted(self, mocker):
|
||||
"""In E2B mode, whitelisting 'Read' must include mcp__copilot__read_file."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"mcp__copilot__read_file",
|
||||
"mcp__copilot__write_file",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash", "Read", "Write", "Edit", "Glob", "Grep"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Write", "Task"]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.e2b_file_tools.E2B_FILE_TOOL_NAMES",
|
||||
["read_file", "write_file", "edit_file", "glob", "grep"],
|
||||
)
|
||||
# Whitelist Read and run_block — E2B read_file should be included
|
||||
perms = CopilotPermissions(tools=["Read", "run_block"], tools_exclude=False)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=True)
|
||||
assert "mcp__copilot__read_file" in allowed
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
# Write not whitelisted — write_file should NOT be included
|
||||
assert "mcp__copilot__write_file" not in allowed
|
||||
|
||||
def test_e2b_file_tools_excluded_when_sdk_builtin_blacklisted(self, mocker):
|
||||
"""In E2B mode, blacklisting 'Read' must also remove mcp__copilot__read_file."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"mcp__copilot__read_file",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash", "Read", "Write", "Edit", "Glob", "Grep"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Task"]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.e2b_file_tools.E2B_FILE_TOOL_NAMES",
|
||||
["read_file", "write_file", "edit_file", "glob", "grep"],
|
||||
)
|
||||
# Blacklist Read — E2B read_file should also be removed
|
||||
perms = CopilotPermissions(tools=["Read"], tools_exclude=True)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=True)
|
||||
assert "mcp__copilot__read_file" not in allowed
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
# mcp__copilot__Read is always preserved for SDK internals
|
||||
assert "mcp__copilot__Read" in allowed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDK_BUILTIN_TOOL_NAMES sanity check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSdkBuiltinToolNames:
|
||||
def test_expected_builtins_present(self):
|
||||
expected = {
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
}
|
||||
assert expected.issubset(SDK_BUILTIN_TOOL_NAMES)
|
||||
|
||||
def test_platform_names_match_tool_registry(self):
|
||||
"""PLATFORM_TOOL_NAMES (derived from ToolName Literal) must match TOOL_REGISTRY keys."""
|
||||
registry_keys = frozenset(TOOL_REGISTRY.keys())
|
||||
assert PLATFORM_TOOL_NAMES == registry_keys, (
|
||||
f"ToolName Literal is out of sync with TOOL_REGISTRY. "
|
||||
f"Missing: {registry_keys - PLATFORM_TOOL_NAMES}, "
|
||||
f"Extra: {PLATFORM_TOOL_NAMES - registry_keys}"
|
||||
)
|
||||
|
||||
def test_all_tool_names_is_union(self):
|
||||
"""ALL_TOOL_NAMES must equal PLATFORM_TOOL_NAMES | SDK_BUILTIN_TOOL_NAMES."""
|
||||
assert ALL_TOOL_NAMES == PLATFORM_TOOL_NAMES | SDK_BUILTIN_TOOL_NAMES
|
||||
|
||||
def test_no_overlap_between_platform_and_sdk(self):
|
||||
"""Platform and SDK built-in names must not overlap."""
|
||||
assert PLATFORM_TOOL_NAMES.isdisjoint(SDK_BUILTIN_TOOL_NAMES)
|
||||
|
||||
def test_known_tools_includes_registry_and_builtins(self):
|
||||
known = all_known_tool_names()
|
||||
assert "run_block" in known
|
||||
assert "Read" in known
|
||||
assert "Task" in known
|
||||
@@ -6,39 +6,24 @@ handling the distinction between:
|
||||
- Local mode vs E2B mode (storage/filesystem differences)
|
||||
"""
|
||||
|
||||
from backend.blocks.autopilot import AUTOPILOT_BLOCK_ID
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
_SHARED_TOOL_NOTES = f"""\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
### Sharing files
|
||||
After `write_workspace_file`, embed the `download_url` in Markdown:
|
||||
- File: `[report.csv](workspace://file_id#text/csv)`
|
||||
- Image: ``
|
||||
- Video: ``
|
||||
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Passing file content to tools — @@agptfile: references
|
||||
Instead of copying large file contents into a tool argument, pass a file
|
||||
reference and the platform will load the content for you.
|
||||
|
||||
Syntax: `@@agptfile:<uri>[<start>-<end>]`
|
||||
|
||||
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
|
||||
- `workspace://<file_id>` — workspace file by ID
|
||||
- `workspace:///<path>` — workspace file by virtual path
|
||||
- `/absolute/local/path` — ephemeral or sdk_cwd file
|
||||
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
|
||||
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
|
||||
- URIs that do not start with `workspace://` or `/` are **not** expanded.
|
||||
### File references — @@agptfile:
|
||||
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
|
||||
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
|
||||
- `/absolute/path` — local/sandbox files
|
||||
- `[start-end]` — optional 1-indexed line range
|
||||
- Multiple refs per argument supported. Only `workspace://` and absolute paths are expanded.
|
||||
|
||||
Examples:
|
||||
```
|
||||
@@ -49,21 +34,9 @@ Examples:
|
||||
@@agptfile:/home/user/script.py
|
||||
```
|
||||
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
**Structured data**: When the entire argument is a single file reference, the platform auto-parses by extension/MIME. Supported: JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel (.xlsx only; legacy `.xls` is NOT supported). Unrecognised formats return plain string.
|
||||
|
||||
**Structured data**: When the **entire** argument value is a single file
|
||||
reference (no surrounding text), the platform automatically parses the file
|
||||
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||
unrecognised or parsing fails, the content is returned as a plain string.
|
||||
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||
|
||||
**Type coercion**: The platform also coerces expanded values to match the
|
||||
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||
**Type coercion**: The platform auto-coerces expanded string values to match block input types (e.g. JSON string → `list[list[str]]`).
|
||||
|
||||
### Media file inputs (format: "file")
|
||||
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||
@@ -81,18 +54,97 @@ that would be corrupted by text encoding.
|
||||
|
||||
Example — committing an image file to GitHub:
|
||||
```json
|
||||
{
|
||||
"files": [{
|
||||
{{
|
||||
"files": [{{
|
||||
"path": "docs/hero.png",
|
||||
"content": "workspace://abc123#image/png",
|
||||
"operation": "upsert"
|
||||
}]
|
||||
}
|
||||
}}]
|
||||
}}
|
||||
```
|
||||
|
||||
### Writing large files — CRITICAL
|
||||
**Never write an entire large document in a single tool call.** When the
|
||||
content you want to write exceeds ~2000 words the tool call's output token
|
||||
limit will silently truncate the arguments, producing an empty `{{}}` input
|
||||
that fails repeatedly.
|
||||
|
||||
**Preferred: compose from file references.** If the data is already in
|
||||
files (tool outputs, workspace files), compose the report in one call
|
||||
using `@@agptfile:` references — the system expands them inline:
|
||||
|
||||
```bash
|
||||
cat > report.md << 'EOF'
|
||||
# Research Report
|
||||
## Data from web research
|
||||
@@agptfile:/home/user/web_results.txt
|
||||
## Block execution output
|
||||
@@agptfile:workspace://<file_id>
|
||||
## Conclusion
|
||||
<brief synthesis>
|
||||
EOF
|
||||
```
|
||||
|
||||
**Fallback: write section-by-section.** When you must generate content
|
||||
from conversation context (no files to reference), split into multiple
|
||||
`bash_exec` calls — one section per call:
|
||||
|
||||
```bash
|
||||
cat > report.md << 'EOF'
|
||||
# Section 1
|
||||
<content from your earlier tool call results>
|
||||
EOF
|
||||
```
|
||||
```bash
|
||||
cat >> report.md << 'EOF'
|
||||
# Section 2
|
||||
<content from your earlier tool call results>
|
||||
EOF
|
||||
```
|
||||
Use `cat >` for the first chunk and `cat >>` to append subsequent chunks.
|
||||
Do not re-fetch or re-generate data you already have from prior tool calls.
|
||||
|
||||
After building the file, reference it with `@@agptfile:` in other tools:
|
||||
`@@agptfile:/home/user/report.md`
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
|
||||
### Delegating to another autopilot (sub-autopilot pattern)
|
||||
Use the **AutoPilotBlock** (`run_block` with block_id
|
||||
`{AUTOPILOT_BLOCK_ID}`) to delegate a task to a fresh
|
||||
autopilot instance. The sub-autopilot has its own full tool set and can
|
||||
perform multi-step work autonomously.
|
||||
|
||||
- **Input**: `prompt` (required) — the task description.
|
||||
Optional: `system_context` to constrain behavior, `session_id` to
|
||||
continue a previous conversation, `max_recursion_depth` (default 3).
|
||||
- **Output**: `response` (text), `tool_calls` (list), `session_id`
|
||||
(for continuation), `conversation_history`, `token_usage`.
|
||||
|
||||
Use this when a task is complex enough to benefit from a separate
|
||||
autopilot context, e.g. "research X and write a report" while the
|
||||
parent autopilot handles orchestration.
|
||||
"""
|
||||
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
- If the token changes mid-session (e.g. user reconnects with a new token),
|
||||
run `gh auth setup-git` to re-register the credential helper.
|
||||
- If `gh` or `git` fails with an authentication error (e.g. "authentication
|
||||
required", "could not read Username", or exit code 128), call
|
||||
`connect_integration(provider="github")` to surface the GitHub credentials
|
||||
setup card so the user can connect their account. Once connected, retry
|
||||
the operation.
|
||||
- For operations that need broader access (e.g. private org repos, GitHub
|
||||
Actions), pass the required scopes: e.g.
|
||||
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -105,6 +157,7 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
extra_notes: str = "",
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
@@ -119,6 +172,7 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
extra_notes: Environment-specific notes appended after shared notes
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
@@ -128,17 +182,12 @@ def _build_storage_supplement(
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs {sandbox_type}.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `{working_dir}`
|
||||
- All SDK file tools AND `bash_exec` operate on the same filesystem
|
||||
- Use relative paths or absolute paths under `{working_dir}` for all file operations
|
||||
### Shell & filesystem
|
||||
- The SDK built-in Bash tool is NOT available. Use `bash_exec` for shell commands ({sandbox_type}). Working dir: `{working_dir}`
|
||||
- SDK file tools (Read/Write/Edit/Glob/Grep) and `bash_exec` share one filesystem — use relative or absolute paths under this dir.
|
||||
- `read_workspace_file`/`write_workspace_file` operate on **persistent cloud workspace storage** (separate from the working dir).
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **{storage_system_1_name}** (`{working_dir}`):
|
||||
{characteristics}
|
||||
{persistence}
|
||||
@@ -152,12 +201,23 @@ def _build_storage_supplement(
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
### SDK tool-result files
|
||||
When tool outputs are large, the SDK truncates them and saves the full output to
|
||||
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
|
||||
always use `read_file` or `Read` (NOT `read_workspace_file`).
|
||||
`read_workspace_file` reads from cloud workspace storage, where SDK
|
||||
tool-results are NOT stored.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
"""Local ephemeral storage (files lost between turns).
|
||||
|
||||
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
|
||||
like gh will not work — no integration env-var notes are included.
|
||||
"""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
@@ -175,7 +235,11 @@ def _get_local_storage_supplement(cwd: str) -> str:
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
"""Cloud persistent sandbox (files survive across turns in session).
|
||||
|
||||
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
|
||||
injected per command in bash_exec — include the CLI guidance notes.
|
||||
"""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
@@ -190,6 +254,7 @@ def _get_cloud_sandbox_supplement() -> str:
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
extra_notes=_E2B_TOOL_NOTES,
|
||||
)
|
||||
|
||||
|
||||
|
||||
63
autogpt_platform/backend/backend/copilot/providers.py
Normal file
63
autogpt_platform/backend/backend/copilot/providers.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Single source of truth for copilot-supported integration providers.
|
||||
|
||||
Both :mod:`~backend.copilot.integration_creds` (env-var injection) and
|
||||
:mod:`~backend.copilot.tools.connect_integration` (UI setup card) import from
|
||||
here, eliminating the risk of the two registries drifting out of sync.
|
||||
"""
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class ProviderEntry(TypedDict):
|
||||
"""Metadata for a supported integration provider.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable display name (e.g. "GitHub").
|
||||
env_vars: Environment variable names injected when the provider is
|
||||
connected (e.g. ``["GH_TOKEN", "GITHUB_TOKEN"]``).
|
||||
default_scopes: Default OAuth scopes requested when the agent does not
|
||||
specify any.
|
||||
"""
|
||||
|
||||
name: str
|
||||
env_vars: list[str]
|
||||
default_scopes: list[str]
|
||||
|
||||
|
||||
def _is_github_oauth_configured() -> bool:
|
||||
"""Return True if GitHub OAuth env vars are set.
|
||||
|
||||
Uses a lazy import to avoid triggering ``Secrets()`` during module import,
|
||||
which can fail in environments where secrets are not yet loaded (e.g. tests,
|
||||
CLI tooling).
|
||||
"""
|
||||
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
return GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
|
||||
# -- Registry ----------------------------------------------------------------
|
||||
# Add new providers here. Both env-var injection and the setup-card tool read
|
||||
# from this single registry.
|
||||
|
||||
SUPPORTED_PROVIDERS: dict[str, ProviderEntry] = {
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"env_vars": ["GH_TOKEN", "GITHUB_TOKEN"],
|
||||
"default_scopes": ["repo"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_provider_auth_types(provider: str) -> list[str]:
|
||||
"""Return the supported credential types for *provider* at runtime.
|
||||
|
||||
OAuth types are only offered when the corresponding OAuth client env vars
|
||||
are configured.
|
||||
"""
|
||||
if provider == "github":
|
||||
if _is_github_oauth_configured():
|
||||
return ["api_key", "oauth2"]
|
||||
return ["api_key"]
|
||||
# Default for unknown/future providers — API key only.
|
||||
return ["api_key"]
|
||||
@@ -43,6 +43,7 @@ class ResponseType(str, Enum):
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
HEARTBEAT = "heartbeat"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
@@ -263,3 +264,19 @@ class StreamHeartbeat(StreamBaseResponse):
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE comment format to keep connection alive."""
|
||||
return ": heartbeat\n\n"
|
||||
|
||||
|
||||
class StreamStatus(StreamBaseResponse):
|
||||
"""Transient status notification shown to the user during long operations.
|
||||
|
||||
Used to provide feedback when the backend performs behind-the-scenes work
|
||||
(e.g., compacting conversation context on a retry) that would otherwise
|
||||
leave the user staring at an unexplained pause.
|
||||
|
||||
Sent as a proper ``data:`` event so the frontend can display it to the
|
||||
user. The AI SDK stream parser gracefully skips unknown chunk types
|
||||
(logs a console warning), so this does not break the stream.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.STATUS
|
||||
message: str = Field(..., description="Human-readable status message")
|
||||
|
||||
@@ -19,9 +19,19 @@ least invasive way to break the cycle while keeping module-level constants
|
||||
intact.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
# Static imports for type checkers so they can resolve __all__ entries
|
||||
# without executing the lazy-import machinery at runtime.
|
||||
if TYPE_CHECKING:
|
||||
from .collect import CopilotResult as CopilotResult
|
||||
from .collect import collect_copilot_response as collect_copilot_response
|
||||
from .service import stream_chat_completion_sdk as stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server as create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"CopilotResult",
|
||||
"collect_copilot_response",
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
@@ -29,6 +39,8 @@ __all__ = [
|
||||
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
|
||||
# pair so new exports can be added without touching __getattr__ itself.
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"CopilotResult": (".collect", "CopilotResult"),
|
||||
"collect_copilot_response": (".collect", "collect_copilot_response"),
|
||||
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
|
||||
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
|
||||
}
|
||||
|
||||
@@ -143,6 +143,71 @@ To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||
tool_arguments.
|
||||
6. Output: `result` (the tool's return value) and `error` (error message)
|
||||
|
||||
### Using OrchestratorBlock (AI Orchestrator with Agent Mode)
|
||||
|
||||
To create an agent where AI autonomously decides which tools or sub-agents to
|
||||
call in a loop until the task is complete:
|
||||
1. Create a `OrchestratorBlock` node
|
||||
(ID: `3b191d9f-356f-482d-8238-ba04b6d18381`)
|
||||
2. Set `input_default`:
|
||||
- `agent_mode_max_iterations`: Choose based on task complexity:
|
||||
- `1` for single-step tool calls (AI picks one tool, calls it, done)
|
||||
- `3`–`10` for multi-step tasks (AI calls tools iteratively)
|
||||
- `-1` for open-ended orchestration (AI loops until it decides it's done).
|
||||
**Use with caution** — prefer bounded iterations (3–10) unless
|
||||
genuinely needed, as unbounded loops risk runaway cost and execution.
|
||||
Do NOT use `0` (traditional mode) — it requires complex external
|
||||
conversation-history loop wiring that the agent generator does not
|
||||
produce.
|
||||
- `conversation_compaction`: `true` (recommended to avoid context overflow)
|
||||
- `retry`: Number of retries on tool-call failure (default `3`).
|
||||
Set to `0` to disable retries.
|
||||
- `multiple_tool_calls`: Whether the AI can invoke multiple tools in a
|
||||
single turn (default `false`). Enable when tools are independent and
|
||||
can run concurrently.
|
||||
- Optional: `sys_prompt` for extra LLM context about how to orchestrate
|
||||
3. Wire the `prompt` input from an `AgentInputBlock` (the user's task)
|
||||
4. Create downstream tool blocks — regular blocks **or** `AgentExecutorBlock`
|
||||
nodes that call sub-agents
|
||||
5. Link each tool to the Orchestrator: set `source_name: "tools"` on
|
||||
the Orchestrator side and `sink_name: <input_field>` on each tool
|
||||
block's input. Create one link per input field the tool needs.
|
||||
6. Wire the `finished` output to an `AgentOutputBlock` for the final result
|
||||
7. Credentials (LLM API key) are configured by the user in the platform UI
|
||||
after saving — do NOT require them upfront
|
||||
|
||||
**Example — Orchestrator calling two sub-agents:**
|
||||
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
|
||||
- Node 2: `OrchestratorBlock` (input_default:
|
||||
`{"agent_mode_max_iterations": 10, "conversation_compaction": true}`)
|
||||
- Node 3: `AgentExecutorBlock` (sub-agent A — set `graph_id`, `graph_version`,
|
||||
`input_schema`, `output_schema` from library agent)
|
||||
- Node 4: `AgentExecutorBlock` (sub-agent B — same pattern)
|
||||
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
|
||||
- Links:
|
||||
- Input→Orchestrator: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- Orchestrator→Agent A (per input field): `source_name: "tools"`,
|
||||
`sink_name: "<agent_a_input_field>"`
|
||||
- Orchestrator→Agent B (per input field): `source_name: "tools"`,
|
||||
`sink_name: "<agent_b_input_field>"`
|
||||
- Orchestrator→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
|
||||
**Example — Orchestrator calling regular blocks as tools:**
|
||||
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
|
||||
- Node 2: `OrchestratorBlock` (input_default:
|
||||
`{"agent_mode_max_iterations": 5, "conversation_compaction": true}`)
|
||||
- Node 3: `GetWebpageBlock` (regular block — the AI calls it as a tool)
|
||||
- Node 4: `AITextGeneratorBlock` (another regular block as a tool)
|
||||
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
|
||||
- Links:
|
||||
- Input→Orchestrator: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- Orchestrator→GetWebpage: `source_name: "tools"`, `sink_name: "url"`
|
||||
- Orchestrator→AITextGenerator: `source_name: "tools"`, `sink_name: "prompt"`
|
||||
- Orchestrator→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
|
||||
Regular blocks work exactly like sub-agents as tools — wire each input
|
||||
field from `source_name: "tools"` on the Orchestrator side.
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
A minimal agent with input, processing, and output:
|
||||
|
||||
115
autogpt_platform/backend/backend/copilot/sdk/collect.py
Normal file
115
autogpt_platform/backend/backend/copilot/sdk/collect.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Public helpers for consuming a copilot stream as a simple request-response.
|
||||
|
||||
This module exposes :class:`CopilotResult` and :func:`collect_copilot_response`
|
||||
so that callers (e.g. the AutoPilot block) can consume the copilot stream
|
||||
without implementing their own event loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
from .service import stream_chat_completion_sdk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
|
||||
class CopilotResult:
|
||||
"""Aggregated result from consuming a copilot stream.
|
||||
|
||||
Returned by :func:`collect_copilot_response` so callers don't need to
|
||||
implement their own event-loop over the raw stream events.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"response_text",
|
||||
"tool_calls",
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.response_text: str = ""
|
||||
self.tool_calls: list[dict[str, Any]] = []
|
||||
self.prompt_tokens: int = 0
|
||||
self.completion_tokens: int = 0
|
||||
self.total_tokens: int = 0
|
||||
|
||||
|
||||
async def collect_copilot_response(
|
||||
*,
|
||||
session_id: str,
|
||||
message: str,
|
||||
user_id: str,
|
||||
is_user_message: bool = True,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> CopilotResult:
|
||||
"""Consume :func:`stream_chat_completion_sdk` and return aggregated results.
|
||||
|
||||
This is the recommended entry-point for callers that need a simple
|
||||
request-response interface (e.g. the AutoPilot block) rather than
|
||||
streaming individual events. It avoids duplicating the event-collection
|
||||
logic and does NOT wrap the stream in ``asyncio.timeout`` — the SDK
|
||||
manages its own heartbeat-based timeouts internally.
|
||||
|
||||
Args:
|
||||
session_id: Chat session to use.
|
||||
message: The user message / prompt.
|
||||
user_id: Authenticated user ID.
|
||||
is_user_message: Whether this is a user-initiated message.
|
||||
permissions: Optional capability filter. When provided, restricts
|
||||
which tools and blocks the copilot may use during this execution.
|
||||
|
||||
Returns:
|
||||
A :class:`CopilotResult` with the aggregated response text,
|
||||
tool calls, and token usage.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the stream yields a ``StreamError`` event.
|
||||
"""
|
||||
result = CopilotResult()
|
||||
response_parts: list[str] = []
|
||||
tool_calls_by_id: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
permissions=permissions,
|
||||
):
|
||||
if isinstance(event, StreamTextDelta):
|
||||
response_parts.append(event.delta)
|
||||
elif isinstance(event, StreamToolInputAvailable):
|
||||
entry: dict[str, Any] = {
|
||||
"tool_call_id": event.toolCallId,
|
||||
"tool_name": event.toolName,
|
||||
"input": event.input,
|
||||
"output": None,
|
||||
"success": None,
|
||||
}
|
||||
result.tool_calls.append(entry)
|
||||
tool_calls_by_id[event.toolCallId] = entry
|
||||
elif isinstance(event, StreamToolOutputAvailable):
|
||||
if tc := tool_calls_by_id.get(event.toolCallId):
|
||||
tc["output"] = event.output
|
||||
tc["success"] = event.success
|
||||
elif isinstance(event, StreamUsage):
|
||||
result.prompt_tokens += event.prompt_tokens
|
||||
result.completion_tokens += event.completion_tokens
|
||||
result.total_tokens += event.total_tokens
|
||||
elif isinstance(event, StreamError):
|
||||
raise RuntimeError(f"Copilot error: {event.errorText}")
|
||||
|
||||
result.response_text = "".join(response_parts)
|
||||
return result
|
||||
@@ -12,6 +12,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from ..model import ChatMessage, ChatSession
|
||||
@@ -119,14 +120,12 @@ def filter_compaction_messages(
|
||||
filtered: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
real_calls: list[dict[str, Any]] = []
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
|
||||
compaction_ids.add(tc.get("id", ""))
|
||||
real_calls = [
|
||||
tc
|
||||
for tc in msg.tool_calls
|
||||
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
|
||||
]
|
||||
else:
|
||||
real_calls.append(tc)
|
||||
if not real_calls and not msg.content:
|
||||
continue
|
||||
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
|
||||
@@ -222,6 +221,7 @@ class CompactionTracker:
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._compact_start.clear()
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
54
autogpt_platform/backend/backend/copilot/sdk/conftest.py
Normal file
54
autogpt_platform/backend/backend/copilot/sdk/conftest.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Shared test fixtures for copilot SDK tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_chat_config():
|
||||
"""Mock ChatConfig so compact_transcript tests skip real config lookup."""
|
||||
with patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type("Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"})(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
|
||||
"""Build a minimal valid JSONL transcript from (role, content) pairs.
|
||||
|
||||
Use this helper in any copilot SDK test that needs a well-formed
|
||||
transcript without hitting the real storage layer.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for role, content in pairs:
|
||||
uid = str(uuid4())
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
msg: dict = {"role": role, "content": content}
|
||||
if role == "assistant":
|
||||
msg.update(
|
||||
{
|
||||
"model": "",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": msg,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
@@ -1,9 +1,17 @@
|
||||
"""Dummy SDK service for testing copilot streaming.
|
||||
|
||||
Returns mock streaming responses without calling Claude Agent SDK.
|
||||
Enable via COPILOT_TEST_MODE=true environment variable.
|
||||
Enable via CHAT_TEST_MODE=true in .env (ChatConfig.test_mode).
|
||||
|
||||
WARNING: This is for testing only. Do not use in production.
|
||||
|
||||
Magic keywords (case-insensitive, anywhere in message):
|
||||
__test_transient_error__ — Simulate a transient Anthropic API error
|
||||
(ECONNRESET). Streams partial text, then
|
||||
yields StreamError with retryable prefix.
|
||||
__test_fatal_error__ — Simulate a non-retryable SDK error.
|
||||
__test_slow_response__ — Simulate a slow response (2s per word).
|
||||
(no keyword) — Normal dummy response.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -12,12 +20,39 @@ import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ..model import ChatSession
|
||||
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
FRIENDLY_TRANSIENT_MSG,
|
||||
)
|
||||
from ..model import ChatMessage, ChatSession, get_chat_session, upsert_chat_session
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _safe_upsert(session: ChatSession) -> None:
|
||||
"""Best-effort session persist — skip silently if DB is unavailable."""
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception:
|
||||
logger.debug("[TEST MODE] Could not persist session (DB unavailable)")
|
||||
|
||||
|
||||
def _has_keyword(message: str | None, keyword: str) -> bool:
|
||||
return keyword in (message or "").lower()
|
||||
|
||||
|
||||
async def stream_chat_completion_dummy(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -36,24 +71,89 @@ async def stream_chat_completion_dummy(
|
||||
- No timeout occurs
|
||||
- Text arrives in chunks
|
||||
- StreamFinish is sent by mark_session_completed
|
||||
|
||||
See module docstring for magic keywords that trigger error scenarios.
|
||||
"""
|
||||
logger.warning(
|
||||
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
|
||||
)
|
||||
|
||||
# Load session from DB (matches SDK service behaviour) so error markers
|
||||
# and the assistant reply are persisted and survive page refresh.
|
||||
# Best-effort: skip if DB is unavailable (e.g. unit tests).
|
||||
if session is None:
|
||||
try:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
except Exception:
|
||||
logger.debug("[TEST MODE] Could not load session (DB unavailable)")
|
||||
session = None
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Start the stream
|
||||
# Start the stream (matches baseline: StreamStart → StreamStartStep)
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
yield StreamStartStep()
|
||||
|
||||
# Simulate streaming text response with delays
|
||||
# --- Magic keyword: transient error (retryable) -------------------------
|
||||
if _has_keyword(message, "__test_transient_error__"):
|
||||
# Stream some partial text first (simulates mid-stream failure)
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
for word in ["Working", "on", "it..."]:
|
||||
yield StreamTextDelta(id=text_block_id, delta=f"{word} ")
|
||||
await asyncio.sleep(0.1)
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
yield StreamFinishStep()
|
||||
# Persist retryable marker so "Try Again" button shows after refresh
|
||||
if session:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_RETRYABLE_ERROR_PREFIX} {FRIENDLY_TRANSIENT_MSG}",
|
||||
)
|
||||
)
|
||||
await _safe_upsert(session)
|
||||
yield StreamError(
|
||||
errorText=FRIENDLY_TRANSIENT_MSG,
|
||||
code="transient_api_error",
|
||||
)
|
||||
return
|
||||
|
||||
# --- Magic keyword: fatal error (non-retryable) -------------------------
|
||||
if _has_keyword(message, "__test_fatal_error__"):
|
||||
yield StreamFinishStep()
|
||||
error_msg = "Internal SDK error: model refused to respond"
|
||||
# Persist non-retryable error marker
|
||||
if session:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_ERROR_PREFIX} {error_msg}",
|
||||
)
|
||||
)
|
||||
await _safe_upsert(session)
|
||||
yield StreamError(errorText=error_msg, code="sdk_error")
|
||||
return
|
||||
|
||||
# --- Magic keyword: slow response ---------------------------------------
|
||||
delay = 2.0 if _has_keyword(message, "__test_slow_response__") else 0.1
|
||||
|
||||
# --- Normal dummy response ----------------------------------------------
|
||||
dummy_response = "I counted: 1... 2... 3. All done!"
|
||||
words = dummy_response.split()
|
||||
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
for i, word in enumerate(words):
|
||||
# Add space except for last word
|
||||
text = word if i == len(words) - 1 else f"{word} "
|
||||
yield StreamTextDelta(id=text_block_id, delta=text)
|
||||
# Small delay to simulate real streaming
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(delay)
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
|
||||
# Persist the assistant reply so it survives page refresh
|
||||
if session:
|
||||
session.messages.append(ChatMessage(role="assistant", content=dummy_response))
|
||||
await _safe_upsert(session)
|
||||
|
||||
yield StreamFinishStep()
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -26,6 +26,41 @@ from backend.copilot.context import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _check_sandbox_symlink_escape(
|
||||
sandbox: Any,
|
||||
parent: str,
|
||||
) -> str | None:
|
||||
"""Resolve the canonical parent path inside the sandbox to detect symlink escapes.
|
||||
|
||||
``normpath`` (used by ``resolve_sandbox_path``) only normalises the string;
|
||||
``readlink -f`` follows actual symlinks on the sandbox filesystem.
|
||||
|
||||
Returns the canonical parent path, or ``None`` if the path escapes
|
||||
``E2B_WORKDIR``.
|
||||
|
||||
Note: There is an inherent TOCTOU window between this check and the
|
||||
subsequent ``sandbox.files.write()``. A symlink could theoretically be
|
||||
replaced between the two operations. This is acceptable in the E2B
|
||||
sandbox model since the sandbox is single-user and ephemeral.
|
||||
"""
|
||||
canonical_res = await sandbox.commands.run(
|
||||
f"readlink -f {shlex.quote(parent or E2B_WORKDIR)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=5,
|
||||
)
|
||||
canonical_parent = (canonical_res.stdout or "").strip()
|
||||
if (
|
||||
canonical_res.exit_code != 0
|
||||
or not canonical_parent
|
||||
or (
|
||||
canonical_parent != E2B_WORKDIR
|
||||
and not canonical_parent.startswith(E2B_WORKDIR + "/")
|
||||
)
|
||||
):
|
||||
return None
|
||||
return canonical_parent
|
||||
|
||||
|
||||
def _get_sandbox():
|
||||
return get_current_sandbox()
|
||||
|
||||
@@ -106,6 +141,10 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
parent = os.path.dirname(remote)
|
||||
if parent and parent != E2B_WORKDIR:
|
||||
await sandbox.files.make_dir(parent)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
@@ -130,6 +169,12 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
parent = os.path.dirname(remote)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
|
||||
@@ -4,15 +4,19 @@ Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .e2b_file_tools import _read_local, resolve_sandbox_path
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
|
||||
|
||||
from .e2b_file_tools import (
|
||||
_check_sandbox_symlink_escape,
|
||||
_read_local,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
@@ -21,46 +25,48 @@ _SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
class TestResolveSandboxPath:
|
||||
def test_relative_path_resolved(self):
|
||||
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
|
||||
assert resolve_sandbox_path("src/main.py") == f"{E2B_WORKDIR}/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
|
||||
assert (
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/file.txt") == f"{E2B_WORKDIR}/file.txt"
|
||||
)
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
assert resolve_sandbox_path(E2B_WORKDIR) == E2B_WORKDIR
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
|
||||
assert resolve_sandbox_path("./README.md") == f"{E2B_WORKDIR}/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == f"{E2B_WORKDIR}/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert resolve_sandbox_path("src/") == "/home/user/src"
|
||||
assert resolve_sandbox_path("src/") == f"{E2B_WORKDIR}/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
"""Path that resolves back within E2B_WORKDIR is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == f"{E2B_WORKDIR}/a/c.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -73,9 +79,13 @@ class TestResolveSandboxPath:
|
||||
|
||||
|
||||
class TestReadLocal:
|
||||
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
|
||||
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
"""Create a tool-results file under <encoded>/<uuid>/tool-results/."""
|
||||
tool_results_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
|
||||
)
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
@@ -107,7 +117,9 @@ class TestReadLocal:
|
||||
def test_read_nonexistent_tool_results(self):
|
||||
"""A tool-results path that doesn't exist returns FileNotFoundError."""
|
||||
encoded = "-tmp-copilot-e2b-test-nofile"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
tool_results_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
|
||||
)
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
@@ -117,7 +129,7 @@ class TestReadLocal:
|
||||
assert "not found" in result["content"][0]["text"].lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.rmdir(tool_results_dir)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
def test_read_traversal_path_blocked(self):
|
||||
"""A traversal attempt that escapes allowed directories is blocked."""
|
||||
@@ -152,3 +164,66 @@ class TestReadLocal:
|
||||
"""Without _current_project_dir set, all paths are blocked."""
|
||||
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_sandbox_symlink_escape — symlink escape detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_sandbox(stdout: str, exit_code: int = 0) -> SimpleNamespace:
|
||||
"""Build a minimal sandbox mock whose commands.run returns a fixed result."""
|
||||
run_result = SimpleNamespace(stdout=stdout, exit_code=exit_code)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
return SimpleNamespace(commands=commands)
|
||||
|
||||
|
||||
class TestCheckSandboxSymlinkEscape:
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonical_path_within_workdir_returns_path(self):
|
||||
"""When readlink -f resolves to a path inside E2B_WORKDIR, returns it."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/src\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result == f"{E2B_WORKDIR}/src"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workdir_itself_returns_workdir(self):
|
||||
"""When readlink -f resolves to E2B_WORKDIR exactly, returns E2B_WORKDIR."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, E2B_WORKDIR)
|
||||
assert result == E2B_WORKDIR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_symlink_escape_returns_none(self):
|
||||
"""When readlink -f resolves outside E2B_WORKDIR (symlink escape), returns None."""
|
||||
sandbox = _make_sandbox(stdout="/etc\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/evil")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonzero_exit_code_returns_none(self):
|
||||
"""A non-zero exit code from readlink -f returns None."""
|
||||
sandbox = _make_sandbox(stdout="", exit_code=1)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_stdout_returns_none(self):
|
||||
"""Empty stdout from readlink (e.g. path doesn't exist yet) returns None."""
|
||||
sandbox = _make_sandbox(stdout="", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefix_collision_returns_none(self):
|
||||
"""A path prefixed with E2B_WORKDIR but not within it is rejected."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}-evil\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}-evil")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deeply_nested_path_within_workdir(self):
|
||||
"""Deep nested paths inside E2B_WORKDIR are allowed."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/a/b/c/d\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/a/b/c/d")
|
||||
assert result == f"{E2B_WORKDIR}/a/b/c/d"
|
||||
|
||||
@@ -0,0 +1,651 @@
|
||||
"""Tests for retry logic and transcript compaction helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _friendly_error_text, _is_prompt_too_long
|
||||
from .transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_run_compression,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenAssistantContent:
|
||||
def test_text_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "World"},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Let me read that."},
|
||||
{"type": "tool_use", "name": "Read", "input": {"path": "/foo"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
|
||||
def test_raw_strings(self):
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "See this image:"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "See this image:" in result
|
||||
assert "[__image__]" in result
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_tool_result_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenToolResultContent:
|
||||
def test_tool_result_with_text(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": [{"type": "text", "text": "file contents here"}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "file contents here"
|
||||
|
||||
def test_tool_result_with_string_content(self):
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "123", "content": "ok"}]
|
||||
assert _flatten_tool_result_content(blocks) == "ok"
|
||||
|
||||
def test_text_block(self):
|
||||
blocks = [{"type": "text", "text": "plain text"}]
|
||||
assert _flatten_tool_result_content(blocks) == "plain text"
|
||||
|
||||
def test_raw_string(self):
|
||||
assert _flatten_tool_result_content(["raw"]) == "raw"
|
||||
|
||||
def test_tool_result_with_none_content(self):
|
||||
"""tool_result with content=None should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": None}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_tool_result_with_empty_list_content(self):
|
||||
"""tool_result with content=[] should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": []}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_tool_result_content([]) == ""
|
||||
|
||||
def test_nested_dict_without_text(self):
|
||||
"""Dict blocks without text key use json.dumps fallback."""
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "x",
|
||||
"content": [{"type": "image", "source": "data:..."}],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "image" in result # json.dumps fallback
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [{"type": "image", "source": {"type": "base64", "data": "..."}}]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "[__image__]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _transcript_to_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_entry(entry_type: str, role: str, content: str | list, **kwargs) -> str:
|
||||
"""Build a JSONL line for testing."""
|
||||
uid = str(uuid4())
|
||||
msg: dict = {"role": role, "content": content}
|
||||
msg.update(kwargs)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": None,
|
||||
"message": msg,
|
||||
}
|
||||
return json.dumps(entry, separators=(",", ":"))
|
||||
|
||||
|
||||
class TestTranscriptToMessages:
|
||||
def test_basic_roundtrip(self):
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == {"role": "user", "content": "Hello"}
|
||||
assert messages[1] == {"role": "assistant", "content": "Hi"}
|
||||
|
||||
def test_skips_strippable_types(self):
|
||||
"""Progress and metadata entries are excluded."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "progress",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"role": "assistant", "content": "..."},
|
||||
}
|
||||
),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _transcript_to_messages("") == []
|
||||
|
||||
def test_tool_result_content(self):
|
||||
"""User entries with tool_result content blocks are flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": "tool output",
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "tool output"
|
||||
|
||||
def test_malformed_json_lines_skipped(self):
|
||||
"""Malformed JSON lines in transcript are silently skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"this is not valid json",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_lines_skipped(self):
|
||||
"""Empty lines and whitespace-only lines are skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"",
|
||||
" ",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_unicode_content_preserved(self):
|
||||
"""Unicode characters survive transcript roundtrip."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello 你好 🌍"),
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[{"type": "text", "text": "Bonjour 日本語 émojis 🎉"}],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert messages[0]["content"] == "Hello 你好 🌍"
|
||||
assert messages[1]["content"] == "Bonjour 日本語 émojis 🎉"
|
||||
|
||||
def test_entry_without_role_skipped(self):
|
||||
"""Entries with missing role in message are skipped."""
|
||||
entry_no_role = json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"content": "no role here"},
|
||||
}
|
||||
)
|
||||
lines = [
|
||||
entry_no_role,
|
||||
_make_entry("user", "user", "Hello"),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "Hello"
|
||||
|
||||
def test_tool_use_and_result_pairs(self):
|
||||
"""Tool use + tool result pairs are properly flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[
|
||||
{"type": "text", "text": "Let me check."},
|
||||
{"type": "tool_use", "name": "read_file", "input": {"path": "/x"}},
|
||||
],
|
||||
),
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "abc",
|
||||
"content": [{"type": "text", "text": "file contents"}],
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert "Let me check." in messages[0]["content"]
|
||||
assert "[tool_use: read_file]" in messages[0]["content"]
|
||||
assert messages[1]["content"] == "file contents"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _messages_to_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessagesToTranscript:
|
||||
def test_produces_valid_jsonl(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
for line in lines:
|
||||
parsed = json.loads(line)
|
||||
assert "type" in parsed
|
||||
assert "uuid" in parsed
|
||||
assert "message" in parsed
|
||||
|
||||
def test_assistant_has_proper_structure(self):
|
||||
messages = [{"role": "assistant", "content": "Hello"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "assistant"
|
||||
msg = entry["message"]
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["type"] == "message"
|
||||
assert msg["stop_reason"] == "end_turn"
|
||||
assert isinstance(msg["content"], list)
|
||||
assert msg["content"][0]["type"] == "text"
|
||||
|
||||
def test_user_has_plain_content(self):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "user"
|
||||
assert entry["message"]["content"] == "Hi"
|
||||
|
||||
def test_parent_uuid_chain(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "A"},
|
||||
{"role": "assistant", "content": "B"},
|
||||
{"role": "user", "content": "C"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
entries = [json.loads(line) for line in lines]
|
||||
assert entries[0]["parentUuid"] == ""
|
||||
assert entries[1]["parentUuid"] == entries[0]["uuid"]
|
||||
assert entries[2]["parentUuid"] == entries[1]["uuid"]
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _messages_to_transcript([]) == ""
|
||||
|
||||
def test_output_is_valid_transcript(self):
|
||||
"""Output should pass validate_transcript if it has assistant entries."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result)
|
||||
|
||||
def test_roundtrip_to_messages(self):
|
||||
"""Messages → transcript → messages preserves structure."""
|
||||
original = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
transcript = _messages_to_transcript(original)
|
||||
restored = _transcript_to_messages(transcript)
|
||||
assert len(restored) == len(original)
|
||||
for orig, rest in zip(original, restored):
|
||||
assert orig["role"] == rest["role"]
|
||||
assert orig["content"] == rest["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compact_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_few_messages_returns_none(self, mock_chat_config):
|
||||
"""compact_transcript returns None when transcript has < 2 messages."""
|
||||
transcript = _build_transcript([("user", "Hello")])
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_compacted(self, mock_chat_config):
|
||||
"""When compress_context says no compaction needed, returns None.
|
||||
The compressor couldn't reduce it, so retrying with the same
|
||||
content would fail identically."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi there"),
|
||||
]
|
||||
)
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": False,
|
||||
"messages": [],
|
||||
"original_token_count": 100,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_compacted_transcript(self, mock_chat_config):
|
||||
"""When compaction succeeds, returns a valid compacted transcript."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
("user", "More"),
|
||||
("assistant", "Details"),
|
||||
]
|
||||
)
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 500,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[1]["content"] == "Summarized response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_compression_failure(self, mock_chat_config):
|
||||
"""When _run_compression raises, returns None."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
]
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
"""Unit tests for _is_prompt_too_long pattern matching."""
|
||||
|
||||
def test_prompt_is_too_long(self):
|
||||
err = RuntimeError("prompt is too long for model context")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_request_too_large(self):
|
||||
err = Exception("request too large: 250000 tokens")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_maximum_context_length(self):
|
||||
err = ValueError("maximum context length exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_context_length_exceeded(self):
|
||||
err = Exception("context_length_exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_tokens_exceed(self):
|
||||
err = Exception("input tokens exceed the max_tokens limit")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_is_too_long(self):
|
||||
err = Exception("input is too long for the model")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_content_length_exceeds(self):
|
||||
err = Exception("content length exceeds maximum")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_unrelated_error_returns_false(self):
|
||||
err = RuntimeError("network timeout")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_auth_error_returns_false(self):
|
||||
err = Exception("authentication failed: invalid API key")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_chained_exception_detected(self):
|
||||
"""Prompt-too-long error wrapped in another exception is detected."""
|
||||
inner = RuntimeError("prompt is too long")
|
||||
outer = Exception("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
err = Exception("PROMPT IS TOO LONG")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_old_max_tokens_exceeded_not_matched(self):
|
||||
"""The old broad 'max_tokens_exceeded' pattern was removed.
|
||||
Only 'input tokens exceed' should match now."""
|
||||
err = Exception("max_tokens_exceeded")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_compression timeout fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunCompressionTimeout:
|
||||
"""Verify _run_compression falls back to truncation when LLM times out."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_falls_back_to_truncation(self):
|
||||
"""When compress_context with LLM client times out,
|
||||
_run_compression falls back to truncation (client=None)."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
truncation_result = CompressResult(
|
||||
messages=messages,
|
||||
was_compacted=False,
|
||||
original_token_count=50,
|
||||
token_count=50,
|
||||
messages_summarized=0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
call_args: list[dict] = []
|
||||
|
||||
async def _mock_compress(**kwargs):
|
||||
call_args.append(kwargs)
|
||||
if kwargs.get("client") is not None:
|
||||
# Simulate timeout by raising asyncio.TimeoutError
|
||||
raise asyncio.TimeoutError("LLM compaction timed out")
|
||||
return truncation_result
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value="fake-client",
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_mock_compress,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(messages, "test-model", "[test]")
|
||||
|
||||
assert result == truncation_result
|
||||
# Should have been called twice: once with client, once without
|
||||
assert len(call_args) == 2
|
||||
assert call_args[0]["client"] is not None # LLM attempt
|
||||
assert call_args[1]["client"] is None # truncation fallback
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation_directly(self):
|
||||
"""When no OpenAI client is configured, goes straight to truncation."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
truncation_result = CompressResult(
|
||||
messages=messages,
|
||||
was_compacted=False,
|
||||
original_token_count=50,
|
||||
token_count=50,
|
||||
messages_summarized=0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(messages, "test-model", "[test]")
|
||||
|
||||
assert result == truncation_result
|
||||
mock_compress.assert_called_once()
|
||||
# When no client, compress_context is called with client=None
|
||||
assert mock_compress.call_args.kwargs.get("client") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _friendly_error_text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFriendlyErrorText:
|
||||
"""Verify user-friendly error message mapping."""
|
||||
|
||||
def test_authentication_error(self):
|
||||
result = _friendly_error_text("authentication failed: invalid API key")
|
||||
assert "Authentication" in result
|
||||
assert "API key" in result
|
||||
|
||||
def test_rate_limit_error(self):
|
||||
result = _friendly_error_text("rate limit exceeded")
|
||||
assert "Rate limit" in result
|
||||
|
||||
def test_overloaded_error(self):
|
||||
result = _friendly_error_text("API is overloaded")
|
||||
assert "overloaded" in result
|
||||
|
||||
def test_timeout_error(self):
|
||||
result = _friendly_error_text("Request timeout after 30s")
|
||||
assert "timed out" in result
|
||||
|
||||
def test_connection_error(self):
|
||||
result = _friendly_error_text("Connection refused")
|
||||
assert "Connection" in result or "connection" in result
|
||||
|
||||
def test_unknown_error_passthrough(self):
|
||||
result = _friendly_error_text("some unknown error XYZ")
|
||||
assert "SDK stream error:" in result
|
||||
assert "XYZ" in result
|
||||
|
||||
def test_unauthorized_error(self):
|
||||
result = _friendly_error_text("401 Unauthorized")
|
||||
assert "Authentication" in result
|
||||
@@ -20,6 +20,7 @@ from claude_agent_sdk import (
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.copilot.constants import FRIENDLY_TRANSIENT_MSG, is_transient_api_error
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -214,10 +215,12 @@ class SDKResponseAdapter:
|
||||
if sdk_message.subtype == "success":
|
||||
responses.append(StreamFinish())
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
error_msg = sdk_message.result or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
raw_error = str(sdk_message.result or "Unknown error")
|
||||
if is_transient_api_error(raw_error):
|
||||
error_text, code = FRIENDLY_TRANSIENT_MSG, "transient_api_error"
|
||||
else:
|
||||
error_text, code = raw_error, "sdk_error"
|
||||
responses.append(StreamError(errorText=error_text, code=code))
|
||||
responses.append(StreamFinish())
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
1410
autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py
Normal file
1410
autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -42,7 +42,7 @@ def _validate_workspace_path(
|
||||
Delegates to :func:`is_allowed_local_path` which permits:
|
||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||
- The current session's tool-results directory
|
||||
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
|
||||
(``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/``)
|
||||
"""
|
||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||
if not path:
|
||||
@@ -302,7 +302,11 @@ def create_security_hooks(
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
# Sanitize untrusted input before logging to prevent log injection
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
transcript_path = (
|
||||
str(input_data.get("transcript_path", ""))
|
||||
.replace("\n", "")
|
||||
|
||||
@@ -122,7 +122,7 @@ def test_read_no_cwd_denies_absolute():
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-results/12345.txt"
|
||||
# is_allowed_local_path requires the session's encoded cwd to be set
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,337 @@
|
||||
"""Unit tests for extracted service helpers.
|
||||
|
||||
Covers ``_is_prompt_too_long``, ``_reduce_context``, ``_iter_sdk_messages``,
|
||||
``ReducedContext``, and the ``is_parallel_continuation`` logic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from claude_agent_sdk import AssistantMessage, TextBlock, ToolUseBlock
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
ReducedContext,
|
||||
_is_prompt_too_long,
|
||||
_is_tool_only_message,
|
||||
_iter_sdk_messages,
|
||||
_reduce_context,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
def test_direct_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("prompt is too long")) is True
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("PROMPT IS TOO LONG")) is True
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("network timeout")) is False
|
||||
|
||||
def test_request_too_large(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("request too large for model")) is True
|
||||
|
||||
def test_context_length_exceeded(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("context_length_exceeded")) is True
|
||||
|
||||
def test_max_tokens_exceeded_not_matched(self) -> None:
|
||||
"""'max_tokens_exceeded' is intentionally excluded (too broad)."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens_exceeded")) is False
|
||||
|
||||
def test_max_tokens_config_error_no_match(self) -> None:
|
||||
"""'max_tokens must be at least 1' should NOT match."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens must be at least 1")) is False
|
||||
|
||||
def test_chained_cause(self) -> None:
|
||||
inner = Exception("prompt is too long")
|
||||
outer = RuntimeError("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_chained_context(self) -> None:
|
||||
inner = Exception("request too large")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__context__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_deep_chain(self) -> None:
|
||||
bottom = Exception("maximum context length")
|
||||
middle = RuntimeError("middle")
|
||||
middle.__cause__ = bottom
|
||||
top = ValueError("top")
|
||||
top.__cause__ = middle
|
||||
assert _is_prompt_too_long(top) is True
|
||||
|
||||
def test_chain_no_match(self) -> None:
|
||||
inner = Exception("rate limit exceeded")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is False
|
||||
|
||||
def test_cycle_detection(self) -> None:
|
||||
"""Exception chain with a cycle should not infinite-loop."""
|
||||
a = Exception("error a")
|
||||
b = Exception("error b")
|
||||
a.__cause__ = b
|
||||
b.__cause__ = a # cycle
|
||||
assert _is_prompt_too_long(a) is False
|
||||
|
||||
def test_all_patterns(self) -> None:
|
||||
patterns = [
|
||||
"prompt is too long",
|
||||
"request too large",
|
||||
"maximum context length",
|
||||
"context_length_exceeded",
|
||||
"input tokens exceed",
|
||||
"input is too long",
|
||||
"content length exceeds",
|
||||
]
|
||||
for pattern in patterns:
|
||||
assert _is_prompt_too_long(Exception(pattern)) is True, pattern
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reduce_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReduceContext:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_retry_compaction_success(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value="/tmp/resume.jsonl",
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert isinstance(ctx, ReducedContext)
|
||||
assert ctx.use_resume is True
|
||||
assert ctx.resume_file == "/tmp/resume.jsonl"
|
||||
assert ctx.transcript_lost is False
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_fails_drops_transcript(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.resume_file is None
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_tried_compaction_skips(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
ctx = await _reduce_context(transcript, True, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_drops(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_returns_same_content_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript, # same content
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_tempfile_fails_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _iter_sdk_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIterSdkMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_messages(self) -> None:
|
||||
messages = ["msg1", "msg2", "msg3"]
|
||||
client = AsyncMock()
|
||||
|
||||
async def _fake_receive() -> AsyncGenerator[str]:
|
||||
for m in messages:
|
||||
yield m
|
||||
|
||||
client.receive_response = _fake_receive
|
||||
result = [msg async for msg in _iter_sdk_messages(client)]
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_on_timeout(self) -> None:
|
||||
"""Yields None when asyncio.wait times out."""
|
||||
client = AsyncMock()
|
||||
received: list = []
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
await asyncio.sleep(100) # never completes
|
||||
yield "never" # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
with patch("backend.copilot.sdk.service._HEARTBEAT_INTERVAL", 0.01):
|
||||
count = 0
|
||||
async for msg in _iter_sdk_messages(client):
|
||||
received.append(msg)
|
||||
count += 1
|
||||
if count >= 3:
|
||||
break
|
||||
|
||||
assert all(m is None for m in received)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_propagates(self) -> None:
|
||||
client = AsyncMock()
|
||||
|
||||
async def _error_receive() -> AsyncGenerator[str]:
|
||||
raise RuntimeError("SDK crash")
|
||||
yield # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _error_receive
|
||||
|
||||
with pytest.raises(RuntimeError, match="SDK crash"):
|
||||
async for _ in _iter_sdk_messages(client):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cleanup_on_break(self) -> None:
|
||||
"""Pending task is cancelled when generator is closed."""
|
||||
client = AsyncMock()
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
yield "first"
|
||||
await asyncio.sleep(100)
|
||||
yield "second"
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
gen = _iter_sdk_messages(client)
|
||||
first = await gen.__anext__()
|
||||
assert first == "first"
|
||||
await gen.aclose() # should cancel pending task cleanly
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_parallel_continuation logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsParallelContinuation:
|
||||
"""Unit tests for the is_parallel_continuation expression in the streaming loop.
|
||||
|
||||
Verifies the vacuous-truth guard (empty content must return False) and the
|
||||
boundary cases for mixed TextBlock+ToolUseBlock messages.
|
||||
"""
|
||||
|
||||
def _make_tool_block(self) -> MagicMock:
|
||||
block = MagicMock(spec=ToolUseBlock)
|
||||
return block
|
||||
|
||||
def test_all_tool_use_blocks_is_parallel(self):
|
||||
"""AssistantMessage with only ToolUseBlocks is a parallel continuation."""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = [self._make_tool_block(), self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is True
|
||||
|
||||
def test_empty_content_is_not_parallel(self):
|
||||
"""AssistantMessage with empty content must NOT be treated as parallel.
|
||||
|
||||
Without the bool(sdk_msg.content) guard, all() on an empty iterable
|
||||
returns True via vacuous truth — this test ensures the guard is present.
|
||||
"""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = []
|
||||
assert _is_tool_only_message(msg) is False
|
||||
|
||||
def test_mixed_text_and_tool_blocks_not_parallel(self):
|
||||
"""AssistantMessage with text + tool blocks is NOT a parallel continuation."""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
text_block = MagicMock(spec=TextBlock)
|
||||
msg.content = [text_block, self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is False
|
||||
|
||||
def test_non_assistant_message_not_parallel(self):
|
||||
"""Non-AssistantMessage types are never parallel continuations."""
|
||||
assert _is_tool_only_message("not a message") is False
|
||||
assert _is_tool_only_message(None) is False
|
||||
assert _is_tool_only_message(42) is False
|
||||
|
||||
def test_single_tool_block_is_parallel(self):
|
||||
"""Single ToolUseBlock AssistantMessage is a parallel continuation."""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = [self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is True
|
||||
@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .service import _prepare_file_attachments
|
||||
from .service import _prepare_file_attachments, _resolve_sdk_model
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -288,3 +288,214 @@ class TestPromptSupplement:
|
||||
# Count how many times this tool appears as a bullet point
|
||||
count = docs.count(f"- **`{tool_name}`**")
|
||||
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cleanup_sdk_tool_results — orchestration + rate-limiting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupSdkToolResults:
|
||||
"""Tests for _cleanup_sdk_tool_results orchestration and sweep rate-limiting."""
|
||||
|
||||
# All valid cwds must start with /tmp/copilot- (the _SDK_CWD_PREFIX).
|
||||
_CWD_PREFIX = "/tmp/copilot-"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_removes_cwd_directory(self):
|
||||
"""Cleanup removes the session working directory."""
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-cleanup-remove"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch("backend.copilot.sdk.service.cleanup_stale_project_dirs"):
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
svc_mod._last_sweep_time = 0.0
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
assert not os.path.exists(cwd)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sweep_runs_when_interval_elapsed(self):
|
||||
"""cleanup_stale_project_dirs is called when 5-minute interval has elapsed."""
|
||||
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-sweep-elapsed"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
# Set last sweep to a time far in the past
|
||||
svc_mod._last_sweep_time = 0.0
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
mock_sweep.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sweep_skipped_within_interval(self):
|
||||
"""cleanup_stale_project_dirs is NOT called when within 5-minute interval."""
|
||||
import time
|
||||
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-sweep-ratelimit"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
# Set last sweep to now — interval not elapsed
|
||||
svc_mod._last_sweep_time = time.time()
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
mock_sweep.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_path_outside_prefix(self, tmp_path):
|
||||
"""Cleanup rejects a cwd that does not start with the expected prefix."""
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
evil_cwd = str(tmp_path / "evil-path")
|
||||
os.makedirs(evil_cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
await _cleanup_sdk_tool_results(evil_cwd)
|
||||
|
||||
# Directory should NOT have been removed (rejected early)
|
||||
assert os.path.exists(evil_cwd)
|
||||
mock_sweep.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Env vars that ChatConfig validators read — must be cleared so explicit
|
||||
# constructor values are used.
|
||||
# ---------------------------------------------------------------------------
|
||||
_CONFIG_ENV_VARS = (
|
||||
"CHAT_USE_OPENROUTER",
|
||||
"CHAT_API_KEY",
|
||||
"OPEN_ROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"CHAT_BASE_URL",
|
||||
"OPENROUTER_BASE_URL",
|
||||
"OPENAI_BASE_URL",
|
||||
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
|
||||
"CHAT_USE_CLAUDE_AGENT_SDK",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _clean_config_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _CONFIG_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestResolveSdkModel:
|
||||
"""Tests for _resolve_sdk_model — model ID resolution for the SDK CLI."""
|
||||
|
||||
def test_openrouter_active_keeps_dots(self, monkeypatch, _clean_config_env):
|
||||
"""When OpenRouter is fully active, model keeps dot-separated version."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4.6"
|
||||
|
||||
def test_openrouter_disabled_normalizes_to_hyphens(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When OpenRouter is disabled, dots are replaced with hyphens."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
def test_openrouter_enabled_but_missing_key_normalizes(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When OpenRouter is enabled but api_key is missing, falls back to
|
||||
direct Anthropic and normalizes dots to hyphens."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key=None,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
def test_explicit_claude_agent_model_takes_precedence(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When claude_agent_model is explicitly set, it is returned as-is."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model="claude-sonnet-4-5-20250514",
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-sonnet-4-5-20250514"
|
||||
|
||||
def test_subscription_mode_returns_none(self, monkeypatch, _clean_config_env):
|
||||
"""When using Claude Code subscription, returns None (CLI picks model)."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() is None
|
||||
|
||||
def test_model_without_provider_prefix(self, monkeypatch, _clean_config_env):
|
||||
"""When model has no provider prefix, it still normalizes correctly."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
144
autogpt_platform/backend/backend/copilot/sdk/subscription.py
Normal file
144
autogpt_platform/backend/backend/copilot/sdk/subscription.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Claude Code subscription auth helpers.
|
||||
|
||||
Handles locating the SDK-bundled CLI binary, provisioning credentials from
|
||||
environment variables, and validating that subscription auth is functional.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_bundled_cli() -> str:
|
||||
"""Locate the Claude CLI binary bundled inside ``claude_agent_sdk``.
|
||||
|
||||
Falls back to ``shutil.which("claude")`` if the SDK bundle is absent.
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk._internal.transport.subprocess_cli import (
|
||||
SubprocessCLITransport,
|
||||
)
|
||||
|
||||
path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
|
||||
if path:
|
||||
return str(path)
|
||||
except Exception:
|
||||
pass
|
||||
system_path = shutil.which("claude")
|
||||
if system_path:
|
||||
return system_path
|
||||
raise RuntimeError(
|
||||
"Claude CLI not found — neither the SDK-bundled binary nor a "
|
||||
"system-installed `claude` could be located."
|
||||
)
|
||||
|
||||
|
||||
def provision_credentials_file() -> None:
|
||||
"""Write ``~/.claude/.credentials.json`` from env when running headless.
|
||||
|
||||
If ``CLAUDE_CODE_OAUTH_TOKEN`` is set (an OAuth *access* token obtained
|
||||
from ``claude auth status`` or extracted from the macOS keychain), this
|
||||
helper writes a minimal credentials file so the bundled CLI can
|
||||
authenticate without an interactive ``claude login``.
|
||||
|
||||
A ``CLAUDE_CODE_REFRESH_TOKEN`` env var is optional but recommended —
|
||||
it lets the CLI silently refresh an expired access token.
|
||||
"""
|
||||
access_token = os.environ.get("CLAUDE_CODE_OAUTH_TOKEN", "").strip()
|
||||
if not access_token:
|
||||
return
|
||||
|
||||
creds_dir = os.path.expanduser("~/.claude")
|
||||
creds_path = os.path.join(creds_dir, ".credentials.json")
|
||||
|
||||
# Don't overwrite an existing credentials file (e.g. from a volume mount).
|
||||
if os.path.exists(creds_path):
|
||||
logger.debug("Credentials file already exists at %s — skipping", creds_path)
|
||||
return
|
||||
|
||||
os.makedirs(creds_dir, exist_ok=True)
|
||||
|
||||
creds = {
|
||||
"claudeAiOauth": {
|
||||
"accessToken": access_token,
|
||||
"refreshToken": os.environ.get("CLAUDE_CODE_REFRESH_TOKEN", "").strip(),
|
||||
"expiresAt": 0,
|
||||
"scopes": [
|
||||
"user:inference",
|
||||
"user:profile",
|
||||
"user:sessions:claude_code",
|
||||
],
|
||||
}
|
||||
}
|
||||
with open(creds_path, "w") as f:
|
||||
json.dump(creds, f)
|
||||
logger.info("Provisioned Claude credentials file at %s", creds_path)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def validate_subscription() -> None:
|
||||
"""Validate the bundled Claude CLI is reachable and authenticated.
|
||||
|
||||
Cached so the blocking subprocess check runs at most once per process
|
||||
lifetime. On first call, also provisions ``~/.claude/.credentials.json``
|
||||
from the ``CLAUDE_CODE_OAUTH_TOKEN`` env var when available.
|
||||
"""
|
||||
provision_credentials_file()
|
||||
|
||||
cli = find_bundled_cli()
|
||||
result = subprocess.run(
|
||||
[cli, "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Claude CLI check failed (exit {result.returncode}): "
|
||||
f"{result.stderr.strip()}"
|
||||
)
|
||||
logger.info(
|
||||
"Claude Code subscription mode: CLI version %s",
|
||||
result.stdout.strip(),
|
||||
)
|
||||
|
||||
# Verify the CLI is actually authenticated.
|
||||
auth_result = subprocess.run(
|
||||
[cli, "auth", "status"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
env={
|
||||
**os.environ,
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
},
|
||||
)
|
||||
if auth_result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
"Claude CLI is not authenticated. Either:\n"
|
||||
" • Set CLAUDE_CODE_OAUTH_TOKEN env var (from `claude auth status` "
|
||||
"or macOS keychain), or\n"
|
||||
" • Mount ~/.claude/.credentials.json into the container, or\n"
|
||||
" • Run `claude login` inside the container."
|
||||
)
|
||||
try:
|
||||
status = json.loads(auth_result.stdout)
|
||||
if not status.get("loggedIn"):
|
||||
raise RuntimeError(
|
||||
"Claude CLI reports loggedIn=false. Set CLAUDE_CODE_OAUTH_TOKEN "
|
||||
"or run `claude login`."
|
||||
)
|
||||
logger.info(
|
||||
"Claude subscription auth: method=%s, email=%s",
|
||||
status.get("authMethod"),
|
||||
status.get("email"),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Could not parse `claude auth status` output")
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Tests for the tool call circuit breaker in tool_adapter.py."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.tool_adapter import (
|
||||
_MAX_CONSECUTIVE_TOOL_FAILURES,
|
||||
_check_circuit_breaker,
|
||||
_clear_tool_failures,
|
||||
_consecutive_tool_failures,
|
||||
_record_tool_failure,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_tracker():
|
||||
"""Reset the circuit breaker tracker for each test."""
|
||||
token = _consecutive_tool_failures.set({})
|
||||
yield
|
||||
_consecutive_tool_failures.reset(token)
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
def test_no_trip_below_threshold(self):
|
||||
"""Circuit breaker should not trip before reaching the limit."""
|
||||
args = {"file_path": "/tmp/test.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES - 1):
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
_record_tool_failure("write_file", args)
|
||||
# Still under the limit
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
|
||||
def test_trips_at_threshold(self):
|
||||
"""Circuit breaker should trip after reaching the failure limit."""
|
||||
args = {"file_path": "/tmp/test.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
_record_tool_failure("write_file", args)
|
||||
# Now it should trip
|
||||
result = _check_circuit_breaker("write_file", args)
|
||||
assert result is not None
|
||||
assert "STOP" in result
|
||||
assert "write_file" in result
|
||||
|
||||
def test_different_args_tracked_separately(self):
|
||||
"""Different args should have separate failure counters."""
|
||||
args_a = {"file_path": "/tmp/a.txt"}
|
||||
args_b = {"file_path": "/tmp/b.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("write_file", args_a)
|
||||
# args_a should trip
|
||||
assert _check_circuit_breaker("write_file", args_a) is not None
|
||||
# args_b should NOT trip
|
||||
assert _check_circuit_breaker("write_file", args_b) is None
|
||||
|
||||
def test_different_tools_tracked_separately(self):
|
||||
"""Different tools should have separate failure counters."""
|
||||
args = {"file_path": "/tmp/test.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("tool_a", args)
|
||||
# tool_a should trip
|
||||
assert _check_circuit_breaker("tool_a", args) is not None
|
||||
# tool_b with same args should NOT trip
|
||||
assert _check_circuit_breaker("tool_b", args) is None
|
||||
|
||||
def test_empty_args_tracked(self):
|
||||
"""Empty args ({}) — the exact failure pattern from the bug — should be tracked."""
|
||||
args = {}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("write_file", args)
|
||||
assert _check_circuit_breaker("write_file", args) is not None
|
||||
|
||||
def test_clear_resets_counter(self):
|
||||
"""Clearing failures should reset the counter."""
|
||||
args = {}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("write_file", args)
|
||||
_clear_tool_failures("write_file")
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
|
||||
def test_success_clears_failures(self):
|
||||
"""A successful call should reset the failure counter."""
|
||||
args = {}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES - 1):
|
||||
_record_tool_failure("write_file", args)
|
||||
# Success clears failures
|
||||
_clear_tool_failures("write_file")
|
||||
# Should be able to fail again without tripping
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES - 1):
|
||||
_record_tool_failure("write_file", args)
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
|
||||
def test_no_tracker_returns_none(self):
|
||||
"""If tracker is not initialized, circuit breaker should not trip."""
|
||||
_consecutive_tool_failures.set(None) # type: ignore[arg-type]
|
||||
_record_tool_failure("write_file", {}) # should not raise
|
||||
assert _check_circuit_breaker("write_file", {}) is None
|
||||
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
from backend.copilot.context import (
|
||||
_current_permissions,
|
||||
_current_project_dir,
|
||||
_current_sandbox,
|
||||
_current_sdk_cwd,
|
||||
@@ -41,6 +42,8 @@ from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
|
||||
@@ -50,6 +53,14 @@ _MCP_MAX_CHARS = 500_000
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||
|
||||
# Map from tool_name -> Queue of pre-launched (task, args) pairs.
|
||||
# Initialised per-session in set_execution_context() so concurrent sessions
|
||||
# never share the same dict.
|
||||
_TaskQueueItem = tuple[asyncio.Task[dict[str, Any]], dict[str, Any]]
|
||||
_tool_task_queues: ContextVar[dict[str, asyncio.Queue[_TaskQueueItem]] | None] = (
|
||||
ContextVar("_tool_task_queues", default=None)
|
||||
)
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
@@ -66,12 +77,23 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
|
||||
"_stash_event", default=None
|
||||
)
|
||||
|
||||
# Circuit breaker: tracks consecutive tool failures to detect infinite retry loops.
|
||||
# When a tool is called repeatedly with empty/identical args and keeps failing,
|
||||
# this counter is incremented. After _MAX_CONSECUTIVE_TOOL_FAILURES identical
|
||||
# failures the tool handler returns a hard-stop message instead of the raw error.
|
||||
_MAX_CONSECUTIVE_TOOL_FAILURES = 3
|
||||
_consecutive_tool_failures: ContextVar[dict[str, int]] = ContextVar(
|
||||
"_consecutive_tool_failures",
|
||||
default=None, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
@@ -83,14 +105,83 @@ def set_execution_context(
|
||||
session: Current chat session.
|
||||
sandbox: Optional E2B sandbox; when set, bash_exec routes commands there.
|
||||
sdk_cwd: SDK working directory; used to scope tool-results reads.
|
||||
permissions: Optional capability filter restricting tools/blocks.
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
_current_permissions.set(permissions)
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
_tool_task_queues.set({})
|
||||
_consecutive_tool_failures.set({})
|
||||
|
||||
|
||||
def reset_stash_event() -> None:
|
||||
"""Clear any stale stash signal left over from a previous stream attempt.
|
||||
|
||||
``_stash_event`` is set once per session in ``set_execution_context`` and
|
||||
reused across retry attempts. A PostToolUse hook from a failed attempt may
|
||||
leave the event set; calling this at the start of each retry prevents
|
||||
``wait_for_stash`` from returning prematurely on a stale signal.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is not None:
|
||||
event.clear()
|
||||
|
||||
|
||||
async def cancel_pending_tool_tasks() -> None:
|
||||
"""Cancel all queued pre-launched tasks for the current execution context.
|
||||
|
||||
Call this when a stream attempt aborts (error, cancellation) to prevent
|
||||
pre-launched tasks from continuing to execute against a rolled-back session.
|
||||
Tasks that are already done are skipped; in-flight tasks are cancelled and
|
||||
awaited so that any cleanup (``finally`` blocks, DB rollbacks) completes
|
||||
before the next retry starts.
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if not queues:
|
||||
return
|
||||
cancelled_tasks: list[asyncio.Task] = []
|
||||
for tool_name, queue in list(queues.items()):
|
||||
cancelled = 0
|
||||
while not queue.empty():
|
||||
task, _args = queue.get_nowait()
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
cancelled_tasks.append(task)
|
||||
cancelled += 1
|
||||
if cancelled:
|
||||
logger.debug(
|
||||
"Cancelled %d pre-launched task(s) for tool '%s'", cancelled, tool_name
|
||||
)
|
||||
queues.clear()
|
||||
# Await all cancelled tasks so their cleanup (finally blocks, DB rollbacks)
|
||||
# completes before the next retry attempt starts new pre-launches.
|
||||
# Use a timeout to prevent hanging indefinitely if a task's cleanup is stuck.
|
||||
if cancelled_tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*cancelled_tasks, return_exceptions=True),
|
||||
timeout=5.0,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Timed out waiting for %d cancelled task(s) to clean up",
|
||||
len(cancelled_tasks),
|
||||
)
|
||||
|
||||
|
||||
def reset_tool_failure_counters() -> None:
|
||||
"""Reset all tool-level circuit breaker counters.
|
||||
|
||||
Called at the start of each SDK retry attempt so that failure counts
|
||||
from a previous (rolled-back) attempt do not carry over and prematurely
|
||||
trip the breaker on a fresh attempt with different context.
|
||||
"""
|
||||
_consecutive_tool_failures.set({})
|
||||
|
||||
|
||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
@@ -146,7 +237,7 @@ def stash_pending_tool_output(tool_name: str, output: Any) -> None:
|
||||
event.set()
|
||||
|
||||
|
||||
async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
"""Wait for a PostToolUse hook to stash tool output.
|
||||
|
||||
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
|
||||
@@ -155,12 +246,13 @@ async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
by waiting on the ``_stash_event``, which is signaled by
|
||||
:func:`stash_pending_tool_output`.
|
||||
|
||||
After the event fires, callers should ``await asyncio.sleep(0)`` to
|
||||
give any remaining concurrent hooks a chance to complete.
|
||||
Uses ``asyncio.Event.wait()`` so it returns the instant the hook signals —
|
||||
the timeout is purely a safety net for the case where the hook never fires.
|
||||
Returns ``True`` if the stash signal was received, ``False`` on timeout.
|
||||
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
The timeout is a safety net — normally the stash happens within
|
||||
microseconds of yielding to the event loop.
|
||||
The 2.0 s default was chosen to accommodate slower tool startup in cloud
|
||||
sandboxes while still failing fast when the hook genuinely will not fire.
|
||||
With the parallel pre-launch path, hooks typically fire well under 1 ms.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
@@ -169,7 +261,7 @@ async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
if event.is_set():
|
||||
event.clear()
|
||||
return True
|
||||
# Slow path: wait for the hook to signal.
|
||||
# Slow path: block until the hook signals or the safety timeout expires.
|
||||
try:
|
||||
async with asyncio.timeout(timeout):
|
||||
await event.wait()
|
||||
@@ -179,6 +271,82 @@ async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def pre_launch_tool_call(tool_name: str, args: dict[str, Any]) -> None:
|
||||
"""Pre-launch a tool as a background task so parallel calls run concurrently.
|
||||
|
||||
Called when an AssistantMessage with ToolUseBlocks is received, before the
|
||||
SDK dispatches the MCP tool/call requests. The tool_handler will await the
|
||||
pre-launched task instead of executing fresh.
|
||||
|
||||
The tool_name may include an MCP prefix (e.g. ``mcp__copilot__run_block``);
|
||||
the prefix is stripped automatically before looking up the tool.
|
||||
|
||||
Ordering guarantee: the Claude Agent SDK dispatches MCP ``tools/call`` requests
|
||||
in the same order as the ToolUseBlocks appear in the AssistantMessage.
|
||||
Pre-launched tasks are queued FIFO per tool name, so the N-th handler for a
|
||||
given tool name dequeues the N-th pre-launched task — result and args always
|
||||
correspond when the SDK preserves order (which it does in the current SDK).
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if queues is None:
|
||||
return
|
||||
|
||||
# Strip the MCP server prefix (e.g. "mcp__copilot__") to get the bare tool name.
|
||||
# Use removeprefix so tool names that themselves contain "__" are handled correctly.
|
||||
bare_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
base_tool = TOOL_REGISTRY.get(bare_name)
|
||||
if base_tool is None:
|
||||
return
|
||||
|
||||
user_id, session = get_execution_context()
|
||||
if session is None:
|
||||
return
|
||||
|
||||
# Expand @@agptfile: references before launching the task.
|
||||
# The _truncating wrapper (which normally handles expansion) runs AFTER
|
||||
# pre_launch_tool_call — the pre-launched task would otherwise receive raw
|
||||
# @@agptfile: tokens and fail to resolve them inside _execute_tool_sync.
|
||||
# Use _build_input_schema (same path as _truncating) for schema-aware expansion.
|
||||
input_schema: dict[str, Any] | None
|
||||
try:
|
||||
input_schema = _build_input_schema(base_tool)
|
||||
except Exception:
|
||||
input_schema = None # schema unavailable — skip schema-aware expansion
|
||||
try:
|
||||
args = await expand_file_refs_in_args(
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
logger.warning(
|
||||
"pre_launch_tool_call: @@agptfile expansion failed for %s: %s — skipping pre-launch",
|
||||
bare_name,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
task = asyncio.create_task(_execute_tool_sync(base_tool, user_id, session, args))
|
||||
# Log unhandled exceptions so "Task exception was never retrieved" warnings
|
||||
# do not pollute stderr when a task is pre-launched but never dequeued.
|
||||
task.add_done_callback(
|
||||
lambda t, name=bare_name: (
|
||||
logger.warning(
|
||||
"Pre-launched task for %s raised unhandled: %s",
|
||||
name,
|
||||
t.exception(),
|
||||
)
|
||||
if not t.cancelled() and t.exception()
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
if bare_name not in queues:
|
||||
queues[bare_name] = asyncio.Queue[_TaskQueueItem]()
|
||||
# Store (task, args) so the handler can log a warning if the SDK dispatches
|
||||
# calls in a different order than the ToolUseBlocks appeared in the message.
|
||||
queues[bare_name].put_nowait((task, args))
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
base_tool: BaseTool,
|
||||
user_id: str | None,
|
||||
@@ -187,8 +355,10 @@ async def _execute_tool_sync(
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a tool synchronously and return MCP-formatted response.
|
||||
|
||||
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
|
||||
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
|
||||
Note: ``@@agptfile:`` expansion should be performed by the caller before
|
||||
invoking this function. For the normal (non-parallel) path it is handled
|
||||
by the ``_truncating`` wrapper; for the pre-launched parallel path it is
|
||||
handled in :func:`pre_launch_tool_call` before the task is created.
|
||||
"""
|
||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
@@ -217,6 +387,66 @@ def _mcp_error(message: str) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _failure_key(tool_name: str, args: dict[str, Any]) -> str:
|
||||
"""Compute a stable fingerprint for (tool_name, args) used by the circuit breaker."""
|
||||
args_key = json.dumps(args, sort_keys=True, default=str)
|
||||
return f"{tool_name}:{args_key}"
|
||||
|
||||
|
||||
def _check_circuit_breaker(tool_name: str, args: dict[str, Any]) -> str | None:
|
||||
"""Check if a tool has hit the consecutive failure limit.
|
||||
|
||||
Tracks failures keyed by (tool_name, args_fingerprint). Returns an error
|
||||
message if the circuit breaker has tripped, or None if the call should proceed.
|
||||
"""
|
||||
tracker = _consecutive_tool_failures.get(None)
|
||||
if tracker is None:
|
||||
return None
|
||||
|
||||
key = _failure_key(tool_name, args)
|
||||
count = tracker.get(key, 0)
|
||||
if count >= _MAX_CONSECUTIVE_TOOL_FAILURES:
|
||||
logger.warning(
|
||||
"Circuit breaker tripped for tool %s after %d consecutive "
|
||||
"identical failures (args=%s)",
|
||||
tool_name,
|
||||
count,
|
||||
key[len(tool_name) + 1 :][:200],
|
||||
)
|
||||
return (
|
||||
f"STOP: Tool '{tool_name}' has failed {count} consecutive times with "
|
||||
f"the same arguments. Do NOT retry this tool call. "
|
||||
f"If you were trying to write content to a file, instead respond with "
|
||||
f"the content directly as a text message to the user."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _record_tool_failure(tool_name: str, args: dict[str, Any]) -> None:
|
||||
"""Record a tool failure for circuit breaker tracking."""
|
||||
tracker = _consecutive_tool_failures.get(None)
|
||||
if tracker is None:
|
||||
return
|
||||
key = _failure_key(tool_name, args)
|
||||
tracker[key] = tracker.get(key, 0) + 1
|
||||
|
||||
|
||||
def _clear_tool_failures(tool_name: str) -> None:
|
||||
"""Clear failure tracking for a tool on success.
|
||||
|
||||
Clears ALL args variants for the tool, not just the successful call's args.
|
||||
This gives the tool a "fresh start" on any success, which is appropriate for
|
||||
the primary use case (detecting infinite loops with identical failing args).
|
||||
"""
|
||||
tracker = _consecutive_tool_failures.get(None)
|
||||
if tracker is None:
|
||||
return
|
||||
# Clear all entries for this tool name
|
||||
keys_to_remove = [k for k in tracker if k.startswith(f"{tool_name}:")]
|
||||
for k in keys_to_remove:
|
||||
del tracker[k]
|
||||
|
||||
|
||||
def create_tool_handler(base_tool: BaseTool):
|
||||
"""Create an async handler function for a BaseTool.
|
||||
|
||||
@@ -225,7 +455,83 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||
"""Execute the wrapped tool and return MCP-formatted response.
|
||||
|
||||
If a pre-launched task exists (from parallel tool pre-launch in the
|
||||
message loop), await it instead of executing fresh.
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if queues and base_tool.name in queues:
|
||||
queue = queues[base_tool.name]
|
||||
if not queue.empty():
|
||||
task, launch_args = queue.get_nowait()
|
||||
# Sanity-check: warn if the args don't match — this can happen
|
||||
# if the SDK dispatches tool calls in a different order than the
|
||||
# ToolUseBlocks appeared in the AssistantMessage (unlikely but
|
||||
# could occur in future SDK versions or with SDK bugs).
|
||||
# We compare full values (not just keys) so that two run_block
|
||||
# calls with different block_id values are caught even though
|
||||
# both have the same key set.
|
||||
if launch_args != args:
|
||||
logger.warning(
|
||||
"Pre-launched task for %s: arg mismatch "
|
||||
"(launch_keys=%s, call_keys=%s) — cancelling "
|
||||
"pre-launched task and falling back to direct execution",
|
||||
base_tool.name,
|
||||
(
|
||||
sorted(launch_args.keys())
|
||||
if isinstance(launch_args, dict)
|
||||
else type(launch_args).__name__
|
||||
),
|
||||
(
|
||||
sorted(args.keys())
|
||||
if isinstance(args, dict)
|
||||
else type(args).__name__
|
||||
),
|
||||
)
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
# Await cancellation to prevent duplicate concurrent
|
||||
# execution for blocks with side effects.
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
# Fall through to the direct-execution path below.
|
||||
else:
|
||||
# Args match — await the pre-launched task.
|
||||
try:
|
||||
result = await task
|
||||
except asyncio.CancelledError:
|
||||
# Re-raise: CancelledError may be propagating from the
|
||||
# outer streaming loop being cancelled — swallowing it
|
||||
# would mask the cancellation and prevent proper cleanup.
|
||||
logger.warning(
|
||||
"Pre-launched tool %s was cancelled — re-raising",
|
||||
base_tool.name,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Pre-launched tool %s failed: %s",
|
||||
base_tool.name,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return _mcp_error(
|
||||
f"Failed to execute {base_tool.name}. "
|
||||
"Check server logs for details."
|
||||
)
|
||||
|
||||
# Pre-truncate the result so the _truncating wrapper (which
|
||||
# wraps this handler) receives an already-within-budget
|
||||
# value. _truncating handles stashing — we must NOT stash
|
||||
# here or the output will be appended twice to the FIFO
|
||||
# queue and pop_pending_tool_output would return a duplicate
|
||||
# entry on the second call for the same tool.
|
||||
return truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# No pre-launched task — execute directly (fallback for non-parallel calls).
|
||||
user_id, session = get_execution_context()
|
||||
|
||||
if session is None:
|
||||
@@ -234,8 +540,12 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||
logger.error(
|
||||
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
|
||||
)
|
||||
return _mcp_error(
|
||||
f"Failed to execute {base_tool.name}. Check server logs for details."
|
||||
)
|
||||
|
||||
return tool_handler
|
||||
|
||||
@@ -285,7 +595,7 @@ async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
resolved = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(resolved) as f:
|
||||
with open(resolved, encoding="utf-8", errors="replace") as f:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
@@ -358,6 +668,15 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
Applied once to every registered tool."""
|
||||
|
||||
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
||||
# Circuit breaker: stop infinite retry loops with identical args.
|
||||
# Use the original (pre-expansion) args for fingerprinting so
|
||||
# check and record always use the same key — @@agptfile:
|
||||
# expansion mutates args, which would cause a key mismatch.
|
||||
original_args = args
|
||||
stop_msg = _check_circuit_breaker(tool_name, original_args)
|
||||
if stop_msg:
|
||||
return _mcp_error(stop_msg)
|
||||
|
||||
user_id, session = get_execution_context()
|
||||
if session is not None:
|
||||
try:
|
||||
@@ -365,6 +684,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
_record_tool_failure(tool_name, original_args)
|
||||
return _mcp_error(
|
||||
f"@@agptfile: reference could not be resolved: {exc}. "
|
||||
"Ensure the file exists before referencing it. "
|
||||
@@ -374,6 +694,12 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
result = await fn(args)
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# Track consecutive failures for circuit breaker
|
||||
if truncated.get("isError"):
|
||||
_record_tool_failure(tool_name, original_args)
|
||||
else:
|
||||
_clear_tool_failures(tool_name)
|
||||
|
||||
# Stash the text so the response adapter can forward our
|
||||
# middle-out truncated version to the frontend instead of the
|
||||
# SDK's head-truncated version (for outputs >~100 KB the SDK
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars."""
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars, parallel pre-launch."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import get_sdk_cwd
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
_text_from_mcp_result,
|
||||
cancel_pending_tool_tasks,
|
||||
create_tool_handler,
|
||||
pop_pending_tool_output,
|
||||
pre_launch_tool_call,
|
||||
reset_stash_event,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
wait_for_stash,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -120,6 +130,69 @@ class TestToolOutputStash:
|
||||
assert pop_pending_tool_output("a") == "alpha"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset_stash_event / wait_for_stash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetStashEvent:
|
||||
"""Tests for reset_stash_event — the stale-signal fix for retry attempts."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_clears_stale_signal(self):
|
||||
"""After reset, wait_for_stash does NOT return immediately (blocks until timeout)."""
|
||||
# Simulate a stale signal left by a failed attempt's PostToolUse hook.
|
||||
stash_pending_tool_output("some_tool", "stale output")
|
||||
# The stash_pending_tool_output call sets the event.
|
||||
# Now reset it — simulating start of a new retry attempt.
|
||||
reset_stash_event()
|
||||
# wait_for_stash should block and time out since the event was cleared.
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is False, (
|
||||
"wait_for_stash should have timed out after reset_stash_event, "
|
||||
"but it returned True — stale signal was not cleared"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_returns_true_when_signaled_after_reset(self):
|
||||
"""After reset, a new stash signal is correctly detected."""
|
||||
reset_stash_event()
|
||||
|
||||
async def _signal_after_delay():
|
||||
await asyncio.sleep(0.01)
|
||||
stash_pending_tool_output("tool", "fresh output")
|
||||
|
||||
asyncio.create_task(_signal_after_delay())
|
||||
result = await wait_for_stash(timeout=1.0)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_scenario_stale_event_does_not_fire_prematurely(self):
|
||||
"""Simulates: attempt 1 leaves event set → reset → attempt 2 waits correctly."""
|
||||
# Attempt 1: hook fires and sets the event
|
||||
stash_pending_tool_output("t", "attempt-1-output")
|
||||
# Pop it so the stash is empty (simulating normal consumption)
|
||||
pop_pending_tool_output("t")
|
||||
|
||||
# Between attempts: reset (as service.py does before each retry)
|
||||
reset_stash_event()
|
||||
|
||||
# Attempt 2: wait_for_stash should NOT return True immediately
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is False, (
|
||||
"Stale event from attempt 1 caused wait_for_stash to return "
|
||||
"prematurely in attempt 2"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncating wrapper (integration via create_copilot_mcp_server)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -168,3 +241,534 @@ class TestTruncationAndStashIntegration:
|
||||
text = _text_from_mcp_result(truncated)
|
||||
assert len(text) < len(big_text)
|
||||
assert len(str(truncated)) <= _MCP_MAX_CHARS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parallel pre-launch infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_tool(name: str, output: str = "result") -> MagicMock:
|
||||
"""Return a BaseTool mock that returns a successful StreamToolOutputAvailable."""
|
||||
tool = MagicMock()
|
||||
tool.name = name
|
||||
tool.parameters = {"properties": {}, "required": []}
|
||||
tool.execute = AsyncMock(
|
||||
return_value=StreamToolOutputAvailable(
|
||||
toolCallId="test-id",
|
||||
output=output,
|
||||
toolName=name,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
def _make_mock_session() -> MagicMock:
|
||||
"""Return a minimal ChatSession mock."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
def _init_ctx(session=None):
|
||||
set_execution_context(
|
||||
user_id="user-1",
|
||||
session=session, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
)
|
||||
|
||||
|
||||
class TestPreLaunchToolCall:
|
||||
"""Tests for pre_launch_tool_call and the queue-based parallel dispatch."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_is_silently_ignored(self):
|
||||
"""pre_launch_tool_call does nothing for tools not in TOOL_REGISTRY."""
|
||||
# Should not raise even if the tool name is completely unknown
|
||||
await pre_launch_tool_call("nonexistent_tool", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_prefix_stripped_before_registry_lookup(self):
|
||||
"""mcp__copilot__run_block is looked up as 'run_block'."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("mcp__copilot__run_block", {"block_id": "b1"})
|
||||
|
||||
# The task was enqueued — mock_tool.execute should be called once
|
||||
# (may not complete immediately but should start)
|
||||
await asyncio.sleep(0) # yield to event loop
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_tool_name_without_prefix(self):
|
||||
"""Tool names without __ separator are looked up as-is."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
|
||||
await asyncio.sleep(0)
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_enqueued_fifo_for_same_tool(self):
|
||||
"""Two pre-launched calls for the same tool name are enqueued FIFO."""
|
||||
results = []
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
results.append(len(results))
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=str(len(results) - 1),
|
||||
toolName="t",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("t")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"t": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("t", {"n": 1})
|
||||
await pre_launch_tool_call("t", {"n": 2})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_ref_expansion_failure_skips_pre_launch(self):
|
||||
"""When @@agptfile: expansion fails, pre_launch_tool_call skips the task.
|
||||
|
||||
The handler should then fall back to direct execution (which will also
|
||||
fail with a proper MCP error via _truncating's own expansion).
|
||||
"""
|
||||
mock_tool = _make_mock_tool("run_block", output="should-not-execute")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.expand_file_refs_in_args",
|
||||
AsyncMock(side_effect=FileRefExpansionError("@@agptfile:missing.txt")),
|
||||
),
|
||||
):
|
||||
# Should not raise — expansion failure is handled gracefully
|
||||
await pre_launch_tool_call("run_block", {"text": "@@agptfile:missing.txt"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# No task was pre-launched — execute was not called
|
||||
mock_tool.execute.assert_not_awaited()
|
||||
|
||||
|
||||
class TestCreateToolHandlerParallel:
|
||||
"""Tests for create_tool_handler using pre-launched tasks."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_uses_prelaunched_task(self):
|
||||
"""Handler pops and awaits the pre-launched task rather than re-executing."""
|
||||
mock_tool = _make_mock_tool("run_block", output="pre-launched result")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0) # let task start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "pre-launched result" in text
|
||||
# Should only have been called once (the pre-launched task), not twice
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_does_not_double_stash_for_prelaunched_task(self):
|
||||
"""Pre-launched task result must NOT be stashed by tool_handler directly.
|
||||
|
||||
The _truncating wrapper wraps tool_handler and handles stashing after
|
||||
tool_handler returns. If tool_handler also stashed, the output would be
|
||||
appended twice to the FIFO queue and pop_pending_tool_output would return
|
||||
a duplicate on the second call.
|
||||
|
||||
This test calls tool_handler directly (without _truncating) and asserts
|
||||
that nothing was stashed — confirming stashing is deferred to _truncating.
|
||||
"""
|
||||
mock_tool = _make_mock_tool("run_block", output="stash-me")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "stash-me" in result["content"][0]["text"]
|
||||
# tool_handler must NOT stash — _truncating (which wraps handler) does it.
|
||||
# Calling pop here (without going through _truncating) should return None.
|
||||
not_stashed = pop_pending_tool_output("run_block")
|
||||
assert not_stashed is None, (
|
||||
"tool_handler must not stash directly — _truncating handles stashing "
|
||||
"to prevent double-stash in the FIFO queue"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_falls_back_when_queue_empty(self):
|
||||
"""When no pre-launched task exists, handler executes directly."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct result")
|
||||
|
||||
# Don't call pre_launch_tool_call — queue is empty
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "direct result" in text
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_cancelled_error_propagates(self):
|
||||
"""CancelledError from a pre-launched task is re-raised to preserve cancellation semantics."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await handler({"block_id": "b1"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_exception_returns_mcp_error(self):
|
||||
"""Exception from a pre-launched task is caught and returned as MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=RuntimeError("block exploded"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is True
|
||||
assert "Failed to execute run_block" in result["content"][0]["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_same_tool_calls_dispatched_in_order(self):
|
||||
"""Two pre-launched tasks for the same tool are consumed in FIFO order."""
|
||||
call_order = []
|
||||
|
||||
async def execute_with_tag(*args, **kwargs):
|
||||
tag = kwargs.get("block_id", "?")
|
||||
call_order.append(tag)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id", output=f"out-{tag}", toolName="run_block", success=True
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_tag)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "first"})
|
||||
await pre_launch_tool_call("run_block", {"block_id": "second"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
r1 = await handler({"block_id": "first"})
|
||||
r2 = await handler({"block_id": "second"})
|
||||
|
||||
assert "out-first" in r1["content"][0]["text"]
|
||||
assert "out-second" in r2["content"][0]["text"]
|
||||
assert call_order == [
|
||||
"first",
|
||||
"second",
|
||||
], f"Expected FIFO dispatch order but got {call_order}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arg_mismatch_falls_back_to_direct_execution(self):
|
||||
"""When pre-launched args differ from SDK args, handler cancels pre-launched
|
||||
task and falls back to direct execution with the correct args."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct-result")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
# Pre-launch with args {"block_id": "wrong"}
|
||||
await pre_launch_tool_call("run_block", {"block_id": "wrong"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# SDK dispatches with different args
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "correct"})
|
||||
|
||||
assert result["isError"] is False
|
||||
# The tool was called twice: once by pre-launch (wrong args), once by
|
||||
# direct fallback (correct args). The result should come from the
|
||||
# direct execution path.
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_falls_back_gracefully(self):
|
||||
"""When session is None and no pre-launched task, handler returns MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
# session=None means get_execution_context returns (user_id, None)
|
||||
set_execution_context(user_id="u", session=None, sandbox=None) # type: ignore[arg-type]
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is True
|
||||
assert "session" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cancel_pending_tool_tasks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCancelPendingToolTasks:
|
||||
"""Tests for cancel_pending_tool_tasks — the stream-abort cleanup helper."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancels_queued_tasks(self):
|
||||
"""Queued tasks are cancelled and the queue is cleared."""
|
||||
ran = False
|
||||
|
||||
async def never_run(*_args, **_kwargs):
|
||||
nonlocal ran
|
||||
await asyncio.sleep(10) # long enough to still be pending
|
||||
ran = True
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=never_run)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0) # let task start
|
||||
await cancel_pending_tool_tasks()
|
||||
await asyncio.sleep(0) # let cancellation propagate
|
||||
|
||||
assert not ran, "Task should have been cancelled before completing"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_when_no_tasks_queued(self):
|
||||
"""cancel_pending_tool_tasks does not raise when queues are empty."""
|
||||
await cancel_pending_tool_tasks() # should not raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_does_not_find_cancelled_task(self):
|
||||
"""After cancel, tool_handler falls back to direct execution."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct-fallback")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
await cancel_pending_tool_tasks()
|
||||
|
||||
# Queue is now empty — handler should execute directly
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "direct-fallback" in result["content"][0]["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concurrent / parallel pre-launch scenarios
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAllParallelToolsPrelaunchedIndependently:
|
||||
"""Simulate SDK sending N separate AssistantMessages for the same tool concurrently."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_parallel_tools_prelaunched_independently(self):
|
||||
"""5 pre-launches for the same tool all enqueue independently and run concurrently.
|
||||
|
||||
Each task sleeps for PER_TASK_S seconds. If they ran sequentially the total
|
||||
wall time would be ~5*PER_TASK_S. Running concurrently it should finish in
|
||||
roughly PER_TASK_S (plus scheduling overhead).
|
||||
"""
|
||||
PER_TASK_S = 0.05
|
||||
N = 5
|
||||
started: list[int] = []
|
||||
finished: list[int] = []
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
idx = len(started)
|
||||
started.append(idx)
|
||||
await asyncio.sleep(PER_TASK_S)
|
||||
finished.append(idx)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{idx}",
|
||||
output=f"result-{idx}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
for i in range(N):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": f"echo {i}"})
|
||||
|
||||
# Measure only the concurrent execution window, not pre-launch overhead.
|
||||
# Starting the timer here avoids false failures on slow CI runners where
|
||||
# the pre_launch_tool_call setup takes longer than the concurrent sleep.
|
||||
t0 = asyncio.get_running_loop().time()
|
||||
await asyncio.sleep(PER_TASK_S * 2)
|
||||
elapsed = asyncio.get_running_loop().time() - t0
|
||||
|
||||
assert mock_tool.execute.await_count == N
|
||||
assert len(finished) == N
|
||||
# Wall time of the sleep window should be well under N * PER_TASK_S
|
||||
# (sequential would be ~0.25s; concurrent finishes in ~PER_TASK_S = 0.05s)
|
||||
assert elapsed < N * PER_TASK_S, (
|
||||
f"Expected concurrent execution (<{N * PER_TASK_S:.2f}s) "
|
||||
f"but sleep window took {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
class TestHandlerReturnsResultFromCorrectPrelaunchedTask:
|
||||
"""Pop pre-launched tasks in order and verify each returns its own result."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_returns_result_from_correct_prelaunched_task(self):
|
||||
"""Two pre-launches for the same tool: first handler gets first result, second gets second."""
|
||||
|
||||
async def execute_with_cmd(*args, **kwargs):
|
||||
cmd = kwargs.get("cmd", "?")
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=f"output-for-{cmd}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_cmd)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": "alpha"})
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": "beta"})
|
||||
await asyncio.sleep(0) # let both tasks start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
r1 = await handler({"cmd": "alpha"})
|
||||
r2 = await handler({"cmd": "beta"})
|
||||
|
||||
text1 = r1["content"][0]["text"]
|
||||
text2 = r2["content"][0]["text"]
|
||||
assert "output-for-alpha" in text1, f"Expected alpha result, got: {text1}"
|
||||
assert "output-for-beta" in text2, f"Expected beta result, got: {text2}"
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
|
||||
class TestFiveConcurrentPrelaunchAllComplete:
|
||||
"""Pre-launch 5 tasks; consume all 5 via handlers; assert all succeed."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_five_concurrent_prelaunch_all_complete(self):
|
||||
"""All 5 pre-launched tasks complete and return successful results."""
|
||||
N = 5
|
||||
call_count = 0
|
||||
|
||||
async def counting_execute(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
n = call_count
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{n}",
|
||||
output=f"done-{n}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=counting_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
for i in range(N):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": f"task-{i}"})
|
||||
|
||||
await asyncio.sleep(0) # let all tasks start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
results = []
|
||||
for i in range(N):
|
||||
results.append(await handler({"cmd": f"task-{i}"}))
|
||||
|
||||
assert (
|
||||
mock_tool.execute.await_count == N
|
||||
), f"Expected {N} execute calls, got {mock_tool.execute.await_count}"
|
||||
for i, result in enumerate(results):
|
||||
assert result["isError"] is False, f"Result {i} should not be an error"
|
||||
text = result["content"][0]["text"]
|
||||
assert "done-" in text, f"Result {i} missing expected output: {text}"
|
||||
|
||||
@@ -10,6 +10,9 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -17,8 +20,12 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -99,7 +106,14 @@ def strip_progress_entries(content: str) -> str:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
while parent in stripped_uuids:
|
||||
# seen_parents is local per-entry (not shared across iterations) so
|
||||
# it can only detect cycles within a single ancestry walk, not across
|
||||
# entries. This is intentional: each entry's parent chain is
|
||||
# independent, and reusing a global set would incorrectly short-circuit
|
||||
# valid re-use of the same UUID as a parent in different subtrees.
|
||||
seen_parents: set[str] = set()
|
||||
while parent in stripped_uuids and parent not in seen_parents:
|
||||
seen_parents.add(parent)
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
@@ -151,44 +165,110 @@ def _projects_base() -> str:
|
||||
return os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
|
||||
|
||||
def _cli_project_dir(sdk_cwd: str) -> str | None:
|
||||
"""Return the CLI's project directory for a given working directory.
|
||||
_STALE_PROJECT_DIR_SECONDS = 12 * 3600 # 12 hours — matches max session lifetime
|
||||
_MAX_PROJECT_DIRS_TO_SWEEP = 50 # limit per sweep to avoid long pauses
|
||||
|
||||
Returns ``None`` if the path would escape the projects base.
|
||||
|
||||
def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
"""Remove CLI project directories older than ``_STALE_PROJECT_DIR_SECONDS``.
|
||||
|
||||
Each CoPilot SDK turn creates a unique ``~/.claude/projects/<encoded-cwd>/``
|
||||
directory. These are intentionally kept across turns so the model can read
|
||||
tool-result files via ``--resume``. However, after a session ends they
|
||||
become stale. This function sweeps old ones to prevent unbounded disk
|
||||
growth.
|
||||
|
||||
When *encoded_cwd* is provided the sweep is scoped to that single
|
||||
directory, making the operation safe in multi-tenant environments where
|
||||
multiple copilot sessions share the same host. Without it the function
|
||||
falls back to sweeping all directories matching the copilot naming pattern
|
||||
(``-tmp-copilot-``), which is only safe for single-tenant deployments.
|
||||
|
||||
Returns the number of directories removed.
|
||||
"""
|
||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
projects_base = _projects_base()
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
if not os.path.isdir(projects_base):
|
||||
return 0
|
||||
|
||||
if not project_dir.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
"[Transcript] Project dir escaped projects base: %s", project_dir
|
||||
)
|
||||
return None
|
||||
return project_dir
|
||||
now = time.time()
|
||||
removed = 0
|
||||
|
||||
|
||||
def _safe_glob_jsonl(project_dir: str) -> list[Path]:
|
||||
"""Glob ``*.jsonl`` files, filtering out symlinks that escape the directory."""
|
||||
try:
|
||||
resolved_base = Path(project_dir).resolve()
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to resolve project dir: %s", e)
|
||||
return []
|
||||
|
||||
result: list[Path] = []
|
||||
for candidate in Path(project_dir).glob("*.jsonl"):
|
||||
try:
|
||||
resolved = candidate.resolve()
|
||||
if resolved.is_relative_to(resolved_base):
|
||||
result.append(resolved)
|
||||
except (OSError, RuntimeError) as e:
|
||||
logger.debug(
|
||||
"[Transcript] Skipping invalid CLI session candidate %s: %s",
|
||||
candidate,
|
||||
e,
|
||||
# Scoped mode: only clean up the one directory for the current session.
|
||||
if encoded_cwd:
|
||||
target = Path(projects_base) / encoded_cwd
|
||||
if not target.is_dir():
|
||||
return 0
|
||||
# Guard: only sweep copilot-generated dirs.
|
||||
if "-tmp-copilot-" not in target.name:
|
||||
logger.warning(
|
||||
"[Transcript] Refusing to sweep non-copilot dir: %s", target.name
|
||||
)
|
||||
return result
|
||||
return 0
|
||||
try:
|
||||
# st_mtime is used as a proxy for session activity. Claude CLI writes
|
||||
# its JSONL transcript into this directory during each turn, so mtime
|
||||
# advances on every turn. A directory whose mtime is older than
|
||||
# _STALE_PROJECT_DIR_SECONDS has not had an active turn in that window
|
||||
# and is safe to remove (the session cannot --resume after cleanup).
|
||||
age = now - target.stat().st_mtime
|
||||
except OSError:
|
||||
return 0
|
||||
if age < _STALE_PROJECT_DIR_SECONDS:
|
||||
return 0
|
||||
try:
|
||||
shutil.rmtree(target, ignore_errors=True)
|
||||
removed = 1
|
||||
except OSError:
|
||||
pass
|
||||
if removed:
|
||||
logger.info(
|
||||
"[Transcript] Swept stale CLI project dir %s (age %ds > %ds)",
|
||||
target.name,
|
||||
int(age),
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
)
|
||||
return removed
|
||||
|
||||
# Unscoped fallback: sweep all copilot dirs across the projects base.
|
||||
# Only safe for single-tenant deployments; callers should prefer the
|
||||
# scoped variant by passing encoded_cwd.
|
||||
try:
|
||||
entries = Path(projects_base).iterdir()
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to list projects dir: %s", e)
|
||||
return 0
|
||||
|
||||
for entry in entries:
|
||||
if removed >= _MAX_PROJECT_DIRS_TO_SWEEP:
|
||||
break
|
||||
# Only sweep copilot-generated dirs (pattern: -tmp-copilot- or
|
||||
# -private-tmp-copilot-).
|
||||
if "-tmp-copilot-" not in entry.name:
|
||||
continue
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
try:
|
||||
# See the scoped-mode comment above: st_mtime advances on every turn,
|
||||
# so a stale mtime reliably indicates an inactive session.
|
||||
age = now - entry.stat().st_mtime
|
||||
except OSError:
|
||||
continue
|
||||
if age < _STALE_PROJECT_DIR_SECONDS:
|
||||
continue
|
||||
|
||||
try:
|
||||
shutil.rmtree(entry, ignore_errors=True)
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if removed:
|
||||
logger.info(
|
||||
"[Transcript] Swept %d stale CLI project dirs (older than %ds)",
|
||||
removed,
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
)
|
||||
return removed
|
||||
|
||||
|
||||
def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||
@@ -255,63 +335,6 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||
return entries
|
||||
|
||||
|
||||
def read_cli_session_file(sdk_cwd: str) -> str | None:
|
||||
"""Read the CLI's own session file, which reflects any compaction.
|
||||
|
||||
The CLI writes its session transcript to
|
||||
``~/.claude/projects/<encoded_cwd>/<session_id>.jsonl``.
|
||||
Since each SDK turn uses a unique ``sdk_cwd``, there should be
|
||||
exactly one ``.jsonl`` file in that directory.
|
||||
|
||||
Returns the file content, or ``None`` if not found.
|
||||
"""
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir or not os.path.isdir(project_dir):
|
||||
return None
|
||||
|
||||
jsonl_files = _safe_glob_jsonl(project_dir)
|
||||
if not jsonl_files:
|
||||
logger.debug("[Transcript] No CLI session file found in %s", project_dir)
|
||||
return None
|
||||
|
||||
# Pick the most recently modified file (should be only one per turn).
|
||||
try:
|
||||
session_file = max(jsonl_files, key=lambda p: p.stat().st_mtime)
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to inspect CLI session files: %s", e)
|
||||
return None
|
||||
|
||||
try:
|
||||
content = session_file.read_text()
|
||||
logger.info(
|
||||
"[Transcript] Read CLI session file: %s (%d bytes)",
|
||||
session_file,
|
||||
len(content),
|
||||
)
|
||||
return content
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||
safe to remove entirely after the transcript has been uploaded.
|
||||
"""
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir:
|
||||
return
|
||||
|
||||
if os.path.isdir(project_dir):
|
||||
shutil.rmtree(project_dir, ignore_errors=True)
|
||||
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
|
||||
else:
|
||||
logger.debug("[Transcript] Project dir not found: %s", project_dir)
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
transcript_content: str,
|
||||
session_id: str,
|
||||
@@ -327,7 +350,7 @@ def write_transcript_to_tempfile(
|
||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||
real_cwd = os.path.realpath(cwd)
|
||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
||||
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -337,17 +360,17 @@ def write_transcript_to_tempfile(
|
||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||
)
|
||||
if not jsonl_path.startswith(real_cwd):
|
||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
||||
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
|
||||
return None
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
||||
logger.warning("[Transcript] Failed to write resume file: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
@@ -408,8 +431,6 @@ def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, s
|
||||
|
||||
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
wid, fid, fname = parts
|
||||
if isinstance(backend, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
||||
@@ -448,17 +469,15 @@ async def upload_transcript(
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
# Strip metadata entries (progress, file-history-snapshot, etc.)
|
||||
# Note: SDK-built transcripts shouldn't have these, but strip for safety
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types: list[str] = []
|
||||
for line in stripped.strip().split("\n"):
|
||||
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
|
||||
entry_types.append(entry.get("type", "?"))
|
||||
entry_types = [
|
||||
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
|
||||
for line in stripped.strip().split("\n")
|
||||
]
|
||||
logger.warning(
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
@@ -494,11 +513,14 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
log_prefix,
|
||||
len(encoded),
|
||||
len(content),
|
||||
message_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -512,8 +534,6 @@ async def download_transcript(
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
@@ -521,10 +541,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -536,10 +556,14 @@ async def download_transcript(
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, Exception):
|
||||
except FileNotFoundError:
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
except Exception as e:
|
||||
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
@@ -553,8 +577,6 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
|
||||
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
@@ -571,3 +593,280 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
logger.info("[Transcript] Deleted metadata for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete metadata: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcript compaction — LLM summarization for prompt-too-long recovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# JSONL protocol values used in transcript serialization.
|
||||
STOP_REASON_END_TURN = "end_turn"
|
||||
COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string.
|
||||
|
||||
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
|
||||
placeholders. This is intentional: ``compress_context`` requires plain
|
||||
text for token counting and LLM summarization. The structural loss is
|
||||
acceptable because compaction only runs when the original transcript was
|
||||
already too large for the model — a summarized plain-text version is
|
||||
better than no context at all.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict):
|
||||
btype = block.get("type", "")
|
||||
if btype == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif btype == "tool_use":
|
||||
parts.append(f"[tool_use: {block.get('name', '?')}]")
|
||||
else:
|
||||
# Preserve non-text blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _flatten_tool_result_content(blocks: list) -> str:
|
||||
"""Flatten tool_result and other content blocks into plain text.
|
||||
|
||||
Handles nested tool_result structures, text blocks, and raw strings.
|
||||
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
|
||||
or where ``text`` is ``None``.
|
||||
|
||||
Like ``_flatten_assistant_content``, structured blocks (images, nested
|
||||
tool results) are reduced to text representations for compression.
|
||||
"""
|
||||
str_parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
inner = block.get("content") or ""
|
||||
if isinstance(inner, list):
|
||||
for sub in inner:
|
||||
if isinstance(sub, dict):
|
||||
sub_type = sub.get("type")
|
||||
if sub_type in ("image", "document"):
|
||||
# Avoid serializing base64 binary data into
|
||||
# the compaction input — use a placeholder.
|
||||
str_parts.append(f"[__{sub_type}__]")
|
||||
elif sub_type == "text" or sub.get("text") is not None:
|
||||
str_parts.append(str(sub.get("text", "")))
|
||||
else:
|
||||
str_parts.append(json.dumps(sub))
|
||||
else:
|
||||
str_parts.append(str(sub))
|
||||
else:
|
||||
str_parts.append(str(inner))
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
str_parts.append(str(block.get("text", "")))
|
||||
elif isinstance(block, dict):
|
||||
# Preserve non-text/non-tool_result blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
btype = block.get("type", "unknown")
|
||||
str_parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
str_parts.append(block)
|
||||
return "\n".join(str_parts) if str_parts else ""
|
||||
|
||||
|
||||
def _transcript_to_messages(content: str) -> list[dict]:
|
||||
"""Convert JSONL transcript entries to plain message dicts for compression.
|
||||
|
||||
Parses each line of the JSONL *content*, skips strippable metadata entries
|
||||
(progress, file-history-snapshot, etc.), and extracts the ``role`` and
|
||||
flattened ``content`` from the ``message`` field of each remaining entry.
|
||||
|
||||
Structured content blocks (``tool_use``, ``tool_result``, images) are
|
||||
flattened to plain text via ``_flatten_assistant_content`` and
|
||||
``_flatten_tool_result_content`` so that ``compress_context`` can
|
||||
perform token counting and LLM summarization on uniform strings.
|
||||
|
||||
Returns:
|
||||
A list of ``{"role": str, "content": str}`` dicts suitable for
|
||||
``compress_context``.
|
||||
"""
|
||||
messages: list[dict] = []
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
role = msg.get("role", "")
|
||||
if not role:
|
||||
continue
|
||||
msg_dict: dict = {"role": role}
|
||||
raw_content = msg.get("content")
|
||||
if role == "assistant" and isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_assistant_content(raw_content)
|
||||
elif isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_tool_result_content(raw_content)
|
||||
else:
|
||||
msg_dict["content"] = raw_content or ""
|
||||
messages.append(msg_dict)
|
||||
return messages
|
||||
|
||||
|
||||
def _messages_to_transcript(messages: list[dict]) -> str:
|
||||
"""Convert compressed message dicts back to JSONL transcript format.
|
||||
|
||||
Rebuilds a minimal JSONL transcript from the ``{"role", "content"}``
|
||||
dicts returned by ``compress_context``. Each message becomes one JSONL
|
||||
line with a fresh ``uuid`` / ``parentUuid`` chain so the CLI's
|
||||
``--resume`` flag can reconstruct a valid conversation tree.
|
||||
|
||||
Assistant messages are wrapped in the full ``message`` envelope
|
||||
(``id``, ``model``, ``stop_reason``, structured ``content`` blocks)
|
||||
that the CLI expects. User messages use the simpler ``{role, content}``
|
||||
form.
|
||||
|
||||
Returns:
|
||||
A newline-terminated JSONL string, or an empty string if *messages*
|
||||
is empty.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str = "" # root entry uses empty string, not null
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
uid = str(uuid4())
|
||||
content = msg.get("content", "")
|
||||
if role == "assistant":
|
||||
message: dict = {
|
||||
"role": "assistant",
|
||||
"model": "",
|
||||
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
|
||||
"type": ENTRY_TYPE_MESSAGE,
|
||||
"content": [{"type": "text", "text": content}] if content else [],
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"stop_sequence": None,
|
||||
}
|
||||
else:
|
||||
message = {"role": role, "content": content}
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": message,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n" if lines else ""
|
||||
|
||||
|
||||
_COMPACTION_TIMEOUT_SECONDS = 60
|
||||
_TRUNCATION_TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
async def _run_compression(
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
log_prefix: str,
|
||||
) -> CompressResult:
|
||||
"""Run LLM-based compression with truncation fallback.
|
||||
|
||||
Uses the shared OpenAI client from ``get_openai_client()``.
|
||||
If no client is configured or the LLM call fails, falls back to
|
||||
truncation-based compression which drops older messages without
|
||||
summarization.
|
||||
|
||||
A 60-second timeout prevents a hung LLM call from blocking the
|
||||
retry path indefinitely. The truncation fallback also has a
|
||||
30-second timeout to guard against slow tokenization on very large
|
||||
transcripts.
|
||||
"""
|
||||
client = get_openai_client()
|
||||
if client is None:
|
||||
logger.warning("%s No OpenAI client configured, using truncation", log_prefix)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=client),
|
||||
timeout=_COMPACTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
async def compact_transcript(
|
||||
content: str,
|
||||
*,
|
||||
model: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> str | None:
|
||||
"""Compact an oversized JSONL transcript using LLM summarization.
|
||||
|
||||
Converts transcript entries to plain messages, runs ``compress_context``
|
||||
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||
|
||||
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
|
||||
is flattened to plain text for compression. This matches the fidelity of
|
||||
the Plan C (DB compression) fallback path, where
|
||||
``_format_conversation_context`` similarly renders tool calls as
|
||||
``You called tool: name(args)`` and results as ``Tool result: ...``.
|
||||
Neither path preserves structured API content blocks — the compacted
|
||||
context serves as text history for the LLM, which creates proper
|
||||
structured tool calls going forward.
|
||||
|
||||
Images are per-turn attachments loaded from workspace storage by file ID
|
||||
(via ``_prepare_file_attachments``), not part of the conversation history.
|
||||
They are re-attached each turn and are unaffected by compaction.
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
|
||||
See also:
|
||||
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
|
||||
lists for pre-query DB history. Both share ``compress_context()``
|
||||
but operate on different input formats (JSONL transcript entries
|
||||
here vs. ChatMessage dicts there).
|
||||
"""
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||
return None
|
||||
try:
|
||||
result = await _run_compression(messages, model, log_prefix)
|
||||
if not result.was_compacted:
|
||||
# Compressor says it's within budget, but the SDK rejected it.
|
||||
# Return None so the caller falls through to DB fallback.
|
||||
logger.warning(
|
||||
"%s Compressor reports within budget but SDK rejected — "
|
||||
"signalling failure",
|
||||
log_prefix,
|
||||
)
|
||||
return None
|
||||
logger.info(
|
||||
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||
log_prefix,
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
compacted = _messages_to_transcript(result.messages)
|
||||
if not validate_transcript(compacted):
|
||||
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||
return None
|
||||
return compacted
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -68,7 +68,7 @@ class TranscriptBuilder:
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
isCompactSummary=data.get("isCompactSummary") or None,
|
||||
isCompactSummary=data.get("isCompactSummary"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Unit tests for JSONL transcript management utilities."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -9,9 +10,7 @@ from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
STRIPPABLE_TYPES,
|
||||
_cli_project_dir,
|
||||
delete_transcript,
|
||||
read_cli_session_file,
|
||||
read_compacted_entries,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
@@ -292,85 +291,6 @@ class TestStripProgressEntries:
|
||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||
|
||||
|
||||
# --- read_cli_session_file ---
|
||||
|
||||
|
||||
class TestReadCliSessionFile:
|
||||
def test_no_matching_files_returns_none(self, tmp_path, monkeypatch):
|
||||
"""read_cli_session_file returns None when no .jsonl files exist."""
|
||||
# Create a project dir with no jsonl files
|
||||
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||
project_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||
lambda sdk_cwd: str(project_dir),
|
||||
)
|
||||
assert read_cli_session_file("/fake/cwd") is None
|
||||
|
||||
def test_one_jsonl_file_returns_content(self, tmp_path, monkeypatch):
|
||||
"""read_cli_session_file returns the content of a single .jsonl file."""
|
||||
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||
project_dir.mkdir(parents=True)
|
||||
jsonl_file = project_dir / "session.jsonl"
|
||||
jsonl_file.write_text("line1\nline2\n")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||
lambda sdk_cwd: str(project_dir),
|
||||
)
|
||||
result = read_cli_session_file("/fake/cwd")
|
||||
assert result == "line1\nline2\n"
|
||||
|
||||
def test_symlink_escaping_project_dir_is_skipped(self, tmp_path, monkeypatch):
|
||||
"""read_cli_session_file skips symlinks that escape the project dir."""
|
||||
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||
project_dir.mkdir(parents=True)
|
||||
|
||||
# Create a file outside the project dir
|
||||
outside = tmp_path / "outside"
|
||||
outside.mkdir()
|
||||
outside_file = outside / "evil.jsonl"
|
||||
outside_file.write_text("should not be read\n")
|
||||
|
||||
# Symlink from inside project_dir to outside file
|
||||
symlink = project_dir / "evil.jsonl"
|
||||
symlink.symlink_to(outside_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||
lambda sdk_cwd: str(project_dir),
|
||||
)
|
||||
# The symlink target resolves outside project_dir, so it should be skipped
|
||||
result = read_cli_session_file("/fake/cwd")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- _cli_project_dir ---
|
||||
|
||||
|
||||
class TestCliProjectDir:
|
||||
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
|
||||
"""_cli_project_dir returns None when the project dir symlink escapes projects base."""
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir()
|
||||
projects_dir = config_dir / "projects"
|
||||
projects_dir.mkdir()
|
||||
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
# Create a symlink inside projects/ that points outside of it.
|
||||
# _cli_project_dir encodes the cwd as all-alnum-hyphens, so use a
|
||||
# cwd whose encoded form matches the symlink name we create.
|
||||
evil_target = tmp_path / "escaped"
|
||||
evil_target.mkdir()
|
||||
|
||||
# The encoded form of "/evil/cwd" is "-evil-cwd"
|
||||
symlink_path = projects_dir / "-evil-cwd"
|
||||
symlink_path.symlink_to(evil_target)
|
||||
|
||||
result = _cli_project_dir("/evil/cwd")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- delete_transcript ---
|
||||
|
||||
|
||||
@@ -382,7 +302,7 @@ class TestDeleteTranscript:
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -402,7 +322,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -420,7 +340,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -897,3 +817,386 @@ class TestCompactionFlowIntegration:
|
||||
output2 = builder2.to_jsonl()
|
||||
lines2 = [json.loads(line) for line in output2.strip().split("\n")]
|
||||
assert lines2[-1]["parentUuid"] == "a2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_compression (direct tests for the 3 code paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunCompression:
|
||||
"""Direct tests for ``_run_compression`` covering all 3 code paths.
|
||||
|
||||
Paths:
|
||||
(a) No OpenAI client configured → truncation fallback immediately.
|
||||
(b) LLM success → returns LLM-compressed result.
|
||||
(c) LLM call raises → truncation fallback.
|
||||
"""
|
||||
|
||||
def _make_compress_result(self, was_compacted: bool, msgs=None):
|
||||
"""Build a minimal CompressResult-like object."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
was_compacted=was_compacted,
|
||||
messages=msgs or [{"role": "user", "content": "summary"}],
|
||||
original_token_count=500,
|
||||
token_count=100 if was_compacted else 500,
|
||||
messages_summarized=2 if was_compacted else 0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation(self):
|
||||
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated"}]
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "hello"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called with client=None (truncation mode)
|
||||
call_kwargs = mock_compress.call_args
|
||||
assert (
|
||||
call_kwargs.kwargs.get("client") is None
|
||||
or (call_kwargs.args and call_kwargs.args[2] is None)
|
||||
or mock_compress.call_args[1].get("client") is None
|
||||
)
|
||||
assert result is truncation_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_success_returns_llm_result(self):
|
||||
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
llm_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "LLM summary"}]
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=llm_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called with the real client
|
||||
assert mock_compress.called
|
||||
assert result is llm_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_falls_back_to_truncation(self):
|
||||
"""Path (c): LLM call raises → truncation fallback used instead."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated fallback"}]
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
call_count = [0]
|
||||
|
||||
async def _compress_side_effect(**kwargs):
|
||||
call_count[0] += 1
|
||||
if kwargs.get("client") is not None:
|
||||
raise RuntimeError("LLM timeout")
|
||||
return truncation_result
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called twice: once for LLM (raises), once for truncation
|
||||
assert call_count[0] == 2
|
||||
assert result is truncation_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_timeout_falls_back_to_truncation(self):
|
||||
"""Path (d): LLM call exceeds timeout → truncation fallback used."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated after timeout"}]
|
||||
)
|
||||
call_count = [0]
|
||||
|
||||
async def _compress_side_effect(*, messages, model, client):
|
||||
call_count[0] += 1
|
||||
if client is not None:
|
||||
# Simulate a hang that exceeds the timeout
|
||||
await asyncio.sleep(9999)
|
||||
return truncation_result
|
||||
|
||||
fake_client = MagicMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=fake_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._COMPACTION_TIMEOUT_SECONDS",
|
||||
0.05,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._TRUNCATION_TIMEOUT_SECONDS",
|
||||
5,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called twice: once for LLM (times out), once truncation
|
||||
assert call_count[0] == 2
|
||||
assert result is truncation_result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cleanup_stale_project_dirs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupStaleProjectDirs:
|
||||
"""Tests for cleanup_stale_project_dirs (disk leak prevention)."""
|
||||
|
||||
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories matching copilot pattern older than threshold are removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
# Create a stale dir
|
||||
stale = projects_dir / "-tmp-copilot-old-session"
|
||||
stale.mkdir()
|
||||
# Set mtime to past the threshold
|
||||
import time
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
os.utime(stale, (old_time, old_time))
|
||||
|
||||
# Create a fresh dir
|
||||
fresh = projects_dir / "-tmp-copilot-new-session"
|
||||
fresh.mkdir()
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 1
|
||||
assert not stale.exists()
|
||||
assert fresh.exists()
|
||||
|
||||
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories not matching copilot pattern are left alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
# Non-copilot dir that's old
|
||||
import time
|
||||
|
||||
other = projects_dir / "some-other-project"
|
||||
other.mkdir()
|
||||
old_time = time.time() - 999999
|
||||
os.utime(other, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert other.exists()
|
||||
|
||||
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
|
||||
"""A directory exactly at the TTL boundary should NOT be removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
# Dir that's exactly at the TTL (age == threshold, not >) — should survive
|
||||
boundary = projects_dir / "-tmp-copilot-boundary"
|
||||
boundary.mkdir()
|
||||
boundary_time = time.time() - _STALE_PROJECT_DIR_SECONDS + 1
|
||||
os.utime(boundary, (boundary_time, boundary_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert boundary.exists()
|
||||
|
||||
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
|
||||
"""Regular files matching the copilot pattern are not removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
# Create a regular FILE (not a dir) with the copilot pattern name
|
||||
stale_file = projects_dir / "-tmp-copilot-stale-file"
|
||||
stale_file.write_text("not a dir")
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
os.utime(stale_file, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert stale_file.exists()
|
||||
|
||||
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
|
||||
"""If the projects base directory doesn't exist, return 0 gracefully."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: nonexistent,
|
||||
)
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
|
||||
def test_scoped_removes_only_target_dir(self, tmp_path, monkeypatch):
|
||||
"""When encoded_cwd is supplied only that directory is swept."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
|
||||
# Two stale copilot dirs
|
||||
target = projects_dir / "-tmp-copilot-session-abc"
|
||||
target.mkdir()
|
||||
os.utime(target, (old_time, old_time))
|
||||
|
||||
other = projects_dir / "-tmp-copilot-session-xyz"
|
||||
other.mkdir()
|
||||
os.utime(other, (old_time, old_time))
|
||||
|
||||
# Only the target dir should be removed
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-abc")
|
||||
assert removed == 1
|
||||
assert not target.exists()
|
||||
assert other.exists() # untouched — not the current session
|
||||
|
||||
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep leaves a fresh directory alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
fresh = projects_dir / "-tmp-copilot-session-new"
|
||||
fresh.mkdir()
|
||||
# mtime is now — well within TTL
|
||||
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-new")
|
||||
assert removed == 0
|
||||
assert fresh.exists()
|
||||
|
||||
def test_scoped_non_copilot_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep refuses to remove a non-copilot directory."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
non_copilot = projects_dir / "some-other-project"
|
||||
non_copilot.mkdir()
|
||||
os.utime(non_copilot, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="some-other-project")
|
||||
assert removed == 0
|
||||
assert non_copilot.exists()
|
||||
|
||||
@@ -4,11 +4,12 @@ These tests verify the complete copilot flow using dummy implementations
|
||||
for agent generator and SDK service, allowing automated testing without
|
||||
external LLM calls.
|
||||
|
||||
Enable test mode with COPILOT_TEST_MODE=true environment variable.
|
||||
Enable test mode with CHAT_TEST_MODE=true environment variable (or in .env).
|
||||
|
||||
Note: StreamFinish is NOT emitted by the dummy service — it is published
|
||||
by mark_session_completed in the processor layer. These tests only cover
|
||||
the service-level streaming output (StreamStart + StreamTextDelta).
|
||||
The dummy service emits the full AI SDK protocol event sequence:
|
||||
StreamStart → StreamStartStep → StreamTextStart → StreamTextDelta(s) →
|
||||
StreamTextEnd → StreamFinishStep → StreamFinish.
|
||||
The processor skips StreamFinish and publishes its own via mark_session_completed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -20,9 +21,14 @@ import pytest
|
||||
from backend.copilot.model import ChatMessage, ChatSession, upsert_chat_session
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
)
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
|
||||
@@ -30,9 +36,9 @@ from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_test_mode():
|
||||
"""Enable test mode for all tests in this module."""
|
||||
os.environ["COPILOT_TEST_MODE"] = "true"
|
||||
os.environ["CHAT_TEST_MODE"] = "true"
|
||||
yield
|
||||
os.environ.pop("COPILOT_TEST_MODE", None)
|
||||
os.environ.pop("CHAT_TEST_MODE", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -110,9 +116,14 @@ async def test_streaming_event_types():
|
||||
):
|
||||
event_types.add(type(event).__name__)
|
||||
|
||||
# Required event types (StreamFinish is published by processor, not service)
|
||||
# Required event types for full AI SDK protocol
|
||||
assert "StreamStart" in event_types, "Missing StreamStart"
|
||||
assert "StreamStartStep" in event_types, "Missing StreamStartStep"
|
||||
assert "StreamTextStart" in event_types, "Missing StreamTextStart"
|
||||
assert "StreamTextDelta" in event_types, "Missing StreamTextDelta"
|
||||
assert "StreamTextEnd" in event_types, "Missing StreamTextEnd"
|
||||
assert "StreamFinishStep" in event_types, "Missing StreamFinishStep"
|
||||
assert "StreamFinish" in event_types, "Missing StreamFinish"
|
||||
|
||||
print(f"✅ Event types: {sorted(event_types)}")
|
||||
|
||||
@@ -175,16 +186,17 @@ async def test_streaming_heartbeat_timing():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling():
|
||||
"""Test that errors are properly formatted and sent."""
|
||||
# This would require a dummy that can trigger errors
|
||||
# For now, just verify error event structure
|
||||
|
||||
"""Test that error events have correct SSE structure."""
|
||||
error = StreamError(errorText="Test error", code="test_error")
|
||||
assert error.errorText == "Test error"
|
||||
assert error.code == "test_error"
|
||||
assert str(error.type.value) in ["error", "error"]
|
||||
|
||||
print("✅ Error structure verified")
|
||||
# Verify to_sse() strips code (AI SDK protocol compliance)
|
||||
sse = error.to_sse()
|
||||
assert '"errorText"' in sse
|
||||
assert '"code"' not in sse, "to_sse() must strip code field for AI SDK"
|
||||
|
||||
print("✅ Error structure verified (code stripped in SSE)")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -326,20 +338,85 @@ async def test_stream_completeness():
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Check for required events (StreamFinish is published by processor)
|
||||
has_start = any(isinstance(e, StreamStart) for e in events)
|
||||
has_text = any(isinstance(e, StreamTextDelta) for e in events)
|
||||
|
||||
assert has_start, "Stream must include StreamStart"
|
||||
assert has_text, "Stream must include text deltas"
|
||||
# Check for all required event types
|
||||
assert any(isinstance(e, StreamStart) for e in events), "Missing StreamStart"
|
||||
assert any(
|
||||
isinstance(e, StreamStartStep) for e in events
|
||||
), "Missing StreamStartStep"
|
||||
assert any(
|
||||
isinstance(e, StreamTextStart) for e in events
|
||||
), "Missing StreamTextStart"
|
||||
assert any(
|
||||
isinstance(e, StreamTextDelta) for e in events
|
||||
), "Missing StreamTextDelta"
|
||||
assert any(isinstance(e, StreamTextEnd) for e in events), "Missing StreamTextEnd"
|
||||
assert any(
|
||||
isinstance(e, StreamFinishStep) for e in events
|
||||
), "Missing StreamFinishStep"
|
||||
assert any(isinstance(e, StreamFinish) for e in events), "Missing StreamFinish"
|
||||
|
||||
# Verify exactly one start
|
||||
start_count = sum(1 for e in events if isinstance(e, StreamStart))
|
||||
assert start_count == 1, f"Should have exactly 1 StreamStart, got {start_count}"
|
||||
|
||||
print(
|
||||
f"✅ Completeness: 1 start, {sum(1 for e in events if isinstance(e, StreamTextDelta))} text deltas"
|
||||
)
|
||||
print(f"✅ Completeness: {len(events)} events, full protocol sequence")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transient_error_shows_retryable():
|
||||
"""Test __test_transient_error__ yields partial text then retryable StreamError."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-transient",
|
||||
message="please fail __test_transient_error__",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Should start with StreamStart
|
||||
assert isinstance(events[0], StreamStart)
|
||||
|
||||
# Should have some partial text before the error
|
||||
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
|
||||
assert len(text_events) > 0, "Should stream partial text before error"
|
||||
|
||||
# Should end with StreamError
|
||||
error_events = [e for e in events if isinstance(e, StreamError)]
|
||||
assert len(error_events) == 1, "Should have exactly one StreamError"
|
||||
assert error_events[0].code == "transient_api_error"
|
||||
assert "connection interrupted" in error_events[0].errorText.lower()
|
||||
|
||||
print(f"✅ Transient error: {len(text_events)} partial deltas + retryable error")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fatal_error_not_retryable():
|
||||
"""Test __test_fatal_error__ yields StreamError without retryable code."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-fatal",
|
||||
message="__test_fatal_error__",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert isinstance(events[0], StreamStart)
|
||||
|
||||
# Should have StreamError with sdk_error code (not transient)
|
||||
error_events = [e for e in events if isinstance(e, StreamError)]
|
||||
assert len(error_events) == 1
|
||||
assert error_events[0].code == "sdk_error"
|
||||
assert "transient" not in error_events[0].code
|
||||
|
||||
# Should NOT have any text deltas (fatal errors fail immediately)
|
||||
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
|
||||
assert len(text_events) == 0, "Fatal error should not stream any text"
|
||||
|
||||
print("✅ Fatal error: immediate error, no partial text")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -395,6 +472,8 @@ if __name__ == "__main__":
|
||||
asyncio.run(test_message_deduplication())
|
||||
asyncio.run(test_event_ordering())
|
||||
asyncio.run(test_stream_completeness())
|
||||
asyncio.run(test_transient_error_shows_retryable())
|
||||
asyncio.run(test_fatal_error_not_retryable())
|
||||
asyncio.run(test_text_delta_consistency())
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
@@ -84,6 +85,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
"connect_integration": ConnectIntegrationTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
|
||||
@@ -22,13 +22,12 @@ class AddUnderstandingTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Capture and store information about the user's business context,
|
||||
workflows, pain points, and automation goals. Call this tool whenever the user
|
||||
shares information about their business. Each call incrementally adds to the
|
||||
existing understanding - you don't need to provide all fields at once.
|
||||
|
||||
Use this to build a comprehensive profile that helps recommend better agents
|
||||
and automations for the user's specific needs."""
|
||||
return (
|
||||
"Store user's business context, workflows, pain points, and automation goals. "
|
||||
"Call whenever the user shares business info. Each call incrementally merges "
|
||||
"with existing data — provide only the fields you have. "
|
||||
"Builds a profile that helps recommend better agents for the user's needs."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
|
||||
@@ -20,7 +20,9 @@ SSRF protection:
|
||||
|
||||
Requires:
|
||||
npm install -g agent-browser
|
||||
agent-browser install (downloads Chromium, one-time per machine)
|
||||
In Docker: system chromium package with AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium
|
||||
(set automatically — no `agent-browser install` needed).
|
||||
Locally: run `agent-browser install` to download Chromium.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -408,18 +410,11 @@ class BrowserNavigateTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Navigate to a URL using a real browser. Returns an accessibility "
|
||||
"tree snapshot listing the page's interactive elements with @ref IDs "
|
||||
"(e.g. @e3) that can be used with browser_act. "
|
||||
"Session persists — cookies and login state carry over between calls. "
|
||||
"Use this (with browser_act) for multi-step interaction: login flows, "
|
||||
"form filling, button clicks, or anything requiring page interaction. "
|
||||
"For plain static pages, prefer web_fetch — no browser overhead. "
|
||||
"For authenticated pages: navigate to the login page first, use browser_act "
|
||||
"to fill credentials and submit, then navigate to the target page. "
|
||||
"Note: for slow SPAs, the returned snapshot may reflect a partially-loaded "
|
||||
"state. If elements seem missing, use browser_act with action='wait' and a "
|
||||
"CSS selector or millisecond delay, then take a browser_screenshot to verify."
|
||||
"Navigate to a URL in a real browser. Returns accessibility tree with @ref IDs "
|
||||
"for browser_act. Session persists (cookies/auth carry over). "
|
||||
"For static pages, prefer web_fetch. "
|
||||
"For SPAs, elements may load late — use browser_act with wait + browser_screenshot to verify. "
|
||||
"For auth: navigate to login, fill creds and submit with browser_act, then navigate to target."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -429,13 +424,13 @@ class BrowserNavigateTool(BaseTool):
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to navigate to.",
|
||||
"description": "HTTP/HTTPS URL to navigate to.",
|
||||
},
|
||||
"wait_for": {
|
||||
"type": "string",
|
||||
"enum": ["networkidle", "load", "domcontentloaded"],
|
||||
"default": "networkidle",
|
||||
"description": "When to consider navigation complete. Use 'networkidle' for SPAs (default).",
|
||||
"description": "Navigation completion strategy (default: networkidle).",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
@@ -554,14 +549,12 @@ class BrowserActTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Interact with the current browser page. Use @ref IDs from the "
|
||||
"snapshot (e.g. '@e3') to target elements. Returns an updated snapshot. "
|
||||
"Supported actions: click, dblclick, fill, type, scroll, hover, press, "
|
||||
"Interact with the current browser page using @ref IDs from the snapshot. "
|
||||
"Actions: click, dblclick, fill, type, scroll, hover, press, "
|
||||
"check, uncheck, select, wait, back, forward, reload. "
|
||||
"fill clears the field before typing; type appends without clearing. "
|
||||
"wait accepts a CSS selector (waits for element) or milliseconds string (e.g. '1000'). "
|
||||
"Example login flow: fill @e1 with email → fill @e2 with password → "
|
||||
"click @e3 (submit) → browser_navigate to the target page."
|
||||
"fill clears field first; type appends. "
|
||||
"wait accepts CSS selector or milliseconds (e.g. '1000'). "
|
||||
"Returns updated snapshot."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -587,30 +580,21 @@ class BrowserActTool(BaseTool):
|
||||
"forward",
|
||||
"reload",
|
||||
],
|
||||
"description": "The action to perform.",
|
||||
"description": "Action to perform.",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Element to target. Use @ref from snapshot (e.g. '@e3'), "
|
||||
"a CSS selector, or a text description. "
|
||||
"Required for: click, dblclick, fill, type, hover, check, uncheck, select. "
|
||||
"For wait: a CSS selector to wait for, or milliseconds as a string (e.g. '1000')."
|
||||
),
|
||||
"description": "@ref ID (e.g. '@e3'), CSS selector, or text. Required for: click, dblclick, fill, type, hover, check, uncheck, select. For wait: CSS selector or milliseconds string (e.g. '1000').",
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"For fill/type: the text to enter. "
|
||||
"For press: key name (e.g. 'Enter', 'Tab', 'Control+a'). "
|
||||
"For select: the option value to select."
|
||||
),
|
||||
"description": "Text for fill/type, key for press (e.g. 'Enter'), option for select.",
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"enum": ["up", "down", "left", "right"],
|
||||
"default": "down",
|
||||
"description": "For scroll: direction to scroll.",
|
||||
"description": "Scroll direction (default: down).",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
@@ -757,12 +741,10 @@ class BrowserScreenshotTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Take a screenshot of the current browser page and save it to the workspace. "
|
||||
"IMPORTANT: After calling this tool, immediately call read_workspace_file "
|
||||
"with the returned file_id to display the image inline to the user — "
|
||||
"the screenshot is not visible until you do this. "
|
||||
"With annotate=true (default), @ref labels are overlaid on interactive "
|
||||
"elements, making it easy to see which @ref ID maps to which element on screen."
|
||||
"Screenshot the current browser page and save to workspace. "
|
||||
"annotate=true overlays @ref labels on elements. "
|
||||
"IMPORTANT: After calling, you MUST immediately call read_workspace_file with the "
|
||||
"returned file_id to display the image inline."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -773,12 +755,12 @@ class BrowserScreenshotTool(BaseTool):
|
||||
"annotate": {
|
||||
"type": "boolean",
|
||||
"default": True,
|
||||
"description": "Overlay @ref labels on interactive elements (default: true).",
|
||||
"description": "Overlay @ref labels (default: true).",
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"default": "screenshot.png",
|
||||
"description": "Filename to save in the workspace.",
|
||||
"description": "Workspace filename (default: screenshot.png).",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,351 @@
|
||||
"""Integration tests for agent-browser + system chromium.
|
||||
|
||||
These tests actually invoke the agent-browser binary via subprocess and require:
|
||||
- agent-browser installed (npm install -g agent-browser)
|
||||
- AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium (set in Docker)
|
||||
|
||||
Run with:
|
||||
poetry run test
|
||||
|
||||
Or to run only this file:
|
||||
poetry run pytest backend/copilot/tools/agent_browser_integration_test.py -v -p no:autogpt_platform
|
||||
|
||||
Skipped automatically when agent-browser binary is not found.
|
||||
Tests that hit external sites are marked ``integration`` and skipped by default
|
||||
in CI (use ``-m integration`` to include them).
|
||||
|
||||
Two test tiers:
|
||||
- CLI tests: call agent-browser subprocess directly (no backend imports needed)
|
||||
- Tool class tests: call BrowserNavigateTool/BrowserActTool._execute() directly
|
||||
with user_id=None (skips workspace/DB interactions — no Postgres/RabbitMQ needed)
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.agent_browser import BrowserActTool, BrowserNavigateTool
|
||||
from backend.copilot.tools.models import (
|
||||
BrowserActResponse,
|
||||
BrowserNavigateResponse,
|
||||
ErrorResponse,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
shutil.which("agent-browser") is None,
|
||||
reason="agent-browser binary not found",
|
||||
)
|
||||
|
||||
_SESSION = "integration-test-session"
|
||||
|
||||
|
||||
def _agent_browser(
|
||||
*args: str, session: str = _SESSION, timeout: int = 30
|
||||
) -> tuple[int, str, str]:
|
||||
"""Run agent-browser for the given session, return (rc, stdout, stderr)."""
|
||||
result = subprocess.run(
|
||||
["agent-browser", "--session", session, "--session-name", session, *args],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
return result.returncode, result.stdout, result.stderr
|
||||
|
||||
|
||||
def _close_session(session: str, timeout: int = 5) -> None:
|
||||
"""Best-effort close for a browser session; never raises on failure."""
|
||||
try:
|
||||
subprocess.run(
|
||||
["agent-browser", "--session", session, "--session-name", session, "close"],
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _teardown():
|
||||
"""Close the shared test session after each test (best-effort)."""
|
||||
yield
|
||||
_close_session(_SESSION)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_chromium_executable_env_is_set():
|
||||
"""AGENT_BROWSER_EXECUTABLE_PATH must be set and point to an executable binary."""
|
||||
exe = os.environ.get("AGENT_BROWSER_EXECUTABLE_PATH", "")
|
||||
assert exe, "AGENT_BROWSER_EXECUTABLE_PATH is not set"
|
||||
assert os.path.isfile(exe), f"Chromium binary not found at {exe}"
|
||||
assert os.access(exe, os.X_OK), f"Chromium binary at {exe} is not executable"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_navigate_returns_success():
|
||||
"""agent-browser can open a public URL using system chromium."""
|
||||
rc, _, stderr = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0, f"open failed (rc={rc}): {stderr}"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_title_after_navigate():
|
||||
"""get title returns the page title after navigation."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, stdout, stderr = _agent_browser("get", "title", timeout=10)
|
||||
assert rc == 0, f"get title failed: {stderr}"
|
||||
assert "example" in stdout.lower()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_url_after_navigate():
|
||||
"""get url returns the navigated URL."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, stdout, stderr = _agent_browser("get", "url", timeout=10)
|
||||
assert rc == 0, f"get url failed: {stderr}"
|
||||
assert urlparse(stdout.strip()).netloc == "example.com"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_snapshot_returns_interactive_elements():
|
||||
"""snapshot -i -c lists interactive elements on the page."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, stdout, stderr = _agent_browser("snapshot", "-i", "-c", timeout=15)
|
||||
assert rc == 0, f"snapshot failed: {stderr}"
|
||||
assert len(stdout.strip()) > 0, "snapshot returned empty output"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_screenshot_produces_valid_png():
|
||||
"""screenshot saves a non-empty, valid PNG file."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
tmp = f.name
|
||||
try:
|
||||
rc, _, stderr = _agent_browser("screenshot", tmp, timeout=15)
|
||||
assert rc == 0, f"screenshot failed: {stderr}"
|
||||
size = os.path.getsize(tmp)
|
||||
assert size > 1000, f"PNG too small ({size} bytes) — likely blank or corrupt"
|
||||
with open(tmp, "rb") as f:
|
||||
assert f.read(4) == b"\x89PNG", "Output is not a valid PNG"
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_scroll_down():
|
||||
"""scroll down succeeds without error."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, _, stderr = _agent_browser("scroll", "down", timeout=10)
|
||||
assert rc == 0, f"scroll failed: {stderr}"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_fill_form_field():
|
||||
"""fill writes text into an input field."""
|
||||
rc, _, _ = _agent_browser("open", "https://httpbin.org/forms/post")
|
||||
assert rc == 0
|
||||
|
||||
rc, _, stderr = _agent_browser(
|
||||
"fill", "input[name=custname]", "IntegrationTestUser", timeout=10
|
||||
)
|
||||
assert rc == 0, f"fill failed: {stderr}"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_concurrent_independent_sessions():
|
||||
"""Two independent sessions can navigate in parallel without interference."""
|
||||
session_a = "integration-concurrent-a"
|
||||
session_b = "integration-concurrent-b"
|
||||
|
||||
try:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
||||
fut_a = pool.submit(
|
||||
_agent_browser, "open", "https://example.com", session=session_a
|
||||
)
|
||||
fut_b = pool.submit(
|
||||
_agent_browser, "open", "https://httpbin.org/html", session=session_b
|
||||
)
|
||||
rc_a, _, err_a = fut_a.result(timeout=40)
|
||||
rc_b, _, err_b = fut_b.result(timeout=40)
|
||||
assert rc_a == 0, f"session_a open failed: {err_a}"
|
||||
assert rc_b == 0, f"session_b open failed: {err_b}"
|
||||
|
||||
rc_ua, url_a, err_ua = _agent_browser(
|
||||
"get", "url", session=session_a, timeout=10
|
||||
)
|
||||
rc_ub, url_b, err_ub = _agent_browser(
|
||||
"get", "url", session=session_b, timeout=10
|
||||
)
|
||||
assert rc_ua == 0, f"session_a get url failed: {err_ua}"
|
||||
assert rc_ub == 0, f"session_b get url failed: {err_ub}"
|
||||
assert urlparse(url_a.strip()).netloc == "example.com"
|
||||
assert urlparse(url_b.strip()).netloc == "httpbin.org"
|
||||
finally:
|
||||
_close_session(session_a)
|
||||
_close_session(session_b)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_close_session():
|
||||
"""close shuts down the browser daemon cleanly."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, _, stderr = _agent_browser("close", timeout=10)
|
||||
assert rc == 0, f"close failed: {stderr}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Python tool class integration tests
|
||||
#
|
||||
# These tests exercise the actual BrowserNavigateTool / BrowserActTool Python
|
||||
# classes (not just the CLI binary) to verify the full call path — URL
|
||||
# validation, subprocess dispatch, response parsing — works with system
|
||||
# chromium. user_id=None skips workspace/DB interactions so no Postgres or
|
||||
# RabbitMQ is needed.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TOOL_SESSION_ID = "integration-tool-test-session"
|
||||
_TEST_SESSION = ChatSession(
|
||||
session_id=_TOOL_SESSION_ID,
|
||||
user_id="test-user",
|
||||
messages=[],
|
||||
usage=[],
|
||||
started_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=False)
|
||||
def _close_tool_session():
|
||||
"""Tear down the tool-test browser session after each tool test."""
|
||||
yield
|
||||
_close_session(_TOOL_SESSION_ID)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_navigate_returns_response(_close_tool_session):
|
||||
"""BrowserNavigateTool._execute returns a BrowserNavigateResponse with real content."""
|
||||
tool = BrowserNavigateTool()
|
||||
resp = await tool._execute(
|
||||
user_id=None, session=_TEST_SESSION, url="https://example.com"
|
||||
)
|
||||
assert isinstance(
|
||||
resp, BrowserNavigateResponse
|
||||
), f"Expected BrowserNavigateResponse, got: {resp}"
|
||||
assert urlparse(resp.url).netloc == "example.com"
|
||||
assert resp.title, "Expected non-empty page title"
|
||||
assert resp.snapshot, "Expected non-empty accessibility snapshot"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"ssrf_url",
|
||||
[
|
||||
"http://169.254.169.254/", # AWS/GCP/Azure metadata endpoint
|
||||
"http://127.0.0.1/", # IPv4 loopback
|
||||
"http://10.0.0.1/", # RFC-1918 private range
|
||||
"http://[::1]/", # IPv6 loopback
|
||||
"http://0.0.0.0/", # Wildcard / INADDR_ANY
|
||||
],
|
||||
)
|
||||
async def test_tool_navigate_blocked_url(ssrf_url: str, _close_tool_session):
|
||||
"""BrowserNavigateTool._execute rejects internal/private URLs (SSRF guard)."""
|
||||
tool = BrowserNavigateTool()
|
||||
resp = await tool._execute(user_id=None, session=_TEST_SESSION, url=ssrf_url)
|
||||
assert isinstance(
|
||||
resp, ErrorResponse
|
||||
), f"Expected ErrorResponse for SSRF URL {ssrf_url!r}, got: {resp}"
|
||||
assert resp.error == "blocked_url"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_navigate_missing_url(_close_tool_session):
|
||||
"""BrowserNavigateTool._execute returns an error when url is empty."""
|
||||
tool = BrowserNavigateTool()
|
||||
resp = await tool._execute(user_id=None, session=_TEST_SESSION, url="")
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert resp.error == "missing_url"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_act_scroll(_close_tool_session):
|
||||
"""BrowserActTool._execute can scroll after a navigate."""
|
||||
nav = BrowserNavigateTool()
|
||||
nav_resp = await nav._execute(
|
||||
user_id=None, session=_TEST_SESSION, url="https://example.com"
|
||||
)
|
||||
assert isinstance(nav_resp, BrowserNavigateResponse)
|
||||
|
||||
act = BrowserActTool()
|
||||
resp = await act._execute(
|
||||
user_id=None, session=_TEST_SESSION, action="scroll", direction="down"
|
||||
)
|
||||
assert isinstance(
|
||||
resp, BrowserActResponse
|
||||
), f"Expected BrowserActResponse, got: {resp}"
|
||||
assert resp.action == "scroll"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_act_fill_and_click(_close_tool_session):
|
||||
"""BrowserActTool._execute can fill a form field."""
|
||||
nav = BrowserNavigateTool()
|
||||
nav_resp = await nav._execute(
|
||||
user_id=None, session=_TEST_SESSION, url="https://httpbin.org/forms/post"
|
||||
)
|
||||
assert isinstance(nav_resp, BrowserNavigateResponse)
|
||||
|
||||
act = BrowserActTool()
|
||||
resp = await act._execute(
|
||||
user_id=None,
|
||||
session=_TEST_SESSION,
|
||||
action="fill",
|
||||
target="input[name=custname]",
|
||||
value="ToolIntegrationTest",
|
||||
)
|
||||
assert isinstance(resp, BrowserActResponse), f"fill failed: {resp}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_act_missing_action(_close_tool_session):
|
||||
"""BrowserActTool._execute returns an error when action is missing."""
|
||||
act = BrowserActTool()
|
||||
resp = await act._execute(user_id=None, session=_TEST_SESSION, action="")
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert resp.error == "missing_action"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_act_missing_target(_close_tool_session):
|
||||
"""BrowserActTool._execute returns an error when click target is missing."""
|
||||
act = BrowserActTool()
|
||||
resp = await act._execute(
|
||||
user_id=None, session=_TEST_SESSION, action="click", target=""
|
||||
)
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert resp.error == "missing_target"
|
||||
@@ -7,6 +7,7 @@ from typing import Any
|
||||
from .helpers import (
|
||||
AGENT_EXECUTOR_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
TOOL_ORCHESTRATOR_BLOCK_ID,
|
||||
AgentDict,
|
||||
are_types_compatible,
|
||||
generate_uuid,
|
||||
@@ -30,6 +31,14 @@ _GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
||||
_GMAIL_SEND_BLOCK_ID = "6c27abc2-e51d-499e-a85f-5a0041ba94f0"
|
||||
_TEXT_REPLACE_BLOCK_ID = "7e7c87ab-3469-4bcc-9abe-67705091b713"
|
||||
|
||||
# Defaults applied to OrchestratorBlock nodes by the fixer.
|
||||
_SDM_DEFAULTS: dict[str, int | bool] = {
|
||||
"agent_mode_max_iterations": 10,
|
||||
"conversation_compaction": True,
|
||||
"retry": 3,
|
||||
"multiple_tool_calls": False,
|
||||
}
|
||||
|
||||
|
||||
class AgentFixer:
|
||||
"""
|
||||
@@ -1630,6 +1639,43 @@ class AgentFixer:
|
||||
|
||||
return agent
|
||||
|
||||
def fix_orchestrator_blocks(self, agent: AgentDict) -> AgentDict:
|
||||
"""Fix OrchestratorBlock nodes to ensure agent-mode defaults.
|
||||
|
||||
Ensures:
|
||||
1. ``agent_mode_max_iterations`` defaults to ``10`` (bounded agent mode)
|
||||
2. ``conversation_compaction`` defaults to ``True``
|
||||
3. ``retry`` defaults to ``3``
|
||||
4. ``multiple_tool_calls`` defaults to ``False``
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to fix
|
||||
|
||||
Returns:
|
||||
The fixed agent dictionary
|
||||
"""
|
||||
nodes = agent.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != TOOL_ORCHESTRATOR_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id", "unknown")
|
||||
input_default = node.get("input_default")
|
||||
if not isinstance(input_default, dict):
|
||||
input_default = {}
|
||||
node["input_default"] = input_default
|
||||
|
||||
for field, default_value in _SDM_DEFAULTS.items():
|
||||
if field not in input_default or input_default[field] is None:
|
||||
input_default[field] = default_value
|
||||
self.add_fix_log(
|
||||
f"OrchestratorBlock {node_id}: "
|
||||
f"Set {field}={default_value!r}"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
def fix_dynamic_block_sink_names(self, agent: AgentDict) -> AgentDict:
|
||||
"""Fix links that use _#_ notation for dynamic block sink names.
|
||||
|
||||
@@ -1717,6 +1763,9 @@ class AgentFixer:
|
||||
# Apply fixes for MCPToolBlock nodes
|
||||
agent = self.fix_mcp_tool_blocks(agent)
|
||||
|
||||
# Apply fixes for OrchestratorBlock nodes (agent-mode defaults)
|
||||
agent = self.fix_orchestrator_blocks(agent)
|
||||
|
||||
# Apply fixes for AgentExecutorBlock nodes (sub-agents)
|
||||
if library_agents:
|
||||
agent = self.fix_agent_executor_blocks(agent, library_agents)
|
||||
|
||||
@@ -12,6 +12,7 @@ __all__ = [
|
||||
"AGENT_OUTPUT_BLOCK_ID",
|
||||
"AgentDict",
|
||||
"MCP_TOOL_BLOCK_ID",
|
||||
"TOOL_ORCHESTRATOR_BLOCK_ID",
|
||||
"UUID_REGEX",
|
||||
"are_types_compatible",
|
||||
"generate_uuid",
|
||||
@@ -33,6 +34,7 @@ UUID_REGEX = re.compile(r"^" + UUID_RE_STR + r"$")
|
||||
|
||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||
MCP_TOOL_BLOCK_ID = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||
TOOL_ORCHESTRATOR_BLOCK_ID = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from .helpers import (
|
||||
AGENT_INPUT_BLOCK_ID,
|
||||
AGENT_OUTPUT_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
TOOL_ORCHESTRATOR_BLOCK_ID,
|
||||
AgentDict,
|
||||
are_types_compatible,
|
||||
get_defined_property_type,
|
||||
@@ -181,15 +182,23 @@ class AgentValidator:
|
||||
|
||||
return valid
|
||||
|
||||
def _build_node_lookup(self, agent: AgentDict) -> dict[str, dict[str, Any]]:
|
||||
"""Build a node-id → node dict from the agent's nodes."""
|
||||
return {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
|
||||
def validate_data_type_compatibility(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
self,
|
||||
agent: AgentDict,
|
||||
blocks: list[dict[str, Any]],
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that linked data types are compatible between source and sink.
|
||||
Returns True if all data types are compatible, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
if node_lookup is None:
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
block_lookup = {block.get("id", ""): block for block in blocks}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
@@ -209,8 +218,8 @@ class AgentValidator:
|
||||
valid = False
|
||||
continue
|
||||
|
||||
source_node = node_lookup.get(source_id, "")
|
||||
sink_node = node_lookup.get(sink_id, "")
|
||||
source_node = node_lookup.get(source_id)
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
@@ -248,7 +257,10 @@ class AgentValidator:
|
||||
return valid
|
||||
|
||||
def validate_nested_sink_links(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
self,
|
||||
agent: AgentDict,
|
||||
blocks: list[dict[str, Any]],
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate nested sink links (links with _#_ notation).
|
||||
@@ -262,7 +274,8 @@ class AgentValidator:
|
||||
block_names = {
|
||||
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
|
||||
}
|
||||
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
if node_lookup is None:
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_name = link.get("sink_name", "")
|
||||
@@ -388,7 +401,10 @@ class AgentValidator:
|
||||
return valid
|
||||
|
||||
def validate_source_output_existence(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
self,
|
||||
agent: AgentDict,
|
||||
blocks: list[dict[str, Any]],
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that all source_names in links exist in the corresponding
|
||||
@@ -401,6 +417,7 @@ class AgentValidator:
|
||||
Args:
|
||||
agent: The agent dictionary to validate
|
||||
blocks: List of available blocks with their schemas
|
||||
node_lookup: Optional pre-built node-id → node dict
|
||||
|
||||
Returns:
|
||||
True if all source output fields exist, False otherwise
|
||||
@@ -415,7 +432,8 @@ class AgentValidator:
|
||||
block_names = {
|
||||
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
|
||||
}
|
||||
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
if node_lookup is None:
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_id = link.get("source_id")
|
||||
@@ -809,6 +827,96 @@ class AgentValidator:
|
||||
|
||||
return valid
|
||||
|
||||
def validate_orchestrator_blocks(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""Validate that OrchestratorBlock nodes have downstream tools.
|
||||
|
||||
Checks that each OrchestratorBlock node has at least one link
|
||||
with ``source_name == "tools"`` connecting to a downstream block.
|
||||
Without tools, the block has nothing to call and will error at runtime.
|
||||
|
||||
Returns True if all OrchestratorBlock nodes are valid.
|
||||
"""
|
||||
valid = True
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
if node_lookup is None:
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
non_tool_block_ids = {AGENT_INPUT_BLOCK_ID, AGENT_OUTPUT_BLOCK_ID}
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != TOOL_ORCHESTRATOR_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id", "unknown")
|
||||
customized_name = (node.get("metadata") or {}).get(
|
||||
"customized_name", node_id
|
||||
)
|
||||
|
||||
# Warn if agent_mode_max_iterations is 0 (traditional mode) —
|
||||
# requires complex external conversation-history loop wiring
|
||||
# that the agent generator does not produce.
|
||||
input_default = node.get("input_default", {})
|
||||
max_iter = input_default.get("agent_mode_max_iterations")
|
||||
if max_iter is not None and not isinstance(max_iter, int):
|
||||
self.add_error(
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"({node_id}) has non-integer "
|
||||
f"agent_mode_max_iterations={max_iter!r}. "
|
||||
f"This field must be an integer."
|
||||
)
|
||||
valid = False
|
||||
elif isinstance(max_iter, int) and max_iter < -1:
|
||||
self.add_error(
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"({node_id}) has invalid "
|
||||
f"agent_mode_max_iterations={max_iter}. "
|
||||
f"Use -1 for infinite or a positive number for "
|
||||
f"bounded iterations."
|
||||
)
|
||||
valid = False
|
||||
elif isinstance(max_iter, int) and max_iter > 100:
|
||||
self.add_error(
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"({node_id}) has agent_mode_max_iterations="
|
||||
f"{max_iter} which is unusually high. Values above "
|
||||
f"100 risk excessive cost and long execution times. "
|
||||
f"Consider using a lower value (3-10) or -1 for "
|
||||
f"genuinely open-ended tasks."
|
||||
)
|
||||
valid = False
|
||||
elif max_iter == 0:
|
||||
self.add_error(
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"({node_id}) has agent_mode_max_iterations=0 "
|
||||
f"(traditional mode). The agent generator only supports "
|
||||
f"agent mode (set to -1 for infinite or a positive "
|
||||
f"number for bounded iterations)."
|
||||
)
|
||||
valid = False
|
||||
|
||||
has_tools = any(
|
||||
link.get("source_id") == node_id
|
||||
and link.get("source_name") == "tools"
|
||||
and node_lookup.get(link.get("sink_id", ""), {}).get("block_id")
|
||||
not in non_tool_block_ids
|
||||
for link in links
|
||||
)
|
||||
|
||||
if not has_tools:
|
||||
self.add_error(
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"({node_id}) has no downstream tool blocks connected. "
|
||||
f"Connect at least one block to its 'tools' output so "
|
||||
f"the AI has tools to call."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_mcp_tool_blocks(self, agent: AgentDict) -> bool:
|
||||
"""Validate that MCPToolBlock nodes have required fields.
|
||||
|
||||
@@ -870,6 +978,9 @@ class AgentValidator:
|
||||
logger.info("Validating agent...")
|
||||
self.errors = []
|
||||
|
||||
# Build node lookup once and share across validation methods
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
|
||||
checks = [
|
||||
(
|
||||
"Block existence",
|
||||
@@ -885,15 +996,15 @@ class AgentValidator:
|
||||
),
|
||||
(
|
||||
"Data type compatibility",
|
||||
self.validate_data_type_compatibility(agent, blocks),
|
||||
self.validate_data_type_compatibility(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Nested sink links",
|
||||
self.validate_nested_sink_links(agent, blocks),
|
||||
self.validate_nested_sink_links(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Source output existence",
|
||||
self.validate_source_output_existence(agent, blocks),
|
||||
self.validate_source_output_existence(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Prompt double curly braces spaces",
|
||||
@@ -913,6 +1024,10 @@ class AgentValidator:
|
||||
"MCP tool blocks",
|
||||
self.validate_mcp_tool_blocks(agent),
|
||||
),
|
||||
(
|
||||
"Orchestrator blocks",
|
||||
self.validate_orchestrator_blocks(agent, node_lookup),
|
||||
),
|
||||
]
|
||||
|
||||
# Add AgentExecutorBlock detailed validation if library_agents
|
||||
|
||||
@@ -108,22 +108,12 @@ class AgentOutputTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Retrieve execution outputs from agents in the user's library.
|
||||
|
||||
Identify the agent using one of:
|
||||
- agent_name: Fuzzy search in user's library
|
||||
- library_agent_id: Exact library agent ID
|
||||
- store_slug: Marketplace format 'username/agent-name'
|
||||
|
||||
Select which run to retrieve using:
|
||||
- execution_id: Specific execution ID
|
||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||
|
||||
Wait for completion (optional):
|
||||
- wait_if_running: Max seconds to wait if execution is still running (0-300).
|
||||
If the execution is running/queued, waits up to this many seconds for completion.
|
||||
Returns current status on timeout. If already finished, returns immediately.
|
||||
"""
|
||||
return (
|
||||
"Retrieve execution outputs from a library agent. "
|
||||
"Identify by agent_name, library_agent_id, or store_slug. "
|
||||
"Filter by execution_id or run_time. "
|
||||
"Optionally wait for running executions."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -132,32 +122,29 @@ class AgentOutputTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "Agent name to search for in user's library (fuzzy match)",
|
||||
"description": "Agent name (fuzzy match).",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Exact library agent ID",
|
||||
"description": "Library agent ID.",
|
||||
},
|
||||
"store_slug": {
|
||||
"type": "string",
|
||||
"description": "Marketplace identifier: 'username/agent-slug'",
|
||||
"description": "Marketplace 'username/agent-name'.",
|
||||
},
|
||||
"execution_id": {
|
||||
"type": "string",
|
||||
"description": "Specific execution ID to retrieve",
|
||||
"description": "Specific execution ID.",
|
||||
},
|
||||
"run_time": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||
),
|
||||
"description": "Time filter: 'latest', 'today', 'yesterday', 'last week', 'last 7 days', 'last month', 'last 30 days', 'YYYY-MM-DD', or ISO datetime.",
|
||||
},
|
||||
"wait_if_running": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max seconds to wait if execution is still running (0-300). "
|
||||
"If running, waits for completion. Returns current state on timeout."
|
||||
),
|
||||
"description": "Max seconds to wait if still running (0-300). Returns current state on timeout.",
|
||||
"minimum": 0,
|
||||
"maximum": 300,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.api.features.store.model import StoreAgent, StoreAgentDetails
|
||||
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
@@ -19,16 +19,12 @@ from .models import (
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .utils import is_creator_slug, is_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchSource = Literal["marketplace", "library"]
|
||||
|
||||
_UUID_PATTERN = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Keywords that should be treated as "list all" rather than a literal search
|
||||
_LIST_ALL_KEYWORDS = frozenset({"all", "*", "everything", "any", ""})
|
||||
|
||||
@@ -39,149 +35,160 @@ async def search_agents(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Search for agents in marketplace or user library.
|
||||
"""Search for agents in marketplace or user library."""
|
||||
if source == "marketplace":
|
||||
return await _search_marketplace(query, session_id)
|
||||
else:
|
||||
return await _search_library(query, session_id, user_id)
|
||||
|
||||
For library searches, keywords like "all", "*", "everything", or an empty
|
||||
query will list all agents without filtering.
|
||||
|
||||
Args:
|
||||
query: Search query string. Special keywords list all library agents.
|
||||
source: "marketplace" or "library"
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (required for library search)
|
||||
|
||||
Returns:
|
||||
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
|
||||
"""
|
||||
# Normalize list-all keywords to empty string for library searches
|
||||
if source == "library" and query.lower().strip() in _LIST_ALL_KEYWORDS:
|
||||
query = ""
|
||||
|
||||
if source == "marketplace" and not query:
|
||||
async def _search_marketplace(query: str, session_id: str | None) -> ToolResponseBase:
|
||||
"""Search marketplace agents, with direct creator/slug lookup fallback."""
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query", session_id=session_id
|
||||
)
|
||||
|
||||
if source == "library" and not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agents: list[AgentInfo] = []
|
||||
try:
|
||||
if source == "marketplace":
|
||||
# Direct lookup if query matches "creator/slug" pattern
|
||||
if is_creator_slug(query):
|
||||
logger.info(f"Query looks like creator/slug, trying direct lookup: {query}")
|
||||
creator, slug = query.split("/", 1)
|
||||
agent_info = await _get_marketplace_agent_by_slug(creator, slug)
|
||||
if agent_info:
|
||||
agents.append(agent_info)
|
||||
|
||||
if not agents:
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db().get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=f"{agent.creator}/{agent.slug}",
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if _is_uuid(query):
|
||||
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
||||
|
||||
if not agents:
|
||||
search_term = query or None
|
||||
logger.info(
|
||||
f"{'Listing all agents in' if not query else 'Searching'} "
|
||||
f"user library{'' if not query else f' for: {query}'}"
|
||||
)
|
||||
results = await library_db().list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=search_term,
|
||||
page_size=50 if not query else 10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(_library_agent_to_info(agent))
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
agents.append(_marketplace_agent_to_info(agent))
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching {source}: {e}", exc_info=True)
|
||||
logger.error(f"Error searching marketplace: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to search {source}. Please try again.",
|
||||
message="Failed to search marketplace. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
if source == "marketplace":
|
||||
suggestions = [
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
]
|
||||
no_results_msg = (
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"No agents found matching '{query}'. Let the user know they can "
|
||||
"try different keywords or browse the marketplace. Also let them "
|
||||
"know you can create a custom agent for them based on their needs."
|
||||
),
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
message=(
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents. "
|
||||
"Let the user know we can create a custom agent for them based on their needs."
|
||||
),
|
||||
title=f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'",
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
async def _search_library(
|
||||
query: str, session_id: str | None, user_id: str | None
|
||||
) -> ToolResponseBase:
|
||||
"""Search user's library agents, with direct UUID lookup fallback."""
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
query = query.strip()
|
||||
# Normalize list-all keywords to empty string
|
||||
if query.lower() in _LIST_ALL_KEYWORDS:
|
||||
query = ""
|
||||
|
||||
agents: list[AgentInfo] = []
|
||||
try:
|
||||
if is_uuid(query):
|
||||
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||
agent = await _get_library_agent_by_id(user_id, query)
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
|
||||
if not agents:
|
||||
logger.info(
|
||||
f"{'Listing all agents in' if not query else 'Searching'} "
|
||||
f"user library{'' if not query else f' for: {query}'}"
|
||||
)
|
||||
elif not query:
|
||||
# User asked to list all but library is empty
|
||||
suggestions = [
|
||||
"Browse the marketplace to find and add agents",
|
||||
"Use find_agent to search the marketplace",
|
||||
]
|
||||
no_results_msg = (
|
||||
"Your library is empty. Let the user know they can browse the "
|
||||
"marketplace to find agents, or you can create a custom agent "
|
||||
"for them based on their needs."
|
||||
results = await library_db().list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=query or None,
|
||||
page_size=50 if not query else 10,
|
||||
)
|
||||
else:
|
||||
suggestions = [
|
||||
"Try different keywords",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
]
|
||||
no_results_msg = (
|
||||
for agent in results.agents:
|
||||
agents.append(_library_agent_to_info(agent))
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching library: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search library. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
if not query:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
"Your library is empty. Let the user know they can browse the "
|
||||
"marketplace to find agents, or you can create a custom agent "
|
||||
"for them based on their needs."
|
||||
),
|
||||
suggestions=[
|
||||
"Browse the marketplace to find and add agents",
|
||||
"Use find_agent to search the marketplace",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"No agents matching '{query}' found in your library. Let the "
|
||||
"user know you can create a custom agent for them based on "
|
||||
"their needs."
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||
),
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if source == "marketplace":
|
||||
title = (
|
||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
|
||||
)
|
||||
elif not query:
|
||||
if not query:
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library"
|
||||
else:
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library for '{query}'"
|
||||
|
||||
message = (
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents. "
|
||||
"Let the user know we can create a custom agent for them based on their needs."
|
||||
if source == "marketplace"
|
||||
else "Found agents in the user's library. You can provide a link to view "
|
||||
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
|
||||
"execution results, or run_agent to execute. Let the user know we can "
|
||||
"create a custom agent for them based on their needs."
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
message=message,
|
||||
message=(
|
||||
"Found agents in the user's library. You can provide a link to view "
|
||||
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
|
||||
"execution results, or run_agent to execute. Let the user know we can "
|
||||
"create a custom agent for them based on their needs."
|
||||
),
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
@@ -189,9 +196,20 @@ async def search_agents(
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(text: str) -> bool:
|
||||
"""Check if text is a valid UUID v4."""
|
||||
return bool(_UUID_PATTERN.match(text.strip()))
|
||||
def _marketplace_agent_to_info(agent: StoreAgent | StoreAgentDetails) -> AgentInfo:
|
||||
"""Convert a marketplace agent (StoreAgent or StoreAgentDetails) to an AgentInfo."""
|
||||
return AgentInfo(
|
||||
id=f"{agent.creator}/{agent.slug}",
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
)
|
||||
|
||||
|
||||
def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
|
||||
@@ -214,6 +232,23 @@ def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
|
||||
)
|
||||
|
||||
|
||||
async def _get_marketplace_agent_by_slug(creator: str, slug: str) -> AgentInfo | None:
|
||||
"""Fetch a marketplace agent by creator/slug identifier."""
|
||||
try:
|
||||
details = await store_db().get_store_agent_details(creator, slug)
|
||||
return _marketplace_agent_to_info(details)
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch marketplace agent {creator}/{slug}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||
|
||||
@@ -226,10 +261,9 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
try:
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return _library_agent_to_info(agent)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by graph_id: {agent_id}")
|
||||
pass
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -241,10 +275,9 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
try:
|
||||
agent = await lib_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return _library_agent_to_info(agent)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||
pass
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
"""Tests for agent search direct lookup functionality."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .models import AgentsFoundResponse, NoResultsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-agent-search"
|
||||
|
||||
|
||||
class TestMarketplaceSlugLookup:
|
||||
"""Tests for creator/slug direct lookup in marketplace search."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_slug_lookup_found(self):
|
||||
"""creator/slug query returns the agent directly."""
|
||||
mock_details = MagicMock()
|
||||
mock_details.creator = "testuser"
|
||||
mock_details.slug = "my-agent"
|
||||
mock_details.agent_name = "My Agent"
|
||||
mock_details.description = "A test agent"
|
||||
mock_details.rating = 4.5
|
||||
mock_details.runs = 100
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_store_agent_details = AsyncMock(return_value=mock_details)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.store_db",
|
||||
return_value=mock_store,
|
||||
):
|
||||
response = await search_agents(
|
||||
query="testuser/my-agent",
|
||||
source="marketplace",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.count == 1
|
||||
assert response.agents[0].id == "testuser/my-agent"
|
||||
assert response.agents[0].name == "My Agent"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_slug_lookup_not_found_falls_back_to_search(self):
|
||||
"""creator/slug not found falls back to general search."""
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_store_agent_details = AsyncMock(side_effect=NotFoundError(""))
|
||||
|
||||
# Fallback search returns results
|
||||
mock_search_results = MagicMock()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.creator = "other"
|
||||
mock_agent.slug = "similar-agent"
|
||||
mock_agent.agent_name = "Similar Agent"
|
||||
mock_agent.description = "A similar agent"
|
||||
mock_agent.rating = 3.0
|
||||
mock_agent.runs = 50
|
||||
mock_search_results.agents = [mock_agent]
|
||||
|
||||
mock_store.get_store_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.store_db",
|
||||
return_value=mock_store,
|
||||
):
|
||||
response = await search_agents(
|
||||
query="testuser/my-agent",
|
||||
source="marketplace",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.count == 1
|
||||
assert response.agents[0].id == "other/similar-agent"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_slug_lookup_not_found_no_search_results(self):
|
||||
"""creator/slug not found and search returns nothing."""
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_store_agent_details = AsyncMock(side_effect=NotFoundError(""))
|
||||
mock_search_results = MagicMock()
|
||||
mock_search_results.agents = []
|
||||
mock_store.get_store_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.store_db",
|
||||
return_value=mock_store,
|
||||
):
|
||||
response = await search_agents(
|
||||
query="testuser/nonexistent",
|
||||
source="marketplace",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_slug_query_goes_to_search(self):
|
||||
"""Regular keyword query skips slug lookup and goes to search."""
|
||||
mock_store = MagicMock()
|
||||
mock_search_results = MagicMock()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.creator = "creator1"
|
||||
mock_agent.slug = "email-agent"
|
||||
mock_agent.agent_name = "Email Agent"
|
||||
mock_agent.description = "Sends emails"
|
||||
mock_agent.rating = 4.0
|
||||
mock_agent.runs = 200
|
||||
mock_search_results.agents = [mock_agent]
|
||||
mock_store.get_store_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.store_db",
|
||||
return_value=mock_store,
|
||||
):
|
||||
response = await search_agents(
|
||||
query="email",
|
||||
source="marketplace",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
# get_store_agent_details should NOT have been called
|
||||
mock_store.get_store_agent_details.assert_not_called()
|
||||
|
||||
|
||||
class TestLibraryUUIDLookup:
|
||||
"""Tests for UUID direct lookup in library search."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_found_by_graph_id(self):
|
||||
"""UUID query matching a graph_id returns the agent directly."""
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.id = "lib-agent-id"
|
||||
mock_agent.name = "My Library Agent"
|
||||
mock_agent.description = "A library agent"
|
||||
mock_agent.creator_name = "testuser"
|
||||
mock_agent.status.value = "HEALTHY"
|
||||
mock_agent.can_access_graph = True
|
||||
mock_agent.has_external_trigger = False
|
||||
mock_agent.new_output = False
|
||||
mock_agent.graph_id = agent_id
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.input_schema = {}
|
||||
mock_agent.output_schema = {}
|
||||
|
||||
mock_lib_db = MagicMock()
|
||||
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
):
|
||||
response = await search_agents(
|
||||
query=agent_id,
|
||||
source="library",
|
||||
session_id="test-session",
|
||||
user_id=_TEST_USER_ID,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.count == 1
|
||||
assert response.agents[0].name == "My Library Agent"
|
||||
@@ -164,8 +164,9 @@ class BaseTool:
|
||||
|
||||
"""
|
||||
if self.requires_auth and not user_id:
|
||||
logger.error(
|
||||
f"Attempted tool call for {self.name} but user not authenticated"
|
||||
logger.warning(
|
||||
"Attempted tool call for %s but user not authenticated",
|
||||
self.name,
|
||||
)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
@@ -196,7 +197,7 @@ class BaseTool:
|
||||
output=raw_output,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||
logger.warning("Error in %s", self.name, exc_info=True)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
|
||||
@@ -22,6 +22,7 @@ from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
|
||||
from backend.copilot.integration_creds import get_integration_env_vars
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -41,15 +42,9 @@ class BashExecTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a Bash command or script. "
|
||||
"Full Bash scripting is supported (loops, conditionals, pipes, "
|
||||
"functions, etc.). "
|
||||
"The working directory is shared with the SDK Read/Write/Edit/Glob/Grep "
|
||||
"tools — files created by either are immediately visible to both. "
|
||||
"Execution is killed after the timeout (default 30s, max 120s). "
|
||||
"Returns stdout and stderr. "
|
||||
"Useful for file manipulation, data processing, running scripts, "
|
||||
"and installing packages."
|
||||
"Execute a Bash command or script. Shares filesystem with SDK file tools. "
|
||||
"Useful for scripts, data processing, and package installation. "
|
||||
"Killed after timeout (default 30s, max 120s)."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -59,13 +54,11 @@ class BashExecTool(BaseTool):
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Bash command or script to execute.",
|
||||
"description": "Bash command or script.",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max execution time in seconds (default 30, max 120)."
|
||||
),
|
||||
"description": "Max seconds (default 30, max 120).",
|
||||
"default": 30,
|
||||
},
|
||||
},
|
||||
@@ -74,7 +67,10 @@ class BashExecTool(BaseTool):
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
# True because _execute_on_e2b injects user tokens (GH_TOKEN etc.)
|
||||
# when user_id is present. Defense-in-depth: ensures only authenticated
|
||||
# users reach the token injection path.
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
@@ -82,6 +78,14 @@ class BashExecTool(BaseTool):
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Run a bash command on E2B (if available) or in a bubblewrap sandbox.
|
||||
|
||||
Dispatches to :meth:`_execute_on_e2b` when a sandbox is present in the
|
||||
current execution context, otherwise falls back to the local bubblewrap
|
||||
sandbox. Returns a :class:`BashExecResponse` on success or an
|
||||
:class:`ErrorResponse` when the sandbox is unavailable or the command
|
||||
is empty.
|
||||
"""
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
command: str = (kwargs.get("command") or "").strip()
|
||||
@@ -96,7 +100,9 @@ class BashExecTool(BaseTool):
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
return await self._execute_on_e2b(
|
||||
sandbox, command, timeout, session_id, user_id
|
||||
)
|
||||
|
||||
# Bubblewrap fallback: local isolated execution.
|
||||
if not has_full_sandbox():
|
||||
@@ -133,19 +139,42 @@ class BashExecTool(BaseTool):
|
||||
command: str,
|
||||
timeout: int,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute *command* on the E2B sandbox via commands.run()."""
|
||||
"""Execute *command* on the E2B sandbox via commands.run().
|
||||
|
||||
Integration tokens (e.g. GH_TOKEN) are injected into the sandbox env
|
||||
for any user with connected accounts. E2B has full internet access, so
|
||||
CLI tools like ``gh`` work without manual authentication.
|
||||
"""
|
||||
envs: dict[str, str] = {
|
||||
"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
|
||||
}
|
||||
# Collect injected secret values so we can scrub them from output.
|
||||
secret_values: list[str] = []
|
||||
if user_id is not None:
|
||||
integration_env = await get_integration_env_vars(user_id)
|
||||
secret_values = [v for v in integration_env.values() if v]
|
||||
envs.update(integration_env)
|
||||
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=timeout,
|
||||
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
|
||||
envs=envs,
|
||||
)
|
||||
stdout = result.stdout or ""
|
||||
stderr = result.stderr or ""
|
||||
# Scrub injected tokens from command output to prevent exfiltration
|
||||
# via `echo $GH_TOKEN`, `env`, `printenv`, etc.
|
||||
for secret in secret_values:
|
||||
stdout = stdout.replace(secret, "[REDACTED]")
|
||||
stderr = stderr.replace(secret, "[REDACTED]")
|
||||
return BashExecResponse(
|
||||
message=f"Command executed on E2B (exit {result.exit_code})",
|
||||
stdout=result.stdout or "",
|
||||
stderr=result.stderr or "",
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
exit_code=result.exit_code,
|
||||
timed_out=False,
|
||||
session_id=session_id,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user